[edit]
Improving Sharpness-Aware Minimization by Lookahead
Proceedings of the 41st International Conference on Machine Learning, PMLR 235:57776-57802, 2024.
Abstract
Sharpness-Aware Minimization (SAM), which performs gradient descent on adversarially perturbed weights, can improve generalization by identifying flatter minima. However, recent studies have shown that SAM may suffer from convergence instability and oscillate around saddle points, resulting in slow convergence and inferior performance. To address this problem, we propose the use of a lookahead mechanism to gather more information about the landscape by looking further ahead, and thus find a better trajectory to converge. By examining the nature of SAM, we simplify the extrapolation procedure, resulting in a more efficient algorithm. Theoretical results show that the proposed method converges to a stationary point and is less prone to saddle points. Experiments on standard benchmark datasets also verify that the proposed method outperforms the SOTAs, and converge more effectively to flat minima.