Do Transformer World Models Give Better Policy Gradients?

Michel Ma, Tianwei Ni, Clement Gehring, Pierluca D’Oro, Pierre-Luc Bacon
Proceedings of the 41st International Conference on Machine Learning, PMLR 235:33855-33879, 2024.

Abstract

A natural approach for reinforcement learning is to predict future rewards by unrolling a neural network world model, and to backpropagate through the resulting computational graph to learn a control policy. However, this method often becomes impractical for long horizons, since typical world models induce hard-to-optimize loss landscapes. Transformers are known to efficiently propagate gradients over long horizons: could they be the solution to this problem? Surprisingly, we show that commonly-used transformer world models produce circuitous gradient paths, which can be detrimental to long-range policy gradients. To tackle this challenge, we propose a class of world models called Action-conditioned World Models (AWMs), designed to provide more direct routes for gradient propagation. We integrate such AWMs into a policy gradient framework that underscores the relationship between network architectures and the policy gradient updates they inherently represent. We demonstrate that AWMs can generate optimization landscapes that are easier to navigate even when compared to those from the simulator itself. This property allows transformer AWMs to produce better policies than competitive baselines in realistic long-horizon tasks.

Cite this Paper


BibTeX
@InProceedings{pmlr-v235-ma24i, title = {Do Transformer World Models Give Better Policy Gradients?}, author = {Ma, Michel and Ni, Tianwei and Gehring, Clement and D'Oro, Pierluca and Bacon, Pierre-Luc}, booktitle = {Proceedings of the 41st International Conference on Machine Learning}, pages = {33855--33879}, year = {2024}, editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix}, volume = {235}, series = {Proceedings of Machine Learning Research}, month = {21--27 Jul}, publisher = {PMLR}, pdf = {https://raw.githubusercontent.com/mlresearch/v235/main/assets/ma24i/ma24i.pdf}, url = {https://proceedings.mlr.press/v235/ma24i.html}, abstract = {A natural approach for reinforcement learning is to predict future rewards by unrolling a neural network world model, and to backpropagate through the resulting computational graph to learn a control policy. However, this method often becomes impractical for long horizons, since typical world models induce hard-to-optimize loss landscapes. Transformers are known to efficiently propagate gradients over long horizons: could they be the solution to this problem? Surprisingly, we show that commonly-used transformer world models produce circuitous gradient paths, which can be detrimental to long-range policy gradients. To tackle this challenge, we propose a class of world models called Action-conditioned World Models (AWMs), designed to provide more direct routes for gradient propagation. We integrate such AWMs into a policy gradient framework that underscores the relationship between network architectures and the policy gradient updates they inherently represent. We demonstrate that AWMs can generate optimization landscapes that are easier to navigate even when compared to those from the simulator itself. This property allows transformer AWMs to produce better policies than competitive baselines in realistic long-horizon tasks.} }
Endnote
%0 Conference Paper %T Do Transformer World Models Give Better Policy Gradients? %A Michel Ma %A Tianwei Ni %A Clement Gehring %A Pierluca D’Oro %A Pierre-Luc Bacon %B Proceedings of the 41st International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2024 %E Ruslan Salakhutdinov %E Zico Kolter %E Katherine Heller %E Adrian Weller %E Nuria Oliver %E Jonathan Scarlett %E Felix Berkenkamp %F pmlr-v235-ma24i %I PMLR %P 33855--33879 %U https://proceedings.mlr.press/v235/ma24i.html %V 235 %X A natural approach for reinforcement learning is to predict future rewards by unrolling a neural network world model, and to backpropagate through the resulting computational graph to learn a control policy. However, this method often becomes impractical for long horizons, since typical world models induce hard-to-optimize loss landscapes. Transformers are known to efficiently propagate gradients over long horizons: could they be the solution to this problem? Surprisingly, we show that commonly-used transformer world models produce circuitous gradient paths, which can be detrimental to long-range policy gradients. To tackle this challenge, we propose a class of world models called Action-conditioned World Models (AWMs), designed to provide more direct routes for gradient propagation. We integrate such AWMs into a policy gradient framework that underscores the relationship between network architectures and the policy gradient updates they inherently represent. We demonstrate that AWMs can generate optimization landscapes that are easier to navigate even when compared to those from the simulator itself. This property allows transformer AWMs to produce better policies than competitive baselines in realistic long-horizon tasks.
APA
Ma, M., Ni, T., Gehring, C., D’Oro, P. & Bacon, P.. (2024). Do Transformer World Models Give Better Policy Gradients?. Proceedings of the 41st International Conference on Machine Learning, in Proceedings of Machine Learning Research 235:33855-33879 Available from https://proceedings.mlr.press/v235/ma24i.html.

Related Material