Search papers, labs, and topics across Lattice.
This paper introduces FlashCP, a novel framework for context parallelism in training large-scale language models that addresses the inefficiencies of existing methods by implementing a sharding-aware communication mechanism and a Whole-Doc sharding strategy. By optimizing workload balance and minimizing redundant key-value tensor communication, FlashCP significantly improves training efficiency. Experimental results demonstrate that FlashCP achieves up to 1.63x speedup compared to state-of-the-art context parallelism frameworks across various datasets, highlighting its effectiveness in enhancing model training performance.
FlashCP achieves up to 1.63x faster training for large language models by eliminating redundant communication and optimizing workload balance.
Context parallelism (CP) is essential for training large-scale, long-context language models, as it partitions sequences to reduce memory overhead. However, existing CP methods suffer from workload imbalance, inefficient kernels, and redundant communication due to static sequence sharding and key-value (KV) tensor communication. We present FlashCP, a load-balanced and communication-efficient framework for CP training. FlashCP introduces a sharding-aware communication mechanism to eliminate redundant KV communication and proposes a novel Whole-Doc sharding strategy that maximizes communication savings while maintaining balanced workloads. To efficiently combine Whole-Doc and Per-Doc sharding, FlashCP further designs a heuristic algorithm to search for near-optimal sharding plans. Extensive experiments show that FlashCP achieves up to 1.63x speedup over state-of-the-art CP frameworks across diverse datasets.