데이터 타입 (dtype)
dtype이란?
섹션 제목: “dtype이란?”텐서의 dtype(data type) 은 각 원소가 어떤 숫자 형식으로 저장되는지를 결정합니다. dtype은 메모리 사용량, 연산 속도, 수치 정밀도 모두에 영향을 줍니다.
import torch
x = torch.tensor([1.0, 2.0, 3.0])print(x.dtype) # torch.float32
y = torch.tensor([1, 2, 3])print(y.dtype) # torch.int64PyTorch dtype 전체 목록
섹션 제목: “PyTorch dtype 전체 목록”| dtype | 단축 속성 | 비트 수 | 표현 범위 / 정밀도 | 주요 용도 |
|---|---|---|---|---|
torch.float32 | torch.float | 32 | ±3.4 × 10³⁸, 소수점 ~7자리 | 기본 학습 연산 |
torch.float64 | torch.double | 64 | ±1.8 × 10³⁰⁸, 소수점 ~15자리 | 과학 계산, 고정밀 |
torch.float16 | torch.half | 16 | ±65504, 소수점 ~3자리 | GPU 혼합 정밀도 |
torch.bfloat16 | — | 16 | float32와 동일 지수 범위 | TPU / 최신 GPU |
torch.int8 | — | 8 | -128 ~ 127 | 양자화 모델 |
torch.uint8 | — | 8 | 0 ~ 255 | 이미지 픽셀값 |
torch.int16 | torch.short | 16 | -32768 ~ 32767 | 제한적 사용 |
torch.int32 | torch.int | 32 | -2.1 × 10⁹ ~ 2.1 × 10⁹ | 인덱스, 카운트 |
torch.int64 | torch.long | 64 | -9.2 × 10¹⁸ ~ 9.2 × 10¹⁸ | 기본 정수, 인덱스 |
torch.bool | — | 8 | True / False | 마스크, 조건 |
torch.complex64 | — | 64 | 복소수 (float32 × 2) | 신호 처리 |
torch.complex128 | — | 128 | 복소수 (float64 × 2) | 고정밀 복소수 |
기본 dtype과 변경
섹션 제목: “기본 dtype과 변경”PyTorch는 부동소수점 텐서의 기본 dtype을 float32 로 설정합니다.
import torch
# 기본 dtype 확인print(torch.get_default_dtype()) # torch.float32
# 기본 dtype 변경 (전역 설정)torch.set_default_dtype(torch.float64)x = torch.tensor([1.0, 2.0])print(x.dtype) # torch.float64
# 원래대로 복원torch.set_default_dtype(torch.float32)타입 캐스팅
섹션 제목: “타입 캐스팅”.to() — 범용 변환
섹션 제목: “.to() — 범용 변환”.to() 는 dtype과 device를 동시에 지정할 수 있는 가장 유연한 방법입니다.
import torch
x = torch.tensor([1, 2, 3], dtype=torch.int32)
# dtype만 변경x_float = x.to(torch.float32)print(x_float.dtype) # torch.float32
# dtype과 device를 동시에 변경x_gpu = x.to(device='cuda', dtype=torch.float16)단축 메서드
섹션 제목: “단축 메서드”자주 쓰는 dtype에는 편의 메서드가 있습니다.
import torch
x = torch.tensor([1, 0, 1, 0])
x_float = x.float() # → torch.float32x_double = x.double() # → torch.float64x_half = x.half() # → torch.float16x_long = x.long() # → torch.int64x_int = x.int() # → torch.int32x_bool = x.bool() # → torch.bool
print(x_bool) # tensor([True, False, True, False])캐스팅 시 데이터 손실 주의
섹션 제목: “캐스팅 시 데이터 손실 주의”import torch
# float → int: 소수점 버림 (반올림 아님)x = torch.tensor([1.9, 2.5, 3.1])print(x.int()) # tensor([1, 2, 3])
# 범위 초과: 오버플로우 발생big = torch.tensor([300], dtype=torch.int32)print(big.to(torch.int8)) # tensor([44]) ← 오버플로우!정밀도와 메모리 트레이드오프
섹션 제목: “정밀도와 메모리 트레이드오프”dtype이 다르면 메모리 사용량이 크게 달라집니다.
import torch
shape = (1000, 1000)
t_f64 = torch.zeros(shape, dtype=torch.float64)t_f32 = torch.zeros(shape, dtype=torch.float32)t_f16 = torch.zeros(shape, dtype=torch.float16)t_bf16 = torch.zeros(shape, dtype=torch.bfloat16)
def mb(t): return t.element_size() * t.numel() / 1024 / 1024
print(f"float64: {mb(t_f64):.2f} MB") # float64: 7.63 MBprint(f"float32: {mb(t_f32):.2f} MB") # float32: 3.81 MBprint(f"float16: {mb(t_f16):.2f} MB") # float16: 1.91 MBprint(f"bfloat16: {mb(t_bf16):.2f} MB") # bfloat16: 1.91 MB| dtype | 메모리 (상대) | 정밀도 | 학습 안정성 |
|---|---|---|---|
| float64 | 2x | 최고 | 최고 |
| float32 | 1x (기준) | 높음 | 높음 |
| bfloat16 | 0.5x | 낮음 | float32와 유사 |
| float16 | 0.5x | 낮음 | 주의 필요 (언더플로우) |
float16 / bfloat16 혼합 정밀도
섹션 제목: “float16 / bfloat16 혼합 정밀도”혼합 정밀도(Mixed Precision) 학습은 float16 또는 bfloat16으로 순전파·역전파를 수행하고, 파라미터 업데이트는 float32로 유지하는 기법입니다. 메모리와 속도 모두 개선됩니다.
import torchfrom torch.cuda.amp import autocast, GradScaler
model = MyModel().cuda()optimizer = torch.optim.Adam(model.parameters())scaler = GradScaler() # float16의 언더플로우 방지
for inputs, labels in dataloader: inputs, labels = inputs.cuda(), labels.cuda() optimizer.zero_grad()
# autocast 블록 내에서 float16으로 자동 변환 with autocast(): outputs = model(inputs) loss = criterion(outputs, labels)
# GradScaler로 기울기 스케일링 scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()float16 vs bfloat16 비교
섹션 제목: “float16 vs bfloat16 비교”import torch
# float16: 지수 5비트, 가수 10비트# → 표현 범위 좁음, 언더플로우 위험f16 = torch.tensor(1e-8, dtype=torch.float16)print(f16) # tensor(0., dtype=torch.float16) ← 언더플로우!
# bfloat16: 지수 8비트 (float32와 동일), 가수 7비트# → 표현 범위 float32와 같아 언더플로우 안전bf16 = torch.tensor(1e-8, dtype=torch.bfloat16)print(bf16) # tensor(1.e-08, dtype=torch.bfloat16) ← 정상