Muon and Manifold Versions

In this blog, we take a close look at Muon, try to understand it, and then move towards Manifold Muon. We cover all of the math required to implement and understand these optimizers, after which we will implement Manifold Muon in optax. This implementation will be used in a pull request to the official optax library to add Manifold Muon to the many optimizers in optax. Additionally, we do a tiny CIFAR-10 speed-run ourselves to see the improvements of Manifold Muon over the original.

Muon

Muon is an adaptive optimizer designed by Keller Jordan that has recently become quite popular, especially in the speed-running community of deep learning. What makes this optimizer special in comparison to the reigning champion AdamW is its emphasis on orthogonality of updates. More specifically, Muon searches for the nearest semi-orthogonal matrix to the update to a particular matrix, via something called a Newton-Schulz (NS) iteration. We’ll explain all of this precisely with the math below.

Theory and Design

Before we begin understanding Muon in detail, a quick note: Muon only operates on 2D weight matrices. All scalar and vector parameters as well as the input and output layers (for e.g. the embedding look-up layer in transformers and the language modelling head) of a model should be optimized by AdamW and not Muon. For convolutional layers, it is possible to optimize them via Muon, however, the last 3 dimensions of the weight tensor must be collapsed into one dimension so that the resulting tensor is 2D.

Now, given a weight matrix $\theta_{t-1} \in \mathbb{R}^{n\times m}$ at timestep $t$ of training, we compute its gradient w.r.t. the loss $\mathcal{L}$ as $g_t = \nabla_{\theta_{t-1}}\mathcal{L}(\theta_{t-1}), \ g_t \in \mathbb{R}^{n\times m}$. The weights $\theta$ are then updated to $\theta_{t+1}$ as

\[\theta_{t} = \theta_{t-1} - \eta \ u_t\]

where $\eta$ is the learning rate and $u_t \in \mathbb{R}^{n\times m}$ is the update that we compute as follows.

  1. Momentum $\mu$ is applied to the gradient as $m_t = \mu m_{t-1} + g_t, \ m_0 = 0$.
  2. Next, NS iteration on $m_t$ for $5$ steps finally yields the update $u_t$. The NS iteration is explained below.

Newton-Schulz Iteration

The NS iteration is used to approximately orthogonalize the pre-update ($m_t$) by solving the constrained optimization problem

\[\operatorname{Orthogonalize}(X) = \arg\min_Z \| Z - X \|_F \qquad \text{s.t.} \qquad \text{either} \quad Z^\top Z = I \quad \text{or}\quad ZZ^\top = I\]

Note that the constraint specifies either $Z^\top Z = I$ or $ZZ^\top = I$. This makes the minimizer a semi-orthogonal matrix and equivalently means that if the singular value decomposition (SVD) of $X$ is $U\Sigma V^\top$, then the SVD of $\operatorname{Orthogonalize}(X)$ will be $UV^\top$ (the matrix of singular values $\Sigma$ of the semi-orthogonal matrix is just the identity matrix $I$). Having established this, we now show the NS iteration algorithm below.

def ns_iteration(x, num_steps=5, epsilon=1e-7):
  assert x.ndim == 2, "Wrong no. of dimensions in input"
  a, b, c = (3.4445, -4.7750, 2.0315)
  y = x / (x.norm() + epsilon)
  
  if x.shape[0] > x.shape[1]:
    y = y.T
  
  for step in range(num_steps):
    A = y @ y.T
    B = b * A + c * A @ A
    y = a * y + B @ y
  
  if x.shape[0] > x.shape[1]:
    y = y.T
  
  return y

Let us write one step of this iteration to gain a better understanding of why it orthogonalizes its input. Let $X \in \mathbb{R}^{p\times q}$ be the input matrix. Then,

\[Z = aX + b(XX^\top)X + c(XX^\top)^2X\] \[=(aI + bXX^\top + c (XX^\top)^2)X\]

Now substituting the SVD of $X$ as $U\Sigma V^\top$, we get

\[=(a\cdot I + b\cdot U\Sigma^2 U^\top + c\cdot U\Sigma^4 U)U\Sigma V^\top\] \[=U(a\cdot\Sigma +b\cdot\Sigma^3 + c\cdot\Sigma^5 )V^\top\]

Clearly, one step of the NS iteration yields a matrix whose singular values are a quintic ($5^{\text{th}}$ order with only odd powers) polynomial $\rho(x) = ax + bx^3 + cx^5$. Applying this iteration for $T$ steps would then apply the polynomial $T$ times (we denote this by $\rho^T(x)$) on the singular values. Since the singular values all lie in $[0, 1]$ due to the normalization of the input, we then only need to choose the polynomial coefficients such that as $T \to \infty, \rho^T(x) \to 1 \ \forall x \in [0, 1]$. This would make all singular values of the resultant matrix tend to $1$, thereby making it orthogonal.

That’s it! Choosing good coefficients of the polynomials gets us to an appreciably fast iteration, where the following constraints must be kept in mind while choosing them:

  1. $a$ must be large as $\rho’(0)=a$. This implies that $a$ controls the rate of convergence for small initial singular values.
  2. For every $x \in [0, 1]$ we want to converge to the singular value interval $[1-\epsilon, 1+\epsilon]$ for some $\epsilon > 0$ as $T \to \infty$ so that the result of the NS iteration is not far from its input.

While one can experiment with different ways to solve this optimization problem, the creator of Muon arrives at the ones given in the algorithm-code-block by employing a post-hoc gradient-based approach.

Why orthogonalize?

Keller Jordan provides a concise and intuitive explanation for why orthogonality is desirable: normally, when optimizing with Adam/AdamW, we find that weight matrices have high condition numbers, implying only a few dominant directions within the matrix transformation. Orthoginalization is thus speculated to amplify the effect of these damped “rare” directions in order to speed-up learning.

Let’s try and understand the need for orthogonality in more detail. First, we begin by emphasizing that Muon explicitly enforces orthogonal updates. This is different from explicitly enforcing orthogonal weights. Now, recall that for a given timestep

\[\theta_{t} = \theta_{t-1} - \eta \ u_t \ ; \qquad u_t^\top u_t = I\]

Let us write the change in the output of the linear transform

\[\Delta h = \theta_t x - \theta_{t-1}x\]

given some input $x$. Clearly, $\Delta h = \eta \ u_t \ x$ and so

\[\|\Delta h\|^2 = \eta^2 \ (u_t \ x)^\top (u_t \ x) = \eta^2 \ x^\top u_t^\top u_t \ x\] \[= \eta^2 \ x^\top x = \eta^2 \ \|x\|^2\]

This result is quite important as it shows us that the singular value structure of the gradient does not govern the update. We can recall that in default SGD, the effective learning rate along a singular value $\sigma_i(g_t)$’s direction is given by $\eta_{\text{effective}, i} = \eta \ \sigma_i(g_t)$. With Muon, all directions move by the same effective step size, as Muon enforces $\sigma_i(g_t) \to 1$. This removes bias towards a few high variance directions, preventing the update matrix from becoming of high condition number and correspondingly, optimization from collapsing to a rank-deficit regime. This is especially important for deep networks as singular value spectra gets sharper across depth and so optimization essentially locks into a few modes. Also, one learning rate works for all layers.

The Thinking Machines’ blog also notes the same. The goal in neural network optimization is to have well-behaved updates: not too large and not too small. Various normalizations attempt to realise this desire. Relevantly, Muon normalizes the impact of the gradient across its singular value spectra, thus making updates more well-behaved and controllable.

CIFAR-10 speedrun

asdf

Manifold Muon


Theory and design

asdf

Optax implementation

asdf

CIFAR-10 speedrun

asdf

References

  1. Thinking Machines’ Manifold Muon blog
  2. Sam D. Buchanan’s blog
  3. Keller Jordan’s blog
  4. Keller Jordan’s Muon GitHub repo
  5. Google DeepMind’s Optax GitHub repo
  6. Thinking Machine’s Manifolds GitHub repo



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Grokking fast and slow
  • Celebration is the secret
  • Ultra-Scale Playbook vol-3 - DeepSpeed ZeRO
  • Ultra-Scale Playbook vol-2 - Data Parallelism
  • Ultra-Scale Playbook vol-1 - Single GPU