콘텐츠로 이동

Gather와 Scatter

torch.gather(input, dim, index) 는 지정한 축(dim)을 따라 index 가 가리키는 위치의 값을 수집합니다.

출력 shape = index shape 이며, index의 각 값은 dim 방향으로 이동할 위치를 나타냅니다.

import torch
src = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
# dim=1 (열 방향으로 수집)
# 각 행에서 몇 번째 열을 가져올지 지정
idx = torch.tensor([[2, 0], # 행0: 열2(30), 열0(10)
[1, 2], # 행1: 열1(50), 열2(60)
[0, 1]]) # 행2: 열0(70), 열1(80)
result = torch.gather(src, dim=1, index=idx)
print(result)
# tensor([[30, 10],
# [50, 60],
# [70, 80]])
src: idx (dim=1): result:
[[10, 20, 30], [[2, 0], [[30, 10],
[40, 50, 60], [1, 2], [50, 60],
[70, 80, 90]] [0, 1]] [70, 80]]
행0: src[0, idx[0,0]]=src[0,2]=30 src[0, idx[0,1]]=src[0,0]=10
행1: src[1, idx[1,0]]=src[1,1]=50 src[1, idx[1,1]]=src[1,2]=60
import torch
src = torch.tensor([[1, 2],
[3, 4],
[5, 6]])
# 각 열에서 몇 번째 행을 가져올지 지정
idx = torch.tensor([[2, 0], # 열0: 행2(5), 열1: 행0(2)
[1, 2]]) # 열0: 행1(3), 열1: 행2(6)
result = torch.gather(src, dim=0, index=idx)
print(result)
# tensor([[5, 2],
# [3, 6]])
import torch
# (배치=4, 클래스=5) 예측 확률
logits = torch.tensor([
[0.1, 0.3, 0.5, 0.05, 0.05],
[0.6, 0.1, 0.1, 0.1, 0.1 ],
[0.2, 0.2, 0.2, 0.2, 0.2 ],
[0.05,0.05,0.1, 0.7, 0.1 ],
])
# 각 샘플의 정답 클래스 인덱스
targets = torch.tensor([2, 0, 3, 3])
# gather로 정답 클래스의 확률 추출
# index shape은 (4, 1)이어야 함
target_probs = torch.gather(logits, dim=1, index=targets.unsqueeze(1))
print(target_probs)
# tensor([[0.5000],
# [0.6000],
# [0.2000],
# [0.7000]])
print(target_probs.squeeze())
# tensor([0.5000, 0.6000, 0.2000, 0.7000])

scatter_gather역 연산 입니다. index가 가리키는 위치에 값을 분산(scatter) 합니다. 언더스코어(_)는 인플레이스(in-place) 연산임을 의미합니다.

# tensor.scatter_(dim, index, src)
# index[i][j]가 가리키는 self의 위치에 src[i][j]를 씀
import torch
dst = torch.zeros(3, 5)
src = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
idx = torch.tensor([[0, 2, 4],
[1, 3, 0],
[2, 4, 1]])
dst.scatter_(dim=1, index=idx, src=src)
print(dst)
# tensor([[1., 0., 2., 0., 3.],
# [6., 4., 0., 5., 0.],
# [0., 9., 7., 0., 8.]])

scatter_ 의 가장 대표적인 활용입니다.

import torch
num_classes = 5
labels = torch.tensor([2, 0, 4, 1, 3]) # 클래스 인덱스
# 원-핫 인코딩
one_hot = torch.zeros(len(labels), num_classes)
one_hot.scatter_(dim=1,
index=labels.unsqueeze(1),
value=1.0)
print(one_hot)
# tensor([[0., 0., 1., 0., 0.],
# [1., 0., 0., 0., 0.],
# [0., 0., 0., 0., 1.],
# [0., 1., 0., 0., 0.],
# [0., 0., 0., 1., 0.]])
# PyTorch 내장 함수로도 가능
import torch.nn.functional as F
one_hot_v2 = F.one_hot(labels, num_classes=num_classes).float()

torch.index_select(input, dim, index) 는 지정 차원에서 동일한 인덱스 집합 을 적용해 선택합니다.

import torch
m = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 특정 열 선택
cols = torch.tensor([0, 2, 3])
result = torch.index_select(m, dim=1, index=cols)
print(result)
# tensor([[ 1, 3, 4],
# [ 5, 7, 8],
# [ 9, 11, 12]])
# 특정 행 선택
rows = torch.tensor([2, 0])
result2 = torch.index_select(m, dim=0, index=rows)
print(result2)
# tensor([[ 9, 10, 11, 12],
# [ 1, 2, 3, 4]])
함수특징index 타입결과 shape주요 용도
gather위치별 다른 인덱스LongTensor, input과 같은 ndimindex와 동일클래스별 값 추출, Q값 선택
scatter_값을 특정 위치에 기록LongTensor, dst와 같은 ndimdst와 동일원-핫 인코딩, 누적 합산
index_select같은 인덱스를 모든 슬라이스에 적용1D LongTensor가변행/열 선택, 임베딩 룩업
import torch
m = torch.tensor([[10, 20, 30],
[40, 50, 60]])
# index_select: 1D 인덱스를 모든 행에 동일 적용
print(torch.index_select(m, 1, torch.tensor([0, 2])))
# tensor([[10, 30],
# [40, 60]])
# gather: 행마다 다른 열 선택 가능
idx = torch.tensor([[0, 2],
[2, 1]])
print(torch.gather(m, 1, idx))
# tensor([[10, 30],
# [60, 50]])