Unified Spatio-Temporal Token Scoring for Efficient Video VLMs
TL;DR Highlight
A lightweight token pruning module that cuts 50% of visual tokens in video AI models with only 0.7% performance loss
Who Should Read
ML engineers building video understanding AI services who want to reduce inference speed and GPU memory costs. Especially useful for developers tackling VLM (Vision-Language Model) bottlenecks in long video processing pipelines.
Core Mechanics
- First unified approach that prunes tokens from both ViT (the image encoder that processes patches) and the LLM — existing methods only touch one or the other
- Just add one ultra-lightweight scorer module (3-Layer MLP + Self-Attention). No need for complex text-conditioned selection mechanisms
- Jointly learns spatial importance via LLM downstream gradient and temporal redundancy via adjacent-frame cosine similarity MSE loss
- At 50% token removal: only 0.7% average performance drop across 13 video QA benchmarks, with 1.62x training/inference speedup (at 128 frames)
- At 256 frames with 50% pruning: 2.25x training / 2.22x inference speedup — gains increase with more frames
- With Test-Time Scaling (increasing frame count at inference), achieves 0.5-1.1% performance improvement over baselines on long video QA
Evidence
- At 50% token removal: 13-benchmark average 62.3 (baseline 63.0) — 0.7% drop, VideoMME dropped by only 0.4 points
- 256-frame setting with 50% pruning: 2.25x training, 2.22x inference speedup (128 frames: 1.62x/1.61x)
- STTS outperforms random pruning by 0.9% at 50% removal (62.3 vs 61.4) and 2.3% at 80% removal (59.8 vs 57.5)
- STTS outperforms ToMe (existing token merging) trained version by ~1+ points on both Short/Long QA (62.3 vs 61.1)
How to Apply
- Insert the STTS module after the 3rd ViT layer in ViT+LLM architecture VLMs like Molmo2, then fine-tune with task loss + cosine similarity MSE auxiliary loss
- For long video inference services, train with 50% pruning then apply Test-Time Scaling (2x frame count at inference) to achieve higher performance at the same compute cost
- Speed gains increase with frame count, so prioritize adoption in pipelines processing 64+ frames for best ROI — at 256 frames you get 2x+ speedup
Code Example
# STTS Scorer core structure (PyTorch pseudocode)
import torch
import torch.nn as nn
class STTSScorer(nn.Module):
def __init__(self, dim, pool_width=3):
super().__init__()
self.pool_width = pool_width
# Token Pooler: Self-Attention based
self.token_pooler = nn.MultiheadAttention(dim, num_heads=8, batch_first=True)
# 3-Layer MLP for scoring (input: current + previous frame concat -> 2*dim)
self.mlp = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.GELU(),
nn.Linear(dim, dim // 2),
nn.GELU(),
nn.Linear(dim // 2, 1),
nn.Sigmoid()
)
def forward(self, x_curr, x_prev):
# x_curr, x_prev: (B, N, D) — current/previous frame features
# 1. Spatial pooling (w x w)
x_curr_pooled = self.spatial_pool(x_curr) # (B, N/w^2, D)
x_prev_pooled = self.spatial_pool(x_prev) # (B, N/w^2, D)
# 2. Concat neighboring frames for temporal context
x_concat = torch.cat([x_curr_pooled, x_prev_pooled], dim=-1) # (B, N/w^2, 2D)
# 3. Score each pooled patch
scores = self.mlp(x_concat).squeeze(-1) # (B, N/w^2)
# 4. Expand scores back to original resolution
scores_expanded = scores.repeat_interleave(self.pool_width**2, dim=1) # (B, N)
return scores_expanded
# Auxiliary Temporal Loss
def temporal_aux_loss(scores, features_l, t, pool_width=3):
"""MSE loss that aligns scores with adjacent frame cosine similarity"""
if t == 0: # Always keep the first frame
return torch.tensor(0.0)
feat_curr = spatial_pool(features_l[t], pool_width) # (B, N/w^2, D)
feat_prev = spatial_pool(features_l[t-1], pool_width) # (B, N/w^2, D)
# Compute per-patch cosine similarity
feat_curr_norm = F.normalize(feat_curr, dim=-1)
feat_prev_norm = F.normalize(feat_prev, dim=-1)
cos_sim = (feat_curr_norm * feat_prev_norm).sum(dim=-1) # (B, N/w^2)
# High similarity = redundant = low importance -> use 1 - cos_sim as target
target = 1 - cos_sim
loss = F.mse_loss(scores, target)
return loss
# Token Pruning: remove bottom-k%
def prune_tokens(x, scores, k_percent):
B, N, D = x.shape
keep_n = int(N * (1 - k_percent / 100))
_, top_indices = scores.topk(keep_n, dim=1) # Keep important tokens
top_indices_sorted = top_indices.sort(dim=1).values
pruned = torch.gather(x, 1, top_indices_sorted.unsqueeze(-1).expand(-1, -1, D))
return pruned # (B, keep_n, D)Terminology
Related Resources
Original Abstract (Expand)
Token pruning is essential for enhancing the computational efficiency of vision-language models (VLMs), particularly for video-based tasks where temporal redundancy is prevalent. Prior approaches typically prune tokens either (1) within the vision transformer (ViT) exclusively for unimodal perception tasks such as action recognition and object segmentation, without adapting to downstream vision-language tasks; or (2) only within the LLM while leaving the ViT output intact, often requiring complex text-conditioned token selection mechanisms. In this paper, we introduce Spatio-Temporal Token Scoring (STTS), a simple and lightweight module that prunes vision tokens across both the ViT and the LLM without text conditioning or token merging, and is fully compatible with end-to-end training. By learning how to score temporally via an auxiliary loss and spatially via LLM downstream gradients, aided by our efficient packing algorithm, STTS prunes 50% of vision tokens throughout the entire architecture, resulting in a 62% improvement in efficiency during both training and inference with only a 0.7% drop in average performance across 13 short and long video QA tasks. Efficiency gains increase with more sampled frames per video. Applying test-time scaling for long-video QA further yields performance gains of 0.5-1% compared to the baseline. Overall, STTS represents a novel, simple yet effective technique for unified, architecture-wide vision token pruning.