DivPrune: Diversity-based Visual Token Pruning for Large Multimodal Models
TL;DR Highlight
A plug-and-play inference optimization technique that removes up to 90% of visual tokens in image/video multimodal models while barely losing performance.
Who Should Read
ML engineers serving LLaVA-style multimodal LLMs who want to reduce inference speed and GPU memory costs. Especially teams where image/video processing pipelines have latency bottlenecks.
Core Mechanics
- Existing attention score-based pruning (FastV, VTW) tends to select similar tokens together causing sharp performance drop at high compression
- DivPrune redefines pruning as MMDP (Max-Min Diversity Problem) to evenly preserve diverse visual information
- Plug-and-play — no fine-tuning or calibration data needed; just plug into existing models
- LLaVA 1.5-7B with 84% TFLOP reduction: POPE F1 actually improved by 0.18% vs original; COCO CIDEr loss only 12.7% (FastV/VTW have 95% loss under same conditions)
- Works equally for video (LLaVA-NeXT-Video-7B) — better results with larger visual contexts
- Compatible with existing inference optimizations like KV caching
Evidence
- LLaVA 1.6-7B at 89% TFLOP reduction: FastV/VTW POPE F1 -79% vs DivPrune -3.4%
- LLaVA-NeXT-Video-7B ActivityNet accuracy: FastV 33.91% vs DivPrune 45.90% (original 48.10%)
- Video model: E2E latency 22% reduction, prefill time 55% reduction, GPU memory 400MB savings
- LLaVA 1.5-13B POPE F1: DivPrune 83% better than VTW, 53.4% better than FastV, 15.2% better than PruMerge
How to Apply
- For LLaVA-family model serving: apply DivPrune to remove 90% of visual tokens before passing to LLM's first layer — 84% TFLOP savings. Apply directly with GitHub code.
- If using attention score-based FastV, consider switching to DivPrune. Performance gap grows dramatically especially at higher compression ratios.
- For video inputs with many tokens per frame (1152+), the effect is stronger — prioritize testing for video understanding services.
Code Example
import torch
import torch.nn.functional as F
def divprune(visual_tokens: torch.Tensor, keep_ratio: float = 0.1) -> torch.Tensor:
"""
DivPrune: Max-Min Diversity-based visual token selection
Args:
visual_tokens: visual token tensor of shape [M, D]
keep_ratio: ratio of tokens to keep (0.1 = keep 10%)
Returns:
selected tokens [M_kept, D]
"""
M, D = visual_tokens.shape
M_keep = max(1, int(M * keep_ratio))
# Pre-compute cosine distance matrix (1 - cosine_similarity)
normed = F.normalize(visual_tokens, dim=-1)
sim_matrix = normed @ normed.T # [M, M]
dist_matrix = 1.0 - sim_matrix # cosine distance
dist_matrix.fill_diagonal_(float('inf')) # exclude self
selected_idx = []
remaining = list(range(M))
# Stage 1: Select the first token (the one with the largest minimum distance to all other tokens)
min_dists = dist_matrix[remaining][:, remaining].min(dim=1).values
first = remaining[min_dists.argmax().item()]
selected_idx.append(first)
remaining.remove(first)
# Stage 2: Iteratively add the token with the largest minimum distance to the selected set
while len(selected_idx) < M_keep:
sel_tensor = torch.tensor(selected_idx)
rem_tensor = torch.tensor(remaining)
# For each remaining token → minimum distance to selected tokens
dists_to_sel = dist_matrix[rem_tensor][:, sel_tensor].min(dim=1).values
next_idx = remaining[dists_to_sel.argmax().item()]
selected_idx.append(next_idx)
remaining.remove(next_idx)
return visual_tokens[torch.tensor(selected_idx)]
# Usage example (LLaVA-style)
# visual_tokens: vision encoder output [576, 4096]
# pruned = divprune(visual_tokens, keep_ratio=0.10) # 576 -> ~58 tokens
# llm_input = torch.cat([text_tokens, pruned], dim=0)Terminology
Related Resources
Original Abstract (Expand)
Large Multimodal Models (LMMs) have emerged as powerful models capable of understanding various data modalities, including text, images, and videos. LMMs encode both text and visual data into tokens that are then combined and processed by an integrated Large Language Model (LLM). Including visual tokens substantially increases the total token count, often by thousands. The increased input length for LLM significantly raises the complexity of inference, resulting in high latency in LMMs. To address this issue, token pruning methods, which remove part of the visual tokens, are proposed. The existing token pruning methods either require extensive calibration and fine-tuning or rely on suboptimal importance metrics which results in increased redundancy among the retained tokens. In this paper, we first formulate token pruning as Max-Min Diversity Problem (MMDP) where the goal is to select a subset such that the diversity among the selected tokens is maximized. Then, we solve the MMDP to obtain the selected subset and prune the rest. The proposed method, DivPrune, reduces redundancy and achieves the highest diversity of the selected tokens. By ensuring high diversity, the selected tokens better represent the original tokens, enabling effective performance even at high pruning ratios without requiring fine-tuning. Extensive experiments with various LMMs show that DivPrune achieves state-of-the-art accuracy over 16 image- and video-language datasets. Additionally, DivPrune reduces both the end-to-end latency and GPU memory usage for the tested models. The code is available here⋄.