PyTorch Training Loop 완전 해부: 각 줄이 하는 일과 순서를 바꾸면 생기는 문제
The annotated PyTorch training loop
TL;DR Highlight
PyTorch 학습 루프의 각 코드 줄이 왜 그 위치에 있어야 하는지, 순서를 바꾸거나 빠뜨렸을 때 어떤 문제가 생기는지를 단계별로 설명한 심층 가이드다.
Who Should Read
PyTorch로 모델을 직접 학습시키는 ML 엔지니어나 딥러닝을 처음 공부하는 개발자 중, 학습이 안 되거나 메모리가 터지는 이유를 제대로 이해하고 싶은 사람.
Core Mechanics
- PyTorch 학습 루프는 겉으로는 단순해 보이지만, 코드 줄의 순서가 조금만 틀려도 에러 없이 조용히 잘못된 결과를 낸다. 수렴 실패, 메모리 폭증, 잘못된 업데이트가 모두 이런 식으로 발생한다.
- model.to(device) 호출은 반드시 optimizer를 만들기 전에 해야 한다. .half() 같은 dtype 변환과 함께 쓸 때 to()가 새 nn.Parameter 객체를 만들어버리는데, optimizer는 이미 버려진 옛 객체를 참조하게 되어 실제 파라미터에 업데이트가 적용되지 않는다.
- optimiser.zero_grad()는 반드시 loss.backward() 전에 호출해야 한다. 뒤에 두면 여러 배치의 gradient가 누적되어 합산된 gradient로 파라미터를 업데이트하게 되고, 의도치 않은 학습이 일어난다.
- clip_grad_norm_()(gradient 폭발을 막기 위해 gradient 크기를 제한하는 함수)은 loss.backward() 이후, optimiser.step() 이전에 넣어야 한다. backward() 전에 두면 gradient가 아직 없으니 아무 효과가 없고, step() 후에 두면 이미 적용된 gradient를 자르는 꼴이라 역시 의미가 없다.
- scheduler.step()(학습률을 epoch마다 조정하는 스케줄러)은 batch 루프 안이 아니라 epoch 루프 안에 있어야 한다. batch 루프 안에 두면 한 epoch에 loader 길이만큼 학습률이 감소해버린다.
- model.eval()을 validation에 쓰고 나서 다시 model.train()으로 돌려놓지 않으면, Dropout이 꺼지고 BatchNorm이 frozen된 상태로 학습이 계속된다. 에러는 전혀 나지 않아서 발견하기 어렵다.
- validation 구간에서 torch.no_grad()를 빠뜨리면, 매 validation batch마다 autograd 계산 그래프가 쌓여서 메모리가 계속 늘어나다가 OOM(Out of Memory)으로 터진다.
- loss를 로깅할 때 loss 텐서 자체를 저장하면(loss.item() 대신), 해당 배치의 전체 계산 그래프가 메모리에 계속 붙잡혀 있게 된다. 반드시 loss.item()으로 Python float 값만 빼서 저장해야 한다.
Evidence
- PyTorch가 이미 복잡한 DL 코드를 꽤 읽기 좋게 만들어준다는 점에서, 이런 breaking point들은 어쩔 수 없는 trade-off라는 의견이 있었다. GPT 같은 모델을 GPU에서 돌릴 수 있는 코드가 이 정도 복잡도면 나쁘지 않다는 시각이다.
- Anthropic이나 OpenAI도 PyTorch를 쓰냐는 질문이 달렸는데, 실제로 두 회사 모두 PyTorch 기반으로 학습 인프라를 운영하는 것으로 알려져 있다.
- 원문 사이트(idlemachines.co.uk)가 courses, workshops, interview prep 등 다양한 컨텐츠를 제공한다는 점에서 관심을 보이는 댓글이 있었다.
- 모바일에서 /courses/foundations 페이지 렌더링이 깨진다는 버그 리포트가 댓글로 올라왔다.
How to Apply
- 처음 PyTorch 학습 루프를 작성할 때, 위의 'TL;DR' 표를 체크리스트로 써라. model.to(device) 위치, zero_grad() 타이밍, scheduler.step() 위치, model.train()/eval() 전환, torch.no_grad() 포함 여부를 순서대로 확인하면 조용한 버그 대부분을 예방할 수 있다.
- 학습이 수렴하지 않거나 메모리가 계속 늘어날 때, 먼저 validation 루프에 torch.no_grad()가 있는지, loss 로깅을 loss.item()으로 하고 있는지 확인해라. 이 두 가지가 메모리 누수의 가장 흔한 원인이다.
- model.half()나 .to(dtype=...)와 model.to(device)를 함께 쓰는 경우, optimizer 생성 코드보다 반드시 먼저 배치해야 한다. 이미 optimizer를 만든 뒤에 dtype 변환을 하면 optimizer가 엉뚱한 파라미터를 업데이트하게 되어 학습이 아예 진행되지 않는다.
- CosineAnnealingLR 같은 epoch 단위 scheduler를 쓸 때는 scheduler.step()을 batch 루프가 아닌 epoch 루프 끝에 두어라. 안 그러면 1 epoch 동안 learning rate가 수백 번 감소해서 초반에 학습률이 거의 0이 되어버린다.
Code Example
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
# --- data ---
dataset = TensorDataset(X_train, y_train)
loader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=2,
pin_memory=True,
persistent_workers=True,
)
# --- model, loss, optimiser ---
# ⚠️ model.to(device)는 반드시 optimiser 생성 전에!
model = MLP(in_features=2, hidden=128, out_features=3)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=100)
# --- loop ---
for epoch in range(100):
model.train() # ⚠️ eval() 후 반드시 다시 train()으로
for X_batch, y_batch in loader:
optimiser.zero_grad() # ⚠️ backward() 전에!
logits = model(X_batch)
loss = criterion(logits, y_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # ⚠️ backward() 후, step() 전!
optimiser.step()
scheduler.step() # ⚠️ batch 루프 밖 epoch 루프 안에!
model.eval()
with torch.no_grad(): # ⚠️ 빠뜨리면 메모리 누수!
val_logits = model(X_val)
val_loss = criterion(val_logits, y_val)
print(f'Epoch {epoch}, val_loss: {val_loss.item():.4f}') # ⚠️ .item()으로 float만!Terminology
관련 논문
Neural Particle Automata: 자기조직화 파티클 시스템을 학습하는 신경망 모델
고정된 격자 대신 움직이는 파티클 위에서 동작하는 Neural Cellular Automata의 확장 버전으로, 형태 생성·포인트 클라우드 분류·텍스처 합성 등 다양한 작업에서 자기조직화 동작을 학습할 수 있다.
좋은 Verifier도 망가질 수 있다: Self-Improving VLM이 새로운 태스크에서 오히려 퇴보하는 현상
VLM 자가학습 루프에서 verifier가 특정 태스크에 맞지 않으면 학습할수록 오히려 성능이 떨어지는데, DPO 손실값은 멀쩡히 내려가서 눈치채기도 어렵다.
Self-Distillation에서 Feedback Alignment의 역할
LLM이 스스로를 가르칠 때, 피드백을 모델의 추론 흐름에 단계별로 맞추면 GRPO보다 16점 이상 수학 추론 성능이 오른다.
작고 수정 가능한 CUDA 기반 Language Model 직접 구현체
CUDA로 작성된 GPT(Generative Pretrained Transformer) 미니멀 구현체로, 텍스트뿐 아니라 모든 바이트 스트림을 학습할 수 있어 LLM 내부 구조를 직접 뜯어보고 싶은 개발자에게 유용하다.
Stanford CS336: Language Modeling from Scratch - LLM을 처음부터 직접 만드는 강의
Stanford에서 운영하는 LLM 전 과정 구현 강의로, 토크나이저부터 데이터 수집, 트랜스포머 구현, 분산 학습, RL 기반 정렬까지 직접 코딩하며 배운다. 이론이 아닌 구현 중심이라 실제로 LLM이 어떻게 작동하는지 깊이 이해하고 싶은 개발자에게 가장 체계적인 커리큘럼 중 하나다.
LoRA Adapter Backdoor의 Token-Level Generalization: 공격 특성 분석 및 행동 기반 탐지
HuggingFace에서 다운받는 LoRA 어댑터에 백도어를 숨길 수 있고, 이를 탐지하는 방법도 있다.