Search papers, labs, and topics across Lattice.
Flash-KMeans optimizes the K-means algorithm for modern GPUs by addressing IO bottlenecks in the assignment stage and atomic write contention in the centroid update. It introduces FlashAssign, which fuses distance computation with online argmin to avoid materializing the distance matrix, and sort-inverse update, which transforms atomic scatters into localized reductions. Evaluations on NVIDIA H200 GPUs show up to 17.9x speedup over existing baselines and significant performance gains compared to cuML and FAISS.
K-means gets a 17.9x speed boost on modern GPUs thanks to a clever redesign that avoids memory bottlenecks and atomic write contention.
$k$-means has historically been positioned primarily as an offline processing primitive, typically used for dataset organization or embedding preprocessing rather than as a first-class component in online systems. In this work, we revisit this classical algorithm under the lens of modern AI system design and enable $k$-means as an online primitive. We point out that existing GPU implementations of $k$-means remain fundamentally bottlenecked by low-level system constraints rather than theoretical algorithmic complexity. Specifically, the assignment stage suffers from a severe IO bottleneck due to the massive explicit materialization of the $N \times K$ distance matrix in High Bandwidth Memory (HBM). Simultaneously, the centroid update stage is heavily penalized by hardware-level atomic write contention caused by irregular, scatter-style token aggregations. To bridge this performance gap, we propose flash-kmeans, an IO-aware and contention-free $k$-means implementation for modern GPU workloads. Flash-kmeans introduces two core kernel-level innovations: (1) FlashAssign, which fuses distance computation with an online argmin to completely bypass intermediate memory materialization; (2) sort-inverse update, which explicitly constructs an inverse mapping to transform high-contention atomic scatters into high-bandwidth, segment-level localized reductions. Furthermore, we integrate algorithm-system co-designs, including chunked-stream overlap and cache-aware compile heuristics, to ensure practical deployability. Extensive evaluations on NVIDIA H200 GPUs demonstrate that flash-kmeans achieves up to 17.9$\times$ end-to-end speedup over best baselines, while outperforming industry-standard libraries like cuML and FAISS by 33$\times$ and over 200$\times$, respectively.