Search papers, labs, and topics across Lattice.
×dc\tilde{\bm{\mathsfit{V}}}=\operatorname{Reshape}\left(\bm{C}^{\text{KV}},\,\left[n,\,1,\,d_{c}\right]\right)\in\mathbb{R}^{n\times 1\times d_{c}}. Under this formulation, decoding reduces to an MQA-style attention mechanism in which the attention logits (i.e., query–key inner products before softmax) are computed in a (dc+dhR)(d_{c}+d_{h}^{R})-dimensional space using these shared KV states. Incorporating the concatenated query 𝑸~n−1,:,:=Concat([𝑸~n−1,:,:NoPE,𝑸n−1,:,:RoPE],dim=1)∈ℝh×(dc+dhR)\tilde{\bm{\mathsfit{Q}}}_{n-1,:,:}=\operatorname{Concat}\left(\left[\tilde{\bm{\mathsfit{Q}}}_{n-1,:,:}^{\text{NoPE}},\,\bm{\mathsfit{Q}}_{n-1,:,:}^{\text{RoPE}}\right],\,\text{dim=1}\right)\in\mathbb{R}^{h\times(d_{c}+d_{h}^{R})}, the attention output is calculated as follows: 𝒁n−1,:,:=Attention(𝑸~n−1,:,:,RepeatInterleave(𝑲~,h,dim=1),RepeatInterleave(𝑽~,h,dim=1)),\bm{\mathsfit{Z}}_{n-1,:,:}=\operatorname{Attention}\left(\tilde{\bm{\mathsfit{Q}}}_{n-1,:,:},\,\operatorname{RepeatInterleave}\left(\tilde{\bm{\mathsfit{K}}},\,h,\,\text{dim}=1\right),\,\operatorname{RepeatInterleave}\left(\tilde{\bm{\mathsfit{V}}},\,h,\,\text{dim}=1\right)\right), (1) where 𝒁n−1,:,:∈ℝh×dc\bm{\mathsfit{Z}}_{n-1,:,:}\in\mathbb{R}^{h\times d_{c}}. FlashAttention-3 (Shah et al., 2024) and FlashMLA (Jiashi Li, 2025) provide highly optimized kernels designed to implement the Step-2 decoding computation directly. Step 3 (Output Up-Projection). Finally, the up-projection tensor maps the intermediate attention output to the final attention output: 𝑶n−1,:,:=einsum("hc,hcp->hp",𝒁n−1,:,:,𝑾~UV),c=dc,p=dh,𝑶n−1,:,:∈ℝh×dh.\bm{\mathsfit{O}}_{n-1,:,:}=\text{einsum}\left(\texttt{"hc,hcp->hp"},\,\bm{\mathsfit{Z}}_{n-1,:,:},\,\tilde{\bm{\mathsfit{W}}}^{\text{UV}}\right),\quad c=d_{c},\,p=d_{h},\quad\bm{\mathsfit{O}}_{n-1,:,:}\in\mathbb{R}^{h\times d_{h}}. Block Multiplications. For each head ii, we define the constituent sub-blocks 𝑾(b),(i)⋅∈ℝdh×dh\bm{W}_{(b),(i)}^{\cdot}\in\mathbb{R}^{d_{h}\times d_{h}} by partitioning the up-projection matrices into dhd_{h}-sized row blocks for b∈{0,1,2,3}b\in\{0,1,2,3\}: 𝑾(b),(i)UK:=𝑾UK[bdh:(b+1)dh,idh:(i+1)dh],𝑾(b),(i)UV:=𝑾UV[bdh:(b+1)dh,idh:(i+1)dh].\begin{split}\bm{W}_{(b),(i)}^{\text{UK}}&:=\bm{W}^{\text{UK}}\left[bd_{h}:(b+1)d_{h},\,id_{h}:(i+1)d_{h}\right],\\ \bm{W}_{(b),(i)}^{\text{UV}}&:=\bm{W}^{\text{UV}}\left[bd_{h}:(b+1)d_{h},\,id_{h}:(i+1)d_{h}\right].\end{split} Consequently, each head-specific up-projection matrix can be expressed as a vertical stack of these four row-blocks: 𝑾:,(i)UK=[𝑾(0),(i)UK𝑾(1),(i)UK𝑾(2),(i)UK𝑾(3),(i)UK],𝑾:,(i)UV=[𝑾(0),(i)UV𝑾(1),(i)UV𝑾(2),(i)UV𝑾(3),(i)UV].\bm{W}_{:,(i)}^{\text{UK}}=\begin{bmatrix}\bm{W}_{(0),(i)}^{\text{UK}}\\ \bm{W}_{(1),(i)}^{\text{UK}}\\ \bm{W}_{(2),(i)}^{\text{UK}}\\ \bm{W}_{(3),(i)}^{\text{UK}}\end{bmatrix},\quad\bm{W}_{:,(i)}^{\text{UV}}=\begin{bmatrix}\bm{W}_{(0),(i)}^{\text{UV}}\\ \bm{W}_{(1),(i)}^{\text{UV}}\\ \bm{W}_{(2),(i)}^{\text{UV}}\\ \bm{W}_{(3),(i)}^{\text{UV}}\end{bmatrix}. Similarly, we partition the KV latent matrix 𝑪KV∈ℝn×dc\bm{C}^{\text{KV}}\in\mathbb{R}^{n\times d_{c}} into horizontal channel blocks 𝑪:,(b)KV:=𝑪KV[:,bdh:(b+1)dh]\bm{C}_{:,(b)}^{\text{KV}}:=\bm{C}^{\text{KV}}\left[:,\,bd_{h}:(b+1)d_{h}\right], such that 𝑪KV=[𝑪:,(0)KV,…,𝑪:,(3)KV]\bm{C}^{\text{KV}}=\left[\bm{C}_{:,(0)}^{\text{KV}},\dots,\bm{C}_{:,(3)}^{\text{KV}}\right]. This block decomposition allows the key and value projections for head ii to be reformulated as a sum of four sub-block products: 𝑲:,(i),:NoPE=∑b=03𝑪:,(b)KV𝑾(b),(i)UK,𝑽:,(i),:=∑b=03𝑪:,(b)KV𝑾(b),(i)UV.\bm{\mathsfit{K}}_{:,(i),:}^{\text{NoPE}}=\sum_{b=0}^{3}\bm{C}_{:,(b)}^{\text{KV}}\bm{W}_{(b),(i)}^{\text{UK}},\qquad\bm{\mathsfit{V}}_{:,(i),:}=\sum_{b=0}^{3}\bm{C}_{:,(b)}^{\text{KV}}\bm{W}_{(b),(i)}^{\text{UV}}. (2) 2.2 Grouped Latent Attention Grouped Latent Attention (GLA-2) (Zadouri et al., 2025) bisects MLA’s single latent head into two latent heads, using the first latent head (𝑪:,(0)KV,𝑪:,(1)KV)(\bm{C}_{:,(0)}^{\text{KV}},\bm{C}_{:,(1)}^{\text{KV}}) for the first half of attention heads and the second latent head (𝑪:,(2)KV,𝑪:,(3)KV)(\bm{C}_{:,(2)}^{\text{KV}},\bm{C}_{:,(3)}^{\text{KV}}) for the second half. We define the group-mapping function as: γ(i)={0,i<h/2,1,i≥h/2,i¯=i−γ(i)h2∈{0,…,h/2−1}.\gamma(i)=\begin{cases}0,&i<h/2,\\ 1,&i\geq h/2,\end{cases}\qquad\bar{i}=i-\frac{\gamma(i)\,h}{2}\in\{0,\dots,h/2-1\}. (3) Let 𝑾(γ(i)),UK,𝑾(γ(i)),UV∈ℝ2dh×(h/2)dh\bm{W}^{(\gamma(i)),\text{UK}},\bm{W}^{(\gamma(i)),\text{UV}}\in\mathbb{R}^{2d_{h}\times(h/2)\,d_{h}} denote the up-projection matrices for latent group γ(i)∈{0,1}\gamma(i)\in\{0,1\}. We extract the head-specific slices for head ii by indexing into these matrices: 𝑾:,(i)(γ(i)),UK=𝑾(γ(i)),UK[:,i¯dh:(i¯+1)dh],𝑾:,(i)(γ(i)),UV=𝑾(γ(i)),UV[:,i¯dh:(i¯+1)dh],\bm{W}_{:,(i)}^{(\gamma(i)),\text{UK}}=\bm{W}^{(\gamma(i)),\text{UK}}\left[:,\,\bar{i}d_{h}:(\bar{i}+1)d_{h}\right],\ \ \bm{W}_{:,(i)}^{(\gamma(i)),\text{UV}}=\bm{W}^{(\gamma(i)),\text{UV}}\left[:,\,\bar{i}d_{h}:(\bar{i}+1)d_{h}\right], where 𝑾:,(i)(γ(i)),UK,𝑾:,(i)(γ(i)),UV∈ℝ2dh×dh\bm{W}_{:,(i)}^{(\gamma(i)),\text{UK}},\bm{W}_{:,(i)}^{(\gamma(i)),\text{UV}}\in\mathbb{R}^{2d_{h}\times d_{h}}. To further facilitate block-wise computation, we partition these slices into dhd_{h}-row blocks 𝑾(b),(i)(γ(i)),⋅\bm{W}_{(b),(i)}^{(\gamma(i)),\cdot} for b∈{0,1}b\in\{0,1\}, defined as: 𝑾(b),(i)(γ(i)),UK:=𝑾(γ(i)),UK[bdh:(b+1)dh,i¯dh:(i¯+1)dh],𝑾(b),(i)(γ(i)),UV:=𝑾(γ(i)),UV[bdh:(b+1)dh,i¯dh:(i¯+1)dh],\begin{split}\bm{W}_{(b),(i)}^{(\gamma(i)),\text{UK}}&:=\bm{W}^{(\gamma(i)),\text{UK}}\left[bd_{h}:(b+1)d_{h},\,\bar{i}d_{h}:(\bar{i}+1)d_{h}\right],\\ \bm{W}_{(b),(i)}^{(\gamma(i)),\text{UV}}&:=\bm{W}^{(\gamma(i)),\text{UV}}\left[bd_{h}:(b+1)d_{h},\,\bar{i}d_{h}:(\bar{i}+1)d_{h}\right],\end{split} where each block 𝑾(b),(i)(γ(i)),UK,𝑾(b),(i)(γ(i)),UV∈ℝdh×dh\bm{W}_{(b),(i)}^{(\gamma(i)),\text{UK}},\bm{W}_{(b),(i)}^{(\gamma(i)),\text{UV}}\in\mathbb{R}^{d_{h}\times d_{h}}. This partitioning allows us to decompose the head-specific up-projection matrices into two row-wise blocks: 𝑾:,(i)(γ(i)),UK=[𝑾(0),(i)(γ(i)),UK𝑾(1),(i)(γ(i)),UK],𝑾:,(i)(γ(i)),UV=[𝑾(0),(i)(γ(i)),UV𝑾(1),(i)(γ(i)),UV].\bm{W}_{:,(i)}^{(\gamma(i)),\text{UK}}=\begin{bmatrix}\bm{W}_{(0),(i)}^{(\gamma(i)),\text{UK}}\\ \bm{W}_{(1),(i)}^{(\gamma(i)),\text{UK}}\end{bmatrix},\qquad\bm{W}_{:,(i)}^{(\gamma(i)),\text{UV}}=\begin{bmatrix}\bm{W}_{(0),(i)}^{(\gamma(i)),\text{UV}}\\ \bm{W}_{(1),(i)}^{(\gamma(i)),\text{UV}}\end{bmatrix}. Consequently, the NoPE key and value computations for head ii can be expressed as the summation of two block products: 𝑲:,(i),:NoPE=𝑪:,(2γ(i))KV𝑾(0),(i)(γ(i)),UK+𝑪:,(2γ(i)+1)KV𝑾(1),(i)(γ(i)),UK,𝑽:,(i),:=𝑪:,(2γ(i))KV𝑾(0),(i)(γ(i)),UV+𝑪:,(2γ(i)+1)KV𝑾(1),(i)(γ(i)),UV.\begin{split}\bm{\mathsfit{K}}_{:,(i),:}^{\text{NoPE}}&=\bm{C}_{:,(2\gamma(i))}^{\text{KV}}\bm{W}_{(0),(i)}^{(\gamma(i)),\text{UK}}+\bm{C}_{:,(2\gamma(i)+1)}^{\text{KV}}\bm{W}_{(1),(i)}^{(\gamma(i)),\text{UK}},\\ \bm{\mathsfit{V}}_{:,(i),:}&=\bm{C}_{:,(2\gamma(i))}^{\text{KV}}\bm{W}_{(0),(i)}^{(\gamma(i)),\text{UV}}+\bm{C}_{:,(2\gamma(i)+1)}^{\text{KV}}\bm{W}_{(1),(i)}^{(\gamma(i)),\text{UV}}.\end{split} (4) 3 Multi-Head Low-Rank Attention Building on the block decompositions in Sections 2.1 and 2.2, we propose MLRA. By shifting the summation from KV computation to attention output, MLRA treats each block projection as an independent low-rank branch and sums their outputs. MLRA is illustrated in Figures 8 and 9. 3.1 MLRA-4 By substituting the block-partitioned identities from Eq. (2) into the attention mechanism, the output for head ii can be expressed as: 𝑶:,i,:=Softmax(τ𝑸:,i,:NoPE(∑b=03𝑪:,(b)KV𝑾(b),(i)UK)⊤+τ𝑸:,i,:RoPE(𝑲RoPE)⊤)(∑b=03𝑪:,(b)KV𝑾(b),(i)UV).\begin{split}\bm{\mathsfit{O}}_{:,i,:}=&\ \operatorname{Softmax}\left(\tau\bm{\mathsfit{Q}}_{:,i,:}^{\text{NoPE}}\left(\sum_{b=0}^{3}\bm{C}_{:,(b)}^{\text{KV}}\bm{W}_{(b),(i)}^{\text{UK}}\right)^{\top}+\tau\bm{\mathsfit{Q}}_{:,i,:}^{\text{RoPE}}\left(\bm{K}^{\text{RoPE}}\right)^{\top}\right)\left(\sum_{b=0}^{3}\bm{C}_{:,(b)}^{\text{KV}}\bm{W}_{(b),(i)}^{\text{UV}}\right).\end{split} Motivated by Eq. (2), we propose MLRA-4, which computes attention independently on each blockwise branch and sums the resulting outputs: 𝑶:,i,:=∑b=
1
0
3
4
MLRA unlocks 2.8x faster LLM decoding by enabling efficient tensor parallelism for latent attention, sidestepping the memory traffic bottlenecks that plague existing methods.