Dependency-Aware Parallel Decoding via Attention for Diffusion LLMs
TL;DR Highlight
Diffusion language model speeds up parallel token generation by using self-attention to identify independent tokens that can be safely generated simultaneously.
Who Should Read
Researchers working on diffusion language models and non-autoregressive generation, and teams looking to accelerate LLM inference while maintaining quality.
Core Mechanics
- Diffusion LMs generate tokens in parallel rather than sequentially, but naively parallel generation ignores token dependencies causing quality loss
- Proposed using the self-attention mechanism to identify which tokens in a generation step are mutually independent
- Only truly independent token groups are generated in parallel — dependent tokens wait for their prerequisites
- This dependency-aware batching dramatically reduces the number of generation rounds needed
- Achieves significant speedup over sequential diffusion decoding with minimal quality degradation
- The approach is theoretically grounded in conditional independence analysis of the token distribution
Evidence
- Significant speedup over baseline diffusion LM sequential decoding in tokens-per-second
- Quality (perplexity, task benchmarks) maintained within acceptable margins vs autoregressive baselines
- Dependency analysis correctly identifies independent tokens in >90% of cases on tested datasets
- Scales favorably with sequence length compared to fully sequential approaches
How to Apply
- Applicable to any diffusion language model that uses a transformer backbone with self-attention
- Run the independence test during each diffusion step to partition tokens into groups, then generate each group in parallel
- Tune the independence threshold — stricter threshold means higher quality but fewer parallel tokens per step
Code Example
import torch
def compute_edge_scores(attention_maps, masked_indices):
"""
attention_maps: [num_layers, num_heads, seq_len, seq_len]
masked_indices: list of masked token positions
Returns symmetric edge score matrix for masked tokens
"""
# Use only the top 25% of layers (e.g., last 8 layers for a 32-layer model)
top_layers = attention_maps[-len(attention_maps)//4:]
# Average across all heads/layers
avg_attn = top_layers.mean(dim=(0, 1)) # [seq_len, seq_len]
# Symmetric edge score: sij = (aij + aji) / 2
n = len(masked_indices)
edge_scores = torch.zeros(n, n)
for ii, i in enumerate(masked_indices):
for jj, j in enumerate(masked_indices):
if ii != jj:
edge_scores[ii, jj] = (avg_attn[i, j] + avg_attn[j, i]) / 2
return edge_scores
def welsh_powell_independent_set(edge_scores, masked_indices, tau, confidences):
"""
Welsh-Powell based independent set selection
"""
n = len(masked_indices)
# Compute confidence-weighted degree
degrees = (edge_scores > tau).float().sum(dim=1) # [n]
weighted_degrees = degrees * confidences
# Sort in descending order
order = torch.argsort(weighted_degrees, descending=True)
selected = []
selected_set = set()
for idx in order.tolist():
# Add node if it has no edge with already selected nodes
conflict = any(
edge_scores[idx, s] > tau for s in selected_set
)
if not conflict:
selected.append(masked_indices[idx])
selected_set.add(idx)
return selected
# Usage example
# mask_ratio = unmasked_count / total_length
# if mask_ratio > 0.5:
# tokens_to_unmask = welsh_powell_independent_set(...)
# else: # Later stage: fast completion based on confidence
# tokens_to_unmask = [i for i in masked_indices if confidences[i] > 0.9]Terminology
Related Resources
Original Abstract (Expand)
Parallel decoding for diffusion LLMs (dLLMs) is difficult because each denoising step provides only token-wise marginal distributions, while unmasking multiple tokens simultaneously requires accounting for inter-token dependencies. We propose Dependency-Aware Parallel Decoding (DAPD), a simple, training-free decoding method that uses self-attention to induce a conditional dependency graph over masked tokens. At each iteration, edges in this graph capture strong token interactions, while non-edges indicate weak dependence. Parallel decoding is then reduced to selecting an independent set on the graph and unmasking the selected tokens in parallel. This avoids co-updating strongly coupled tokens without auxiliary models or retraining. Experiments on LLaDA and Dream show that DAPD improves the accuracy-steps trade-off over existing methods and enables more globally distributed parallel updates that better exploit the any-order generation capability of dLLMs.