Simple linear attention language models balance the recall-throughput tradeoff

Simran Arora, Sabri Eyuboglu, Michael Zhang, Aman Timalsina, Silas Alberti, James Zou, Atri Rudra, Christopher Re
Proceedings of the 41st International Conference on Machine Learning, PMLR 235:1763-1840, 2024.

Abstract

Recent work has shown that attention-based language models excel at "recall", the ability to ground generations in tokens previously seen in context. However, the efficiency of attention-based models is bottle-necked during inference by the KV-cache’s aggressive memory consumption. In this work, we explore whether we can improve language model efficiency (e.g. by reducing memory consumption) without compromising on recall. By applying experiments and theory to a broad set of architectures, we identify a key tradeoff between a model’s recurrent state size and recall ability. We show that efficient alternatives to attention (e.g. H3, Mamba, RWKV) maintain a fixed-size recurrent state, but struggle at recall. We propose BASED a simple architecture combining linear and sliding window attention. By varying BASED window size and linear attention feature dimension, we can dial the state size and traverse the Pareto frontier of the recall-memory tradeoff curve, recovering the full quality of attention on one end and the small state size of attention-alternatives on the other. We train language models up to $1.3$b parameters and show that BASED matches the strongest sub-quadratic models (e.g. Mamba) in perplexity and outperforms them on real-world recall-intensive tasks by 10.36 accuracy points. We further develop IO-aware algorithms that enable BASED to provide 24× higher throughput on language generation than FlashAttention-2, when generating 1024 tokens using 1.3b parameter models. Overall, BASED expands the Pareto frontier of the throughput-recall tradeoff space beyond prior architectures.

Cite this Paper


BibTeX
@InProceedings{pmlr-v235-arora24a, title = {Simple linear attention language models balance the recall-throughput tradeoff}, author = {Arora, Simran and Eyuboglu, Sabri and Zhang, Michael and Timalsina, Aman and Alberti, Silas and Zou, James and Rudra, Atri and Re, Christopher}, booktitle = {Proceedings of the 41st International Conference on Machine Learning}, pages = {1763--1840}, year = {2024}, editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, volume = {235}, series = {Proceedings of Machine Learning Research}, month = {21--27 Jul}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/arora24a/arora24a.pdf}, url = {https://proceedings.mlr.press/v235/arora24a.html}, abstract = {Recent work has shown that attention-based language models excel at "recall", the ability to ground generations in tokens previously seen in context. However, the efficiency of attention-based models is bottle-necked during inference by the KV-cache’s aggressive memory consumption. In this work, we explore whether we can improve language model efficiency (e.g. by reducing memory consumption) without compromising on recall. By applying experiments and theory to a broad set of architectures, we identify a key tradeoff between a model’s recurrent state size and recall ability. We show that efficient alternatives to attention (e.g. H3, Mamba, RWKV) maintain a fixed-size recurrent state, but struggle at recall. We propose BASED a simple architecture combining linear and sliding window attention. By varying BASED window size and linear attention feature dimension, we can dial the state size and traverse the Pareto frontier of the recall-memory tradeoff curve, recovering the full quality of attention on one end and the small state size of attention-alternatives on the other. We train language models up to $1.3$b parameters and show that BASED matches the strongest sub-quadratic models (e.g. Mamba) in perplexity and outperforms them on real-world recall-intensive tasks by 10.36 accuracy points. We further develop IO-aware algorithms that enable BASED to provide 24× higher throughput on language generation than FlashAttention-2, when generating 1024 tokens using 1.3b parameter models. Overall, BASED expands the Pareto frontier of the throughput-recall tradeoff space beyond prior architectures.} }
Endnote
%0 Conference Paper %T Simple linear attention language models balance the recall-throughput tradeoff %A Simran Arora %A Sabri Eyuboglu %A Michael Zhang %A Aman Timalsina %A Silas Alberti %A James Zou %A Atri Rudra %A Christopher Re %B Proceedings of the 41st International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2024 %E Ruslan Salakhutdinov %E Zico Kolter %E Katherine Heller %E Adrian Weller %E Nuria Oliver %E Jonathan Scarlett %E Felix Berkenkamp %F pmlr-v235-arora24a %I PMLR %P 1763--1840 %U https://proceedings.mlr.press/v235/arora24a.html %V 235 %X Recent work has shown that attention-based language models excel at "recall", the ability to ground generations in tokens previously seen in context. However, the efficiency of attention-based models is bottle-necked during inference by the KV-cache’s aggressive memory consumption. In this work, we explore whether we can improve language model efficiency (e.g. by reducing memory consumption) without compromising on recall. By applying experiments and theory to a broad set of architectures, we identify a key tradeoff between a model’s recurrent state size and recall ability. We show that efficient alternatives to attention (e.g. H3, Mamba, RWKV) maintain a fixed-size recurrent state, but struggle at recall. We propose BASED a simple architecture combining linear and sliding window attention. By varying BASED window size and linear attention feature dimension, we can dial the state size and traverse the Pareto frontier of the recall-memory tradeoff curve, recovering the full quality of attention on one end and the small state size of attention-alternatives on the other. We train language models up to $1.3$b parameters and show that BASED matches the strongest sub-quadratic models (e.g. Mamba) in perplexity and outperforms them on real-world recall-intensive tasks by 10.36 accuracy points. We further develop IO-aware algorithms that enable BASED to provide 24× higher throughput on language generation than FlashAttention-2, when generating 1024 tokens using 1.3b parameter models. Overall, BASED expands the Pareto frontier of the throughput-recall tradeoff space beyond prior architectures.
APA
Arora, S., Eyuboglu, S., Zhang, M., Timalsina, A., Alberti, S., Zou, J., Rudra, A. & Re, C.. (2024). Simple linear attention language models balance the recall-throughput tradeoff. Proceedings of the 41st International Conference on Machine Learning, in Proceedings of Machine Learning Research 235:1763-1840 Available from https://proceedings.mlr.press/v235/arora24a.html.

Related Material