콘텐츠로 이동

reshape와 view

텐서의 형태(shape) 를 바꾸더라도 원소의 총 개수는 동일해야 합니다. 형태 변환은 데이터를 복사하지 않고 메모리를 재해석하는 방식으로 동작합니다.

import torch
t = torch.arange(12) # [0, 1, 2, ..., 11], shape: (12,)
# 12개 원소 → 다양한 형태로 변환
print(t.reshape(3, 4).shape) # torch.Size([3, 4])
print(t.reshape(2, 6).shape) # torch.Size([2, 6])
print(t.reshape(2, 2, 3).shape) # torch.Size([2, 2, 3])
print(t.reshape(1, 12).shape) # torch.Size([1, 12])

arange(12) → reshape(3, 4)

변환 전
0
1
2
3
4
5
6
7
8
9
10
11
형태: [12]
변환 후
0
1
2
3
4
5
6
7
8
9
10
11
형태: [3, 4]
메모리 순서 (1D 평탄화)
0
1
2
3
4
5
6
7
8
9
10
11

reshape(3, 4) → reshape(2, 6)

변환 전
0
1
2
3
4
5
6
7
8
9
10
11
형태: [3, 4]
변환 후
0
1
2
3
4
5
6
7
8
9
10
11
형태: [2, 6]
메모리 순서 (1D 평탄화)
0
1
2
3
4
5
6
7
8
9
10
11

view() — 메모리 공유, contiguous 필수

섹션 제목: “view() — 메모리 공유, contiguous 필수”

view() 는 원본 텐서와 메모리를 공유 합니다. 데이터 복사가 없어 매우 빠르지만, 텐서가 contiguous(연속 메모리) 상태일 때만 사용할 수 있습니다.

t = torch.arange(12)
v = t.view(3, 4)
# 메모리 공유 확인
v[0, 0] = 99
print(t[0]) # 99 — 원본도 변경됨
# contiguous 확인
print(t.is_contiguous()) # True
# transpose 후에는 non-contiguous가 됨
t2 = torch.randn(3, 4)
t2_T = t2.T # transpose
print(t2_T.is_contiguous()) # False
# t2_T.view(12) # RuntimeError!

reshape() 는 가능하면 view() 처럼 메모리를 공유하고, 불가능하면 자동으로 데이터를 복사 해 새 텐서를 반환합니다. 대부분의 상황에서 reshape() 가 더 안전합니다.

t = torch.randn(3, 4)
t_T = t.T # non-contiguous
# reshape: non-contiguous에서도 동작
r = t_T.reshape(12) # 내부적으로 복사 후 변환
print(r.shape) # torch.Size([12])
# view: non-contiguous에서 실패
# t_T.view(12) # RuntimeError!
# contiguous()로 명시적 복사 후 view 사용
r2 = t_T.contiguous().view(12)
print(r2.shape) # torch.Size([12])
상황권장
메모리 공유가 중요하고 contiguous 보장됨view()
일반적인 형태 변환reshape()
non-contiguous 텐서 처리reshape() 또는 contiguous().view()
메모리 공유 여부를 명확히 하고 싶을 때view()

형태 변환 시 차원 하나를 -1 로 지정하면 PyTorch가 나머지 원소 수에서 자동으로 계산 합니다.

t = torch.arange(24) # 24개 원소
# 하나의 차원만 -1로 지정 가능
print(t.reshape(4, -1).shape) # torch.Size([4, 6]) — 24/4=6
print(t.reshape(-1, 6).shape) # torch.Size([4, 6]) — 24/6=4
print(t.reshape(2, 3, -1).shape) # torch.Size([2, 3, 4]) — 24/(2*3)=4
print(t.reshape(-1).shape) # torch.Size([24]) — 1D 펼치기

실전 예: 배치 처리에서 자주 사용하는 패턴

# 배치 이미지: (batch, channel, H, W)
images = torch.randn(32, 3, 28, 28)
# Fully Connected 레이어 입력을 위해 2D로 변환
# 배치 크기 유지, 나머지를 1D로 펼침
flat = images.reshape(32, -1)
print(flat.shape) # torch.Size([32, 2352]) — 3*28*28=2352
# 또는 배치 크기도 자동 추론
flat2 = images.reshape(-1, 3 * 28 * 28)
print(flat2.shape) # torch.Size([32, 2352])

flatten() 은 텐서를 1D로 펼치는 가장 명확한 방법입니다. 특정 차원 범위만 펼치는 것도 가능합니다.

t = torch.randn(2, 3, 4) # shape: (2, 3, 4)
# 전체를 1D로
print(t.flatten().shape) # torch.Size([24])
# start_dim, end_dim으로 부분 flatten
print(t.flatten(start_dim=1).shape) # torch.Size([2, 12]) — 배치 유지
print(t.flatten(start_dim=0, end_dim=1).shape) # torch.Size([6, 4])
import torch.nn as nn
# 딥러닝 모델 내에서는 nn.Flatten 레이어 사용
flatten_layer = nn.Flatten() # 기본: start_dim=1
images = torch.randn(32, 3, 28, 28)
output = flatten_layer(images)
print(output.shape) # torch.Size([32, 2352])
# 패턴 1: CNN → FC 레이어 전환
def forward(self, x):
x = self.conv_layers(x) # (B, C, H, W)
x = x.reshape(x.size(0), -1) # (B, C*H*W)
x = self.fc_layers(x)
return x
# 패턴 2: 임베딩 벡터 처리
embeddings = torch.randn(10, 5, 64) # (문장, 단어, 임베딩)
# 문장×단어를 하나의 시퀀스로
flat = embeddings.reshape(-1, 64) # (50, 64)
# 패턴 3: 배치 크기 변환
t = torch.randn(8, 4) # 배치 8, 특성 4
# 배치를 두 배로 쪼개기
t2 = t.reshape(16, 2) # 배치 16, 특성 2