Search papers, labs, and topics across Lattice.
This paper introduces Flash-MaxSim, an IO-aware fused GPU kernel that avoids materializing the full query-token x document-token similarity tensor in late-interaction retrieval models like ColBERT and ColPali. By streaming query and document tiles through on-chip SRAM and folding the row-maximum reduction into the same pass, Flash-MaxSim achieves significant memory savings and speedups. Experiments show up to 4.7x speedup on an H100 and 16x less inference memory usage compared to naive PyTorch implementations, while preserving ranking accuracy.
Late-interaction retrieval just got a whole lot faster and cheaper: Flash-MaxSim slashes memory usage by 16x and speeds up inference by 4.7x on an H100 by ditching the massive similarity tensor.
Late-interaction retrieval (ColBERT, ColPali) scores a query against a document with the MaxSim operator: for every query token, the maximum similarity over the document tokens, summed over query tokens. The standard implementation materializes the full query-token x document-token similarity tensor in GPU memory; for visual ColPali at 10K documents this tensor alone is 21 GB in FP16, created only to be reduced to one score per document and discarded. It exhausts a 40 GB GPU and bounds the achievable batch size in both inference and training. We present Flash-MaxSim, an IO-aware fused GPU kernel that computes exactly the same scores without ever materializing the tensor, by streaming query and document tiles through on-chip SRAM and folding the row-maximum reduction into the same pass. We extend the IO-aware principle through the training backward pass, an inverse-grid CSR construction that reuses the forward argmax for an atomic-free, destination-owned gradient reduction, and through INT8xINT8 quantization and variable-length (padding-free) scoring. Flash-MaxSim is up to 3.9x faster on an A100 (4.7x on an H100) than naive PyTorch at matched precision, uses up to 16x less inference memory and ~28x less training memory, unlocks corpus and batch sizes that exhaust PyTorch entirely, preserves the exact ranking (100% top-20 agreement with an FP32 reference)