Search papers, labs, and topics across Lattice.
This paper introduces a theoretical framework for modular generative modeling, combining pre-trained experts via a gating mechanism to achieve robust performance across diverse data mixtures without heuristic tuning. The authors formulate the problem as a minimax game to find a robust gate minimizing divergence to the worst-case data mixture and prove the existence of such a gate using Kakutani's fixed-point theorem. They demonstrate that this modular approach, coupled with a scalable Stochastic Primal-Dual algorithm and Structural Distillation, can theoretically outperform monolithic models and empirically mitigates gradient conflict.
Modular generative models can theoretically and empirically outperform monolithic models, offering a robust alternative to resource-intensive retraining on aggregate data.
Training large-scale generative models is resource-intensive and relies heavily on heuristic dataset weighting. We address two fundamental questions: Can we train Large Language Models (LLMs) modularly-combining small, domain-specific experts to match monolithic performance-and can we do so robustly for any data mixture, eliminating heuristic tuning? We present a theoretical framework for modular generative modeling where a set of pre-trained experts are combined via a gating mechanism. We define the space of normalized gating functions, $G_{1}$, and formulate the problem as a minimax game to find a single robust gate that minimizes divergence to the worst-case data mixture. We prove the existence of such a robust gate using Kakutani's fixed-point theorem and show that modularity acts as a strong regularizer, with generalization bounds scaling with the lightweight gate's complexity. Furthermore, we prove that this modular approach can theoretically outperform models retrained on aggregate data, with the gap characterized by the Jensen-Shannon Divergence. Finally, we introduce a scalable Stochastic Primal-Dual algorithm and a Structural Distillation method for efficient inference. Empirical results on synthetic and real-world datasets confirm that our modular architecture effectively mitigates gradient conflict and can robustly outperform monolithic baselines.