FastAV: Audio-Visual Large Language Model 추론을 위한 효율적인 Token Pruning
FastAV: Efficient Token Pruning for Audio-Visual Large Language Model Inference
TL;DR Highlight
영상+오디오 멀티모달 LLM에서 추가 학습 없이 연산량을 40% 이상 줄이면서 성능은 유지하거나 오히려 올리는 토큰 가지치기 프레임워크.
Who Should Read
VideoLLaMA2나 video-SALMONN2 같은 오디오-비주얼 LLM을 프로덕션에 배포하면서 추론 비용과 지연 시간을 줄이고 싶은 ML 엔지니어. 멀티모달 모델의 GPU 메모리·속도 최적화를 고민하는 인프라/서빙 개발자에게도 유용.
Core Mechanics
- VideoLLaMA2와 video-SALMONN2 두 모델에서 레이어 중간(14번째)을 지나면 attention이 앞쪽 토큰에 집중되는 패턴이 생김 — 뒤쪽 토큰은 사실상 불필요
- 2단계 pruning 전략: 중간 레이어에서 attention rollout 기반 global pruning → 이후 레이어에서 마지막 query 토큰 기준 fine pruning
- FlashAttention과 호환 — 전체 attention map이 필요 없고 마지막 query만 씀
- 오디오 토큰이 1,496개에서 10개로 99% 이상 줄어도 성능 유지 (오히려 AV matching +11%↑)
- raw attention weight만 보면 중요 토큰 패턴이 안 보이지만, attention rollout(레이어 누적 집계)을 쓰면 명확히 드러남
- 추가 학습(fine-tuning) 없이 inference-time에만 적용 가능
Evidence
- VideoLLaMA2: FLOPs 100→56 (44% 감소), 추론 latency 0.43→0.32초, 메모리 22G→19G, AV matching 57.8→69.0 (+11.2%p)
- video-SALMONN2: FLOPs 100→58 (42% 감소), latency 0.44→0.29초, 메모리 28G→21G, AVQA 57.6→58.4
- global pruning 비교실험: random(69.0%), low attentive(70.5%), FastAV rollout 기반(74.5%) — 동일 FLOPs에서 최고 성능
- fine pruning 비율 P=20%일 때 FLOPs 56으로 평균 accuracy 74.9%로 최고 (P=30%보다도 높음)
How to Apply
- VideoLLaMA2 같은 AV-LLM 서빙 시 중간 레이어(전체 레이어 수의 절반)에서 attention rollout을 계산해 position 750 이후 토큰을 일괄 제거하면 즉시 메모리·연산 절감 가능
- 그 이후 레이어에서는 마지막 query 토큰의 attention score 하위 20%를 매 레이어마다 제거 — 전체 attention map 불필요해서 FlashAttention 환경에도 바로 적용
- 오디오 토큰이 많은 모델(1,000개 이상)이라면 앞쪽 10~20개만 남기는 aggressive pruning도 고려 — 논문 결과상 성능 손실 없거나 오히려 개선됨
Code Example
# FastAV 핵심 로직 의사코드 (PyTorch 기반)
import torch
import torch.nn.functional as F
def attention_rollout(attention_matrices, alpha=0.5):
"""
attention_matrices: list of [batch, heads, seq, seq] tensors (layer별)
"""
rollout = None
for attn in attention_matrices:
# 헤드 평균
attn_mean = attn.mean(dim=1) # [batch, seq, seq]
# residual connection 반영
identity = torch.eye(attn_mean.size(-1), device=attn.device).unsqueeze(0)
attn_mod = alpha * attn_mean + (1 - alpha) * identity
rollout = attn_mod if rollout is None else torch.bmm(attn_mod, rollout)
return rollout # [batch, seq, seq]
def global_pruning(token_indices, rollout, keep_position=750):
"""중간 레이어에서 position 750 이후 low-informative 토큰 제거"""
# rollout 기준으로 중요도 낮은 뒤쪽 토큰 제거
importance = rollout.sum(dim=1) # [batch, seq]
mask = (token_indices < keep_position) | (importance > importance.median())
return mask
def fine_pruning(query_last, keys, prune_ratio=0.2):
"""
query_last: [batch, heads, 1, dim] — 마지막 query 토큰
keys: [batch, heads, seq, dim]
"""
scores = torch.softmax(
torch.matmul(query_last, keys.transpose(-2, -1)) / (keys.size(-1) ** 0.5),
dim=-1
).mean(dim=1).squeeze(1) # [batch, seq]
k = int(scores.size(-1) * (1 - prune_ratio))
_, keep_indices = scores.topk(k, dim=-1)
return keep_indicesTerminology
Related Resources
Original Abstract (Expand)
In this work, we present FastAV, the first token pruning framework tailored for audio-visual large language models (AV-LLMs). While token pruning has been actively explored in standard large language models (LLMs) and vision-language models (LVLMs), its application to AV-LLMs has received little attention, even though multimodal integration substantially increases their token demands. To address this gap, we introduce a pruning strategy that utilizes attention weights to identify tokens emphasized at different stages and estimates their importance. Building on this analysis, FastAV applies a two-stage pruning strategy: (1) global pruning in intermediate layers to remove broadly less influential tokens, and (2) fine pruning in later layers considering the impact on next token generation. Notably, our method does not rely on full attention maps, which makes it fully compatible with efficient attention mechanisms such as FlashAttention. Extensive experiments demonstrate that FastAV reduces FLOPs by more than 40% on two representative AV-LLMs, while preserving or even improving model performance.