cat, stack, split
torch.cat(): 기존 차원으로 결합
섹션 제목: “torch.cat(): 기존 차원으로 결합”torch.cat(tensors, dim) 은 텐서 목록을 기존 차원 방향으로 이어 붙입니다. 결합하는 차원을 제외한 나머지 차원의 크기는 모두 같아야 합니다.
import torch
a = torch.tensor([[1, 2], [3, 4]]) # shape: (2, 2)
b = torch.tensor([[5, 6], [7, 8]]) # shape: (2, 2)
# dim=0: 행 방향으로 이어 붙임 (아래로 쌓기)cat_0 = torch.cat([a, b], dim=0)print(cat_0.shape) # torch.Size([4, 2])print(cat_0)# tensor([[1, 2],# [3, 4],# [5, 6],# [7, 8]])
# dim=1: 열 방향으로 이어 붙임 (옆으로 붙이기)cat_1 = torch.cat([a, b], dim=1)print(cat_1.shape) # torch.Size([2, 4])print(cat_1)# tensor([[1, 2, 5, 6],# [3, 4, 7, 8]])두 개 이상의 텐서도 한 번에 결합할 수 있습니다.
x = torch.randn(3, 4)y = torch.randn(3, 4)z = torch.randn(3, 4)
result = torch.cat([x, y, z], dim=0)print(result.shape) # torch.Size([9, 4])torch.stack(): 새 차원으로 결합
섹션 제목: “torch.stack(): 새 차원으로 결합”torch.stack(tensors, dim) 은 텐서 목록을 새 차원을 만들어 쌓습니다. 모든 텐서의 shape가 완전히 동일 해야 합니다.
a = torch.tensor([1, 2, 3]) # shape: (3,)b = torch.tensor([4, 5, 6]) # shape: (3,)c = torch.tensor([7, 8, 9]) # shape: (3,)
# dim=0: 새 첫 번째 차원 생성stacked_0 = torch.stack([a, b, c], dim=0)print(stacked_0.shape) # torch.Size([3, 3])print(stacked_0)# tensor([[1, 2, 3],# [4, 5, 6],# [7, 8, 9]])
# dim=1: 새 두 번째 차원 생성stacked_1 = torch.stack([a, b, c], dim=1)print(stacked_1.shape) # torch.Size([3, 3])print(stacked_1)# tensor([[1, 4, 7],# [2, 5, 8],# [3, 6, 9]])cat vs stack 차이 시각화
섹션 제목: “cat vs stack 차이 시각화”핵심 차이: cat은 기존 차원에서, stack은 새 차원을 만들어 결합합니다.
입력: 두 텐서, 각각 shape (3, 4)
torch.cat([a, b], dim=0) → shape (6, 4) — 차원 수 동일, dim 0이 커짐torch.cat([a, b], dim=1) → shape (3, 8) — 차원 수 동일, dim 1이 커짐
torch.stack([a, b], dim=0) → shape (2, 3, 4) — 차원 수 +1, 앞에 새 차원torch.stack([a, b], dim=1) → shape (3, 2, 4) — 차원 수 +1, 중간에 새 차원torch.stack([a, b], dim=2) → shape (3, 4, 2) — 차원 수 +1, 뒤에 새 차원a = torch.randn(3, 4)b = torch.randn(3, 4)
print(torch.cat([a, b], dim=0).shape) # torch.Size([6, 4])print(torch.cat([a, b], dim=1).shape) # torch.Size([3, 8])
print(torch.stack([a, b], dim=0).shape) # torch.Size([2, 3, 4])print(torch.stack([a, b], dim=1).shape) # torch.Size([3, 2, 4])print(torch.stack([a, b], dim=2).shape) # torch.Size([3, 4, 2])사용 시점 기준
섹션 제목: “사용 시점 기준”| 상황 | 권장 함수 |
|---|---|
| 긴 시퀀스를 이어 붙이기 | torch.cat() |
| 여러 샘플을 배치로 묶기 | torch.stack() |
| 입력 shape가 서로 다를 수 있음 | torch.cat() |
| 입력 shape가 모두 동일해야 함 | torch.stack() |
실전 예: 배치 구성
# 개별 샘플 텐서 목록을 배치로 변환samples = [torch.randn(3, 28, 28) for _ in range(8)] # 8개 이미지
# stack으로 배치 차원 추가batch = torch.stack(samples, dim=0)print(batch.shape) # torch.Size([8, 3, 28, 28])
# DataLoader 내부도 이와 유사하게 동작함torch.split(): 균등 분할
섹션 제목: “torch.split(): 균등 분할”torch.split(tensor, split_size, dim) 은 텐서를 지정된 크기의 조각으로 분할합니다. 마지막 조각은 나머지 원소를 담습니다.
t = torch.arange(10) # shape: (10,)
# 크기 3으로 분할pieces = torch.split(t, 3, dim=0)for p in pieces: print(p.shape, p)# torch.Size([3]) tensor([0, 1, 2])# torch.Size([3]) tensor([3, 4, 5])# torch.Size([3]) tensor([6, 7, 8])# torch.Size([1]) tensor([9]) — 나머지
# 각기 다른 크기로 분할 (리스트 전달)pieces = torch.split(t, [2, 3, 5], dim=0)print([p.shape for p in pieces])# [torch.Size([2]), torch.Size([3]), torch.Size([5])]torch.chunk(): 균등 분할 (개수 지정)
섹션 제목: “torch.chunk(): 균등 분할 (개수 지정)”torch.chunk(tensor, chunks, dim) 은 텐서를 지정한 개수 로 최대한 균등하게 분할합니다.
t = torch.arange(10)
# 3개로 분할pieces = torch.chunk(t, 3, dim=0)for p in pieces: print(p.shape, p)# torch.Size([4]) tensor([0, 1, 2, 3])# torch.Size([3]) tensor([4, 5, 6]) — 마지막은 작을 수 있음# torch.Size([3]) tensor([7, 8, 9])
# 2D 텐서를 dim=1로 분할m = torch.randn(4, 6)parts = torch.chunk(m, 3, dim=1)print([p.shape for p in parts])# [torch.Size([4, 2]), torch.Size([4, 2]), torch.Size([4, 2])]split vs chunk 차이
섹션 제목: “split vs chunk 차이”| 함수 | 인수 | 설명 |
|---|---|---|
split(size, dim) | 조각의 크기 | 크기를 지정, 개수는 자동 결정 |
chunk(n, dim) | 조각의 개수 | 개수를 지정, 크기는 자동 결정 |