CITER: Token 단위 라우팅으로 대형/소형 언어모델을 협력시켜 추론 비용 절감
CITER: Collaborative Inference for Efficient Large Language Model Decoding with Token-Level Routing
TL;DR Highlight
중요한 토큰만 큰 모델에게, 나머지는 작은 모델에게 맡기는 토큰 단위 라우터로 LLM 추론 비용을 최대 30% 줄이는 프레임워크.
Who Should Read
LLM 추론 비용이 부담스러워 소형 모델과 대형 모델을 조합하는 방법을 고민하는 ML 엔지니어. 특히 실시간 서비스에서 응답 품질을 유지하면서 GPU 비용을 줄이려는 개발자.
Core Mechanics
- 응답 생성 시 토큰마다 '이건 작은 모델(SLM)이 해도 되는가, 큰 모델(LLM)이 필요한가'를 판단하는 MLP 라우터를 훈련
- 라우터 학습을 강화학습(RL)으로 공식화해서 현재 토큰뿐 아니라 미래 토큰에 미치는 장기적 영향까지 반영
- SLM과 LLM 중 어느 쪽이 맞는 토큰을 생성했는지 비교하는 '숏컷 보상 함수'로 전체 응답 생성 없이 빠르게 학습 데이터 수집 가능 (80~90% 토큰 처리 가능)
- SLM(Qwen2-1.5B)과 LLM(Qwen2-72B) 조합, Llama3.1-8B와 Llama3.1-70B 조합 모두에서 검증
- 기존 쿼리 단위 라우팅(RouteLLM)보다 토큰 단위 라우팅이 훨씬 유연해서 복잡한 쿼리 안의 쉬운 토큰도 효율적으로 처리
- KV cache를 SLM/LLM 각각 따로 유지해서 모델 전환 시 재계산 없이 재사용
Evidence
- 기존 토큰 라우팅 방법인 Co-LLM 대비 동일 정확도에서 최대 27% 추론 비용 절감, 또는 동일 비용에서 최대 17% 정확도 향상
- 쿼리 단위 라우팅 RouteLLM 대비 동일 정확도에서 최대 30% 추론 비용 절감, 또는 동일 비용에서 최대 12% 정확도 향상
- 장기 영향을 무시한 변형 모델 CITER-S 대비 최대 42% 추론 비용 절감 또는 최대 23% 정확도 향상
- Llama3.1 시리즈(8B SLM, 70B LLM)에서도 Co-LLM 대비 최대 32% 추론 비용 절감 확인
How to Apply
- Qwen2나 Llama3.1 시리즈처럼 큰 모델과 작은 모델 쌍을 준비하고, SLM의 마지막 히든 스테이트를 입력으로 받는 3레이어 MLP 라우터를 훈련 데이터로 학습시키면 된다.
- 라우터 학습 데이터 수집 시 SLM이 정답 토큰을 맞추면 SLM 선호, 틀리고 LLM이 맞추면 LLM 선호로 레이블링하는 숏컷 전략을 쓰면 전체 응답 생성 없이 90%가량의 토큰 라벨을 빠르게 만들 수 있다.
- τ(임계값) 하이퍼파라미터를 조정하면 정확도-비용 트레이드오프를 서비스 요구사항에 맞게 튜닝할 수 있다. 비용을 더 줄이려면 τ를 낮추고, 품질을 높이려면 τ를 높이면 된다.
Code Example
# CITER 추론 루프 의사코드 (Algorithm 2 기반)
# router: 훈련된 MLP 라우터
# slm: 소형 언어모델 (예: Qwen2-1.5B)
# llm: 대형 언어모델 (예: Qwen2-72B)
# tau: 라우팅 임계값 (0~1, 기본 0.5)
def citer_generate(prompt, slm, llm, router, tau=0.5):
context = prompt
generated = []
# KV cache 별도 유지
slm_kv_cache = None
llm_kv_cache = None
while True:
# SLM의 히든 스테이트를 라우터 입력으로 사용
slm_hidden = slm.get_hidden_state(context, kv_cache=slm_kv_cache)
# 라우터가 SLM 선택 확률 예측
p_slm = router(slm_hidden) # 0~1 사이 확률
if p_slm >= tau:
# 비중요 토큰 → SLM으로 생성 (저비용)
next_token, slm_kv_cache = slm.generate_token(context, kv_cache=slm_kv_cache)
else:
# 중요 토큰 → LLM으로 생성 (고품질)
next_token, llm_kv_cache = llm.generate_token(context, kv_cache=llm_kv_cache)
generated.append(next_token)
context = context + [next_token]
if next_token == EOS_TOKEN:
break
return generated
# 라우터 학습 데이터 수집 (숏컷 방식)
def collect_routing_preference(prompt, ground_truth, slm, llm):
preferences = []
for h, true_token in enumerate(ground_truth):
state = prompt + ground_truth[:h]
slm_pred = slm.predict_next(state)
if slm_pred == true_token:
# Case 1: SLM이 맞춤 → SLM 선호
preferences.append((state, 1)) # p=1: SLM 선호
elif llm.predict_next(state) == true_token:
# Case 2: LLM이 맞춤 → LLM 선호
preferences.append((state, 0)) # p=0: LLM 선호
else:
# Case 3: 둘 다 틀림 → 실제 롤아웃으로 판단
rollout = citer_generate(state + [slm_pred], slm, llm, router, tau=0.5)
is_correct = evaluate(rollout, ground_truth)
preferences.append((state, 1 if is_correct else 0))
return preferencesTerminology
Original Abstract (Expand)
Large language models have achieved remarkable success in various tasks but suffer from high computational costs during inference, limiting their deployment in resource-constrained applications. To address this issue, we propose a novel Collaborative Inference with Token-lEvel Routing (CITER) framework that enables efficient collaboration between small and large language models (SLMs \&LLMs) through a token-level routing strategy. Specifically, CITER routes non-critical tokens to an SLM for efficiency and routes critical tokens to an LLM for generalization quality. We formulate router training as a policy optimization, where the router receives rewards based on both the quality of predictions and the inference costs of generation. This allows the router to learn to predict token-level routing scores and make routing decisions based on both the current token and the future impact of its decisions. To further accelerate the reward evaluation process, we introduce a shortcut which significantly reduces the costs of the reward estimation and improving the practicality of our approach. Extensive experiments on five benchmark datasets demonstrate that CITER reduces the inference costs while preserving high-quality generation, offering a promising solution for real-time and resource-constrained applications. Our data and code are available at https://github.com/aiming-lab/CITER.