PolySketchFormer: Fast Transformers via Sketching Polynomial Kernels

Praneeth Kacham, Vahab Mirrokni, Peilin Zhong
Proceedings of the 41st International Conference on Machine Learning, PMLR 235:22748-22770, 2024.

Abstract

The quadratic time and memory complexity inherent to self-attention mechanisms, with respect to sequence length, presents a critical computational bottleneck in the training and deployment of large-scale Transformer-based language models. Recent theoretical results indicate the intractability of sub-quadratic softmax attention approximation under reasonable complexity assumptions. This paper addresses this challenge by first demonstrating that polynomial attention with high degree can effectively replace softmax without sacrificing model quality. Next, we develop polynomial sketching techniques from numerical linear algebra to achieve linear-time polynomial attention with approximation guarantees. Crucially, our approach achieves this speedup without requiring the sparsification of attention matrices. We also present a block-based algorithm to apply causal masking efficiently. Combining these techniques, we provide PolySketchFormer, a practical linear-time Transformer architecture for language modeling that offers provable guarantees. We validate PolySketchFormer empirically by training language models capable of handling long contexts. These experiments utilize both synthetic and real-world datasets (PG19, Wikipedia and C4) on Google Cloud TPUs. For context lengths of 32k and GPT-2 style models, our model achieves 2x speedup in training compared to FlashAttention of the fastest configuration, with no observed degradation in quality across our experiments.

Cite this Paper


BibTeX
@InProceedings{pmlr-v235-kacham24a, title = {{P}oly{S}ketch{F}ormer: Fast Transformers via Sketching Polynomial Kernels}, author = {Kacham, Praneeth and Mirrokni, Vahab and Zhong, Peilin}, booktitle = {Proceedings of the 41st International Conference on Machine Learning}, pages = {22748--22770}, year = {2024}, editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, volume = {235}, series = {Proceedings of Machine Learning Research}, month = {21--27 Jul}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/kacham24a/kacham24a.pdf}, url = {https://proceedings.mlr.press/v235/kacham24a.html}, abstract = {The quadratic time and memory complexity inherent to self-attention mechanisms, with respect to sequence length, presents a critical computational bottleneck in the training and deployment of large-scale Transformer-based language models. Recent theoretical results indicate the intractability of sub-quadratic softmax attention approximation under reasonable complexity assumptions. This paper addresses this challenge by first demonstrating that polynomial attention with high degree can effectively replace softmax without sacrificing model quality. Next, we develop polynomial sketching techniques from numerical linear algebra to achieve linear-time polynomial attention with approximation guarantees. Crucially, our approach achieves this speedup without requiring the sparsification of attention matrices. We also present a block-based algorithm to apply causal masking efficiently. Combining these techniques, we provide PolySketchFormer, a practical linear-time Transformer architecture for language modeling that offers provable guarantees. We validate PolySketchFormer empirically by training language models capable of handling long contexts. These experiments utilize both synthetic and real-world datasets (PG19, Wikipedia and C4) on Google Cloud TPUs. For context lengths of 32k and GPT-2 style models, our model achieves 2x speedup in training compared to FlashAttention of the fastest configuration, with no observed degradation in quality across our experiments.} }
Endnote
%0 Conference Paper %T PolySketchFormer: Fast Transformers via Sketching Polynomial Kernels %A Praneeth Kacham %A Vahab Mirrokni %A Peilin Zhong %B Proceedings of the 41st International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2024 %E Ruslan Salakhutdinov %E Zico Kolter %E Katherine Heller %E Adrian Weller %E Nuria Oliver %E Jonathan Scarlett %E Felix Berkenkamp %F pmlr-v235-kacham24a %I PMLR %P 22748--22770 %U https://proceedings.mlr.press/v235/kacham24a.html %V 235 %X The quadratic time and memory complexity inherent to self-attention mechanisms, with respect to sequence length, presents a critical computational bottleneck in the training and deployment of large-scale Transformer-based language models. Recent theoretical results indicate the intractability of sub-quadratic softmax attention approximation under reasonable complexity assumptions. This paper addresses this challenge by first demonstrating that polynomial attention with high degree can effectively replace softmax without sacrificing model quality. Next, we develop polynomial sketching techniques from numerical linear algebra to achieve linear-time polynomial attention with approximation guarantees. Crucially, our approach achieves this speedup without requiring the sparsification of attention matrices. We also present a block-based algorithm to apply causal masking efficiently. Combining these techniques, we provide PolySketchFormer, a practical linear-time Transformer architecture for language modeling that offers provable guarantees. We validate PolySketchFormer empirically by training language models capable of handling long contexts. These experiments utilize both synthetic and real-world datasets (PG19, Wikipedia and C4) on Google Cloud TPUs. For context lengths of 32k and GPT-2 style models, our model achieves 2x speedup in training compared to FlashAttention of the fastest configuration, with no observed degradation in quality across our experiments.
APA
Kacham, P., Mirrokni, V. & Zhong, P.. (2024). PolySketchFormer: Fast Transformers via Sketching Polynomial Kernels. Proceedings of the 41st International Conference on Machine Learning, in Proceedings of Machine Learning Research 235:22748-22770 Available from https://proceedings.mlr.press/v235/kacham24a.html.

Related Material