CITER: Collaborative Inference for Efficient Large Language Model Decoding with Token-Level Routing
TL;DR Highlight
A framework that routes tokens individually — important ones to large models, the rest to small models — reducing LLM inference costs by up to 30%.
Who Should Read
ML engineers looking to combine small and large models to reduce LLM inference costs. Developers wanting to cut GPU costs while maintaining response quality in real-time services.
Core Mechanics
- Trains an MLP router that decides per-token: can a small model (SLM) handle this, or does it need a large model (LLM)?
- Formalized router training as RL to reflect long-term impact on future tokens, not just the current one
- A 'shortcut reward function' comparing which model generated the correct token enables fast training data collection without full response generation (handles 80-90% of tokens)
- Validated on both Qwen2-1.5B + Qwen2-72B and Llama3.1-8B + Llama3.1-70B combos
- Token-level routing is far more flexible than query-level routing (RouteLLM) — handles easy tokens even within complex queries
- Maintains separate KV caches for SLM and LLM — reuses without recomputation when switching models
Evidence
- vs Co-LLM (existing token routing): up to 27% inference cost reduction at same accuracy, or up to 17% accuracy improvement at same cost
- vs RouteLLM (query-level routing): up to 30% inference cost reduction at same accuracy, or up to 12% accuracy improvement at same cost
- vs CITER-S (ignoring long-term impact): up to 42% cost reduction or 23% accuracy improvement
- Llama3.1 (8B SLM, 70B LLM): up to 32% inference cost reduction vs Co-LLM confirmed
How to Apply
- Prepare a large/small model pair (e.g., Qwen2 or Llama3.1 series), then train a 3-layer MLP router using SLM's last hidden state as input.
- For router training data: label as 'SLM preferred' when SLM gets the correct token, 'LLM preferred' when SLM fails but LLM succeeds — generates ~90% of token labels without full response generation.
- Adjust the tau (threshold) hyperparameter to tune the accuracy-cost tradeoff. Lower tau reduces cost, higher tau improves quality.
Code Example
# CITER inference loop pseudocode (based on Algorithm 2)
# router: trained MLP router
# slm: small language model (e.g., Qwen2-1.5B)
# llm: large language model (e.g., Qwen2-72B)
# tau: routing threshold (0~1, default 0.5)
def citer_generate(prompt, slm, llm, router, tau=0.5):
context = prompt
generated = []
# Maintain separate KV caches
slm_kv_cache = None
llm_kv_cache = None
while True:
# Use SLM's hidden state as router input
slm_hidden = slm.get_hidden_state(context, kv_cache=slm_kv_cache)
# Router predicts probability of choosing SLM
p_slm = router(slm_hidden) # probability between 0~1
if p_slm >= tau:
# Non-critical token → generate with SLM (low cost)
next_token, slm_kv_cache = slm.generate_token(context, kv_cache=slm_kv_cache)
else:
# Critical token → generate with LLM (high quality)
next_token, llm_kv_cache = llm.generate_token(context, kv_cache=llm_kv_cache)
generated.append(next_token)
context = context + [next_token]
if next_token == EOS_TOKEN:
break
return generated
# Collect router training data (shortcut method)
def collect_routing_preference(prompt, ground_truth, slm, llm):
preferences = []
for h, true_token in enumerate(ground_truth):
state = prompt + ground_truth[:h]
slm_pred = slm.predict_next(state)
if slm_pred == true_token:
# Case 1: SLM is correct → prefer SLM
preferences.append((state, 1)) # p=1: prefer SLM
elif llm.predict_next(state) == true_token:
# Case 2: LLM is correct → prefer LLM
preferences.append((state, 0)) # p=0: prefer LLM
else:
# Case 3: both wrong → determine via actual rollout
rollout = citer_generate(state + [slm_pred], slm, llm, router, tau=0.5)
is_correct = evaluate(rollout, ground_truth)
preferences.append((state, 1 if is_correct else 0))
return preferencesTerminology
Original Abstract (Expand)
Large language models have achieved remarkable success in various tasks but suffer from high computational costs during inference, limiting their deployment in resource-constrained applications. To address this issue, we propose a novel Collaborative Inference with Token-lEvel Routing (CITER) framework that enables efficient collaboration between small and large language models (SLMs \&LLMs) through a token-level routing strategy. Specifically, CITER routes non-critical tokens to an SLM for efficiency and routes critical tokens to an LLM for generalization quality. We formulate router training as a policy optimization, where the router receives rewards based on both the quality of predictions and the inference costs of generation. This allows the router to learn to predict token-level routing scores and make routing decisions based on both the current token and the future impact of its decisions. To further accelerate the reward evaluation process, we introduce a shortcut which significantly reduces the costs of the reward estimation and improving the practicality of our approach. Extensive experiments on five benchmark datasets demonstrate that CITER reduces the inference costs while preserving high-quality generation, offering a promising solution for real-time and resource-constrained applications. Our data and code are available at https://github.com/aiming-lab/CITER.