Search papers, labs, and topics across Lattice.
FlashHead is introduced as a training-free, hardware-friendly drop-in replacement for the dense classification head in language models, which often constitutes a significant bottleneck in inference. It reframes the classification task as a retrieval problem, using balanced clustering, multiprobe retrieval, a novel sampling mechanism, and selective quantization to enhance efficiency. Experiments on Llama-3.2, Gemma-3, and Qwen-3 demonstrate up to 1.75x model-level inference speedups while maintaining accuracy.
Achieve up to 1.75x faster language model inference by swapping the standard classification head with FlashHead, a training-free retrieval-based alternative.
Language models are increasingly adopting smaller architectures optimized for consumer devices. In this setting, inference efficiency is the primary constraint. Meanwhile, vocabulary sizes continue to grow rapidly, making the classification head a critical bottleneck that accounts for up to 60\% of model parameters, and 50\% of inference compute. We introduce FlashHead, the first efficient drop-in replacement for the dense classification head that is training-free and hardware-friendly. FlashHead builds on principles from information retrieval, reframing that computation at the output head as a retrieval problem rather than a dense classification over the full vocabulary. FlashHead introduces four key innovations: (1) a balanced clustering scheme that structures vocabulary partitions into compact hardware-efficient tensors, (2) extending multiprobe retrieval to language model heads, enabling thousands of clusters to be scored in parallel, (3) a novel inference-time sampling mechanism that extends retrieval beyond top tokens, enabling probabilistic sampling across the full vocabulary, and (4) selective quantization, enabling effective low-bit computation in the head. Experiments on Llama-3.2, Gemma-3, and Qwen-3 show that FlashHead delivers model-level inference speedups of up to \textbf{1.75x} which maintaining output accuracy compared to the original head. By overcoming the classification head bottleneck, FlashHead establishes a new benchmark for efficient inference and removes a key barrier to developing smaller, capable models for consumer hardware.