Improved OOD Generalization via Adversarial Training and Pretraing

Mingyang Yi, Lu Hou, Jiacheng Sun, Lifeng Shang, Xin Jiang, Qun Liu, Zhiming Ma
Proceedings of the 38th International Conference on Machine Learning, PMLR 139:11987-11997, 2021.

Abstract

Recently, learning a model that generalizes well on out-of-distribution (OOD) data has attracted great attention in the machine learning community. In this paper, after defining OOD generalization by Wasserstein distance, we theoretically justify that a model robust to input perturbation also generalizes well on OOD data. Inspired by previous findings that adversarial training helps improve robustness, we show that models trained by adversarial training have converged excess risk on OOD data. Besides, in the paradigm of pre-training then fine-tuning, we theoretically justify that the input perturbation robust model in the pre-training stage provides an initialization that generalizes well on downstream OOD data. Finally, various experiments conducted on image classification and natural language understanding tasks verify our theoretical findings.

Cite this Paper


BibTeX
@InProceedings{pmlr-v139-yi21a, title = {Improved OOD Generalization via Adversarial Training and Pretraing}, author = {Yi, Mingyang and Hou, Lu and Sun, Jiacheng and Shang, Lifeng and Jiang, Xin and Liu, Qun and Ma, Zhiming}, booktitle = {Proceedings of the 38th International Conference on Machine Learning}, pages = {11987--11997}, year = {2021}, editor = {Meila, Marina and Zhang, Tong}, volume = {139}, series = {Proceedings of Machine Learning Research}, month = {18--24 Jul}, publisher = {PMLR}, pdf = {http://proceedings.mlr.press/v139/yi21a/yi21a.pdf}, url = {https://proceedings.mlr.press/v139/yi21a.html}, abstract = {Recently, learning a model that generalizes well on out-of-distribution (OOD) data has attracted great attention in the machine learning community. In this paper, after defining OOD generalization by Wasserstein distance, we theoretically justify that a model robust to input perturbation also generalizes well on OOD data. Inspired by previous findings that adversarial training helps improve robustness, we show that models trained by adversarial training have converged excess risk on OOD data. Besides, in the paradigm of pre-training then fine-tuning, we theoretically justify that the input perturbation robust model in the pre-training stage provides an initialization that generalizes well on downstream OOD data. Finally, various experiments conducted on image classification and natural language understanding tasks verify our theoretical findings.} }
Endnote
%0 Conference Paper %T Improved OOD Generalization via Adversarial Training and Pretraing %A Mingyang Yi %A Lu Hou %A Jiacheng Sun %A Lifeng Shang %A Xin Jiang %A Qun Liu %A Zhiming Ma %B Proceedings of the 38th International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2021 %E Marina Meila %E Tong Zhang %F pmlr-v139-yi21a %I PMLR %P 11987--11997 %U https://proceedings.mlr.press/v139/yi21a.html %V 139 %X Recently, learning a model that generalizes well on out-of-distribution (OOD) data has attracted great attention in the machine learning community. In this paper, after defining OOD generalization by Wasserstein distance, we theoretically justify that a model robust to input perturbation also generalizes well on OOD data. Inspired by previous findings that adversarial training helps improve robustness, we show that models trained by adversarial training have converged excess risk on OOD data. Besides, in the paradigm of pre-training then fine-tuning, we theoretically justify that the input perturbation robust model in the pre-training stage provides an initialization that generalizes well on downstream OOD data. Finally, various experiments conducted on image classification and natural language understanding tasks verify our theoretical findings.
APA
Yi, M., Hou, L., Sun, J., Shang, L., Jiang, X., Liu, Q. & Ma, Z.. (2021). Improved OOD Generalization via Adversarial Training and Pretraing. Proceedings of the 38th International Conference on Machine Learning, in Proceedings of Machine Learning Research 139:11987-11997 Available from https://proceedings.mlr.press/v139/yi21a.html.

Related Material