Gather와 Scatter
torch.gather()
섹션 제목: “torch.gather()”torch.gather(input, dim, index) 는 지정한 축(dim)을 따라 index 가 가리키는 위치의 값을 수집합니다.
출력 shape = index shape 이며, index의 각 값은 dim 방향으로 이동할 위치를 나타냅니다.
2D 예시로 이해하기
섹션 제목: “2D 예시로 이해하기”import torch
src = torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
# dim=1 (열 방향으로 수집)# 각 행에서 몇 번째 열을 가져올지 지정idx = torch.tensor([[2, 0], # 행0: 열2(30), 열0(10) [1, 2], # 행1: 열1(50), 열2(60) [0, 1]]) # 행2: 열0(70), 열1(80)
result = torch.gather(src, dim=1, index=idx)print(result)# tensor([[30, 10],# [50, 60],# [70, 80]])src: idx (dim=1): result:[[10, 20, 30], [[2, 0], [[30, 10], [40, 50, 60], [1, 2], [50, 60], [70, 80, 90]] [0, 1]] [70, 80]]
행0: src[0, idx[0,0]]=src[0,2]=30 src[0, idx[0,1]]=src[0,0]=10행1: src[1, idx[1,0]]=src[1,1]=50 src[1, idx[1,1]]=src[1,2]=60dim=0 (행 방향 수집)
섹션 제목: “dim=0 (행 방향 수집)”import torch
src = torch.tensor([[1, 2], [3, 4], [5, 6]])
# 각 열에서 몇 번째 행을 가져올지 지정idx = torch.tensor([[2, 0], # 열0: 행2(5), 열1: 행0(2) [1, 2]]) # 열0: 행1(3), 열1: 행2(6)
result = torch.gather(src, dim=0, index=idx)print(result)# tensor([[5, 2],# [3, 6]])실전 예시: 클래스별 확률 추출
섹션 제목: “실전 예시: 클래스별 확률 추출”import torch
# (배치=4, 클래스=5) 예측 확률logits = torch.tensor([ [0.1, 0.3, 0.5, 0.05, 0.05], [0.6, 0.1, 0.1, 0.1, 0.1 ], [0.2, 0.2, 0.2, 0.2, 0.2 ], [0.05,0.05,0.1, 0.7, 0.1 ],])
# 각 샘플의 정답 클래스 인덱스targets = torch.tensor([2, 0, 3, 3])
# gather로 정답 클래스의 확률 추출# index shape은 (4, 1)이어야 함target_probs = torch.gather(logits, dim=1, index=targets.unsqueeze(1))print(target_probs)# tensor([[0.5000],# [0.6000],# [0.2000],# [0.7000]])print(target_probs.squeeze())# tensor([0.5000, 0.6000, 0.2000, 0.7000])torch.scatter_()
섹션 제목: “torch.scatter_()”scatter_ 는 gather 의 역 연산 입니다. index가 가리키는 위치에 값을 분산(scatter) 합니다. 언더스코어(_)는 인플레이스(in-place) 연산임을 의미합니다.
# tensor.scatter_(dim, index, src)# index[i][j]가 가리키는 self의 위치에 src[i][j]를 씀import torch
dst = torch.zeros(3, 5)src = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
idx = torch.tensor([[0, 2, 4], [1, 3, 0], [2, 4, 1]])
dst.scatter_(dim=1, index=idx, src=src)print(dst)# tensor([[1., 0., 2., 0., 3.],# [6., 4., 0., 5., 0.],# [0., 9., 7., 0., 8.]])원-핫 인코딩 구현
섹션 제목: “원-핫 인코딩 구현”scatter_ 의 가장 대표적인 활용입니다.
import torch
num_classes = 5labels = torch.tensor([2, 0, 4, 1, 3]) # 클래스 인덱스
# 원-핫 인코딩one_hot = torch.zeros(len(labels), num_classes)one_hot.scatter_(dim=1, index=labels.unsqueeze(1), value=1.0)
print(one_hot)# tensor([[0., 0., 1., 0., 0.],# [1., 0., 0., 0., 0.],# [0., 0., 0., 0., 1.],# [0., 1., 0., 0., 0.],# [0., 0., 0., 1., 0.]])
# PyTorch 내장 함수로도 가능import torch.nn.functional as Fone_hot_v2 = F.one_hot(labels, num_classes=num_classes).float()torch.index_select()
섹션 제목: “torch.index_select()”torch.index_select(input, dim, index) 는 지정 차원에서 동일한 인덱스 집합 을 적용해 선택합니다.
import torch
m = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
# 특정 열 선택cols = torch.tensor([0, 2, 3])result = torch.index_select(m, dim=1, index=cols)print(result)# tensor([[ 1, 3, 4],# [ 5, 7, 8],# [ 9, 11, 12]])
# 특정 행 선택rows = torch.tensor([2, 0])result2 = torch.index_select(m, dim=0, index=rows)print(result2)# tensor([[ 9, 10, 11, 12],# [ 1, 2, 3, 4]])세 함수 비교
섹션 제목: “세 함수 비교”| 함수 | 특징 | index 타입 | 결과 shape | 주요 용도 |
|---|---|---|---|---|
gather | 위치별 다른 인덱스 | LongTensor, input과 같은 ndim | index와 동일 | 클래스별 값 추출, Q값 선택 |
scatter_ | 값을 특정 위치에 기록 | LongTensor, dst와 같은 ndim | dst와 동일 | 원-핫 인코딩, 누적 합산 |
index_select | 같은 인덱스를 모든 슬라이스에 적용 | 1D LongTensor | 가변 | 행/열 선택, 임베딩 룩업 |
import torch
m = torch.tensor([[10, 20, 30], [40, 50, 60]])
# index_select: 1D 인덱스를 모든 행에 동일 적용print(torch.index_select(m, 1, torch.tensor([0, 2])))# tensor([[10, 30],# [40, 60]])
# gather: 행마다 다른 열 선택 가능idx = torch.tensor([[0, 2], [2, 1]])print(torch.gather(m, 1, idx))# tensor([[10, 30],# [60, 50]])