Muon and Manifold Versions

TLDR

In this blog, we take a close look at Muon, try to understand it, and then move towards Manifold Muon. We cover all of the math required to implement and understand these optimizers, after which we will implement Manifold Muon in optax.

Muon

Muon is an adaptive optimizer designed by Keller Jordan that has recently become quite popular, especially in the speed-running community of deep learning. What makes this optimizer special in comparison to the reigning champion AdamW is its emphasis on orthogonality of updates. More specifically, Muon searches for the nearest semi-orthogonal matrix to the update to a particular matrix, via something called a Newton-Schulz (NS) iteration. We’ll explain all of this precisely with the math below.

Theory and Design

Before we begin understanding Muon in detail, a quick note: Muon only operates on 2D weight matrices. All scalar and vector parameters as well as the input and output layers (for e.g. the embedding look-up layer in transformers and the language modelling head) of a model should be optimized by AdamW and not Muon. For convolutional layers, it is possible to optimize them via Muon, however, the last 3 dimensions of the weight tensor must be collapsed into one dimension so that the resulting tensor is 2D.

Now, given a weight matrix $\theta_{t-1} \in \mathbb{R}^{n\times m}$ at timestep $t$ of training, we compute its gradient w.r.t. the loss $\mathcal{L}$ as $g_t = \nabla_{\theta_{t-1}}\mathcal{L}(\theta_{t-1}), \ g_t \in \mathbb{R}^{n\times m}$. The weights $\theta$ are then updated to $\theta_{t+1}$ as

\[\theta_{t} = \theta_{t-1} - \eta \ u_t\]

where $\eta$ is the learning rate and $u_t \in \mathbb{R}^{n\times m}$ is the update that we compute as follows.

  1. Momentum $\mu$ is applied to the gradient as $m_t = \mu m_{t-1} + g_t, \ m_0 = 0$.
  2. Next, NS iteration on $m_t$ for $5$ steps finally yields the update $u_t$. The NS iteration is explained below.

Newton-Schulz Iteration

The NS iteration is used to approximately orthogonalize the pre-update ($m_t$) by solving the constrained optimization problem

\[\operatorname{Orthogonalize}(X) = \arg\min_Z \| Z - X \|_F \qquad \text{s.t.} \qquad \text{either} \quad Z^\top Z = I \quad \text{or}\quad ZZ^\top = I\]

Note that the constraint specifies either $Z^\top Z = I$ or $ZZ^\top = I$. This makes the minimizer a semi-orthogonal matrix and equivalently means that if the singular value decomposition (SVD) of $X$ is $U\Sigma V^\top$, then the SVD of $\operatorname{Orthogonalize}(X)$ will be $UV^\top$ (the matrix of singular values $\Sigma$ of the semi-orthogonal matrix is just the identity matrix $I$). Having established this, we now show the NS iteration algorithm below.

def ns_iteration(x, num_steps=5, epsilon=1e-7):
  assert x.ndim == 2, "Wrong no. of dimensions in input"
  a, b, c = (3.4445, -4.7750, 2.0315)
  y = x / (x.norm() + epsilon)
  
  if x.shape[0] > x.shape[1]:
    y = y.T
  
  for step in range(num_steps):
    A = y @ y.T
    B = b * A + c * A @ A
    y = a * y + B @ y
  
  if x.shape[0] > x.shape[1]:
    y = y.T
  
  return y

Let us write one step of this iteration to gain a better understanding of why it orthogonalizes its input. Let $X \in \mathbb{R}^{p\times q}$ be the input matrix. Then,

\[Z = aX + b(XX^\top)X + c(XX^\top)^2X\] \[=(aI + bXX^\top + c (XX^\top)^2)X\]

Now substituting the SVD of $X$ as $U\Sigma V^\top$, we get

\[=(a\cdot I + b\cdot U\Sigma^2 U^\top + c\cdot U\Sigma^4 U)U\Sigma V^\top\] \[=U(a\cdot\Sigma +b\cdot\Sigma^3 + c\cdot\Sigma^5 )V^\top\]

Clearly, one step of the NS iteration yields a matrix whose singular values are a quintic ($5^{\text{th}}$ order with only odd powers) polynomial $\rho(x) = ax + bx^3 + cx^5$. Applying this iteration for $T$ steps would then apply the polynomial $T$ times (we denote this by $\rho^T(x)$) on the singular values. Since the singular values all lie in $[0, 1]$ due to the normalization of the input, we then only need to choose the polynomial coefficients such that as $T \to \infty, \rho^T(x) \to 1 \ \forall x \in [0, 1]$. This would make all singular values of the resultant matrix tend to $1$, thereby making it orthogonal.

That’s it! Choosing good coefficients of the polynomials gets us to an appreciably fast iteration, where the following constraints must be kept in mind while choosing them:

  1. $a$ must be large as $\rho’(0)=a$. This implies that $a$ controls the rate of convergence for small initial singular values.
  2. For every $x \in [0, 1]$ we want to converge to the singular value interval $[1-\epsilon, 1+\epsilon]$ for some $\epsilon > 0$ as $T \to \infty$ so that the result of the NS iteration is not far from its input.

While one can experiment with different ways to solve this optimization problem, the creator of Muon arrives at the ones given in the algorithm-code-block by employing a post-hoc gradient-based approach.

Why orthogonalize?

The main idea behind orthogonalizing the update is to constrain the transformation applied by the parameter matrix from changing to much. More precisely, if $y = Wx$, then we want to constrain $y$ such that it does not change too much. This constraint provides us with a objective for how $W$ should change \(\arg\min_{\Delta W} \ \langle \mathcal{L}(W), \Delta W\rangle \quad \text{s.t.} \quad \|\Delta y\|_{\text{rms}} \leq \eta\) Notice that we have used the RMS norm here: $|\bullet|{\text{rms}} = (1/\sqrt{d}) \cdot |\bullet|$ where $d$ is the dimension of the input $\bullet$ . Naturally, constraining the amount of change in $y$ means constraining the change in the transformation applied to $x$, i.e., $\Delta W$ . This is concretely shown as \(\|\Delta y\|_{\text{rms}} \leq \|\Delta W\|_{\text{rms}\rightarrow\text{rms}} \cdot \|x\|_{\text{rms}}\) According to the given definition of the RMS-to-RMS norm of a matrix \(\|\Delta W\|_{\text{rms}\rightarrow\text{rms}} = \frac{\sqrt{\text{fan-in}}}{\sqrt{\text{fan-out}}} \|\Delta W\|_s\) where $|\circ|_s$ is the spectral norm of an input matrix $\circ$ , if we assume $|x|{\text{rms}} \leq 1$, then our constraint becomes $|\Delta W|_s \leq \frac{\sqrt{\text{fan-out}}}{\sqrt{\text{fan-in}}}\cdot \eta$ . This concretely quantifies a measure for our desire of the change in output and weights remaining small. Thus, our final objective becomes \(\arg\min_{\Delta W} \ \langle \mathcal{L}(W), \Delta W\rangle \quad \text{s.t.} \quad \|\Delta W\|_s \leq \frac{\sqrt{\text{fan-out}}}{\sqrt{\text{fan-in}}}\cdot \eta\) Finding the solution to this optimization problem has been denoted as dualizing the gradient in Muon literature, and the closed form of this solution is \(\Delta W = -\eta \cdot \sqrt{\frac{\text{fan-out}}{\text{fan-in}}}\cdot UV^\top\) where $U$ and $V$ are the left and right singular vectors of the gradient respectively. <hr>

There is another way to derive this update. Note that all this while, we want to descend down the loss under the constraint that the change in our weights remains small. Hence, let us use perhaps the most popular and intuitive mathematical tool here: the Taylor expansion. Particularly, we adopt proximal descent under the spectral norm upto two orders of expansion. \(\mathcal{L}(W+\Delta W) = m(W) = \mathcal{L}(W)+\text{tr}[\nabla\mathcal{L}(W)^\top \Delta W] + \alpha\|\Delta W\|^2_s\) If we decompose the update as $r\cdot d = \Delta W$, such that $|d|s = 1$, then we need to solve two optimization problems:\(d^* =\arg\min_d \text{tr}[\nabla\mathcal{L}(W)^\top d] \quad \text{and} \quad r^* = \arg\min_r = \text{tr}[r\nabla\mathcal{L}(W)^\top d^* + \alpha r^2]\) Solving for $d^*$ gives us $d^* = -UV^\top$ and thus $r^* = -|\nabla\mathcal{L}(W)|{\text{nuc}}/2\alpha$. Therefore, constraining the change in weights under the spectral norm yields the orthogonalized gradient in the final parameter update, i.e., Muon.

Manifold Muon


Theory and Design

Manifold Muon is the natural extension of Muon to the setting where weights are explicitly constrained to live on a manifold, rather than merely having their updates orthogonalized. The motivation is two-fold: orthogonal updates are great (as Muon shows), but if the weights themselves drift to poorly conditioned regimes over the course of training, stability can still suffer. Manifold Muon addresses both problems simultaneously.

The Stiefel Manifold

The Stiefel manifold is the natural home for well-conditioned weight matrices. For a tall matrix $W \in \mathbb{R}^{m \times n}$ (with $m \geq n$), the Stiefel manifold is defined as

\[\mathsf{Stiefel}(m, n) = \{ W \in \mathbb{R}^{m \times n} : W^\top W = I_n \}.\]

In words, a matrix lies on the Stiefel manifold if and only if its columns form an orthonormal set — or equivalently, if all its singular values are exactly $1$. For the square case $m = n$ this recovers the orthogonal group $O(n)$. The Stiefel manifold is a smooth, compact submanifold of $\mathbb{R}^{m \times n}$, and is equivalently the set of $m \times n$ matrices with unit $\ell_2 \to \ell_2$ condition number.

The tangent space to the Stiefel manifold at a point $W$ is the set of matrices $A \in \mathbb{R}^{m \times n}$ satisfying

\[W^\top A + A^\top W = 0.\]

This is the analogue of the hyperspherical tangent condition $a^\top w = 0$ for vectors. It tells us which update directions leave the manifold to first order — and any direction satisfying this condition does not.

Deriving Manifold Muon via Proximal Descent

Recall how we derived Muon: we wrote down a second-order proximal model of the loss in spectral norm and minimized it to find the optimal unconstrained update. Manifold Muon adds one more ingredient: the update $\Delta W$ must also lie tangent to the Stiefel manifold at $W$, so that the updated weights stay (to first order) on the manifold. With $G = \nabla_W \mathcal{L}$ and the proximal regularizer $\alpha |\Delta W|_s^2$, the extended proximal model becomes

\[m(W) = \mathcal{L}(W) + \operatorname{tr}[G^\top \Delta W] + \alpha \|\Delta W\|_s^2 \quad \text{s.t.} \quad W^\top \Delta W + \Delta W^\top W = 0.\]

Just as before, we decompose $\Delta W = r \cdot d$ with $|d|_s = 1$ and first solve for the direction $d^$, then for the scale $r^$. The direction subproblem now carries the tangent constraint:

\[d^* = \arg\min_{\|d\|_s \leq 1} \ \operatorname{tr}[G^\top d] \quad \text{s.t.} \quad W^\top d + d^\top W = 0.\]

This is a constrained linear minimization over the spectral-norm ball, and we handle the constraint by introducing a symmetric matrix of Lagrange multipliers $\Lambda \in \mathbb{R}^{n \times n}$. The Lagrangian is

\[\mathcal{L}(d, \Lambda) = \operatorname{tr}[G^\top d] + \operatorname{tr}[\Lambda (W^\top d + d^\top W)] = \operatorname{tr}[(G + 2W\Lambda)^\top d],\]

where the second equality follows from the cyclic property of the trace and the symmetry of $\Lambda$. The inner minimization over the spectral-norm ball $|d|_s \leq 1$ is now just minimizing a linear function of $d$, whose closed-form minimizer is

\[d^*(\Lambda) = -\operatorname{msign}(G + 2W\Lambda),\]

where $\operatorname{msign}(X) = UV^\top$ is the matrix polar factor of $X = U\Sigma V^\top$ — exactly the NS iteration target. Substituting back and applying Sion’s minimax theorem converts the problem into an unconstrained maximization over $\Lambda$ (the dual problem). We then ascend on $\Lambda$ using the subgradient

\[H(\Lambda) = W^\top d^*(\Lambda) + d^*(\Lambda)^\top W,\]

which measures how far the current $d^*(\Lambda)$ is from satisfying the tangent space condition. This is the dual ascent algorithm:

  1. Initialize $\Lambda = -\tfrac{1}{4}(W^\top G + G^\top W)$.
  2. Compute the candidate direction: $d^* = -\operatorname{msign}(G + 2W\Lambda)$.
  3. Measure tangent deviation: $H = W^\top d^* + {d^*}^\top W$.
  4. Stop if $|H|_F / \sqrt{mn} < \texttt{tol}$; otherwise update $\Lambda \gets \Lambda + \alpha \cdot H$ and go to step 2.

The initialization is chosen so that for square $W$ the algorithm terminates in one step, recovering the standard Muon update $d^* = -UV^\top$ exactly. Once $d^$ is found, solving the scale subproblem gives $r^ = -|G + 2W\Lambda_\text{opt}|_\text{nuc} / 2\alpha$, so the final update is $\Delta W = r^* \cdot d^$, with the overall learning rate $\eta$ absorbing the scale as usual. The updated weights $W \gets W + \eta \cdot d^$ are then retracted back to the manifold via $W \gets \operatorname{msign}(W)$, or using the analytic retraction $W \gets W + W \cdot {d^}^\top d^ \cdot (1/\sqrt{1+\eta^2} - 1)$.

To summarize: Manifold Muon is just Muon’s proximal descent derivation with one extra Lagrange multiplier enforcing the tangent-space condition. The proximal model, the decomposition into direction and scale, and the closed-form $\operatorname{msign}$ update all carry over unchanged — the only new ingredient is the dual ascent loop that finds the $\Lambda$ making the update tangent to the Stiefel manifold.

A Broader Picture: Gram-Space Manifolds

The Stiefel manifold is not the only game in town. An insightful framing from Tilde Research notes that the Stiefel constraint $W^\top W = I_n$ is really a constraint on the Gram matrix of the columns of $W$. This unlocks a broader family of manifold optimizers by relaxing the identity constraint in different ways:

  • Stiefel: $W^\top W = I_n$ — columns are orthonormal.
  • Diagonal Gram (DGram): $W^\top W$ is required to be diagonal (columns are orthogonal but can have arbitrary norms).
  • Oblique: diagonal entries of $W^\top W$ are fixed to $1$ (columns have unit norm) but off-diagonal entries are free (columns need not be orthogonal).

Remarkably, all three manifolds admit the same efficient dual ascent solution with the same $\operatorname{msign}$ subgradient — there is no new solver to invent. They are all special cases of a family $\mathcal{M}_{P, \mathcal{C}}$ parameterized by a self-adjoint projector $P$ on the space of symmetric matrices and a constraint set $\mathcal{C}$:

Manifold $P$ $\mathcal{C}$
Stiefel $\operatorname{Id}$ ${I}$
Diagonal Gram $\operatorname{Off}$ ${0}$
Oblique $\operatorname{Diag}$ ${I}$

Experiments suggest that DGram and Oblique variants converge faster than strict Stiefel Muon (more degrees of freedom), while still maintaining significantly better conditioning than Adam. This hints at a fruitful design space: the right manifold may not be the most constrained one, and searching this Gram-space family is a principled way to explore it.

Modular Manifolds and Learning Rate Budgeting

A complementary question is: how should these per-layer manifold constraints interact across the depth of a network? The modular manifolds framework (from the Thinking Machines blog) addresses this by treating each neural network module as a triple $(f, \mathcal{M}, |\cdot|)$: a forward function, a manifold constraint, and a norm. The theory then shows how to compose modules — taking the product manifold and a max-norm across layers weighted by sensitivity coefficients — to automatically derive a globally consistent learning rate budget. In this way, manifold constraints do not just regulate individual weight matrices; they unlock a principled story about why one learning rate can work across all layers, generalizing the per-singular-value equalization argument from Muon.

Optax Implementation

Below we give a clean optax-compatible implementation of Manifold Muon in JAX. The structure mirrors the PyTorch reference from the Modula docs, adapted to JAX’s functional style. We implement msign via SVD (one could swap this for a NS iteration for GPU efficiency), and wrap the dual ascent loop as a pure function for jax.lax.while_loop compatibility.

import jax
import jax.numpy as jnp
import optax
from typing import NamedTuple


def msign(X: jnp.ndarray) -> jnp.ndarray:
    """Matrix polar factor: snap all singular values to 1."""
    U, _, Vt = jnp.linalg.svd(X, full_matrices=False)
    return U @ Vt


def manifold_muon_update(
    W: jnp.ndarray,
    G: jnp.ndarray,
    eta: float = 0.1,
    alpha: float = 0.01,
    steps: int = 30,
    tol: float = 1e-5,
) -> jnp.ndarray:
    """
    Compute the Manifold Muon update for a weight matrix W given gradient G.
    Returns the new weight matrix retracted back to the Stiefel manifold.
    """
    should_transpose = W.shape[0] < W.shape[1]
    if should_transpose:
        W, G = W.T, G.T

    # Dual variable initialization (reduces to standard Muon for square W)
    lam = -0.25 * (W.T @ G + G.T @ W)

    def dual_step(state):
        lam, _, step = state
        A = -msign(G + 2.0 * W @ lam)
        H = W.T @ A + A.T @ W
        norm_H = jnp.linalg.norm(H) / jnp.sqrt(H.size)
        # Cosine-annealed step for dual variable
        lam_new = lam + alpha * (1.0 - step / steps) * H
        return lam_new, norm_H, step + 1

    def cond_fn(state):
        _, norm_H, step = state
        return (norm_H >= tol) & (step < steps)

    lam, _, _ = jax.lax.while_loop(
        cond_fn, dual_step, (lam, jnp.inf, 0)
    )

    # Final update direction
    A = -msign(G + 2.0 * W @ lam)

    # Tangent space step
    new_W = W + eta * A

    # Retract back to Stiefel manifold
    new_W = new_W + new_W @ (A.T @ A) * (1.0 / jnp.sqrt(1.0 + eta ** 2) - 1.0)

    if should_transpose:
        new_W = new_W.T
    return new_W


class ManifoldMuonState(NamedTuple):
    momentum: optax.Updates  # momentum buffer (same structure as params)


def manifold_muon(
    learning_rate: float = 0.02,
    momentum: float = 0.95,
    alpha: float = 0.01,
    dual_steps: int = 30,
    tol: float = 1e-5,
) -> optax.GradientTransformation:
    """
    Optax-compatible Manifold Muon optimizer.

    Only applies the manifold update to 2-D weight matrices. All other
    parameters (scalars, vectors, embedding/output layers) should be
    optimized by a separate AdamW pass.
    """

    def init_fn(params):
        mom = jax.tree_util.tree_map(jnp.zeros_like, params)
        return ManifoldMuonState(momentum=mom)

    def update_fn(grads, state, params):
        mom = state.momentum

        # Apply Nesterov-style momentum to gradients
        new_mom = jax.tree_util.tree_map(
            lambda m, g: momentum * m + g, mom, grads
        )

        def compute_update(W, m):
            if W.ndim != 2:
                # Fall back to plain gradient for non-matrix params
                return -learning_rate * m
            return manifold_muon_update(W, m, eta=learning_rate,
                                        alpha=alpha, steps=dual_steps, tol=tol) - W

        updates = jax.tree_util.tree_map(compute_update, params, new_mom)
        return updates, ManifoldMuonState(momentum=new_mom)

    return optax.GradientTransformation(init_fn, update_fn)

A few design notes:

  • The dual ascent loop is implemented as a jax.lax.while_loop for XLA compatibility. The stopping condition checks both the tangent space deviation norm and the iteration budget.
  • The retraction map applied after the tangent step is the analytic one derived by Bernstein: a correction proportional to $A^\top A$ that projects the updated weights back onto the Stiefel manifold exactly. Alternatively, one can retract by calling msign(new_W) directly (simpler but slightly more expensive).
  • For non-matrix parameters (scalars, vectors, embedding tables, LM heads) the optimizer simply falls back to scaled momentum, consistent with the convention that Muon-style updates only make sense for 2-D weight matrices.
  • The dual step size $\alpha$ can be cosine-annealed over the inner loop iterations (as shown), which empirically stabilizes convergence of the dual variable.

References

  1. Thinking Machines’ Manifold Muon blog
  2. Sam D. Buchanan’s blog
  3. Keller Jordan’s blog
  4. Keller Jordan’s Muon GitHub repo
  5. Google DeepMind’s Optax GitHub repo
  6. Thinking Machines’ Manifolds GitHub repo
  7. Modula Docs — Stiefel Manifold (Bernstein)
  8. Gram-Space Manifold Muon (Tilde Research)
  9. Modular Norm paper (Bernstein & Large et al.)
  10. Modular Duality paper



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • From Muon to Spectra
  • A Simple Toy Model Bridging HTSR & $\alpha$-REQ
  • Celebration is the secret
  • New POVs on hypernetworks
  • Forecasting PDE Dynamics with INRs