콘텐츠로 이동

브로드캐스팅 실전 패턴

두 벡터의 외적은 (n, 1)(1, m) 의 브로드캐스팅으로 구현합니다. 루프 없이 n×m 행렬을 한 번에 계산합니다.

a: (4, 1) [열 벡터]
b: (1, 5) [행 벡터]
↓ 브로드캐스팅
결과: (4, 5)
import torch
a = torch.tensor([1.0, 2.0, 3.0, 4.0]).unsqueeze(1) # (4, 1)
b = torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0]).unsqueeze(0) # (1, 5)
outer = a * b
print(outer.shape) # torch.Size([4, 5])
print(outer)
# tensor([[ 10., 20., 30., 40., 50.],
# [ 20., 40., 60., 80., 100.],
# [ 30., 60., 90., 120., 150.],
# [ 40., 80., 120., 160., 200.]])
# torch.outer로도 동일하게 계산 가능
outer2 = torch.outer(a.squeeze(), b.squeeze())
print(torch.allclose(outer, outer2)) # True

활용 예시: 임베딩 행렬 초기화, 어텐션 스코어 계산, 공분산 행렬 근사 등에 사용됩니다.


배치 정규화(Batch Normalization)는 브로드캐스팅의 핵심 활용 사례입니다. 채널별 평균과 분산으로 배치 전체를 정규화합니다.

입력: (배치, 채널, 높이, 너비) = (8, 16, 32, 32)
채널 평균: (1, 16, 1, 1) ← 브로드캐스팅으로 배치 전체에 적용
import torch
# 배치 입력: (배치=8, 채널=16, H=32, W=32)
x = torch.randn(8, 16, 32, 32)
# 채널 차원(dim=1)을 제외한 모든 차원에서 평균/분산 계산
mean = x.mean(dim=(0, 2, 3), keepdim=True) # shape: (1, 16, 1, 1)
var = x.var(dim=(0, 2, 3), keepdim=True) # shape: (1, 16, 1, 1)
print(mean.shape) # torch.Size([1, 16, 1, 1])
# 브로드캐스팅으로 정규화 — (8,16,32,32) 와 (1,16,1,1) 연산
x_norm = (x - mean) / (var + 1e-5).sqrt()
print(x_norm.shape) # torch.Size([8, 16, 32, 32])
# 학습 가능한 스케일(γ)과 이동(β) 파라미터 적용
gamma = torch.ones(1, 16, 1, 1) # 채널별 스케일
beta = torch.zeros(1, 16, 1, 1) # 채널별 이동
output = gamma * x_norm + beta
print(output.shape) # torch.Size([8, 16, 32, 32])

N개의 점과 M개의 점 사이의 모든 유클리드 거리를 루프 없이 계산합니다. k-NN, 클러스터링, 메트릭 학습에서 핵심 연산입니다.

A: (N, D) — N개의 D차원 점
B: (M, D) — M개의 D차원 점
목표: 모든 (i, j) 쌍의 거리 행렬 (N, M)
import torch
N, M, D = 5, 7, 3
A = torch.randn(N, D) # (5, 3)
B = torch.randn(M, D) # (7, 3)
# 브로드캐스팅 방식:
# A를 (5, 1, 3)으로, B를 (1, 7, 3)으로 만들면
# 차이: (5, 7, 3) — 모든 쌍의 차이 벡터
A_exp = A.unsqueeze(1) # (5, 1, 3)
B_exp = B.unsqueeze(0) # (1, 7, 3)
diff = A_exp - B_exp # (5, 7, 3)
dist = diff.pow(2).sum(dim=2).sqrt() # (5, 7)
print(dist.shape) # torch.Size([5, 7])
print(dist) # 5×7 거리 행렬
# PyTorch 내장 함수로도 동일한 결과
dist2 = torch.cdist(A, B)
print(torch.allclose(dist, dist2, atol=1e-5)) # True

연산 비교:

방법코드속도
이중 루프for i in range(N): for j in range(M)가장 느림
브로드캐스팅A.unsqueeze(1) - B.unsqueeze(0)빠름
torch.cdisttorch.cdist(A, B)가장 빠름

트랜스포머의 어텐션에서 패딩 토큰을 무시하는 마스크 연산도 브로드캐스팅입니다.

# 어텐션 스코어: (배치, 헤드 수, 시퀀스 길이, 시퀀스 길이)
batch_size, n_heads, seq_len = 4, 8, 16
attn_scores = torch.randn(batch_size, n_heads, seq_len, seq_len)
# 패딩 마스크: (배치, 시퀀스 길이) — 패딩이면 True
padding_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
padding_mask[:, 12:] = True # 마지막 4개 토큰이 패딩
# (4, 16) → (4, 1, 1, 16) 으로 변환하여 브로드캐스팅
mask = padding_mask.unsqueeze(1).unsqueeze(2) # (4, 1, 1, 16)
print(mask.shape) # torch.Size([4, 1, 1, 16])
# (4, 8, 16, 16) 에 (4, 1, 1, 16) 마스크 적용
attn_scores = attn_scores.masked_fill(mask, float('-inf'))
print(attn_scores.shape) # torch.Size([4, 8, 16, 16])

트랜스포머의 위치 임베딩은 배치 전체에 동일하게 더해집니다.

# 입력 임베딩: (배치, 시퀀스, 임베딩 차원)
batch_size, seq_len, embed_dim = 8, 16, 512
token_embeddings = torch.randn(batch_size, seq_len, embed_dim)
# 위치 임베딩: (1, 시퀀스, 임베딩 차원) — 배치에 무관
pos_embedding = torch.randn(1, seq_len, embed_dim)
# (8, 16, 512) + (1, 16, 512) → (8, 16, 512)
output = token_embeddings + pos_embedding
print(output.shape) # torch.Size([8, 16, 512])

  • 외적 패턴: (n, 1) * (1, m)(n, m) — 루프 없이 모든 쌍 계산
  • 배치 정규화: keepdim=True + 브로드캐스팅으로 채널별 정규화
  • 거리 행렬: unsqueeze 로 차원 정렬 후 브로드캐스팅, 대규모에는 torch.cdist
  • 어텐션 마스크: unsqueeze 로 차원을 맞춰 배치·헤드 전체에 동일 마스크 적용
  • 위치 임베딩: (1, seq, dim)(batch, seq, dim) 에 더하면 배치 전체에 적용

퀴즈를 불러오는 중...