Towards an Understanding of Stepwise Inference in Transformers: A Synthetic Graph Navigation Model

Mikail Khona, Maya Okawa, Jan Hula, Rahul Ramesh, Kento Nishi, Robert P. Dick, Ekdeep Singh Lubana, Hidenori Tanaka
Proceedings of the 41st International Conference on Machine Learning, PMLR 235:23758-23780, 2024.

Abstract

Stepwise inference protocols, such as scratchpads and chain-of-thought, help language models solve complex problems by decomposing them into a sequence of simpler subproblems. To unravel the underlying mechanisms of stepwise inference we propose to study autoregressive Transformer models on a synthetic task that embodies the multi-step nature of problems where stepwise inference is generally most useful. Specifically, we define a graph navigation problem wherein a model is tasked with traversing a path from a start to a goal node on the graph. We find we can empirically reproduce and analyze several phenomena observed at scale: (i) the stepwise inference reasoning gap, the cause of which we find in the structure of the training data; (ii) a diversity-accuracy trade-off in model generations as sampling temperature varies; (iii) a simplicity bias in the model’s output; and (iv) compositional generalization and a primacy bias with in-context exemplars. Overall, our work introduces a grounded, synthetic framework for studying stepwise inference and offers mechanistic hypotheses that can lay the foundation for a deeper understanding of this phenomenon.

Cite this Paper


BibTeX
@InProceedings{pmlr-v235-khona24a, title = {Towards an Understanding of Stepwise Inference in Transformers: A Synthetic Graph Navigation Model}, author = {Khona, Mikail and Okawa, Maya and Hula, Jan and Ramesh, Rahul and Nishi, Kento and Dick, Robert P. and Lubana, Ekdeep Singh and Tanaka, Hidenori}, booktitle = {Proceedings of the 41st International Conference on Machine Learning}, pages = {23758--23780}, year = {2024}, editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, volume = {235}, series = {Proceedings of Machine Learning Research}, month = {21--27 Jul}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/khona24a/khona24a.pdf}, url = {https://proceedings.mlr.press/v235/khona24a.html}, abstract = {Stepwise inference protocols, such as scratchpads and chain-of-thought, help language models solve complex problems by decomposing them into a sequence of simpler subproblems. To unravel the underlying mechanisms of stepwise inference we propose to study autoregressive Transformer models on a synthetic task that embodies the multi-step nature of problems where stepwise inference is generally most useful. Specifically, we define a graph navigation problem wherein a model is tasked with traversing a path from a start to a goal node on the graph. We find we can empirically reproduce and analyze several phenomena observed at scale: (i) the stepwise inference reasoning gap, the cause of which we find in the structure of the training data; (ii) a diversity-accuracy trade-off in model generations as sampling temperature varies; (iii) a simplicity bias in the model’s output; and (iv) compositional generalization and a primacy bias with in-context exemplars. Overall, our work introduces a grounded, synthetic framework for studying stepwise inference and offers mechanistic hypotheses that can lay the foundation for a deeper understanding of this phenomenon.} }
Endnote
%0 Conference Paper %T Towards an Understanding of Stepwise Inference in Transformers: A Synthetic Graph Navigation Model %A Mikail Khona %A Maya Okawa %A Jan Hula %A Rahul Ramesh %A Kento Nishi %A Robert P. Dick %A Ekdeep Singh Lubana %A Hidenori Tanaka %B Proceedings of the 41st International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2024 %E Ruslan Salakhutdinov %E Zico Kolter %E Katherine Heller %E Adrian Weller %E Nuria Oliver %E Jonathan Scarlett %E Felix Berkenkamp %F pmlr-v235-khona24a %I PMLR %P 23758--23780 %U https://proceedings.mlr.press/v235/khona24a.html %V 235 %X Stepwise inference protocols, such as scratchpads and chain-of-thought, help language models solve complex problems by decomposing them into a sequence of simpler subproblems. To unravel the underlying mechanisms of stepwise inference we propose to study autoregressive Transformer models on a synthetic task that embodies the multi-step nature of problems where stepwise inference is generally most useful. Specifically, we define a graph navigation problem wherein a model is tasked with traversing a path from a start to a goal node on the graph. We find we can empirically reproduce and analyze several phenomena observed at scale: (i) the stepwise inference reasoning gap, the cause of which we find in the structure of the training data; (ii) a diversity-accuracy trade-off in model generations as sampling temperature varies; (iii) a simplicity bias in the model’s output; and (iv) compositional generalization and a primacy bias with in-context exemplars. Overall, our work introduces a grounded, synthetic framework for studying stepwise inference and offers mechanistic hypotheses that can lay the foundation for a deeper understanding of this phenomenon.
APA
Khona, M., Okawa, M., Hula, J., Ramesh, R., Nishi, K., Dick, R.P., Lubana, E.S. & Tanaka, H.. (2024). Towards an Understanding of Stepwise Inference in Transformers: A Synthetic Graph Navigation Model. Proceedings of the 41st International Conference on Machine Learning, in Proceedings of Machine Learning Research 235:23758-23780 Available from https://proceedings.mlr.press/v235/khona24a.html.

Related Material