콘텐츠로 이동

AI 모델 서빙에 FastAPI가 적합한 이유

AI 엔지니어 관점에서 FastAPI가 Flask나 Django보다 ML 서빙에 더 잘 맞는 이유는 단순히 “빠르기” 때문이 아니다. 비동기 I/O, 타입 기반 스키마, 스트리밍, 의존성 주입이 ML 워크로드의 특성과 정확히 맞아 떨어진다.

가장 흔한 실수는 요청마다 모델을 로드하는 것이다. HuggingFace 파이프라인 하나를 로드하는 데 수 초가 걸릴 수 있다.

# 잘못된 방법 — 요청마다 모델 로드
from fastapi import FastAPI
from transformers import pipeline
app = FastAPI()
@app.post("/predict")
async def predict(text: str):
pipe = pipeline("sentiment-analysis") # 매 요청마다 로드!
return pipe(text)

lifespan 컨텍스트 매니저로 앱 시작 시 한 번만 로드한다.

from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends
from transformers import pipeline
from typing import Annotated
# 모델을 저장할 전역 딕셔너리
ml_models: dict = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# 앱 시작 시 모델 로드
ml_models["sentiment"] = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
device=-1, # CPU; GPU면 0
)
yield
# 앱 종료 시 정리
ml_models.clear()
app = FastAPI(lifespan=lifespan)
def get_sentiment_model():
return ml_models["sentiment"]
@app.post("/predict")
async def predict(
text: str,
model: Annotated[object, Depends(get_sentiment_model)],
):
result = model(text)[0]
return {
"label": result["label"],
"score": round(result["score"], 4),
}

Pydantic으로 ML 입력 스키마 정의하기

섹션 제목: “Pydantic으로 ML 입력 스키마 정의하기”

ML 모델의 입력은 복잡한 경우가 많다. Pydantic이 유효성 검사를 자동으로 처리한다.

from pydantic import BaseModel, Field, field_validator
from typing import Optional, Literal
class PredictRequest(BaseModel):
text: str = Field(..., min_length=1, max_length=512)
language: Literal["ko", "en", "ja"] = "en"
threshold: float = Field(0.5, ge=0.0, le=1.0)
top_k: Optional[int] = Field(None, ge=1, le=10)
@field_validator("text")
@classmethod
def strip_text(cls, v: str) -> str:
return v.strip()
class PredictResponse(BaseModel):
label: str
score: float
model_version: str = "1.0.0"
@app.post("/predict", response_model=PredictResponse)
async def predict(
req: PredictRequest,
model: Annotated[object, Depends(get_sentiment_model)],
) -> PredictResponse:
result = model(req.text)[0]
return PredictResponse(
label=result["label"],
score=result["score"],
)

response_model=PredictResponse를 지정하면 응답에서 내부 필드가 자동으로 제거되고, /docs에 응답 스키마가 표시된다.

ChatGPT처럼 토큰을 생성되는 즉시 전송하려면 StreamingResponse를 쓴다.

from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import AsyncGenerator
async def token_stream(prompt: str) -> AsyncGenerator[str, None]:
"""토큰을 생성할 때마다 Server-Sent Events 형식으로 전송"""
tokenizer = ml_models["tokenizer"]
model = ml_models["llm"]
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
with torch.no_grad():
for _ in range(200): # 최대 200 토큰
outputs = model(input_ids)
next_token_id = outputs.logits[:, -1, :].argmax(dim=-1)
token_text = tokenizer.decode(next_token_id[0])
# SSE 형식: "data: <내용>\n\n"
yield f"data: {token_text}\n\n"
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=-1)
if next_token_id.item() == tokenizer.eos_token_id:
break
yield "data: [DONE]\n\n"
@app.post("/generate")
async def generate(prompt: str):
return StreamingResponse(
token_stream(prompt),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # nginx 버퍼링 비활성화
},
)

클라이언트에서는 EventSource(브라우저) 또는 httpx의 스트리밍 API로 소비한다.

# Python 클라이언트 예제
import httpx
with httpx.Client() as client:
with client.stream("POST", "http://localhost:8000/generate",
params={"prompt": "FastAPI는"}) as response:
for line in response.iter_lines():
if line.startswith("data: ") and not line.endswith("[DONE]"):
print(line[6:], end="", flush=True)

모델 추론이 수십 초 걸린다면, 즉시 job ID를 반환하고 결과를 비동기로 처리한다.

from fastapi import BackgroundTasks
from uuid import uuid4
# 간단한 인메모리 작업 저장소 (프로덕션에서는 Redis 사용)
job_store: dict = {}
def run_inference(job_id: str, text: str) -> None:
"""백그라운드에서 실행되는 무거운 추론"""
job_store[job_id] = {"status": "running"}
try:
model = ml_models["heavy_model"]
result = model(text)
job_store[job_id] = {"status": "done", "result": result}
except Exception as e:
job_store[job_id] = {"status": "error", "detail": str(e)}
@app.post("/jobs", status_code=202)
async def submit_job(text: str, background_tasks: BackgroundTasks):
job_id = str(uuid4())
job_store[job_id] = {"status": "queued"}
background_tasks.add_task(run_inference, job_id, text)
return {"job_id": job_id, "status": "queued"}
@app.get("/jobs/{job_id}")
async def get_job(job_id: str):
if job_id not in job_store:
raise HTTPException(status_code=404, detail="Job not found")
return job_store[job_id]

FastAPI vs Flask — ML 서빙 관점 비교

섹션 제목: “FastAPI vs Flask — ML 서빙 관점 비교”
항목FastAPIFlask
비동기 지원네이티브 async/await별도 설정 필요
입력 유효성 검사Pydantic 자동수동 구현
자동 문서화/docs 기본 제공플러그인 필요
스트리밍 응답StreamingResponse 내장Response(stream_with_context)
타입 안전성타입 힌트 기반없음
성능 (동시 I/O)높음 (이벤트 루프)낮음 (스레드 기반)
  • 모델은 lifespan에서 한 번만 로드 — 요청마다 재로드는 치명적 성능 저하다
  • Pydantic으로 ML 입력 스키마 정의 — 유효성 검사와 문서화를 동시에 해결한다
  • StreamingResponse로 LLM 토큰 스트리밍text/event-stream과 SSE 형식을 사용한다
  • BackgroundTasks로 무거운 추론 분리 — 즉시 202를 반환하고 폴링으로 결과를 확인한다