[edit]
Simplicity Bias and Optimization Threshold in Two-Layer ReLU Networks
Proceedings of the 42nd International Conference on Machine Learning, PMLR 267:5241-5275, 2025.
Abstract
Understanding generalization of overparametrized models remains a fundamental challenge in machine learning. The literature mostly studies generalization from an interpolation point of view, taking convergence towards a global minimum of the training loss for granted. This interpolation paradigm does not seem valid for complex tasks such as in-context learning or diffusion. It has instead been empirically observed that the trained models go from global minima to spurious local minima of the training loss as the number of training samples becomes larger than some level we call optimization threshold. This paper explores theoretically this phenomenon in the context of two-layer ReLU networks. We demonstrate that, despite overparametrization, networks might converge towards simpler solutions rather than interpolating training data, which leads to a drastic improvement on the test loss. Our analysis relies on the so called early alignment phase, during which neurons align toward specific directions. This directional alignment leads to a simplicity bias, wherein the network approximates the ground truth model without converging to the global minimum of the training loss. Our results suggest this bias, resulting in an optimization threshold from which interpolation is not reached anymore, is beneficial and enhances the generalization of trained models.