비교 연산과 정렬
기본 비교 연산
섹션 제목: “기본 비교 연산”비교 연산은 원소별 로 수행되며, 결과는 bool 타입의 텐서를 반환합니다.
import torch
a = torch.tensor([1, 2, 3, 4, 5])b = torch.tensor([3, 2, 1, 4, 6])
print(a == b) # tensor([False, True, False, True, False])print(a != b) # tensor([ True, False, True, False, True])print(a > b) # tensor([False, False, True, False, False])print(a < b) # tensor([ True, False, False, False, True])print(a >= b) # tensor([False, True, True, True, False])print(a <= b) # tensor([ True, True, False, True, True])함수형 비교 API
섹션 제목: “함수형 비교 API”| 연산자 | 함수형 API | 반환값 |
|---|---|---|
a == b | torch.eq(a, b) | bool 텐서 |
a != b | torch.ne(a, b) | bool 텐서 |
a > b | torch.gt(a, b) | bool 텐서 |
a < b | torch.lt(a, b) | bool 텐서 |
a >= b | torch.ge(a, b) | bool 텐서 |
a <= b | torch.le(a, b) | bool 텐서 |
x = torch.tensor([1.0, 2.0, 3.0])y = torch.tensor([1.0, 0.0, 3.0])
# 함수형 API 사용print(torch.eq(x, y)) # tensor([ True, False, True])print(torch.gt(x, y)) # tensor([False, True, False])
# 스칼라와 비교도 가능print(x > 1.5) # tensor([False, True, True])print(torch.gt(x, 1.5)) # tensor([False, True, True])비교 결과 bool 텐서는 마스킹(masking)에 자주 활용됩니다.
scores = torch.tensor([72.0, 85.0, 60.0, 91.0, 78.0])
# 80점 이상인 점수만 추출mask = scores >= 80print(mask) # tensor([False, True, False, True, False])print(scores[mask]) # tensor([85., 91.])torch.equal() vs torch.eq()
섹션 제목: “torch.equal() vs torch.eq()”| 함수 | 반환값 | 설명 |
|---|---|---|
torch.equal(a, b) | Python bool (True / False) | 두 텐서가 완전히 동일한지 단일 값으로 반환 |
torch.eq(a, b) | bool 텐서 | 원소별 비교 결과를 텐서로 반환 |
a = torch.tensor([1, 2, 3])b = torch.tensor([1, 2, 3])c = torch.tensor([1, 0, 3])
# torch.equal: 전체 텐서 비교 → 단일 boolprint(torch.equal(a, b)) # Trueprint(torch.equal(a, c)) # False
# torch.eq: 원소별 비교 → bool 텐서print(torch.eq(a, b)) # tensor([True, True, True])print(torch.eq(a, c)) # tensor([ True, False, True])
# torch.eq 결과를 단일 값으로 줄이려면 .all() 사용print(torch.eq(a, c).all()) # tensor(False)print(torch.eq(a, c).all().item()) # False (Python bool)부동소수점 비교: torch.allclose()
섹션 제목: “부동소수점 비교: torch.allclose()”부동소수점 연산은 정밀도 문제로 == 비교가 실패할 수 있습니다. torch.allclose() 를 사용하면 허용 오차 범위 내에서 비교합니다.
a = torch.tensor([0.1 + 0.2])b = torch.tensor([0.3])
# == 비교는 실패할 수 있음print(a == b) # tensor([False]) — 부동소수점 오차
# allclose: 허용 오차(atol, rtol) 기반 비교print(torch.allclose(a, b)) # Trueprint(torch.allclose(a, b, atol=1e-8, rtol=1e-5)) # Trueallclose 의 판정 기준:
|a - b| ≤ atol + rtol × |b|atol(절대 허용 오차, 기본값1e-8)rtol(상대 허용 오차, 기본값1e-5)
x = torch.tensor([1.0, 2.0001, 3.0])y = torch.tensor([1.0, 2.0, 3.0])
print(torch.allclose(x, y)) # False (기본 atol=1e-8로는 차이가 큼)print(torch.allclose(x, y, atol=1e-3)) # True (0.0001 < 0.001)정렬: sort(), argsort(), topk()
섹션 제목: “정렬: sort(), argsort(), topk()”torch.sort()
섹션 제목: “torch.sort()”t = torch.tensor([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0])
# 오름차순 정렬 (기본값)values, indices = torch.sort(t)print(values) # tensor([1., 1., 2., 3., 4., 5., 9.])print(indices) # tensor([1, 3, 6, 0, 2, 4, 5])
# 내림차순 정렬values, indices = torch.sort(t, descending=True)print(values) # tensor([9., 5., 4., 3., 2., 1., 1.])2D 텐서에서 dim 파라미터로 정렬 축을 지정합니다.
m = torch.tensor([[3, 1, 2], [9, 5, 7]])
# dim=1: 행 방향으로 정렬print(torch.sort(m, dim=1).values)# tensor([[1, 2, 3],# [5, 7, 9]])
# dim=0: 열 방향으로 정렬print(torch.sort(m, dim=0).values)# tensor([[3, 1, 2],# [9, 5, 7]])torch.argsort()
섹션 제목: “torch.argsort()”값 대신 정렬 후 인덱스 만 반환합니다. 순위를 구하거나 다른 배열을 같은 순서로 재정렬할 때 유용합니다.
scores = torch.tensor([72.0, 85.0, 60.0, 91.0, 78.0])
# 오름차순 인덱스 (낮은 점수 순)rank_asc = torch.argsort(scores)print(rank_asc) # tensor([2, 0, 4, 1, 3])
# 내림차순 인덱스 (높은 점수 순)rank_desc = torch.argsort(scores, descending=True)print(rank_desc) # tensor([3, 1, 4, 0, 2])
names = ["Alice", "Bob", "Charlie", "Dave", "Eve"]print([names[i] for i in rank_desc.tolist()])# ['Dave', 'Bob', 'Eve', 'Alice', 'Charlie']torch.topk()
섹션 제목: “torch.topk()”전체 정렬 없이 상위 k개 만 효율적으로 추출합니다.
t = torch.tensor([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0])
# 상위 3개 (내림차순)values, indices = torch.topk(t, k=3)print(values) # tensor([9., 5., 4.])print(indices) # tensor([5, 4, 2])
# 하위 3개 (largest=False)values, indices = torch.topk(t, k=3, largest=False)print(values) # tensor([1., 1., 2.])print(indices) # tensor([1, 3, 6])