Training Recurrent Neural Networks via Forward Propagation Through Time

Anil Kag, Venkatesh Saligrama
Proceedings of the 38th International Conference on Machine Learning, PMLR 139:5189-5200, 2021.

Abstract

Back-propagation through time (BPTT) has been widely used for training Recurrent Neural Networks (RNNs). BPTT updates RNN parameters on an instance by back-propagating the error in time over the entire sequence length, and as a result, leads to poor trainability due to the well-known gradient explosion/decay phenomena. While a number of prior works have proposed to mitigate vanishing/explosion effect through careful RNN architecture design, these RNN variants still train with BPTT. We propose a novel forward-propagation algorithm, FPTT, where at each time, for an instance, we update RNN parameters by optimizing an instantaneous risk function. Our proposed risk is a regularization penalty at time $t$ that evolves dynamically based on previously observed losses, and allows for RNN parameter updates to converge to a stationary solution of the empirical RNN objective. We consider both sequence-to-sequence as well as terminal loss problems. Empirically FPTT outperforms BPTT on a number of well-known benchmark tasks, thus enabling architectures like LSTMs to solve long range dependencies problems.

Cite this Paper


BibTeX
@InProceedings{pmlr-v139-kag21a, title = {Training Recurrent Neural Networks via Forward Propagation Through Time}, author = {Kag, Anil and Saligrama, Venkatesh}, booktitle = {Proceedings of the 38th International Conference on Machine Learning}, pages = {5189--5200}, year = {2021}, editor = {Meila, Marina and Zhang, Tong}, volume = {139}, series = {Proceedings of Machine Learning Research}, month = {18--24 Jul}, publisher = {PMLR}, pdf = {http://proceedings.mlr.press/v139/kag21a/kag21a.pdf}, url = {https://proceedings.mlr.press/v139/kag21a.html}, abstract = {Back-propagation through time (BPTT) has been widely used for training Recurrent Neural Networks (RNNs). BPTT updates RNN parameters on an instance by back-propagating the error in time over the entire sequence length, and as a result, leads to poor trainability due to the well-known gradient explosion/decay phenomena. While a number of prior works have proposed to mitigate vanishing/explosion effect through careful RNN architecture design, these RNN variants still train with BPTT. We propose a novel forward-propagation algorithm, FPTT, where at each time, for an instance, we update RNN parameters by optimizing an instantaneous risk function. Our proposed risk is a regularization penalty at time $t$ that evolves dynamically based on previously observed losses, and allows for RNN parameter updates to converge to a stationary solution of the empirical RNN objective. We consider both sequence-to-sequence as well as terminal loss problems. Empirically FPTT outperforms BPTT on a number of well-known benchmark tasks, thus enabling architectures like LSTMs to solve long range dependencies problems.} }
Endnote
%0 Conference Paper %T Training Recurrent Neural Networks via Forward Propagation Through Time %A Anil Kag %A Venkatesh Saligrama %B Proceedings of the 38th International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2021 %E Marina Meila %E Tong Zhang %F pmlr-v139-kag21a %I PMLR %P 5189--5200 %U https://proceedings.mlr.press/v139/kag21a.html %V 139 %X Back-propagation through time (BPTT) has been widely used for training Recurrent Neural Networks (RNNs). BPTT updates RNN parameters on an instance by back-propagating the error in time over the entire sequence length, and as a result, leads to poor trainability due to the well-known gradient explosion/decay phenomena. While a number of prior works have proposed to mitigate vanishing/explosion effect through careful RNN architecture design, these RNN variants still train with BPTT. We propose a novel forward-propagation algorithm, FPTT, where at each time, for an instance, we update RNN parameters by optimizing an instantaneous risk function. Our proposed risk is a regularization penalty at time $t$ that evolves dynamically based on previously observed losses, and allows for RNN parameter updates to converge to a stationary solution of the empirical RNN objective. We consider both sequence-to-sequence as well as terminal loss problems. Empirically FPTT outperforms BPTT on a number of well-known benchmark tasks, thus enabling architectures like LSTMs to solve long range dependencies problems.
APA
Kag, A. & Saligrama, V.. (2021). Training Recurrent Neural Networks via Forward Propagation Through Time. Proceedings of the 38th International Conference on Machine Learning, in Proceedings of Machine Learning Research 139:5189-5200 Available from https://proceedings.mlr.press/v139/kag21a.html.

Related Material