효율적인 Video VLM을 위한 통합 Spatio-Temporal Token Scoring (STTS)
Unified Spatio-Temporal Token Scoring for Efficient Video VLMs
TL;DR Highlight
비디오 AI 모델에서 시각 토큰을 50% 줄이면서 성능 손실은 0.7%에 불과한 경량 토큰 제거 모듈
Who Should Read
비디오 이해 AI 서비스를 개발하면서 추론 속도와 GPU 메모리 비용을 줄이고 싶은 ML 엔지니어. 특히 긴 영상 처리 파이프라인에서 VLM(Vision-Language Model) 병목을 해결하려는 개발자.
Core Mechanics
- ViT(이미지를 패치 조각으로 처리하는 인코더)와 LLM 양쪽 모두에서 토큰을 제거하는 최초의 통합 접근법 — 기존 방법들은 둘 중 하나만 건드림
- 3-Layer MLP + Self-Attention으로 구성된 초경량 scorer 모듈 하나만 추가하면 됨. 복잡한 text-conditioned 선택 메커니즘 불필요
- 공간(spatial) 중요도는 LLM downstream gradient로, 시간(temporal) 중복성은 인접 프레임 cosine similarity MSE loss로 동시에 학습
- 50% 토큰 제거 시 13개 video QA 벤치마크 평균 0.7% 성능 하락, 학습/추론 속도 1.62x 향상 (128프레임 기준)
- 256프레임 처리 시 동일 50% 제거로 2.25x 학습 / 2.22x 추론 속도 향상 — 프레임 수 늘수록 이득이 커짐
- Test-Time Scaling(추론 시 프레임 수 늘리기) 적용 시 기존 베이스라인 대비 long video QA에서 0.5~1.1% 성능 향상 달성
Evidence
- 50% 토큰 제거 기준: 13개 벤치마크 평균 62.3 (baseline 63.0) — 0.7% 하락, VideoMME는 0.4점 하락에 그침
- 256프레임 설정에서 50% pruning 시 학습 2.25x, 추론 2.22x 속도 향상 (128프레임에서는 1.62x/1.61x)
- Random pruning 대비 STTS는 50% 제거에서 평균 0.9% 우위 (62.3 vs 61.4), 80% 제거에서는 2.3% 우위 (59.8 vs 57.5)
- ToMe(기존 token merging 기법) 훈련 버전 대비 STTS가 Short/Long QA 모두에서 약 1점 이상 우위 (62.3 vs 61.1)
How to Apply
- Molmo2 같은 ViT+LLM 구조의 VLM에 STTS 모듈을 ViT 3번째 레이어 이후에 삽입하고, task loss + cosine similarity MSE auxiliary loss로 함께 파인튜닝하면 됨
- 긴 영상 추론 서비스라면 50% pruning 모델로 학습 후 추론 시 프레임 수를 2배로 늘리는 Test-Time Scaling을 적용하면 동일 연산량으로 더 높은 성능 달성 가능
- 프레임 수가 많을수록 속도 이득이 커지므로, 64프레임 이상 처리하는 파이프라인에서 먼저 적용하는 게 ROI가 높음 — 256프레임에서는 2배 이상 속도 향상
Code Example
# STTS Scorer 핵심 구조 (PyTorch 의사코드)
import torch
import torch.nn as nn
class STTSScorer(nn.Module):
def __init__(self, dim, pool_width=3):
super().__init__()
self.pool_width = pool_width
# Token Pooler: Self-Attention 기반
self.token_pooler = nn.MultiheadAttention(dim, num_heads=8, batch_first=True)
# 3-Layer MLP for scoring (입력: 현재+이전 프레임 concat -> 2*dim)
self.mlp = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.GELU(),
nn.Linear(dim, dim // 2),
nn.GELU(),
nn.Linear(dim // 2, 1),
nn.Sigmoid()
)
def forward(self, x_curr, x_prev):
# x_curr, x_prev: (B, N, D) — 현재/이전 프레임 features
# 1. Spatial pooling (w x w)
x_curr_pooled = self.spatial_pool(x_curr) # (B, N/w^2, D)
x_prev_pooled = self.spatial_pool(x_prev) # (B, N/w^2, D)
# 2. Concat neighboring frames for temporal context
x_concat = torch.cat([x_curr_pooled, x_prev_pooled], dim=-1) # (B, N/w^2, 2D)
# 3. Score each pooled patch
scores = self.mlp(x_concat).squeeze(-1) # (B, N/w^2)
# 4. Expand scores back to original resolution
scores_expanded = scores.repeat_interleave(self.pool_width**2, dim=1) # (B, N)
return scores_expanded
# Auxiliary Temporal Loss
def temporal_aux_loss(scores, features_l, t, pool_width=3):
"""인접 프레임 cosine similarity와 score를 정렬하는 MSE loss"""
if t == 0: # 첫 프레임은 항상 유지
return torch.tensor(0.0)
feat_curr = spatial_pool(features_l[t], pool_width) # (B, N/w^2, D)
feat_prev = spatial_pool(features_l[t-1], pool_width) # (B, N/w^2, D)
# 패치별 cosine similarity 계산
feat_curr_norm = F.normalize(feat_curr, dim=-1)
feat_prev_norm = F.normalize(feat_prev, dim=-1)
cos_sim = (feat_curr_norm * feat_prev_norm).sum(dim=-1) # (B, N/w^2)
# 높은 similarity = 중복 = 낮은 중요도 -> 1 - cos_sim을 target으로
target = 1 - cos_sim
loss = F.mse_loss(scores, target)
return loss
# Token Pruning: bottom-k% 제거
def prune_tokens(x, scores, k_percent):
B, N, D = x.shape
keep_n = int(N * (1 - k_percent / 100))
_, top_indices = scores.topk(keep_n, dim=1) # 중요한 토큰 유지
top_indices_sorted = top_indices.sort(dim=1).values
pruned = torch.gather(x, 1, top_indices_sorted.unsqueeze(-1).expand(-1, -1, D))
return pruned # (B, keep_n, D)Terminology
Related Resources
Original Abstract (Expand)
Token pruning is essential for enhancing the computational efficiency of vision-language models (VLMs), particularly for video-based tasks where temporal redundancy is prevalent. Prior approaches typically prune tokens either (1) within the vision transformer (ViT) exclusively for unimodal perception tasks such as action recognition and object segmentation, without adapting to downstream vision-language tasks; or (2) only within the LLM while leaving the ViT output intact, often requiring complex text-conditioned token selection mechanisms. In this paper, we introduce Spatio-Temporal Token Scoring (STTS), a simple and lightweight module that prunes vision tokens across both the ViT and the LLM without text conditioning or token merging, and is fully compatible with end-to-end training. By learning how to score temporally via an auxiliary loss and spatially via LLM downstream gradients, aided by our efficient packing algorithm, STTS prunes 50% of vision tokens throughout the entire architecture, resulting in a 62% improvement in efficiency during both training and inference with only a 0.7% drop in average performance across 13 short and long video QA tasks. Efficiency gains increase with more sampled frames per video. Applying test-time scaling for long-video QA further yields performance gains of 0.5-1% compared to the baseline. Overall, STTS represents a novel, simple yet effective technique for unified, architecture-wide vision token pruning.