Can Looped Transformers Learn to Implement Multi-step Gradient Descent for In-context Learning?

Khashayar Gatmiry, Nikunj Saunshi, Sashank J. Reddi, Stefanie Jegelka, Sanjiv Kumar
Proceedings of the 41st International Conference on Machine Learning, PMLR 235:15130-15152, 2024.

Abstract

Transformers to do reasoning and few-shot learning, without any fine-tuning, is widely conjectured to stem from their ability to implicitly simulate a multi-step algorithms – such as gradient descent – with their weights in a single forward pass. Recently, there has been progress in understanding this complex phenomenon from an expressivity point of view, by demonstrating that Transformers can express such multi-step algorithms. However, our knowledge about the more fundamental aspect of its learnability, beyond single layer models, is very limited. In particular, can training Transformers enable convergence to algorithmic solutions? In this work we resolve this for in context linear regression with linear looped Transformers – a multi-layer model with weight sharing that is conjectured to have an inductive bias to learn fix-point iterative algorithms. More specifically, for this setting we show that the global minimizer of the population training loss implements multi-step preconditioned gradient descent, with a preconditioner that adapts to the data distribution. Furthermore, we show a fast convergence for gradient flow on the regression loss, despite the non-convexity of the landscape, by proving a novel gradient dominance condition. To our knowledge, this is the first theoretical analysis for multi-layer Transformer in this setting. We further validate our theoretical findings through synthetic experiments.

Cite this Paper


BibTeX
@InProceedings{pmlr-v235-gatmiry24b, title = {Can Looped Transformers Learn to Implement Multi-step Gradient Descent for In-context Learning?}, author = {Gatmiry, Khashayar and Saunshi, Nikunj and J. Reddi, Sashank and Jegelka, Stefanie and Kumar, Sanjiv}, booktitle = {Proceedings of the 41st International Conference on Machine Learning}, pages = {15130--15152}, year = {2024}, editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, volume = {235}, series = {Proceedings of Machine Learning Research}, month = {21--27 Jul}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/gatmiry24b/gatmiry24b.pdf}, url = {https://proceedings.mlr.press/v235/gatmiry24b.html}, abstract = {Transformers to do reasoning and few-shot learning, without any fine-tuning, is widely conjectured to stem from their ability to implicitly simulate a multi-step algorithms – such as gradient descent – with their weights in a single forward pass. Recently, there has been progress in understanding this complex phenomenon from an expressivity point of view, by demonstrating that Transformers can express such multi-step algorithms. However, our knowledge about the more fundamental aspect of its learnability, beyond single layer models, is very limited. In particular, can training Transformers enable convergence to algorithmic solutions? In this work we resolve this for in context linear regression with linear looped Transformers – a multi-layer model with weight sharing that is conjectured to have an inductive bias to learn fix-point iterative algorithms. More specifically, for this setting we show that the global minimizer of the population training loss implements multi-step preconditioned gradient descent, with a preconditioner that adapts to the data distribution. Furthermore, we show a fast convergence for gradient flow on the regression loss, despite the non-convexity of the landscape, by proving a novel gradient dominance condition. To our knowledge, this is the first theoretical analysis for multi-layer Transformer in this setting. We further validate our theoretical findings through synthetic experiments.} }
Endnote
%0 Conference Paper %T Can Looped Transformers Learn to Implement Multi-step Gradient Descent for In-context Learning? %A Khashayar Gatmiry %A Nikunj Saunshi %A Sashank J. Reddi %A Stefanie Jegelka %A Sanjiv Kumar %B Proceedings of the 41st International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2024 %E Ruslan Salakhutdinov %E Zico Kolter %E Katherine Heller %E Adrian Weller %E Nuria Oliver %E Jonathan Scarlett %E Felix Berkenkamp %F pmlr-v235-gatmiry24b %I PMLR %P 15130--15152 %U https://proceedings.mlr.press/v235/gatmiry24b.html %V 235 %X Transformers to do reasoning and few-shot learning, without any fine-tuning, is widely conjectured to stem from their ability to implicitly simulate a multi-step algorithms – such as gradient descent – with their weights in a single forward pass. Recently, there has been progress in understanding this complex phenomenon from an expressivity point of view, by demonstrating that Transformers can express such multi-step algorithms. However, our knowledge about the more fundamental aspect of its learnability, beyond single layer models, is very limited. In particular, can training Transformers enable convergence to algorithmic solutions? In this work we resolve this for in context linear regression with linear looped Transformers – a multi-layer model with weight sharing that is conjectured to have an inductive bias to learn fix-point iterative algorithms. More specifically, for this setting we show that the global minimizer of the population training loss implements multi-step preconditioned gradient descent, with a preconditioner that adapts to the data distribution. Furthermore, we show a fast convergence for gradient flow on the regression loss, despite the non-convexity of the landscape, by proving a novel gradient dominance condition. To our knowledge, this is the first theoretical analysis for multi-layer Transformer in this setting. We further validate our theoretical findings through synthetic experiments.
APA
Gatmiry, K., Saunshi, N., J. Reddi, S., Jegelka, S. & Kumar, S.. (2024). Can Looped Transformers Learn to Implement Multi-step Gradient Descent for In-context Learning?. Proceedings of the 41st International Conference on Machine Learning, in Proceedings of Machine Learning Research 235:15130-15152 Available from https://proceedings.mlr.press/v235/gatmiry24b.html.

Related Material