Search papers, labs, and topics across Lattice.
This paper analyzes the generalization error of Adam and AdamW, demonstrating that variants without the square root (Adam(W)-srf) achieve a generalization error of O(ρ^-2T/N). To improve generalization, the authors introduce HomeAdam(W), a class of algorithms that occasionally revert to momentum-based SGD. They prove that HomeAdam(W) achieves a superior generalization error of O(1/N) and a faster convergence rate of O(1/T^(1/4)) compared to Adam(W)-srf, supported by empirical results.
Adam's generalization problem? This paper shows how periodically "going home" to momentum-based SGD can provably beat standard Adam and AdamW in generalization error and convergence speed.
Adam and AdamW are a class of default optimizers for training deep learning models in machine learning. These adaptive algorithms converge faster but generalize worse compared to SGD. In fact, their proved generalization error $O(\frac{1}{\sqrt{N}})$ also is larger than $O(\frac{1}{N})$ of SGD, where $N$ denotes training sample size. Recently, although some variants of Adam have been proposed to improve its generalization, their improved generalizations are still unexplored in theory. To fill this gap, in the paper, we restudy generalization of Adam and AdamW via algorithmic stability, and first prove that Adam and AdamW without square-root (i.e., Adam(W)-srf) have a generalization error $O(\frac{\hatρ^{-2T}}{N})$, where $T$ denotes iteration number and $\hatρ>0$ denotes the smallest element of second-order momentum plus a small positive number. To improve generalization, we propose a class of efficient clever Adam (i.e., HomeAdam(W)) algorithms via sometimes returning momentum-based SGD. Moreover, we prove that our HomeAdam(W) have a smaller generalization error $O(\frac{1}{N})$ than $O(\frac{\hatρ^{-2T}}{N})$ of Adam(W)-srf, since $\hatρ$ is generally very small. In particular, it is also smaller than the existing $O(\frac{1}{\sqrt{N}})$ of Adam(W). Meanwhile, we prove our HomeAdam(W) have a faster convergence rate of $O(\frac{1}{T^{1/4}})$ than $O(\frac{\breveρ^{-1}}{T^{1/4}})$ of the Adam(W)-srf, where $\breveρ\leq\hatρ$ also is very small. Extensive numerical experiments demonstrate efficiency of our HomeAdam(W) algorithms.