메모리 레이아웃
메모리에서 텐서가 저장되는 방식
섹션 제목: “메모리에서 텐서가 저장되는 방식”텐서는 2차원 표처럼 보이지만, 컴퓨터 메모리는 1차원 선형 배열 입니다. PyTorch는 다차원 텐서를 1차원 메모리에 매핑하기 위해 stride(보폭) 를 사용합니다.
2D 텐서 (2행 3열):┌─────┬─────┬─────┐│ 1 │ 2 │ 3 │ ← 행 0├─────┼─────┼─────┤│ 4 │ 5 │ 6 │ ← 행 1└─────┴─────┴─────┘
메모리 상의 실제 배치 (행 우선, Row-major):[ 1, 2, 3, 4, 5, 6 ] ↑ ↑ 인덱스 0 인덱스 3Stride 개념
섹션 제목: “Stride 개념”stride 는 특정 차원에서 다음 원소로 이동하기 위해 메모리상에서 몇 칸 건너뛰어야 하는지를 나타냅니다.
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(x.stride()) # (3, 1)# 행 방향(dim=0): 다음 행으로 가려면 3칸 이동# 열 방향(dim=1): 다음 열로 가려면 1칸 이동메모리: [ 1, 2, 3, 4, 5, 6 ]인덱스: 0 1 2 3 4 5
tensor[0, 0] → 메모리[0] (1)tensor[0, 1] → 메모리[0+1] = [1] (2) ← 열 stride = 1tensor[1, 0] → 메모리[0+3] = [3] (4) ← 행 stride = 3tensor[1, 2] → 메모리[1*3+2] = [5] (6)임의 원소의 메모리 오프셋 계산: offset = Σ (인덱스[i] × stride[i])
Contiguous vs Non-Contiguous
섹션 제목: “Contiguous vs Non-Contiguous”Contiguous 텐서
섹션 제목: “Contiguous 텐서”원소들이 메모리상에서 연속적으로 배치된 텐서입니다.
import torch
x = torch.randn(3, 4)print(x.is_contiguous()) # Trueprint(x.stride()) # (4, 1)Non-Contiguous 텐서
섹션 제목: “Non-Contiguous 텐서”전치(transpose)나 슬라이싱 후에는 메모리가 불연속적이 될 수 있습니다.
import torch
x = torch.randn(3, 4)x_t = x.t() # 전치 (뷰 반환, 복사 없음)
print(x.is_contiguous()) # Trueprint(x_t.is_contiguous()) # False
print(x.stride()) # (4, 1)print(x_t.stride()) # (1, 4) ← stride만 바뀜, 메모리는 그대로원본 x (3×4): x.t() (4×3):stride = (4, 1) stride = (1, 4)
메모리: [a, b, c, d, e, f, g, h, i, j, k, l] ↑ ↑ ↑ ↑x[0,:]= a b c d (연속)x_t[:,0]= a e i ... (4칸 간격, 불연속)Storage와 텐서의 관계
섹션 제목: “Storage와 텐서의 관계”여러 텐서가 같은 storage(저장소) 를 공유할 수 있습니다. 뷰(view)는 원본 텐서와 동일한 storage를 가리킵니다.
import torch
x = torch.randn(3, 4)y = x.t() # 전치 뷰z = x[0:2, :] # 슬라이싱 뷰
# storage 주소 비교 (같은 메모리를 공유)print(x.storage().data_ptr() == y.storage().data_ptr()) # Trueprint(x.storage().data_ptr() == z.storage().data_ptr()) # True
# storage 내 시작 위치 (offset)print(x.storage_offset()) # 0print(z.storage_offset()) # 0 (같은 시작점).is_contiguous()와 .contiguous()
섹션 제목: “.is_contiguous()와 .contiguous()”is_contiguous() 확인
섹션 제목: “is_contiguous() 확인”import torch
x = torch.randn(4, 4)print(x.is_contiguous()) # True
x_slice = x[::2, :] # 2행씩 건너뜀print(x_slice.is_contiguous()) # False
x_permuted = x.permute(1, 0)print(x_permuted.is_contiguous()) # False.contiguous() 로 변환
섹션 제목: “.contiguous() 로 변환”non-contiguous 텐서를 새로운 연속 메모리 에 복사합니다.
import torch
x = torch.randn(4, 4)x_t = x.t() # non-contiguous
# contiguous() 호출 시 새 메모리에 복사x_cont = x_t.contiguous()print(x_cont.is_contiguous()) # True
# 이미 contiguous면 복사하지 않고 자신을 반환x_same = x.contiguous()print(x_same.data_ptr() == x.data_ptr()) # True (같은 메모리)언제 contiguous()가 필요한가?
섹션 제목: “언제 contiguous()가 필요한가?”일부 연산은 contiguous 텐서만 허용합니다.
import torch
x = torch.randn(4, 4)x_t = x.t() # non-contiguous
# view()는 contiguous 텐서에서만 동작# x_t.view(16) # RuntimeError!
# 방법 1: contiguous() 후 view()x_t.contiguous().view(16)
# 방법 2: reshape() 사용 (내부적으로 자동 처리)x_t.reshape(16) # 필요 시 자동으로 복사메모리 효율적 텐서 활용 팁
섹션 제목: “메모리 효율적 텐서 활용 팁”| 상황 | 권장 방법 | 이유 |
|---|---|---|
| 형태만 바꿀 때 | reshape() 우선 | 가능하면 복사 없이 뷰 반환 |
| 연속성 보장 필요 | .contiguous().view() | 명시적 제어 |
| 전치 후 연산 | .contiguous() 후 연산 | 일부 커널 최적화 필요 |
| 메모리 절약 | 슬라이싱/전치 뷰 유지 | 복사 비용 없음 |
| 데이터 독립 필요 | .clone() | 원본과 완전히 분리 |
import torch
# 메모리 사용량 비교x = torch.randn(1000, 1000)
# 뷰: 추가 메모리 없음x_view = x.t()print(x_view.storage().data_ptr() == x.storage().data_ptr()) # True
# contiguous(): 새 메모리 할당 (~4MB)x_cont = x.t().contiguous()print(x_cont.storage().data_ptr() == x.storage().data_ptr()) # False
# 성능 팁: reshape가 view보다 안전x_reshaped = x.reshape(-1) # 불연속이어도 동작view vs clone의 is_contiguous() 비교
섹션 제목: “view vs clone의 is_contiguous() 비교”| 연산 | is_contiguous() | 메모리 공유 | 설명 |
|---|---|---|---|
t = torch.randn(3, 4) | True | — | 원본 |
t.view(4, 3) | True | 공유 | view는 contiguous 유지 |
t.T (transpose) | False | 공유 | stride만 변경, 메모리 불연속 |
t.T.contiguous() | True | 별도 | 새 메모리에 복사 |
t.T.clone() | True | 별도 | clone은 항상 복사 |
t.reshape(4, 3) | True | 상황에 따라 | 가능하면 view, 아니면 복사 |
channels_last 메모리 형식 (GPU)
섹션 제목: “channels_last 메모리 형식 (GPU)”GPU에서 이미지 텐서를 처리할 때, 기본 NCHW 형식 대신 channels_last 메모리 형식을 사용하면 일부 연산에서 성능이 향상됩니다.
import torch
# 기본 메모리 형식 (contiguous, NCHW)t = torch.randn(1, 3, 224, 224)print(t.stride()) # (150528, 50176, 224, 1)
# channels_last 메모리 형식 (NHWC 순서로 저장)t_cl = t.to(memory_format=torch.channels_last)print(t_cl.stride()) # (150528, 1, 672, 3)print(t_cl.is_contiguous()) # Falseprint(t_cl.is_contiguous(memory_format=torch.channels_last)) # True퀴즈를 불러오는 중...