Retrieval-Augmented Reinforcement Learning

Anirudh Goyal, Abram Friesen, Andrea Banino, Theophane Weber, Nan Rosemary Ke, Adrià Puigdomènech Badia, Arthur Guez, Mehdi Mirza, Peter C Humphreys, Ksenia Konyushova, Michal Valko, Simon Osindero, Timothy Lillicrap, Nicolas Heess, Charles Blundell
Proceedings of the 39th International Conference on Machine Learning, PMLR 162:7740-7765, 2022.

Abstract

Most deep reinforcement learning (RL) algorithms distill experience into parametric behavior policies or value functions via gradient updates. While effective, this approach has several disadvantages: (1) it is computationally expensive, (2) it can take many updates to integrate experiences into the parametric model, (3) experiences that are not fully integrated do not appropriately influence the agent’s behavior, and (4) behavior is limited by the capacity of the model. In this paper we explore an alternative paradigm in which we train a network to map a dataset of past experiences to optimal behavior. Specifically, we augment an RL agent with a retrieval process (parameterized as a neural network) that has direct access to a dataset of experiences. This dataset can come from the agent’s past experiences, expert demonstrations, or any other relevant source. The retrieval process is trained to retrieve information from the dataset that may be useful in the current context, to help the agent achieve its goal faster and more efficiently. The proposed method facilitates learning agents that at test time can condition their behavior on the entire dataset and not only the current state, or current trajectory. We integrate our method into two different RL agents: an offline DQN agent and an online R2D2 agent. In offline multi-task problems, we show that the retrieval-augmented DQN agent avoids task interference and learns faster than the baseline DQN agent. On Atari, we show that retrieval-augmented R2D2 learns significantly faster than the baseline R2D2 agent and achieves higher scores. We run extensive ablations to measure the contributions of the components of our proposed method.

Cite this Paper


BibTeX
@InProceedings{pmlr-v162-goyal22a, title = {Retrieval-Augmented Reinforcement Learning}, author = {Goyal, Anirudh and Friesen, Abram and Banino, Andrea and Weber, Theophane and Ke, Nan Rosemary and Badia, Adri{\`a} Puigdom{\`e}nech and Guez, Arthur and Mirza, Mehdi and Humphreys, Peter C and Konyushova, Ksenia and Valko, Michal and Osindero, Simon and Lillicrap, Timothy and Heess, Nicolas and Blundell, Charles}, booktitle = {Proceedings of the 39th International Conference on Machine Learning}, pages = {7740--7765}, year = {2022}, editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan}, volume = {162}, series = {Proceedings of Machine Learning Research}, month = {17--23 Jul}, publisher = {PMLR}, pdf = {https://proceedings.mlr.press/v162/goyal22a/goyal22a.pdf}, url = {https://proceedings.mlr.press/v162/goyal22a.html}, abstract = {Most deep reinforcement learning (RL) algorithms distill experience into parametric behavior policies or value functions via gradient updates. While effective, this approach has several disadvantages: (1) it is computationally expensive, (2) it can take many updates to integrate experiences into the parametric model, (3) experiences that are not fully integrated do not appropriately influence the agent’s behavior, and (4) behavior is limited by the capacity of the model. In this paper we explore an alternative paradigm in which we train a network to map a dataset of past experiences to optimal behavior. Specifically, we augment an RL agent with a retrieval process (parameterized as a neural network) that has direct access to a dataset of experiences. This dataset can come from the agent’s past experiences, expert demonstrations, or any other relevant source. The retrieval process is trained to retrieve information from the dataset that may be useful in the current context, to help the agent achieve its goal faster and more efficiently. The proposed method facilitates learning agents that at test time can condition their behavior on the entire dataset and not only the current state, or current trajectory. We integrate our method into two different RL agents: an offline DQN agent and an online R2D2 agent. In offline multi-task problems, we show that the retrieval-augmented DQN agent avoids task interference and learns faster than the baseline DQN agent. On Atari, we show that retrieval-augmented R2D2 learns significantly faster than the baseline R2D2 agent and achieves higher scores. We run extensive ablations to measure the contributions of the components of our proposed method.} }
Endnote
%0 Conference Paper %T Retrieval-Augmented Reinforcement Learning %A Anirudh Goyal %A Abram Friesen %A Andrea Banino %A Theophane Weber %A Nan Rosemary Ke %A Adrià Puigdomènech Badia %A Arthur Guez %A Mehdi Mirza %A Peter C Humphreys %A Ksenia Konyushova %A Michal Valko %A Simon Osindero %A Timothy Lillicrap %A Nicolas Heess %A Charles Blundell %B Proceedings of the 39th International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2022 %E Kamalika Chaudhuri %E Stefanie Jegelka %E Le Song %E Csaba Szepesvari %E Gang Niu %E Sivan Sabato %F pmlr-v162-goyal22a %I PMLR %P 7740--7765 %U https://proceedings.mlr.press/v162/goyal22a.html %V 162 %X Most deep reinforcement learning (RL) algorithms distill experience into parametric behavior policies or value functions via gradient updates. While effective, this approach has several disadvantages: (1) it is computationally expensive, (2) it can take many updates to integrate experiences into the parametric model, (3) experiences that are not fully integrated do not appropriately influence the agent’s behavior, and (4) behavior is limited by the capacity of the model. In this paper we explore an alternative paradigm in which we train a network to map a dataset of past experiences to optimal behavior. Specifically, we augment an RL agent with a retrieval process (parameterized as a neural network) that has direct access to a dataset of experiences. This dataset can come from the agent’s past experiences, expert demonstrations, or any other relevant source. The retrieval process is trained to retrieve information from the dataset that may be useful in the current context, to help the agent achieve its goal faster and more efficiently. The proposed method facilitates learning agents that at test time can condition their behavior on the entire dataset and not only the current state, or current trajectory. We integrate our method into two different RL agents: an offline DQN agent and an online R2D2 agent. In offline multi-task problems, we show that the retrieval-augmented DQN agent avoids task interference and learns faster than the baseline DQN agent. On Atari, we show that retrieval-augmented R2D2 learns significantly faster than the baseline R2D2 agent and achieves higher scores. We run extensive ablations to measure the contributions of the components of our proposed method.
APA
Goyal, A., Friesen, A., Banino, A., Weber, T., Ke, N.R., Badia, A.P., Guez, A., Mirza, M., Humphreys, P.C., Konyushova, K., Valko, M., Osindero, S., Lillicrap, T., Heess, N. & Blundell, C.. (2022). Retrieval-Augmented Reinforcement Learning. Proceedings of the 39th International Conference on Machine Learning, in Proceedings of Machine Learning Research 162:7740-7765 Available from https://proceedings.mlr.press/v162/goyal22a.html.

Related Material