Search papers, labs, and topics across Lattice.
The paper introduces the Recurrent Transformer, an architecture where each layer attends to key-value pairs computed from its own activations, creating layer-wise recurrent memory. This design emulates both standard Transformers and token-to-token recurrent updates, while mitigating optimization instability issues associated with traditional recurrent models. By developing a tiling-based algorithm that reduces HBM traffic from $\Theta(N^2)$ to $\Theta(N\log N)$, the authors demonstrate improved cross-entropy on C4 pretraining with fewer layers compared to parameter-matched Transformers, suggesting a trade-off between depth and width.
Recurrent Transformers let you trade model depth for width, slashing KV cache memory footprint and inference latency without sacrificing performance.
Transformers process tokens in parallel but are temporally shallow: at position $t$, each layer attends to key-value pairs computed based on the previous layer, yielding a depth capped by the number of layers. Recurrent models offer unbounded temporal depth but suffer from optimization instability and historically underutilize modern accelerators. We introduce the Recurrent Transformer, a simple architectural change where each layer attends to key-value pairs computed off its own activations, yielding layerwise recurrent memory while preserving standard autoregressive decoding cost. We show that the architecture can emulate both (i) a conventional Transformer and (ii) token-to-token recurrent updates under mild assumptions, while avoiding optimization instability. Naively, prefill/training appears bandwidth-bound with effective arithmetic intensity near $1$ because keys and values are revealed sequentially; we give an exact tiling-based algorithm that preserves the mathematical computation while reducing HBM traffic from $\Theta(N^2)$ to $\Theta(N\log N)$, increasing effective arithmetic intensity to $\Theta(N/\log N)$ for sequence length $N$. On 150M and 300M parameter C4 pretraining, Recurrent Transformers improve cross-entropy over a parameter-matched Transformer baseline and achieve the improvement with fewer layers (fixed parameters), suggesting that recurrence can trade depth for width, thus reducing KV cache memory footprint and inference latency.