[edit]
FlashTP: Fused, Sparsity-Aware Tensor Product for Machine Learning Interatomic Potentials
Proceedings of the 42nd International Conference on Machine Learning, PMLR 267:33143-33156, 2025.
Abstract
Machine Learning Interatomic Potentials (MLIPs) enable efficient molecular dynamics (MD) simulations with high accuracy. While equivariant MLIPs achieve state-of-the-art accuracy, they face significant computational bottlenecks centered around their Tensor-Product layer, which account for up to 75% of training time and cause substantial memory overhead. We present FlashTP, a highly optimized tensor-product library that addresses these inefficiencies through kernel fusion, sparse computation, and path-aggregated execution. FlashTP achieves up to 41.6$\times$ and 60.8$\times$ kernel speedups over e3nn and NVIDIA cuEquivariance, respectively. For SevenNet-l3i5, it delivers 4.2$\times$ and 3.5$\times$ speedup while reducing peak memory usage by 6.3$\times$ and 6.2$\times$ for inference and training, respectively. The code is available at https://github.com/SNU-ARC/flashTP.