In my original Grokking Grokking post, I argued that Grokking could be caused simply by diffusive dynamics on the optimal manifold. I.e. the idea being that during the pretraining phase to zero loss in an overparametrized network, the weight dynamics minimize loss until they hit an optimal manifold of solutions. In an overparametrized model, due to the overparametrization, there is not an optimal point but an optimal manifold of equally good solutions. As the model becomes more overparametrized, the volume of this manifold relative to the total weight space increases and the more likely the manifold is to become connected – i.e. there are no isolated regions within it. Then, the argument goes, since the true gradients are zero on the optimal manifold, the only source of movement during the grokking phase when zero training loss has been reached, is the intrinsic noise caused by SGD. If this noise is assumed approximately Gaussian, then this induces an OU process in weight-space along the optimal manifold. Grokking then occurs as the diffusion tends towards regions of better generalization than the random, possibly overfitted bit of the optimal loss manifold that the initial pretraining reached.
The key question that wasn’t answered in the post is why we should expect the random OU diffusion to tend towards regions of better generalization than not, and stay there reliably. A-priori, this is highly unclear. It could be that the diffusion tends to move away from generalizing regions or that it drifts in and out of generalizing and overfitting at random. However, empirically, this seems not to be the case. Networks reliably grok and, once grokked, rarely seem to ‘un-grok’.
It has now become clear to me that there are a few reasons why we should expect this directionality in grokking, both of which relates to the fact that the OU process is biased towards regions of low curvature and we should expect low curvature reasons to exhibit greater generalization.
Firstly, a number of works have proposed that the SGD noise does in fact induce an average drift towards regions of lower curvature. This has been argued for here and here and can be demonstrated mathematically for a quadratic approximation around the optimal loss manifold. My intuition for this is quite simple and is essentially that if we think about the weights starting out on the manifold. SGD noise will periodically bounce the weights ‘off’ from the manifold which will then be corrected to project the weights back onto or near to the manifold in the next update step. Now, if we think about the manifold being in a region of high curvature, when the noise bounces the weights off the manifold, due to the high curvature, this bounce travels a further effective distance in loss-space than a lower curvature region, and when it is projected back to the loss manifold will be further away from where it started than in lower curvature regions. Conversely, when the curvature is very low, the bounces caused by the SGD noise will not make the model move far off the loss manifold at all, and so when it is projected back it will stay in nearly the same place. Effectively, the size of the displacements is larger when the model is in high curvature regions compared to low curvature ones and hence, on average, the model will seem to have a random drift towards lower curvature regions of the optimal manifold, which then have better generalization properties.
Another factor, which I proposed here relates to the simple geometric fact about the volumes of low curvature regions near the optimal manifold (and presumably of the optimal manifold) being much larger than the volumes of high curvature regions. Even if we assume no intrinsic bias to SGD, then the equilibrium distribution for the OU process on this manifold is uniform across the manifold (assuming ergodicity). In this case, then the relative volumes of low vs high curvature or grokking vs non-grokking solutions will determine the observed behaviour of grokking. As the greater volume of low curvature regions in high dimensional space increases relative to the volume of high curvature regions, then we should observe that grokking would become more and more deterministic and easy to achieve in more overparametrized models due simply to this volume effect. Whether we see this effect in practice depends on the mixing time of the OU process, but the fact that grokking appears to take many orders of magnitude more time than pretraining makes it not insane that this is simply the timescale of the mixing of the OU process.