Mechanics of Next Token Prediction with Self-Attention

Yingcong Li, Yixiao Huang, Muhammed E. Ildiz, Ankit Singh Rawat, Samet Oymak
Proceedings of The 27th International Conference on Artificial Intelligence and Statistics, PMLR 238:685-693, 2024.

Abstract

Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this simple training objective, they have led to revolutionary advances in natural language processing. Underlying this success is the self-attention mechanism. In this work, we ask: What does a single self-attention layer learn from next-token prediction? We show that training self-attention with gradient descent learns an automaton which generates the next token in two distinct steps: (1) Hard retrieval: Given input sequence, self-attention precisely selects the high-priority input tokens associated with the last input token. (2) Soft composition: It then creates a convex combination of the high-priority tokens from which the next token can be sampled. Under suitable conditions, we rigorously characterize these mechanics through a directed graph over tokens extracted from the training data. We prove that gradient descent implicitly discovers the strongly-connected components (SCC) of this graph and self-attention learns to retrieve the tokens that belong to the highest-priority SCC available in the context window. Our theory relies on decomposing the model weights into a directional component and a finite component that correspond to hard retrieval and soft composition steps respectively. This also formalizes a related implicit bias formula conjectured in [Tarzanagh et al. 2023]. We hope that these findings shed light on how self-attention processes sequential data and pave the path toward demystifying more complex architectures.

Cite this Paper


BibTeX
@InProceedings{pmlr-v238-li24f, title = {Mechanics of Next Token Prediction with Self-Attention}, author = {Li, Yingcong and Huang, Yixiao and Ildiz, Muhammed E. and Singh Rawat, Ankit and Oymak, Samet}, booktitle = {Proceedings of The 27th International Conference on Artificial Intelligence and Statistics}, pages = {685--693}, year = {2024}, editor = {Dasgupta, Sanjoy and Mandt, Stephan and Li, Yingzhen}, volume = {238}, series = {Proceedings of Machine Learning Research}, month = {02--04 May}, publisher = {PMLR}, pdf = {https://proceedings.mlr.press/v238/li24f/li24f.pdf}, url = {https://proceedings.mlr.press/v238/li24f.html}, abstract = {Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this simple training objective, they have led to revolutionary advances in natural language processing. Underlying this success is the self-attention mechanism. In this work, we ask: What does a single self-attention layer learn from next-token prediction? We show that training self-attention with gradient descent learns an automaton which generates the next token in two distinct steps: (1) Hard retrieval: Given input sequence, self-attention precisely selects the high-priority input tokens associated with the last input token. (2) Soft composition: It then creates a convex combination of the high-priority tokens from which the next token can be sampled. Under suitable conditions, we rigorously characterize these mechanics through a directed graph over tokens extracted from the training data. We prove that gradient descent implicitly discovers the strongly-connected components (SCC) of this graph and self-attention learns to retrieve the tokens that belong to the highest-priority SCC available in the context window. Our theory relies on decomposing the model weights into a directional component and a finite component that correspond to hard retrieval and soft composition steps respectively. This also formalizes a related implicit bias formula conjectured in [Tarzanagh et al. 2023]. We hope that these findings shed light on how self-attention processes sequential data and pave the path toward demystifying more complex architectures.} }
Endnote
%0 Conference Paper %T Mechanics of Next Token Prediction with Self-Attention %A Yingcong Li %A Yixiao Huang %A Muhammed E. Ildiz %A Ankit Singh Rawat %A Samet Oymak %B Proceedings of The 27th International Conference on Artificial Intelligence and Statistics %C Proceedings of Machine Learning Research %D 2024 %E Sanjoy Dasgupta %E Stephan Mandt %E Yingzhen Li %F pmlr-v238-li24f %I PMLR %P 685--693 %U https://proceedings.mlr.press/v238/li24f.html %V 238 %X Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this simple training objective, they have led to revolutionary advances in natural language processing. Underlying this success is the self-attention mechanism. In this work, we ask: What does a single self-attention layer learn from next-token prediction? We show that training self-attention with gradient descent learns an automaton which generates the next token in two distinct steps: (1) Hard retrieval: Given input sequence, self-attention precisely selects the high-priority input tokens associated with the last input token. (2) Soft composition: It then creates a convex combination of the high-priority tokens from which the next token can be sampled. Under suitable conditions, we rigorously characterize these mechanics through a directed graph over tokens extracted from the training data. We prove that gradient descent implicitly discovers the strongly-connected components (SCC) of this graph and self-attention learns to retrieve the tokens that belong to the highest-priority SCC available in the context window. Our theory relies on decomposing the model weights into a directional component and a finite component that correspond to hard retrieval and soft composition steps respectively. This also formalizes a related implicit bias formula conjectured in [Tarzanagh et al. 2023]. We hope that these findings shed light on how self-attention processes sequential data and pave the path toward demystifying more complex architectures.
APA
Li, Y., Huang, Y., Ildiz, M.E., Singh Rawat, A. & Oymak, S.. (2024). Mechanics of Next Token Prediction with Self-Attention. Proceedings of The 27th International Conference on Artificial Intelligence and Statistics, in Proceedings of Machine Learning Research 238:685-693 Available from https://proceedings.mlr.press/v238/li24f.html.

Related Material