브로드캐스팅 함정과 디버깅
함정 1 — 조용한 shape 불일치
섹션 제목: “함정 1 — 조용한 shape 불일치”브로드캐스팅은 오류 없이 조용히 동작하기 때문에 버그를 발견하기 어렵습니다. 의도한 shape가 아님에도 연산이 성공하면 잘못된 결과가 코드 깊숙이 전파됩니다.
import torch
# 의도: 4×3 행렬의 각 행에서 3차원 벡터를 빼기matrix = torch.ones(4, 3)bias = torch.ones(4) # ← 실수! (3,)이어야 하는데 (4,)로 생성
# 오류가 발생하지 않음 — 하지만 결과가 의도와 다름try: result = matrix - bias # RuntimeError 발생except RuntimeError as e: print(f"오류 발생: {e}") # 4 ≠ 3 이므로 오류
# 이 경우는 더 위험 — 오류 없이 잘못된 결과matrix2 = torch.ones(4, 1)result2 = matrix2 - bias # (4, 1) - (4,) → (4, 1) - (1, 4) → (4, 4)print(result2.shape) # torch.Size([4, 4]) ← 의도한 (4,)이 아님!비호환 shape: (3, 4) + (3, 5)
브로드캐스팅 불가: shape [3,4]와 [3,5]는 호환되지 않습니다.
함정 2 — 배치 차원 혼동
섹션 제목: “함정 2 — 배치 차원 혼동”딥러닝에서 가장 흔한 실수입니다. 배치 차원과 특성 차원을 혼동하면 브로드캐스팅이 의도와 반대 방향으로 동작합니다.
# 의도: 각 샘플(행)마다 다른 가중치를 곱하기data = torch.randn(8, 16) # (배치=8, 특성=16)weights = torch.randn(8) # 각 샘플의 가중치
# 잘못된 코드 — (8, 16) * (8,) 는 오류# weights가 (1, 8)로 해석되어 (8, 8) → 호환 불가try: result = data * weightsexcept RuntimeError as e: print(f"오류: {e}")
# 올바른 코드 — unsqueeze로 차원을 명확히weights_col = weights.unsqueeze(1) # (8,) → (8, 1)result = data * weights_col # (8, 16) * (8, 1) → (8, 16)print(result.shape) # torch.Size([8, 16])함정 3 — expand vs repeat 메모리 차이
섹션 제목: “함정 3 — expand vs repeat 메모리 차이”브로드캐스팅을 명시적으로 수행할 때 expand()와 repeat() 중 어느 것을 쓰느냐에 따라 메모리 사용량이 크게 달라집니다.
| 방법 | 메모리 복사 | 쓰기 가능 | 용도 |
|---|---|---|---|
expand() | 없음 (뷰 반환) | 불가 | 읽기 전용 확장 |
repeat() | 있음 (실제 복사) | 가능 | 쓰기 가능한 복사본 |
import torch
base = torch.tensor([[1.0, 2.0, 3.0]]) # shape: (1, 3)
# expand: 메모리 복사 없이 뷰 반환expanded = base.expand(1000, 3)print(expanded.shape) # torch.Size([1000, 3])print(expanded.is_contiguous()) # False — 실제 메모리는 (1, 3)
# repeat: 실제로 데이터를 복사repeated = base.repeat(1000, 1)print(repeated.shape) # torch.Size([1000, 3])print(repeated.is_contiguous()) # True — 실제 1000×3 메모리 사용
# 메모리 크기 비교import sysprint(f"expand 메모리: {expanded.element_size() * expanded.nelement()} bytes")# 실제로는 3 * 4 = 12 bytes만 사용 (뷰이므로)print(f"repeat 메모리: {repeated.element_size() * repeated.nelement()} bytes")# 1000 * 3 * 4 = 12,000 bytes 사용디버깅 방법 1 — shape 출력 습관
섹션 제목: “디버깅 방법 1 — shape 출력 습관”버그가 의심될 때는 연산 전후에 shape를 출력하는 것이 가장 빠른 디버깅 방법입니다.
def debug_op(name, a, b): """브로드캐스팅 디버그 헬퍼""" print(f"[{name}]") print(f" A shape: {a.shape}") print(f" B shape: {b.shape}") try: result = a + b print(f" 결과 shape: {result.shape}") return result except RuntimeError as e: print(f" 오류: {e}") return None
x = torch.randn(4, 3)y = torch.randn(3)debug_op("행렬 + 벡터", x, y)# [행렬 + 벡터]# A shape: torch.Size([4, 3])# B shape: torch.Size([3])# 결과 shape: torch.Size([4, 3])디버깅 방법 2 — RuntimeError 메시지 읽기
섹션 제목: “디버깅 방법 2 — RuntimeError 메시지 읽기”PyTorch의 브로드캐스팅 오류 메시지는 충분한 정보를 담고 있습니다.
try: a = torch.randn(3, 4) b = torch.randn(3, 5) c = a + bexcept RuntimeError as e: print(e)오류 메시지 예시:
The size of tensor a (4) must match the size of tensor b (5)at non-singleton dimension 1읽는 법:
tensor a (4)— A의 문제 차원 크기는 4tensor b (5)— B의 문제 차원 크기는 5non-singleton dimension 1— 1번 차원(0-indexed)에서 충돌 발생
# 오류 없이 shape 호환성만 미리 확인try: result_shape = torch.broadcast_shapes(a.shape, b.shape) print(f"호환: {result_shape}")except RuntimeError as e: print(f"비호환: {e}")디버깅 방법 3 — assert로 shape 명시
섹션 제목: “디버깅 방법 3 — assert로 shape 명시”중요한 연산 앞에 shape를 단언(assert)하면 문제를 조기에 발견할 수 있습니다.
def safe_normalize(batch, mean, std): """shape 검증이 포함된 배치 정규화""" assert batch.ndim == 2, f"batch는 2D여야 합니다. 현재: {batch.ndim}D" n_features = batch.shape[1] assert mean.shape == (n_features,), \ f"mean shape 불일치. 기대: ({n_features},), 실제: {mean.shape}" assert std.shape == (n_features,), \ f"std shape 불일치. 기대: ({n_features},), 실제: {std.shape}"
return (batch - mean) / std
# 올바른 사용batch = torch.randn(32, 16)mean = torch.zeros(16)std = torch.ones(16)result = safe_normalize(batch, mean, std)print(result.shape) # torch.Size([32, 16])
# 잘못된 shape — 즉시 AssertionError 발생bad_mean = torch.zeros(32)try: safe_normalize(batch, bad_mean, std)except AssertionError as e: print(f"shape 오류: {e}")핵심 요약
섹션 제목: “핵심 요약”- 브로드캐스팅은 오류 없이 조용히 잘못 동작할 수 있으므로 shape를 항상 확인
expand()는 메모리 복사 없는 뷰,repeat()는 실제 복사 — 읽기 전용이면expand()선호- 연산 전후에
.shape를 출력하는 습관이 가장 효과적인 디버깅 torch.broadcast_shapes()로 연산 전에 호환성을 확인- 중요한 함수에는
assert로 shape를 명시적으로 검증
다음 장에서는 브로드캐스팅을 활용한 실전 딥러닝 패턴을 다룹니다.