Search papers, labs, and topics across Lattice.
This paper investigates the performance gap between trainable INT8 attention (SageBwd) and full-precision attention (FPA) during pre-training. Through empirical analysis and theoretical insights, the authors identify that QK-normalization is crucial for stable training with large tokens per step and that quantization errors primarily stem from the backward-pass score gradient dS. By reducing tokens per step and maintaining K-smoothing, SageBwd achieves performance matching FPA during pre-training.
Trainable INT8 attention can match full-precision attention during pre-training, but only if you normalize QK and reduce tokens per step.
Low-bit attention, such as SageAttention, has emerged as an effective approach for accelerating model inference, but its applicability to training remains poorly understood. In prior work, we introduced SageBwd, a trainable INT8 attention that quantizes six of seven attention matrix multiplications while preserving fine-tuning performance. However, SageBwd exhibited a persistent performance gap to full-precision attention (FPA) during pre-training. In this work, we investigate why this gap occurs and demonstrate that SageBwd matches full-precision attention during pretraining. Through experiments and theoretical analysis, we reach a few important insights and conclusions: (i) QK-norm is necessary for stable training at large tokens per step, (ii) quantization errors primarily arise from the backward-pass score gradient dS, (iii) reducing tokens per step enables SageBwd to match FPA performance in pre-training, and (iv) K-smoothing remains essential for training stability, while Q-smoothing provides limited benefit during pre-training.