Search papers, labs, and topics across Lattice.
This paper introduces AdaSplash-2, a fast differentiable sparse attention mechanism based on $\alpha$-entmax that addresses the computational bottleneck of normalizing $\tau$. AdaSplash-2 uses a novel histogram-based initialization computed on-the-fly and stored in on-chip SRAM to significantly reduce the iterations needed to compute $\tau$. Experiments show AdaSplash-2 matches or improves per-step training time compared to FlashAttention-2 with moderate-to-high sparsity and achieves performance gains in long-context tasks.
By cleverly initializing sparse attention with on-chip histograms, AdaSplash-2 achieves comparable or better training speed than FlashAttention-2 at moderate-to-high sparsity, unlocking the potential of $\alpha$-entmax for long-context transformers.
Sparse attention has been proposed as a way to alleviate the quadratic cost of transformers, a central bottleneck in long-context training. A promising line of work is $\alpha$-entmax attention, a differentiable sparse alternative to softmax that enables input-dependent sparsity yet has lagged behind softmax due to the computational overhead necessary to compute the normalizer $\tau$. In this paper, we introduce AdaSplash-2, which addresses this limitation through a novel histogram-based initialization that reduces the number of iterations needed to compute $\tau$ to typically 1--2. The key idea is to compute a coarse histogram of attention scores on the fly and store it in on-chip SRAM, yielding a more accurate initialization that enables fast forward and backward computation. Combined with a sparsity-aware GPU implementation that skips zero blocks with low overhead, AdaSplash-2 matches or improves per-step training time relative to FlashAttention-2 when block sparsity is moderate-to-high (e.g., $>$60\%), which often occurs at long-context lengths. On downstream tasks, models trained with our efficient $\alpha$-entmax attention match softmax baselines at short-context lengths and achieve substantial gains in long-context settings.