Efficiently Dispatching Flash Attention For Partially Filled Attention Masks

Agniv Sharma, Jonas A. Geiping
Proceedings of The 4th NeurIPS Efficient Natural Language and Speech Processing Workshop, PMLR 262:423-442, 2024.

Abstract

Transformers are widely used across various applications, many of which yield sparse or partially filled attention matrices. Examples include attention masks designed to reduce the quadratic complexity of attention, sequence packing techniques, and recent innovations like tree masking for fast validation in MEDUSA. Despite the inherent sparsity in these matrices, the state-of-the-art algorithm Flash Attention still processes them with quadratic complexity as though they were dense. In this paper, we introduce \textbf{Binary Block Masking}, a highly efficient modification that enhances Flash Attention by making it mask-aware. We further propose two optimizations: one tailored for masks with contiguous non-zero patterns and another for extremely sparse masks. Our experiments on attention masks derived from real-world scenarios demonstrate up to a 9x runtime improvement. The implementation will be publicly released to foster further research and application.

Cite this Paper


BibTeX
@InProceedings{pmlr-v262-sharma24a, title = {Efficiently Dispatching Flash Attention For Partially Filled Attention Masks}, author = {Sharma, Agniv and A. Geiping, Jonas}, booktitle = {Proceedings of The 4th NeurIPS Efficient Natural Language and Speech Processing Workshop}, pages = {423--442}, year = {2024}, editor = {Rezagholizadeh, Mehdi and Passban, Peyman and Samiee, Soheila and Partovi Nia, Vahid and Cheng, Yu and Deng, Yue and Liu, Qun and Chen, Boxing}, volume = {262}, series = {Proceedings of Machine Learning Research}, month = {14 Dec}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v262/main/assets/sharma24a/sharma24a.pdf}, url = {https://proceedings.mlr.press/v262/sharma24a.html}, abstract = {Transformers are widely used across various applications, many of which yield sparse or partially filled attention matrices. Examples include attention masks designed to reduce the quadratic complexity of attention, sequence packing techniques, and recent innovations like tree masking for fast validation in MEDUSA. Despite the inherent sparsity in these matrices, the state-of-the-art algorithm Flash Attention still processes them with quadratic complexity as though they were dense. In this paper, we introduce \textbf{Binary Block Masking}, a highly efficient modification that enhances Flash Attention by making it mask-aware. We further propose two optimizations: one tailored for masks with contiguous non-zero patterns and another for extremely sparse masks. Our experiments on attention masks derived from real-world scenarios demonstrate up to a 9x runtime improvement. The implementation will be publicly released to foster further research and application.} }
Endnote
%0 Conference Paper %T Efficiently Dispatching Flash Attention For Partially Filled Attention Masks %A Agniv Sharma %A Jonas A. Geiping %B Proceedings of The 4th NeurIPS Efficient Natural Language and Speech Processing Workshop %C Proceedings of Machine Learning Research %D 2024 %E Mehdi Rezagholizadeh %E Peyman Passban %E Soheila Samiee %E Vahid Partovi Nia %E Yu Cheng %E Yue Deng %E Qun Liu %E Boxing Chen %F pmlr-v262-sharma24a %I PMLR %P 423--442 %U https://proceedings.mlr.press/v262/sharma24a.html %V 262 %X Transformers are widely used across various applications, many of which yield sparse or partially filled attention matrices. Examples include attention masks designed to reduce the quadratic complexity of attention, sequence packing techniques, and recent innovations like tree masking for fast validation in MEDUSA. Despite the inherent sparsity in these matrices, the state-of-the-art algorithm Flash Attention still processes them with quadratic complexity as though they were dense. In this paper, we introduce \textbf{Binary Block Masking}, a highly efficient modification that enhances Flash Attention by making it mask-aware. We further propose two optimizations: one tailored for masks with contiguous non-zero patterns and another for extremely sparse masks. Our experiments on attention masks derived from real-world scenarios demonstrate up to a 9x runtime improvement. The implementation will be publicly released to foster further research and application.
APA
Sharma, A. & A. Geiping, J.. (2024). Efficiently Dispatching Flash Attention For Partially Filled Attention Masks. Proceedings of The 4th NeurIPS Efficient Natural Language and Speech Processing Workshop, in Proceedings of Machine Learning Research 262:423-442 Available from https://proceedings.mlr.press/v262/sharma24a.html.

Related Material