PPO를 활용한 언어 모델의 Tree Search Distillation
Tree Search Distillation for Language Models Using PPO
TL;DR Highlight
AlphaZero 스타일 MCTS로 탐색한 추론 경로를 PPO로 증류한 방법이 표준 RL 방법 GRPO보다 높은 성능을 달성했다.
Who Should Read
LLM의 추론 능력을 강화하는 방법을 연구하거나, RL 기반 학습(GRPO, PPO 등)을 언어 모델에 적용해보려는 ML 엔지니어 및 연구자.
Core Mechanics
- AlphaZero 같은 게임 AI는 '정책(policy) + 탐색(search) + 증류(distillation)' 사이클로 성능을 높이는데, 언어 모델에서는 이 방법이 잘 쓰이지 않는다. DeepSeek-R1 팀도 MCTS를 시도했지만 효과가 제한적이었다고 밝혔는데, 이는 탐색 알고리즘으로 UCT 대신 pUCT를 썼어야 한다는 분석이 있다.
- 실험 모델은 Qwen-2.5-1.5B-Instruct이고, 태스크는 Countdown(주어진 정수 4개로 사칙연산을 이용해 목표 숫자를 만드는 조합 산술 게임)이다. GSM8K로 먼저 시도했지만 GRPO와 MCTS 간 차이가 미미해서 조합적 탐색이 더 유리한 Countdown으로 전환했다.
- 성능 결과: 증류된 모델(탐색 없이 단독 추론)이 mean@16 기준 11.3%를 달성했고, 비교 대상인 CISPO는 8.4%, best-of-N은 7.7%, 사전 RL 모델은 3.1%였다. 절대 수치가 낮은 이유는 1.5B라는 작은 모델 크기 때문이며, 향후 더 큰 모델로 실험을 이어갈 예정이다.
- 보상 함수 설계가 중요했다. 처음에 정답/오답만 주는 sparse reward(0/1)를 쓰면 학습이 불안정했다. 대신 예측값과 정답의 차이에 비례하는 dense reward($1.0 - 2 \cdot \min(|t-p|/t, 1.0)$)를 학습에 사용하고, 평가는 여전히 sparse reward로 진행했다.
- 토큰 단위가 아닌 추론 스텝(reasoning step) 단위로 MCTS를 적용했다. 토큰 단위로 분기하면 'but', 'however', 'yet' 같은 기능어에서도 가지가 갈라져 탐색 트리가 비효율적으로 커진다. 대신 Tree-of-Thoughts 방식처럼 `<step>...</step>` 태그로 묶인 추론 단계를 하나의 노드로 취급했다.
- 탐색 다양성을 높이기 위해 N개의 에이전트가 같은 샘플의 탐색 트리를 공유하는 parallel MCTS를 구현했다. 각 에이전트는 virtual loss를 써서 서로 다른 경로를 탐색하도록 유도했다. pUCT에서 필요한 행동(action) 사전 확률은 시퀀스 로그확률을 합산한 뒤 softmax를 취해 계산했다.
- 모델에 value head(MLP + tanh)를 추가해 현재 상태의 가치를 예측하게 했다. 이 value head는 학습 중 점차 개선되면서 MCTS가 더 좋은 탐색 경로를 찾도록 안내한다. MCTS로 찾은 더 강한 추론 경로를 PPO 루프를 통해 모델 가중치에 온라인으로 증류하는 구조다.
Evidence
- MCTS를 학습 시 탐색에만 쓰고 PPO로 증류하면, 추론 시에는 탐색 없이 모델만 쓰는 것이므로 추론 비용이 GRPO와 동일한 것 아니냐는 질문이 있었다. 원문에서 'MCTS는 샘플당 추론 컴퓨팅을 더 많이 쓰니까 당연히 성능이 좋다'는 표현이 있는데, 증류된 모델은 추론 시 탐색을 쓰지 않으므로 이 표현이 혼란스럽다는 지적이었다.
- MCTS를 증류 없이 test-time compute harness로만 쓸 때 같은 컴퓨팅 예산 기준으로 best-of-N과 비교해봤는지 묻는 댓글이 있었다. 이는 MCTS의 탐색 효율 자체를 검증하는 중요한 비교인데, 원문에서는 이 비교가 명시적으로 다뤄지지 않았다.
- MCTS가 test-time compute 방법으로 왜 더 많이 쓰이지 않는지 의문을 표하는 댓글이 있었다. 언어 모델에서 MCTS가 어려운 이유(토큰 단위 탐색의 비효율, value function 학습의 어려움 등)에 대한 관심이 있었고, 이 글이 그 가능성을 탐구하는 시도라는 점에서 긍정적인 반응이 있었다.
How to Apply
- 소규모 조합 최적화 또는 수학 추론 태스크를 다루는 경우, 토큰 단위가 아닌 추론 스텝 단위로 MCTS를 적용하면 탐색 트리의 크기를 제어하면서도 다양한 추론 경로를 탐색할 수 있다. `<step>...</step>` 같은 구조화된 태그를 프롬프트에 도입하고, 각 스텝 완성 시점을 MCTS 노드 전환점으로 삼으면 된다.
- RL 기반 파인튜닝 시 sparse reward(정답/오답 0/1)로 학습이 불안정하다면, 예측값과 정답 사이의 거리에 비례하는 dense reward 함수를 학습에 사용하고 평가만 sparse reward로 유지하는 방식을 고려해볼 수 있다. 이 글의 공식은 $1.0 - 2 \cdot \min(|t-p|/t, 1.0)$이다.
- MCTS 탐색 다양성이 부족하다면 virtual loss를 도입한 parallel MCTS를 적용할 수 있다. 여러 에이전트가 같은 탐색 트리를 공유하되 서로 다른 노드를 방문하도록 유도하면 같은 컴퓨팅 예산으로 더 넓은 탐색이 가능하다.
- pUCT 선택 확률 계산 시 raw 누적 시퀀스 확률 대신 로그확률 합산 후 softmax를 적용하면 수치적 언더플로 문제를 피할 수 있다. 특히 긴 시퀀스를 다루는 경우 이 처리가 학습 안정성에 중요하다.
Terminology
관련 논문
PyTorch Lightning AI 학습 라이브러리에서 Shai-Hulud 테마 악성코드 발견
널리 쓰이는 딥러닝 프레임워크 PyTorch Lightning의 PyPI 패키지 버전 2.6.2와 2.6.3이 공급망 공격으로 침해되어, import 시 자격증명 탈취 악성코드가 실행된다.
Alignment Whack-a-Mole: 파인튜닝이 LLM 내부의 저작권 도서 암기를 활성화한다
안전 정렬(alignment)된 LLM도 파인튜닝을 거치면 억제됐던 저작권 책 내용을 그대로 출력하게 된다는 연구로, LLM의 저작권 침해 위험이 단순히 프롬프트 필터링으로는 해결되지 않음을 보여준다.
MacMind – 1989년 Macintosh의 HyperCard로 구현한 Transformer 신경망
HyperTalk으로 1,216개 파라미터짜리 단일 레이어 Transformer를 Macintosh SE/30에서 학습시켜 현대 LLM의 핵심 수학이 30년 전 하드웨어에서도 동일하게 동작함을 증명했다.
MegaTrain: 단일 GPU로 100B+ 파라미터 LLM을 Full Precision으로 학습하기
MegaTrain은 CPU 메모리를 주 저장소로, GPU를 연산 엔진으로만 활용함으로써 H200 GPU 단 한 장으로 120B 파라미터 모델을 풀 정밀도로 학습할 수 있다.
9M 파라미터짜리 초소형 LLM으로 언어 모델 작동 원리 직접 이해하기
물고기 Guppy를 학습한 870만 파라미터 미니 LLM이 Colab 노트북 하나로 5분 만에 처음부터 구현되어, LLM의 블랙박스 이미지를 완전히 걷어낸다.
Nanocode: $200로 TPU에서 JAX로 구현하는 나만의 Claude Code 학습 라이브러리
이 오픈소스 라이브러리는 Constitutional AI 방식으로 $200 TPU에서 1.3B 파라미터 규모의 coding agent 모델을 처음부터 학습하게 하며 개발자가 AI 학습 파이프라인 전체를 직접 이해하고 실습할 수 있는 환경을 제공한다.