PPO를 활용한 언어 모델의 Tree Search Distillation
Tree Search Distillation for Language Models Using PPO
TL;DR Highlight
AlphaZero 스타일 MCTS로 탐색한 추론 경로를 PPO로 증류한 방법이 표준 RL 방법 GRPO보다 높은 성능을 달성했다.
Who Should Read
LLM의 추론 능력을 강화하는 방법을 연구하거나, RL 기반 학습(GRPO, PPO 등)을 언어 모델에 적용해보려는 ML 엔지니어 및 연구자.
Core Mechanics
- AlphaZero 같은 게임 AI는 '정책(policy) + 탐색(search) + 증류(distillation)' 사이클로 성능을 높이는데, 언어 모델에서는 이 방법이 잘 쓰이지 않는다. DeepSeek-R1 팀도 MCTS를 시도했지만 효과가 제한적이었다고 밝혔는데, 이는 탐색 알고리즘으로 UCT 대신 pUCT를 썼어야 한다는 분석이 있다.
- 실험 모델은 Qwen-2.5-1.5B-Instruct이고, 태스크는 Countdown(주어진 정수 4개로 사칙연산을 이용해 목표 숫자를 만드는 조합 산술 게임)이다. GSM8K로 먼저 시도했지만 GRPO와 MCTS 간 차이가 미미해서 조합적 탐색이 더 유리한 Countdown으로 전환했다.
- 성능 결과: 증류된 모델(탐색 없이 단독 추론)이 mean@16 기준 11.3%를 달성했고, 비교 대상인 CISPO는 8.4%, best-of-N은 7.7%, 사전 RL 모델은 3.1%였다. 절대 수치가 낮은 이유는 1.5B라는 작은 모델 크기 때문이며, 향후 더 큰 모델로 실험을 이어갈 예정이다.
- 보상 함수 설계가 중요했다. 처음에 정답/오답만 주는 sparse reward(0/1)를 쓰면 학습이 불안정했다. 대신 예측값과 정답의 차이에 비례하는 dense reward($1.0 - 2 \cdot \min(|t-p|/t, 1.0)$)를 학습에 사용하고, 평가는 여전히 sparse reward로 진행했다.
- 토큰 단위가 아닌 추론 스텝(reasoning step) 단위로 MCTS를 적용했다. 토큰 단위로 분기하면 'but', 'however', 'yet' 같은 기능어에서도 가지가 갈라져 탐색 트리가 비효율적으로 커진다. 대신 Tree-of-Thoughts 방식처럼 `<step>...</step>` 태그로 묶인 추론 단계를 하나의 노드로 취급했다.
- 탐색 다양성을 높이기 위해 N개의 에이전트가 같은 샘플의 탐색 트리를 공유하는 parallel MCTS를 구현했다. 각 에이전트는 virtual loss를 써서 서로 다른 경로를 탐색하도록 유도했다. pUCT에서 필요한 행동(action) 사전 확률은 시퀀스 로그확률을 합산한 뒤 softmax를 취해 계산했다.
- 모델에 value head(MLP + tanh)를 추가해 현재 상태의 가치를 예측하게 했다. 이 value head는 학습 중 점차 개선되면서 MCTS가 더 좋은 탐색 경로를 찾도록 안내한다. MCTS로 찾은 더 강한 추론 경로를 PPO 루프를 통해 모델 가중치에 온라인으로 증류하는 구조다.
Evidence
- MCTS를 학습 시 탐색에만 쓰고 PPO로 증류하면, 추론 시에는 탐색 없이 모델만 쓰는 것이므로 추론 비용이 GRPO와 동일한 것 아니냐는 질문이 있었다. 원문에서 'MCTS는 샘플당 추론 컴퓨팅을 더 많이 쓰니까 당연히 성능이 좋다'는 표현이 있는데, 증류된 모델은 추론 시 탐색을 쓰지 않으므로 이 표현이 혼란스럽다는 지적이었다.
- MCTS를 증류 없이 test-time compute harness로만 쓸 때 같은 컴퓨팅 예산 기준으로 best-of-N과 비교해봤는지 묻는 댓글이 있었다. 이는 MCTS의 탐색 효율 자체를 검증하는 중요한 비교인데, 원문에서는 이 비교가 명시적으로 다뤄지지 않았다.
- MCTS가 test-time compute 방법으로 왜 더 많이 쓰이지 않는지 의문을 표하는 댓글이 있었다. 언어 모델에서 MCTS가 어려운 이유(토큰 단위 탐색의 비효율, value function 학습의 어려움 등)에 대한 관심이 있었고, 이 글이 그 가능성을 탐구하는 시도라는 점에서 긍정적인 반응이 있었다.
How to Apply
- 소규모 조합 최적화 또는 수학 추론 태스크를 다루는 경우, 토큰 단위가 아닌 추론 스텝 단위로 MCTS를 적용하면 탐색 트리의 크기를 제어하면서도 다양한 추론 경로를 탐색할 수 있다. `<step>...</step>` 같은 구조화된 태그를 프롬프트에 도입하고, 각 스텝 완성 시점을 MCTS 노드 전환점으로 삼으면 된다.
- RL 기반 파인튜닝 시 sparse reward(정답/오답 0/1)로 학습이 불안정하다면, 예측값과 정답 사이의 거리에 비례하는 dense reward 함수를 학습에 사용하고 평가만 sparse reward로 유지하는 방식을 고려해볼 수 있다. 이 글의 공식은 $1.0 - 2 \cdot \min(|t-p|/t, 1.0)$이다.
- MCTS 탐색 다양성이 부족하다면 virtual loss를 도입한 parallel MCTS를 적용할 수 있다. 여러 에이전트가 같은 탐색 트리를 공유하되 서로 다른 노드를 방문하도록 유도하면 같은 컴퓨팅 예산으로 더 넓은 탐색이 가능하다.
- pUCT 선택 확률 계산 시 raw 누적 시퀀스 확률 대신 로그확률 합산 후 softmax를 적용하면 수치적 언더플로 문제를 피할 수 있다. 특히 긴 시퀀스를 다루는 경우 이 처리가 학습 안정성에 중요하다.
Terminology
관련 논문
Neural Particle Automata: 자기조직화 파티클 시스템을 학습하는 신경망 모델
고정된 격자 대신 움직이는 파티클 위에서 동작하는 Neural Cellular Automata의 확장 버전으로, 형태 생성·포인트 클라우드 분류·텍스처 합성 등 다양한 작업에서 자기조직화 동작을 학습할 수 있다.
PyTorch Training Loop 완전 해부: 각 줄이 하는 일과 순서를 바꾸면 생기는 문제
PyTorch 학습 루프의 각 코드 줄이 왜 그 위치에 있어야 하는지, 순서를 바꾸거나 빠뜨렸을 때 어떤 문제가 생기는지를 단계별로 설명한 심층 가이드다.
좋은 Verifier도 망가질 수 있다: Self-Improving VLM이 새로운 태스크에서 오히려 퇴보하는 현상
VLM 자가학습 루프에서 verifier가 특정 태스크에 맞지 않으면 학습할수록 오히려 성능이 떨어지는데, DPO 손실값은 멀쩡히 내려가서 눈치채기도 어렵다.
Self-Distillation에서 Feedback Alignment의 역할
LLM이 스스로를 가르칠 때, 피드백을 모델의 추론 흐름에 단계별로 맞추면 GRPO보다 16점 이상 수학 추론 성능이 오른다.
작고 수정 가능한 CUDA 기반 Language Model 직접 구현체
CUDA로 작성된 GPT(Generative Pretrained Transformer) 미니멀 구현체로, 텍스트뿐 아니라 모든 바이트 스트림을 학습할 수 있어 LLM 내부 구조를 직접 뜯어보고 싶은 개발자에게 유용하다.
Stanford CS336: Language Modeling from Scratch - LLM을 처음부터 직접 만드는 강의
Stanford에서 운영하는 LLM 전 과정 구현 강의로, 토크나이저부터 데이터 수집, 트랜스포머 구현, 분산 학습, RL 기반 정렬까지 직접 코딩하며 배운다. 이론이 아닌 구현 중심이라 실제로 LLM이 어떻게 작동하는지 깊이 이해하고 싶은 개발자에게 가장 체계적인 커리큘럼 중 하나다.