Training Discrete Deep Generative Models via Gapped Straight-Through Estimator

Ting-Han Fan, Ta-Chung Chi, Alexander I. Rudnicky, Peter J Ramadge
Proceedings of the 39th International Conference on Machine Learning, PMLR 162:6059-6073, 2022.

Abstract

While deep generative models have succeeded in image processing, natural language processing, and reinforcement learning, training that involves discrete random variables remains challenging due to the high variance of its gradient estimation process. Monte Carlo is a common solution used in most variance reduction approaches. However, this involves time-consuming resampling and multiple function evaluations. We propose a Gapped Straight-Through (GST) estimator to reduce the variance without incurring resampling overhead. This estimator is inspired by the essential properties of Straight-Through Gumbel-Softmax. We determine these properties and show via an ablation study that they are essential. Experiments demonstrate that the proposed GST estimator enjoys better performance compared to strong baselines on two discrete deep generative modeling tasks, MNIST-VAE and ListOps.

Cite this Paper


BibTeX
@InProceedings{pmlr-v162-fan22a, title = {Training Discrete Deep Generative Models via Gapped Straight-Through Estimator}, author = {Fan, Ting-Han and Chi, Ta-Chung and Rudnicky, Alexander I. and Ramadge, Peter J}, booktitle = {Proceedings of the 39th International Conference on Machine Learning}, pages = {6059--6073}, year = {2022}, editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan}, volume = {162}, series = {Proceedings of Machine Learning Research}, month = {17--23 Jul}, publisher = {PMLR}, pdf = {https://proceedings.mlr.press/v162/fan22a/fan22a.pdf}, url = {https://proceedings.mlr.press/v162/fan22a.html}, abstract = {While deep generative models have succeeded in image processing, natural language processing, and reinforcement learning, training that involves discrete random variables remains challenging due to the high variance of its gradient estimation process. Monte Carlo is a common solution used in most variance reduction approaches. However, this involves time-consuming resampling and multiple function evaluations. We propose a Gapped Straight-Through (GST) estimator to reduce the variance without incurring resampling overhead. This estimator is inspired by the essential properties of Straight-Through Gumbel-Softmax. We determine these properties and show via an ablation study that they are essential. Experiments demonstrate that the proposed GST estimator enjoys better performance compared to strong baselines on two discrete deep generative modeling tasks, MNIST-VAE and ListOps.} }
Endnote
%0 Conference Paper %T Training Discrete Deep Generative Models via Gapped Straight-Through Estimator %A Ting-Han Fan %A Ta-Chung Chi %A Alexander I. Rudnicky %A Peter J Ramadge %B Proceedings of the 39th International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2022 %E Kamalika Chaudhuri %E Stefanie Jegelka %E Le Song %E Csaba Szepesvari %E Gang Niu %E Sivan Sabato %F pmlr-v162-fan22a %I PMLR %P 6059--6073 %U https://proceedings.mlr.press/v162/fan22a.html %V 162 %X While deep generative models have succeeded in image processing, natural language processing, and reinforcement learning, training that involves discrete random variables remains challenging due to the high variance of its gradient estimation process. Monte Carlo is a common solution used in most variance reduction approaches. However, this involves time-consuming resampling and multiple function evaluations. We propose a Gapped Straight-Through (GST) estimator to reduce the variance without incurring resampling overhead. This estimator is inspired by the essential properties of Straight-Through Gumbel-Softmax. We determine these properties and show via an ablation study that they are essential. Experiments demonstrate that the proposed GST estimator enjoys better performance compared to strong baselines on two discrete deep generative modeling tasks, MNIST-VAE and ListOps.
APA
Fan, T., Chi, T., Rudnicky, A.I. & Ramadge, P.J.. (2022). Training Discrete Deep Generative Models via Gapped Straight-Through Estimator. Proceedings of the 39th International Conference on Machine Learning, in Proceedings of Machine Learning Research 162:6059-6073 Available from https://proceedings.mlr.press/v162/fan22a.html.

Related Material