In-context Learning on Function Classes Unveiled for Transformers

Zhijie Wang, Bo Jiang, Shuai Li
Proceedings of the 41st International Conference on Machine Learning, PMLR 235:50726-50745, 2024.

Abstract

Transformer-based neural sequence models exhibit a remarkable ability to perform in-context learning. Given some training examples, a pre-trained model can make accurate predictions on an unseen input. This paper studies why transformers can learn different types of function classes in-context. We first show by construction that there exists a family of transformers (with different activation functions) that implement approximate gradient descent on the parameters of neural networks, and we provide an upper bound for the number of heads, hidden dimensions, and layers of the transformer. We also show that a transformer can learn linear functions, the indicator function of a unit ball, and smooth functions in-context by learning neural networks that approximate them. The above instances mainly focus on a transformer pre-trained on single tasks. We also prove that when pre-trained on two tasks: linear regression and classification, a transformer can make accurate predictions on both tasks simultaneously. Our results move beyond linearity in terms of in-context learning instances and provide a comprehensive understanding of why transformers can learn many types of function classes through the bridge of neural networks.

Cite this Paper


BibTeX
@InProceedings{pmlr-v235-wang24ae, title = {In-context Learning on Function Classes Unveiled for Transformers}, author = {Wang, Zhijie and Jiang, Bo and Li, Shuai}, booktitle = {Proceedings of the 41st International Conference on Machine Learning}, pages = {50726--50745}, year = {2024}, editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, volume = {235}, series = {Proceedings of Machine Learning Research}, month = {21--27 Jul}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/wang24ae/wang24ae.pdf}, url = {https://proceedings.mlr.press/v235/wang24ae.html}, abstract = {Transformer-based neural sequence models exhibit a remarkable ability to perform in-context learning. Given some training examples, a pre-trained model can make accurate predictions on an unseen input. This paper studies why transformers can learn different types of function classes in-context. We first show by construction that there exists a family of transformers (with different activation functions) that implement approximate gradient descent on the parameters of neural networks, and we provide an upper bound for the number of heads, hidden dimensions, and layers of the transformer. We also show that a transformer can learn linear functions, the indicator function of a unit ball, and smooth functions in-context by learning neural networks that approximate them. The above instances mainly focus on a transformer pre-trained on single tasks. We also prove that when pre-trained on two tasks: linear regression and classification, a transformer can make accurate predictions on both tasks simultaneously. Our results move beyond linearity in terms of in-context learning instances and provide a comprehensive understanding of why transformers can learn many types of function classes through the bridge of neural networks.} }
Endnote
%0 Conference Paper %T In-context Learning on Function Classes Unveiled for Transformers %A Zhijie Wang %A Bo Jiang %A Shuai Li %B Proceedings of the 41st International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2024 %E Ruslan Salakhutdinov %E Zico Kolter %E Katherine Heller %E Adrian Weller %E Nuria Oliver %E Jonathan Scarlett %E Felix Berkenkamp %F pmlr-v235-wang24ae %I PMLR %P 50726--50745 %U https://proceedings.mlr.press/v235/wang24ae.html %V 235 %X Transformer-based neural sequence models exhibit a remarkable ability to perform in-context learning. Given some training examples, a pre-trained model can make accurate predictions on an unseen input. This paper studies why transformers can learn different types of function classes in-context. We first show by construction that there exists a family of transformers (with different activation functions) that implement approximate gradient descent on the parameters of neural networks, and we provide an upper bound for the number of heads, hidden dimensions, and layers of the transformer. We also show that a transformer can learn linear functions, the indicator function of a unit ball, and smooth functions in-context by learning neural networks that approximate them. The above instances mainly focus on a transformer pre-trained on single tasks. We also prove that when pre-trained on two tasks: linear regression and classification, a transformer can make accurate predictions on both tasks simultaneously. Our results move beyond linearity in terms of in-context learning instances and provide a comprehensive understanding of why transformers can learn many types of function classes through the bridge of neural networks.
APA
Wang, Z., Jiang, B. & Li, S.. (2024). In-context Learning on Function Classes Unveiled for Transformers. Proceedings of the 41st International Conference on Machine Learning, in Proceedings of Machine Learning Research 235:50726-50745 Available from https://proceedings.mlr.press/v235/wang24ae.html.

Related Material