AI 모델 서빙에 FastAPI가 적합한 이유
AI 엔지니어 관점에서 FastAPI가 Flask나 Django보다 ML 서빙에 더 잘 맞는 이유는 단순히 “빠르기” 때문이 아니다. 비동기 I/O, 타입 기반 스키마, 스트리밍, 의존성 주입이 ML 워크로드의 특성과 정확히 맞아 떨어진다.
모델을 싱글턴으로 로드하기
섹션 제목: “모델을 싱글턴으로 로드하기”가장 흔한 실수는 요청마다 모델을 로드하는 것이다. HuggingFace 파이프라인 하나를 로드하는 데 수 초가 걸릴 수 있다.
# 잘못된 방법 — 요청마다 모델 로드from fastapi import FastAPIfrom transformers import pipeline
app = FastAPI()
@app.post("/predict")async def predict(text: str): pipe = pipeline("sentiment-analysis") # 매 요청마다 로드! return pipe(text)lifespan 컨텍스트 매니저로 앱 시작 시 한 번만 로드한다.
from contextlib import asynccontextmanagerfrom fastapi import FastAPI, Dependsfrom transformers import pipelinefrom typing import Annotated
# 모델을 저장할 전역 딕셔너리ml_models: dict = {}
@asynccontextmanagerasync def lifespan(app: FastAPI): # 앱 시작 시 모델 로드 ml_models["sentiment"] = pipeline( "sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=-1, # CPU; GPU면 0 ) yield # 앱 종료 시 정리 ml_models.clear()
app = FastAPI(lifespan=lifespan)
def get_sentiment_model(): return ml_models["sentiment"]
@app.post("/predict")async def predict( text: str, model: Annotated[object, Depends(get_sentiment_model)],): result = model(text)[0] return { "label": result["label"], "score": round(result["score"], 4), }Pydantic으로 ML 입력 스키마 정의하기
섹션 제목: “Pydantic으로 ML 입력 스키마 정의하기”ML 모델의 입력은 복잡한 경우가 많다. Pydantic이 유효성 검사를 자동으로 처리한다.
from pydantic import BaseModel, Field, field_validatorfrom typing import Optional, Literal
class PredictRequest(BaseModel): text: str = Field(..., min_length=1, max_length=512) language: Literal["ko", "en", "ja"] = "en" threshold: float = Field(0.5, ge=0.0, le=1.0) top_k: Optional[int] = Field(None, ge=1, le=10)
@field_validator("text") @classmethod def strip_text(cls, v: str) -> str: return v.strip()
class PredictResponse(BaseModel): label: str score: float model_version: str = "1.0.0"
@app.post("/predict", response_model=PredictResponse)async def predict( req: PredictRequest, model: Annotated[object, Depends(get_sentiment_model)],) -> PredictResponse: result = model(req.text)[0] return PredictResponse( label=result["label"], score=result["score"], )response_model=PredictResponse를 지정하면 응답에서 내부 필드가 자동으로 제거되고, /docs에 응답 스키마가 표시된다.
LLM 토큰 스트리밍
섹션 제목: “LLM 토큰 스트리밍”ChatGPT처럼 토큰을 생성되는 즉시 전송하려면 StreamingResponse를 쓴다.
from fastapi import FastAPIfrom fastapi.responses import StreamingResponsefrom transformers import AutoTokenizer, AutoModelForCausalLMimport torchfrom typing import AsyncGenerator
async def token_stream(prompt: str) -> AsyncGenerator[str, None]: """토큰을 생성할 때마다 Server-Sent Events 형식으로 전송""" tokenizer = ml_models["tokenizer"] model = ml_models["llm"]
inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"]
with torch.no_grad(): for _ in range(200): # 최대 200 토큰 outputs = model(input_ids) next_token_id = outputs.logits[:, -1, :].argmax(dim=-1) token_text = tokenizer.decode(next_token_id[0])
# SSE 형식: "data: <내용>\n\n" yield f"data: {token_text}\n\n"
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=-1)
if next_token_id.item() == tokenizer.eos_token_id: break
yield "data: [DONE]\n\n"
@app.post("/generate")async def generate(prompt: str): return StreamingResponse( token_stream(prompt), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", # nginx 버퍼링 비활성화 }, )클라이언트에서는 EventSource(브라우저) 또는 httpx의 스트리밍 API로 소비한다.
# Python 클라이언트 예제import httpx
with httpx.Client() as client: with client.stream("POST", "http://localhost:8000/generate", params={"prompt": "FastAPI는"}) as response: for line in response.iter_lines(): if line.startswith("data: ") and not line.endswith("[DONE]"): print(line[6:], end="", flush=True)BackgroundTasks로 긴 추론 처리
섹션 제목: “BackgroundTasks로 긴 추론 처리”모델 추론이 수십 초 걸린다면, 즉시 job ID를 반환하고 결과를 비동기로 처리한다.
from fastapi import BackgroundTasksfrom uuid import uuid4
# 간단한 인메모리 작업 저장소 (프로덕션에서는 Redis 사용)job_store: dict = {}
def run_inference(job_id: str, text: str) -> None: """백그라운드에서 실행되는 무거운 추론""" job_store[job_id] = {"status": "running"} try: model = ml_models["heavy_model"] result = model(text) job_store[job_id] = {"status": "done", "result": result} except Exception as e: job_store[job_id] = {"status": "error", "detail": str(e)}
@app.post("/jobs", status_code=202)async def submit_job(text: str, background_tasks: BackgroundTasks): job_id = str(uuid4()) job_store[job_id] = {"status": "queued"} background_tasks.add_task(run_inference, job_id, text) return {"job_id": job_id, "status": "queued"}
@app.get("/jobs/{job_id}")async def get_job(job_id: str): if job_id not in job_store: raise HTTPException(status_code=404, detail="Job not found") return job_store[job_id]FastAPI vs Flask — ML 서빙 관점 비교
섹션 제목: “FastAPI vs Flask — ML 서빙 관점 비교”| 항목 | FastAPI | Flask |
|---|---|---|
| 비동기 지원 | 네이티브 async/await | 별도 설정 필요 |
| 입력 유효성 검사 | Pydantic 자동 | 수동 구현 |
| 자동 문서화 | /docs 기본 제공 | 플러그인 필요 |
| 스트리밍 응답 | StreamingResponse 내장 | Response(stream_with_context) |
| 타입 안전성 | 타입 힌트 기반 | 없음 |
| 성능 (동시 I/O) | 높음 (이벤트 루프) | 낮음 (스레드 기반) |
핵심 정리
섹션 제목: “핵심 정리”- 모델은
lifespan에서 한 번만 로드 — 요청마다 재로드는 치명적 성능 저하다 - Pydantic으로 ML 입력 스키마 정의 — 유효성 검사와 문서화를 동시에 해결한다
StreamingResponse로 LLM 토큰 스트리밍 —text/event-stream과 SSE 형식을 사용한다BackgroundTasks로 무거운 추론 분리 — 즉시 202를 반환하고 폴링으로 결과를 확인한다