Search papers, labs, and topics across Lattice.
This paper introduces MaskAlign, a novel token-subset representation alignment method designed to enhance the efficiency of diffusion transformer training by addressing the mismatch between noisy diffusion inputs and clean-image representations. By applying alignment to randomly sampled token subsets, MaskAlign reduces the model's reliance on the complete set of tokens, leading to more stable alignment behavior across training iterations. The results demonstrate that this approach not only accelerates convergence but also improves the quality of generated outputs compared to traditional full-token alignment methods.
Randomly sampling token subsets during training can significantly enhance diffusion model performance by stabilizing representation alignment.
Representation alignment with pretrained vision models has recently shown strong potential for accelerating diffusion transformer training. By aligning intermediate diffusion features with clean-image representations from self-supervised vision encoders, existing methods improve convergence and generation quality. However, such alignment also introduces a non-trivial constraint: diffusion models operate on noisy inputs whose usable information varies across timesteps, while the reference features are extracted from clean images. In this paper, we revisit this mismatch from a token-level perspective. We find that, under full-token representation alignment, tokens with large alignment-gradient norms exhibit a stable spatial preference, suggesting that the alignment objective does not affect all tokens uniformly and may encourage the model to rely on the complete set of clean-image tokens. To address this issue, we propose MaskAlign, a token-subset representation alignment method that applies alignment to randomly sampled token subsets during training. By exposing the model to different token subsets across iterations, MaskAlign reduces the dependence of representation alignment on the complete token set and encourages alignment behavior that is more stable under token-subset perturbations. To mitigate the information loss caused by directly dropping tokens, we further introduce a lightweight pre-mask token mixing block that shares information across tokens before masking.