Search papers, labs, and topics across Lattice.
This paper addresses the incompatibility between Query-Key (QK) normalization and Multi-head Latent Attention (MLA) by demonstrating that the issue arises from implementation rather than architectural constraints. The authors show that RMSNorm can be adapted to MLA by absorbing static weights into the query-side projection and simplifying the dynamic statistics, allowing for effective QK normalization without the need for full key caching. Their experiments reveal that QK-Normed MLA outperforms QK clipping in terms of training loss and downstream accuracy, while maintaining low latency in decoding tasks.
QK normalization can be effectively integrated into MLA without the overhead of full key caching, leading to improved performance and efficiency.
Query-key (QK) normalization stabilizes attention by controlling the scale of queries and keys before the dot product, but is not immediately compatible with Multi-head Latent Attention (MLA). MLA achieves efficient decoding by caching low-dimensional latent states instead of full keys, whereas post-projection QK RMSNorm appears to require the fully projected key for every cached token. We show this apparent incompatibility is an implementation artifact, not an architectural constraint. RMSNorm decomposes into a static affine weight and a dynamic scalar RMS statistic. The static key-side weight can be absorbed into the MLA query-side projection; the dynamic key statistic reduces to one inverse-RMS scalar per token and KV group. The resulting formulation is exactly equivalent to explicit post-projection QK RMSNorm in exact arithmetic and preserves MLA's latent decode path. In our 400M runs trained for up to 100B tokens, QK-Normed MLA achieves lower training loss and better downstream accuracy than QK clipping, while H800 decode benchmarks show less than 2% latency overhead up to 256k context. These results make QK normalization a practical stabilization option for MLA models without requiring full-key caching.