[edit]
Beyond Uniform Smoothness: A Stopped Analysis of Adaptive SGD
Proceedings of Thirty Sixth Conference on Learning Theory, PMLR 195:89-160, 2023.
Abstract
This work considers the problem of finding a first-order stationary point of a non-convex function with potentially unbounded smoothness constant using a stochastic gradient oracle. We focus on the class of $(L_0,L_1)$-smooth functions proposed by Zhang et al. (ICLR’20). Empirical evidence suggests that these functions more closely capture practical machine learning problems as compared to the pervasive $L_0$-smoothness. This class is rich enough to include highly non-smooth functions, such as $\exp(L_1 x)$ which is $(0,\mathcal{O}(L_1))$-smooth. Despite the richness, an emerging line of works achieves the $\widetilde{\mathcal{O}}(\frac{1}{\sqrt{T}})$ rate of convergence when the noise of the stochastic gradients is deterministically and uniformly bounded. This noise restriction is not required in the $\L_0$-smooth setting, and in many practical settings is either not satisfied, or results in weaker convergence rates with respect to the noise scaling of the convergence rate.We develop a technique that allows us to prove $\mathcal{O}(\frac{\mathrm{poly}\log(T)}{\sqrt{T}})$ convergence rates for $(L_0,L_1)$-smooth functions without assuming uniform bounds on the noise support. The key innovation behind our results is a carefully constructed stopping time $\tau$ which is simultaneously “large” on average, yet also allows us to treat the adaptive step sizes before $\tau$ as (roughly) independent of the gradients. For general $(L_0,L_1)$-smooth functions, our analysis requires the mild restriction that the multiplicative noise parameter $\sigma_1 < 1$. For a broad subclass of $(L_0,L_1)$-smooth functions, our convergence rate continues to hold when $\sigma_1 \geq 1$. By contrast, we prove that many algorithms analyzed by prior works on $(L_0,L_1)$-smooth optimization diverge with constant probability even for smooth and strongly-convex functions when $\sigma_1 > 1$.