콘텐츠로 이동

비교 연산과 정렬

비교 연산은 원소별 로 수행되며, 결과는 bool 타입의 텐서를 반환합니다.

import torch
a = torch.tensor([1, 2, 3, 4, 5])
b = torch.tensor([3, 2, 1, 4, 6])
print(a == b) # tensor([False, True, False, True, False])
print(a != b) # tensor([ True, False, True, False, True])
print(a > b) # tensor([False, False, True, False, False])
print(a < b) # tensor([ True, False, False, False, True])
print(a >= b) # tensor([False, True, True, True, False])
print(a <= b) # tensor([ True, True, False, True, True])
연산자함수형 API반환값
a == btorch.eq(a, b)bool 텐서
a != btorch.ne(a, b)bool 텐서
a > btorch.gt(a, b)bool 텐서
a < btorch.lt(a, b)bool 텐서
a >= btorch.ge(a, b)bool 텐서
a <= btorch.le(a, b)bool 텐서
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.0, 0.0, 3.0])
# 함수형 API 사용
print(torch.eq(x, y)) # tensor([ True, False, True])
print(torch.gt(x, y)) # tensor([False, True, False])
# 스칼라와 비교도 가능
print(x > 1.5) # tensor([False, True, True])
print(torch.gt(x, 1.5)) # tensor([False, True, True])

비교 결과 bool 텐서는 마스킹(masking)에 자주 활용됩니다.

scores = torch.tensor([72.0, 85.0, 60.0, 91.0, 78.0])
# 80점 이상인 점수만 추출
mask = scores >= 80
print(mask) # tensor([False, True, False, True, False])
print(scores[mask]) # tensor([85., 91.])
함수반환값설명
torch.equal(a, b)Python bool (True / False)두 텐서가 완전히 동일한지 단일 값으로 반환
torch.eq(a, b)bool 텐서원소별 비교 결과를 텐서로 반환
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
c = torch.tensor([1, 0, 3])
# torch.equal: 전체 텐서 비교 → 단일 bool
print(torch.equal(a, b)) # True
print(torch.equal(a, c)) # False
# torch.eq: 원소별 비교 → bool 텐서
print(torch.eq(a, b)) # tensor([True, True, True])
print(torch.eq(a, c)) # tensor([ True, False, True])
# torch.eq 결과를 단일 값으로 줄이려면 .all() 사용
print(torch.eq(a, c).all()) # tensor(False)
print(torch.eq(a, c).all().item()) # False (Python bool)

부동소수점 연산은 정밀도 문제로 == 비교가 실패할 수 있습니다. torch.allclose() 를 사용하면 허용 오차 범위 내에서 비교합니다.

a = torch.tensor([0.1 + 0.2])
b = torch.tensor([0.3])
# == 비교는 실패할 수 있음
print(a == b) # tensor([False]) — 부동소수점 오차
# allclose: 허용 오차(atol, rtol) 기반 비교
print(torch.allclose(a, b)) # True
print(torch.allclose(a, b, atol=1e-8, rtol=1e-5)) # True

allclose 의 판정 기준:

|a - b| ≤ atol + rtol × |b|
  • atol (절대 허용 오차, 기본값 1e-8)
  • rtol (상대 허용 오차, 기본값 1e-5)
x = torch.tensor([1.0, 2.0001, 3.0])
y = torch.tensor([1.0, 2.0, 3.0])
print(torch.allclose(x, y)) # False (기본 atol=1e-8로는 차이가 큼)
print(torch.allclose(x, y, atol=1e-3)) # True (0.0001 < 0.001)
t = torch.tensor([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0])
# 오름차순 정렬 (기본값)
values, indices = torch.sort(t)
print(values) # tensor([1., 1., 2., 3., 4., 5., 9.])
print(indices) # tensor([1, 3, 6, 0, 2, 4, 5])
# 내림차순 정렬
values, indices = torch.sort(t, descending=True)
print(values) # tensor([9., 5., 4., 3., 2., 1., 1.])

2D 텐서에서 dim 파라미터로 정렬 축을 지정합니다.

m = torch.tensor([[3, 1, 2],
[9, 5, 7]])
# dim=1: 행 방향으로 정렬
print(torch.sort(m, dim=1).values)
# tensor([[1, 2, 3],
# [5, 7, 9]])
# dim=0: 열 방향으로 정렬
print(torch.sort(m, dim=0).values)
# tensor([[3, 1, 2],
# [9, 5, 7]])

값 대신 정렬 후 인덱스 만 반환합니다. 순위를 구하거나 다른 배열을 같은 순서로 재정렬할 때 유용합니다.

scores = torch.tensor([72.0, 85.0, 60.0, 91.0, 78.0])
# 오름차순 인덱스 (낮은 점수 순)
rank_asc = torch.argsort(scores)
print(rank_asc) # tensor([2, 0, 4, 1, 3])
# 내림차순 인덱스 (높은 점수 순)
rank_desc = torch.argsort(scores, descending=True)
print(rank_desc) # tensor([3, 1, 4, 0, 2])
names = ["Alice", "Bob", "Charlie", "Dave", "Eve"]
print([names[i] for i in rank_desc.tolist()])
# ['Dave', 'Bob', 'Eve', 'Alice', 'Charlie']

전체 정렬 없이 상위 k개 만 효율적으로 추출합니다.

t = torch.tensor([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0])
# 상위 3개 (내림차순)
values, indices = torch.topk(t, k=3)
print(values) # tensor([9., 5., 4.])
print(indices) # tensor([5, 4, 2])
# 하위 3개 (largest=False)
values, indices = torch.topk(t, k=3, largest=False)
print(values) # tensor([1., 1., 2.])
print(indices) # tensor([1, 3, 6])