콘텐츠로 이동

그래디언트 계산 심화

PyTorch의 .backward() 는 호출할 때마다 기존 .grad더합니다. 초기화 없이 여러 번 호출하면 그래디언트가 누적되어 잘못된 업데이트가 일어납니다.

import torch
x = torch.tensor(2.0, requires_grad=True)
# 첫 번째 역전파
loss1 = x ** 2
loss1.backward()
print(f"1회 후 x.grad: {x.grad}") # tensor(4.) — 2x = 4
# 두 번째 역전파 (초기화 없이)
loss2 = x ** 2
loss2.backward()
print(f"2회 후 x.grad: {x.grad}") # tensor(8.) — 4 + 4 누적!
# 세 번째 역전파 (초기화 없이)
loss3 = x ** 2
loss3.backward()
print(f"3회 후 x.grad: {x.grad}") # tensor(12.) — 계속 누적

매 학습 스텝 시작 전에 .grad0 으로 초기화합니다.

x = torch.tensor(2.0, requires_grad=True)
for step in range(3):
# 매 스텝 초기화
if x.grad is not None:
x.grad.zero_()
loss = x ** 2 + x
loss.backward()
print(f"스텝 {step+1}: x.grad = {x.grad}") # 항상 tensor(5.)
# 출력:
# 스텝 1: x.grad = tensor(5.)
# 스텝 2: x.grad = tensor(5.)
# 스텝 3: x.grad = tensor(5.)

실제 학습에서는 옵티마이저의 zero_grad() 를 사용합니다:

import torch
import torch.nn as nn
import torch.optim as optim
model = nn.Linear(4, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)
x = torch.randn(8, 4)
y = torch.randn(8, 1)
for epoch in range(3):
optimizer.zero_grad() # ← 매 에폭마다 그래디언트 초기화
pred = model(x)
loss = nn.MSELoss()(pred, y)
loss.backward()
optimizer.step() # 파라미터 업데이트
print(f"에폭 {epoch+1}: loss = {loss.item():.4f}")

loss = a²b + b³ 계산 그래프 (a=2, b=3)
a값: 2.0b값: 3.0a²b값: 12.0값: 27.0loss값: 39.0

추론(inference) 또는 검증(validation) 단계에서는 그래디언트 계산이 필요 없습니다. torch.no_grad() 컨텍스트 안에서는 연산이 추적되지 않아 메모리와 속도를 절약합니다.

x = torch.tensor(3.0, requires_grad=True)
# 추적 활성화 상태
y = x ** 2
print(y.requires_grad) # True
print(y.grad_fn) # <PowBackward0>
# 추적 비활성화
with torch.no_grad():
y_no_grad = x ** 2
print(y_no_grad.requires_grad) # False
print(y_no_grad.grad_fn) # None

추론 시 성능 비교:

import time
model = nn.Linear(1000, 1000)
x = torch.randn(256, 1000)
# 추적 활성화 (학습 모드)
start = time.time()
for _ in range(100):
_ = model(x)
print(f"추적 활성화: {time.time() - start:.3f}초")
# 추적 비활성화 (추론 모드)
start = time.time()
with torch.no_grad():
for _ in range(100):
_ = model(x)
print(f"추적 비활성화: {time.time() - start:.3f}초")
# 추론 모드가 일반적으로 더 빠름

detach() 는 텐서를 계산 그래프에서 분리하여 새 텐서를 반환합니다. no_grad() 와 달리, 텐서 단위로 선택적으로 분리할 수 있습니다.

x = torch.tensor(3.0, requires_grad=True)
y = x ** 2 # 계산 그래프에 연결됨
# detach: 같은 값이지만 그래프에서 분리된 텐서
y_detached = y.detach()
print(y_detached.requires_grad) # False
print(y_detached.grad_fn) # None
print(y_detached.item()) # 9.0 — 값은 동일

실전 활용 — 타겟 네트워크(Target Network):

강화학습에서 타겟 네트워크의 출력은 그래디언트 전파 없이 손실 계산에만 사용합니다.

# 메인 네트워크와 타겟 네트워크
main_net = nn.Linear(4, 2)
target_net = nn.Linear(4, 2)
state = torch.randn(8, 4)
main_q = main_net(state) # 그래디언트 추적됨
target_q = target_net(state).detach() # 그래디언트 전파 차단
loss = nn.MSELoss()(main_q, target_q)
loss.backward() # main_net만 업데이트, target_net은 그대로

backward() 를 호출하면 계산 그래프가 메모리에서 해제됩니다. 같은 그래프로 여러 번 역전파해야 할 때는 retain_graph=True 를 사용합니다.

x = torch.tensor(2.0, requires_grad=True)
y = x ** 3 # y = x³
# 첫 번째 backward — 그래프가 해제됨
y.backward(retain_graph=True)
print(f"1회: x.grad = {x.grad}") # tensor(12.) — 3x² = 12
x.grad.zero_()
# 두 번째 backward — retain_graph=True 덕분에 가능
y.backward(retain_graph=True)
print(f"2회: x.grad = {x.grad}") # tensor(12.)
# retain_graph=True 없이 두 번 호출하면 RuntimeError 발생
x2 = torch.tensor(2.0, requires_grad=True)
y2 = x2 ** 3
y2.backward() # 그래프 해제됨
try:
y2.backward()
except RuntimeError as e:
print(f"오류: {e}")
# Trying to backward through the graph a second time...

항목torch.no_grad().detach()
적용 범위블록 내 모든 연산특정 텐서 하나
사용 방식컨텍스트 관리자텐서 메서드
반환값동일 텐서새 텐서
주요 용도추론/검증 전체일부 경로 차단

  • .backward().grad누적하므로 매 스텝 zero_grad() 필수
  • torch.no_grad() 는 추론 시 메모리·속도 최적화
  • detach() 는 텐서 단위로 선택적 그래프 분리
  • retain_graph=True 는 동일 그래프로 여러 번 역전파할 때만 사용

다음 장에서는 GPU 프로그래밍 기초와 텐서를 GPU로 이동하는 방법을 다룹니다.