Predict then Interpolate: A Simple Algorithm to Learn Stable Classifiers

Yujia Bao, Shiyu Chang, Regina Barzilay
Proceedings of the 38th International Conference on Machine Learning, PMLR 139:640-650, 2021.

Abstract

We propose Predict then Interpolate (PI), a simple algorithm for learning correlations that are stable across environments. The algorithm follows from the intuition that when using a classifier trained on one environment to make predictions on examples from another environment, its mistakes are informative as to which correlations are unstable. In this work, we prove that by interpolating the distributions of the correct predictions and the wrong predictions, we can uncover an oracle distribution where the unstable correlation vanishes. Since the oracle interpolation coefficients are not accessible, we use group distributionally robust optimization to minimize the worst-case risk across all such interpolations. We evaluate our method on both text classification and image classification. Empirical results demonstrate that our algorithm is able to learn robust classifiers (outperforms IRM by 23.85% on synthetic environments and 12.41% on natural environments). Our code and data are available at https://github.com/YujiaBao/ Predict-then-Interpolate.

Cite this Paper


BibTeX
@InProceedings{pmlr-v139-bao21a, title = {Predict then Interpolate: A Simple Algorithm to Learn Stable Classifiers}, author = {Bao, Yujia and Chang, Shiyu and Barzilay, Regina}, booktitle = {Proceedings of the 38th International Conference on Machine Learning}, pages = {640--650}, year = {2021}, editor = {Meila, Marina and Zhang, Tong}, volume = {139}, series = {Proceedings of Machine Learning Research}, month = {18--24 Jul}, publisher = {PMLR}, pdf = {http://proceedings.mlr.press/v139/bao21a/bao21a.pdf}, url = {https://proceedings.mlr.press/v139/bao21a.html}, abstract = {We propose Predict then Interpolate (PI), a simple algorithm for learning correlations that are stable across environments. The algorithm follows from the intuition that when using a classifier trained on one environment to make predictions on examples from another environment, its mistakes are informative as to which correlations are unstable. In this work, we prove that by interpolating the distributions of the correct predictions and the wrong predictions, we can uncover an oracle distribution where the unstable correlation vanishes. Since the oracle interpolation coefficients are not accessible, we use group distributionally robust optimization to minimize the worst-case risk across all such interpolations. We evaluate our method on both text classification and image classification. Empirical results demonstrate that our algorithm is able to learn robust classifiers (outperforms IRM by 23.85% on synthetic environments and 12.41% on natural environments). Our code and data are available at https://github.com/YujiaBao/ Predict-then-Interpolate.} }
Endnote
%0 Conference Paper %T Predict then Interpolate: A Simple Algorithm to Learn Stable Classifiers %A Yujia Bao %A Shiyu Chang %A Regina Barzilay %B Proceedings of the 38th International Conference on Machine Learning %C Proceedings of Machine Learning Research %D 2021 %E Marina Meila %E Tong Zhang %F pmlr-v139-bao21a %I PMLR %P 640--650 %U https://proceedings.mlr.press/v139/bao21a.html %V 139 %X We propose Predict then Interpolate (PI), a simple algorithm for learning correlations that are stable across environments. The algorithm follows from the intuition that when using a classifier trained on one environment to make predictions on examples from another environment, its mistakes are informative as to which correlations are unstable. In this work, we prove that by interpolating the distributions of the correct predictions and the wrong predictions, we can uncover an oracle distribution where the unstable correlation vanishes. Since the oracle interpolation coefficients are not accessible, we use group distributionally robust optimization to minimize the worst-case risk across all such interpolations. We evaluate our method on both text classification and image classification. Empirical results demonstrate that our algorithm is able to learn robust classifiers (outperforms IRM by 23.85% on synthetic environments and 12.41% on natural environments). Our code and data are available at https://github.com/YujiaBao/ Predict-then-Interpolate.
APA
Bao, Y., Chang, S. & Barzilay, R.. (2021). Predict then Interpolate: A Simple Algorithm to Learn Stable Classifiers. Proceedings of the 38th International Conference on Machine Learning, in Proceedings of Machine Learning Research 139:640-650 Available from https://proceedings.mlr.press/v139/bao21a.html.

Related Material