Grokking fast and slow

TLDR

Grokking, or delayed generalisation in neural networks, can

  • help us understand when and how (potentially dangerous) capabilities can emerge in neural networks
  • circumvent data limitations by showing that the right inductive biases can lead to strong networks
  • be partially explained by current theories of lazy-to-feature learning, however, exceptions exist as shown below.

What is grokking?

Grokking is when a neural network generalises, i.e, achieves zero test error, significantly after already achieving near zero train error. It has been extensively studied in small networks on toy math problems, however, it’s not restricted to these settings.

What motivates grokking research?

Firstly, observing a phenomenon like grokking suggests the possibility that a neural network can suddenly attain strong skills in potentially dangerous contexts, all while practitioners never expected it to do so (due to the train-test gap before the grok). This point has become central for machine learning researchers interested in safety of neural networks. Thus, grokking offers controllable toy models that can help investigate how and why neural networks might exhibit rapidly emerging capabilities.

Second, it is notable that grokking is popularly seen under specific train-test data splits, often where training data is low. This becomes relevant for machine learning practitioners: in the event that data becomes the bottleneck to creating more capable models (some argue this is already happening), grokking can offer hope that strong generalisation can eventually emerge given the correct inductive biases are baked into the models.

Why grokking happens: weight decay v/s lazy-to-feature learning

Why a network groks on a particular task is still a mystery, although a number of studies have tried to answer this. Popularly, one collection of these studies approached grokking from the perspective of weight decay. When the train loss is nearly zero, parameter updates are largely driven by weight decay, which still allow significant exploration in the parameters to induce the sudden discovery of the rule (the steep ascent). Newer works along these lines have also formulated grokking as weight-norm minimisation in the post-memorization phase.

Another theory, however, has emerged which models grokking as the transition from a “lazy-like” learning regime to feature learning regime. More specifically, it argues that in the memorization phase, the network learns lazily, i.e, its neural tangent kernel (NTK) evolves rather negligbly until the network discovers the generalisable rule, after which the network enters a “rich feature learning” phase where the NTK evolves quite variably.

While both theories present compelling evidence for their core hypotheses, one cannot explain grokking in every architecture and task using them. Grokking profiles can vary quite strongly depending on one’s model, task, and optimisation setup. For instance, a counter-example to the argument for general solutions being low in weight-norm, one can perform a simple polynomial regression experiment and find increasing weight-norm across training, likely because the task requires large degrees-of-freedom in the weights. As for the lazy-to-feature learning interpretation, this holds only for MLP models following an MSE objective under SGD optimisers. More complex models, objectives, and optimisers do not exhibit near-static NTK evolution across training.




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Celebration is the secret
  • Ultra-Scale Playbook vol-3 - DeepSpeed ZeRO
  • Ultra-Scale Playbook vol-2 - Data Parallelism
  • Ultra-Scale Playbook vol-1 - Single GPU
  • New POVs on hypernetworks