콘텐츠로 이동

브로드캐스팅 함정과 디버깅

브로드캐스팅은 오류 없이 조용히 동작하기 때문에 버그를 발견하기 어렵습니다. 의도한 shape가 아님에도 연산이 성공하면 잘못된 결과가 코드 깊숙이 전파됩니다.

import torch
# 의도: 4×3 행렬의 각 행에서 3차원 벡터를 빼기
matrix = torch.ones(4, 3)
bias = torch.ones(4) # ← 실수! (3,)이어야 하는데 (4,)로 생성
# 오류가 발생하지 않음 — 하지만 결과가 의도와 다름
try:
result = matrix - bias # RuntimeError 발생
except RuntimeError as e:
print(f"오류 발생: {e}")
# 4 ≠ 3 이므로 오류
# 이 경우는 더 위험 — 오류 없이 잘못된 결과
matrix2 = torch.ones(4, 1)
result2 = matrix2 - bias # (4, 1) - (4,) → (4, 1) - (1, 4) → (4, 4)
print(result2.shape) # torch.Size([4, 4]) ← 의도한 (4,)이 아님!
비호환 shape: (3, 4) + (3, 5)
브로드캐스팅 불가: shape [3,4][3,5]는 호환되지 않습니다.

딥러닝에서 가장 흔한 실수입니다. 배치 차원과 특성 차원을 혼동하면 브로드캐스팅이 의도와 반대 방향으로 동작합니다.

# 의도: 각 샘플(행)마다 다른 가중치를 곱하기
data = torch.randn(8, 16) # (배치=8, 특성=16)
weights = torch.randn(8) # 각 샘플의 가중치
# 잘못된 코드 — (8, 16) * (8,) 는 오류
# weights가 (1, 8)로 해석되어 (8, 8) → 호환 불가
try:
result = data * weights
except RuntimeError as e:
print(f"오류: {e}")
# 올바른 코드 — unsqueeze로 차원을 명확히
weights_col = weights.unsqueeze(1) # (8,) → (8, 1)
result = data * weights_col # (8, 16) * (8, 1) → (8, 16)
print(result.shape) # torch.Size([8, 16])

함정 3 — expand vs repeat 메모리 차이

섹션 제목: “함정 3 — expand vs repeat 메모리 차이”

브로드캐스팅을 명시적으로 수행할 때 expand()repeat() 중 어느 것을 쓰느냐에 따라 메모리 사용량이 크게 달라집니다.

방법메모리 복사쓰기 가능용도
expand()없음 (뷰 반환)불가읽기 전용 확장
repeat()있음 (실제 복사)가능쓰기 가능한 복사본
import torch
base = torch.tensor([[1.0, 2.0, 3.0]]) # shape: (1, 3)
# expand: 메모리 복사 없이 뷰 반환
expanded = base.expand(1000, 3)
print(expanded.shape) # torch.Size([1000, 3])
print(expanded.is_contiguous()) # False — 실제 메모리는 (1, 3)
# repeat: 실제로 데이터를 복사
repeated = base.repeat(1000, 1)
print(repeated.shape) # torch.Size([1000, 3])
print(repeated.is_contiguous()) # True — 실제 1000×3 메모리 사용
# 메모리 크기 비교
import sys
print(f"expand 메모리: {expanded.element_size() * expanded.nelement()} bytes")
# 실제로는 3 * 4 = 12 bytes만 사용 (뷰이므로)
print(f"repeat 메모리: {repeated.element_size() * repeated.nelement()} bytes")
# 1000 * 3 * 4 = 12,000 bytes 사용

디버깅 방법 1 — shape 출력 습관

섹션 제목: “디버깅 방법 1 — shape 출력 습관”

버그가 의심될 때는 연산 전후에 shape를 출력하는 것이 가장 빠른 디버깅 방법입니다.

def debug_op(name, a, b):
"""브로드캐스팅 디버그 헬퍼"""
print(f"[{name}]")
print(f" A shape: {a.shape}")
print(f" B shape: {b.shape}")
try:
result = a + b
print(f" 결과 shape: {result.shape}")
return result
except RuntimeError as e:
print(f" 오류: {e}")
return None
x = torch.randn(4, 3)
y = torch.randn(3)
debug_op("행렬 + 벡터", x, y)
# [행렬 + 벡터]
# A shape: torch.Size([4, 3])
# B shape: torch.Size([3])
# 결과 shape: torch.Size([4, 3])

디버깅 방법 2 — RuntimeError 메시지 읽기

섹션 제목: “디버깅 방법 2 — RuntimeError 메시지 읽기”

PyTorch의 브로드캐스팅 오류 메시지는 충분한 정보를 담고 있습니다.

try:
a = torch.randn(3, 4)
b = torch.randn(3, 5)
c = a + b
except RuntimeError as e:
print(e)

오류 메시지 예시:

The size of tensor a (4) must match the size of tensor b (5)
at non-singleton dimension 1

읽는 법:

  • tensor a (4) — A의 문제 차원 크기는 4
  • tensor b (5) — B의 문제 차원 크기는 5
  • non-singleton dimension 1 — 1번 차원(0-indexed)에서 충돌 발생
# 오류 없이 shape 호환성만 미리 확인
try:
result_shape = torch.broadcast_shapes(a.shape, b.shape)
print(f"호환: {result_shape}")
except RuntimeError as e:
print(f"비호환: {e}")

디버깅 방법 3 — assert로 shape 명시

섹션 제목: “디버깅 방법 3 — assert로 shape 명시”

중요한 연산 앞에 shape를 단언(assert)하면 문제를 조기에 발견할 수 있습니다.

def safe_normalize(batch, mean, std):
"""shape 검증이 포함된 배치 정규화"""
assert batch.ndim == 2, f"batch는 2D여야 합니다. 현재: {batch.ndim}D"
n_features = batch.shape[1]
assert mean.shape == (n_features,), \
f"mean shape 불일치. 기대: ({n_features},), 실제: {mean.shape}"
assert std.shape == (n_features,), \
f"std shape 불일치. 기대: ({n_features},), 실제: {std.shape}"
return (batch - mean) / std
# 올바른 사용
batch = torch.randn(32, 16)
mean = torch.zeros(16)
std = torch.ones(16)
result = safe_normalize(batch, mean, std)
print(result.shape) # torch.Size([32, 16])
# 잘못된 shape — 즉시 AssertionError 발생
bad_mean = torch.zeros(32)
try:
safe_normalize(batch, bad_mean, std)
except AssertionError as e:
print(f"shape 오류: {e}")

  • 브로드캐스팅은 오류 없이 조용히 잘못 동작할 수 있으므로 shape를 항상 확인
  • expand() 는 메모리 복사 없는 , repeat()실제 복사 — 읽기 전용이면 expand() 선호
  • 연산 전후에 .shape 를 출력하는 습관이 가장 효과적인 디버깅
  • torch.broadcast_shapes() 로 연산 전에 호환성을 확인
  • 중요한 함수에는 assert 로 shape를 명시적으로 검증

다음 장에서는 브로드캐스팅을 활용한 실전 딥러닝 패턴을 다룹니다.