Search papers, labs, and topics across Lattice.
The paper introduces MDM-Prime-v2, a masked diffusion language model that improves upon the MDM-Prime framework by incorporating Binary Encoding and Index Shuffling to address limitations in hyperparameter selection and likelihood estimation. Through a scaling analysis, the authors demonstrate that MDM-Prime-v2 achieves a 21.8x improvement in compute efficiency compared to autoregressive models. Empirically, MDM-Prime-v2 achieves a perplexity of 7.77 on OpenWebText and exhibits superior zero-shot accuracy on commonsense reasoning tasks when scaled to 1.1B parameters.
Masked diffusion language models can now achieve 21.8x better compute efficiency than autoregressive models, thanks to binary encoding and index shuffling.
Masked diffusion models (MDM) exhibit superior generalization when learned using a Partial masking scheme (Prime). This approach converts tokens into sub-tokens and models the diffusion process at the sub-token level. We identify two limitations of the MDM-Prime framework. First, we lack tools to guide the hyperparameter choice of the token granularity in the subtokenizer. Second, we find that the function form of the subtokenizer significantly degrades likelihood estimation when paired with commonly used Byte-Pair-Encoding (BPE) tokenizers. To address these limitations, we study the tightness of the variational bound in MDM-Prime and develop MDM-Prime-v2, a masked diffusion language model which incorporates Binary Encoding and Index Shuffling. Our scaling analysis reveals that MDM-Prime-v2 is 21.8$\times$ more compute-efficient than autoregressive models (ARM). In compute-optimal comparisons, MDM-Prime-v2 achieves 7.77 perplexity on OpenWebText, outperforming ARM (12.99), MDM (18.94), and MDM-Prime (13.41). When extending the model size to 1.1B parameters, our model further demonstrates superior zero-shot accuracy on various commonsense reasoning tasks.