콘텐츠로 이동

cat, stack, split

torch.cat(tensors, dim) 은 텐서 목록을 기존 차원 방향으로 이어 붙입니다. 결합하는 차원을 제외한 나머지 차원의 크기는 모두 같아야 합니다.

import torch
a = torch.tensor([[1, 2],
[3, 4]]) # shape: (2, 2)
b = torch.tensor([[5, 6],
[7, 8]]) # shape: (2, 2)
# dim=0: 행 방향으로 이어 붙임 (아래로 쌓기)
cat_0 = torch.cat([a, b], dim=0)
print(cat_0.shape) # torch.Size([4, 2])
print(cat_0)
# tensor([[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]])
# dim=1: 열 방향으로 이어 붙임 (옆으로 붙이기)
cat_1 = torch.cat([a, b], dim=1)
print(cat_1.shape) # torch.Size([2, 4])
print(cat_1)
# tensor([[1, 2, 5, 6],
# [3, 4, 7, 8]])

두 개 이상의 텐서도 한 번에 결합할 수 있습니다.

x = torch.randn(3, 4)
y = torch.randn(3, 4)
z = torch.randn(3, 4)
result = torch.cat([x, y, z], dim=0)
print(result.shape) # torch.Size([9, 4])

torch.stack(tensors, dim) 은 텐서 목록을 새 차원을 만들어 쌓습니다. 모든 텐서의 shape가 완전히 동일 해야 합니다.

a = torch.tensor([1, 2, 3]) # shape: (3,)
b = torch.tensor([4, 5, 6]) # shape: (3,)
c = torch.tensor([7, 8, 9]) # shape: (3,)
# dim=0: 새 첫 번째 차원 생성
stacked_0 = torch.stack([a, b, c], dim=0)
print(stacked_0.shape) # torch.Size([3, 3])
print(stacked_0)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
# dim=1: 새 두 번째 차원 생성
stacked_1 = torch.stack([a, b, c], dim=1)
print(stacked_1.shape) # torch.Size([3, 3])
print(stacked_1)
# tensor([[1, 4, 7],
# [2, 5, 8],
# [3, 6, 9]])

핵심 차이: cat은 기존 차원에서, stack은 새 차원을 만들어 결합합니다.

입력: 두 텐서, 각각 shape (3, 4)
torch.cat([a, b], dim=0) → shape (6, 4) — 차원 수 동일, dim 0이 커짐
torch.cat([a, b], dim=1) → shape (3, 8) — 차원 수 동일, dim 1이 커짐
torch.stack([a, b], dim=0) → shape (2, 3, 4) — 차원 수 +1, 앞에 새 차원
torch.stack([a, b], dim=1) → shape (3, 2, 4) — 차원 수 +1, 중간에 새 차원
torch.stack([a, b], dim=2) → shape (3, 4, 2) — 차원 수 +1, 뒤에 새 차원
a = torch.randn(3, 4)
b = torch.randn(3, 4)
print(torch.cat([a, b], dim=0).shape) # torch.Size([6, 4])
print(torch.cat([a, b], dim=1).shape) # torch.Size([3, 8])
print(torch.stack([a, b], dim=0).shape) # torch.Size([2, 3, 4])
print(torch.stack([a, b], dim=1).shape) # torch.Size([3, 2, 4])
print(torch.stack([a, b], dim=2).shape) # torch.Size([3, 4, 2])
상황권장 함수
긴 시퀀스를 이어 붙이기torch.cat()
여러 샘플을 배치로 묶기torch.stack()
입력 shape가 서로 다를 수 있음torch.cat()
입력 shape가 모두 동일해야 함torch.stack()

실전 예: 배치 구성

# 개별 샘플 텐서 목록을 배치로 변환
samples = [torch.randn(3, 28, 28) for _ in range(8)] # 8개 이미지
# stack으로 배치 차원 추가
batch = torch.stack(samples, dim=0)
print(batch.shape) # torch.Size([8, 3, 28, 28])
# DataLoader 내부도 이와 유사하게 동작함

torch.split(tensor, split_size, dim) 은 텐서를 지정된 크기의 조각으로 분할합니다. 마지막 조각은 나머지 원소를 담습니다.

t = torch.arange(10) # shape: (10,)
# 크기 3으로 분할
pieces = torch.split(t, 3, dim=0)
for p in pieces:
print(p.shape, p)
# torch.Size([3]) tensor([0, 1, 2])
# torch.Size([3]) tensor([3, 4, 5])
# torch.Size([3]) tensor([6, 7, 8])
# torch.Size([1]) tensor([9]) — 나머지
# 각기 다른 크기로 분할 (리스트 전달)
pieces = torch.split(t, [2, 3, 5], dim=0)
print([p.shape for p in pieces])
# [torch.Size([2]), torch.Size([3]), torch.Size([5])]

torch.chunk(): 균등 분할 (개수 지정)

섹션 제목: “torch.chunk(): 균등 분할 (개수 지정)”

torch.chunk(tensor, chunks, dim) 은 텐서를 지정한 개수 로 최대한 균등하게 분할합니다.

t = torch.arange(10)
# 3개로 분할
pieces = torch.chunk(t, 3, dim=0)
for p in pieces:
print(p.shape, p)
# torch.Size([4]) tensor([0, 1, 2, 3])
# torch.Size([3]) tensor([4, 5, 6]) — 마지막은 작을 수 있음
# torch.Size([3]) tensor([7, 8, 9])
# 2D 텐서를 dim=1로 분할
m = torch.randn(4, 6)
parts = torch.chunk(m, 3, dim=1)
print([p.shape for p in parts])
# [torch.Size([4, 2]), torch.Size([4, 2]), torch.Size([4, 2])]
함수인수설명
split(size, dim)조각의 크기크기를 지정, 개수는 자동 결정
chunk(n, dim)조각의 개수개수를 지정, 크기는 자동 결정