Layer Sharding for Large‑Scale Training with Muon
May 15, 2025
There’s a lot of excitement around the Muon optimizer as a replacement for Adam and AdamW, but it’s challenging to scale it up to larger training runs. As we illuminated in our recent paper, Muon confers batch‑size advantages beyond AdamW, offering a much wider range of resources to deploy at the pre‑training workload for better time to target loss. In this post, we describe our sharding strategy that scales Muon further than others we’ve seen discussed.
Muon has a larger overhead than Adam
The final test of an optimizer is how fast you can achieve a given loss. Traditionally, we think of this in terms of tokens (or flops), and Muon indeed uses fewer tokens or flops to get that loss. However, the true cost of training a model is hours of compute. With Muon, you can use larger batches and achieve significantly better MFU at large scale.
However, Muon requires significantly more calculations for parameter updates than Adam, and in some configurations, this cost can be a significant portion of the total training time.
Suppose we have a tensor X with dimensions (L, P, Q) where L is the number of layers and P and Q are the “in” and “out” dimensions of the weights. For the largest weight tensors (MLP weights), P and Q are the model dimension d and 4d, respectively. We want to compute the following (omitting several cheap elementwise ops):
for _ in range(5):
A = X @ X^T # (L, P, Q) x (L, Q, P) => (L, P, P) with cost 2LPQP flops
B = A @ A # (L, P, P) x (L, P, P) => (L, P, P) with cost 2LPPP flops
X = B @ X # (L, P, P) x (L, P, Q) => (L, P, Q) with cost 2LPQP flops
This costs 5×2LPP(2Q + P) flops, sometimes rounded to 30LPPQ flops.
By comparison, the forward+backward for this same weight tensor costs 6LPQB flops, where B is batch tokens. The ratio of Muon flops to model flops (ignoring attention) is 30LPPQ / 6LPQB = 5P/B. With P≈d, the rule of thumb is: Muon costs 5d/B extra flops per model flop.
This leads to claims like this:
However, this assumes we can perfectly shard the Muon computation across all 16,384 GPUs used to train the model. This is unrealistic. For reference, Llama 405B used 8‑way Tensor Parallelism, 16‑way Pipeline Parallelism, and 128‑way Fully‑Sharded Data Parallelism (FSDP). If computation duplicates across FSDP, the cost can balloon from 0.5% to ~64% of model flops — too expensive.
We have options. PP naturally shards the L dimension. But FSDP shards P and TP shards Q — both contracting dims in these matmuls — which can add heavy communication.
Strategy #1: Replicated computation
Let X′ be sharded S ways across P (similar for Q). In Llama 405B, S ranges 16–128. One option (as in this paper) is to all‑gather X and replicate compute on each device (except the last matmul can compute only your shard):
# let X' be our shard of X with shape (L, P/S, Q)
X = all_gather(X')
for _ in range(5):
A = X @ X^T # (L, P, Q) x (L, Q, P) => (L, P, P) with cost 2LPQP flops
B = A @ A # (L, P, P) x (L, P, P) => (L, P, P) with cost 2LPPP flops
X = B @ X # (L, P, P) x (L, P, Q) => (L, P, Q) with cost 2LPQP flops
# (except last iter can output shard: (L, P/S, Q))
Flops stay ~the same; too large. With Llama 405B’s typical config, this would duplicate compute 1024× — meaning ~5 Muon flops per model flop. Not acceptable.
Here’s a profile of Muon on an MI300X node using this approach. Stream #20 (network) is active mainly during the initial all‑gather.
Strategy #2: Sharded matmul
Another choice is to shard the matmuls directly (JAX’s naive default; also discussed in this post):
# let X' be our shard of X with shape (L, P/S, Q)
for _ in range(5):
X = all_gather(X')
A' = X' @ X^T # => (L, P/S, P) cost 2LPQP/S
A = all_gather(A')
B' = A' @ A # => (L, P/S, P) cost 2LPPP/S
X' = B' @ X # => (L, P/S, Q) cost 2LPQP/S
This reduces flops by 1/S — great — but communication is huge: per iteration you all‑gather X (model‑sized) and A (~¼ model). Across 5 iters → ≈6.25 model‑sized gathers. Minimum viable batch size for DP/FSDP skyrockets, and perfect overlap is unrealistic.
On MI300X, Stream #20 becomes active many times — communication dominates and isn’t well overlapped.
Strategy #3: Layer sharding
Communication is large because we shard contracting dims. What about the non‑contracting dim L? Pipeline parallelism shards across layers, but we often avoid PP because fwd/bwd have data dependencies across layers. The Muon update, however, has no inter‑layer dependency — so we can re‑shard by layer, compute, then re‑shard back.
# let X' be our shard of X with shape (L, P/S, Q)
X'' = all_to_all(X') # (L, P/S, Q) -> (L/S, P, Q)
for _ in range(5):
A'' = X'' @ X''^T # (L/S, P, Q) x (L/S, Q, P) -> (L/S, P, P)
B'' = A'' @ A'' # (L/S, P, P) x (L/S, P, P) -> (L/S, P, P)
X'' = B'' @ X'' # (L/S, P, P) x (L/S, P, Q) -> (L/S, P, Q)
X' = all_to_all(X'') # (L/S, P, Q) -> (L, P/S, Q)
Same 1/S flop reduction, but far less comms: two (P,Q) all‑to‑alls total. On TPUs, all‑to‑all is ~4× faster than all‑gather; on GPUs it’s also much faster (esp. intra‑node with full NVLink/NVSwitch connectivity).
Here’s an MI300X profile: Stream #20 is active only occasionally (theoretically twice per tensor).
Eventually you run out of layers (Llama 405B has 126), capping layer sharding. With effort, you can likely get another ~4× by placing the three MLP tensors on separate GPUs and another for attention projections. Doing all this, Muon’s effective overhead can reach ~16% — borderline but feasible.
Muon also enables larger batches. If you can 2× batch size, model flops rise while Muon flops don’t, halving the relative overhead (e.g., from 16% → ~8%). Llama reports MFU 41% → 43% when doubling batch (≈+4.8%), leaving ~3.2% penalty — maybe acceptable.
Summary and observed times
Summary of options:
flops (vs replicated) | comms (# all-gathers of the model)
replicated computation 1 | 1 (or 0 if already in DP)
sharded matmul 1/S | 6.25
layer sharding 1/S | 0.5 TPU, ~0.1 GPU
We implemented these in JAX and ran on TPUs and MI300X GPUs. 2B model, 24k per‑device batch (note: Llama 405B used ~1k per‑device).
On v5p‑8 TPU (4‑way distributed):
muon time | total step time
replicated: 386 ms | 2.388 s
sharded matmul: 255 ms | 2.161 s
layer sharding: 155 ms | 2.105 s
Ideally, layer‑sharding would be ~25% of replicated; observed ~40% due to poor comm/compute overlap (improvable).
Scaling to v5p‑16 for layer sharding cuts Muon time to ~75 ms, as expected (2× bandwidth + 2× compute).
On MI300X (8‑way distributed):
muon time | total step time
replicated: 1171 ms | 2.720 s (43%)
sharded matmul: 522 ms | 2.200 s (26%)
layer sharding: 189 ms | 1.840 s (10%)
Ideally, layer‑sharding would be 12.5% of replicated; observed ~16% because the model had 13 layers and 8‑way layer sharding wasted 3 slots (we confirmed ratio ~13.7% when adjusted).
These measurements used fp32 for Muon, which is expensive and likely unnecessary. Since layer‑sharded Muon is mostly compute‑bound, real costs are lower with reduced precision.
Future concerns
Using layer sharding for a 405B‑scale run pushes feasibility. Fewer GPUs (e.g., ~2k like DeepSeek‑V3) help a lot.
Mixture‑of‑Experts hurts: activating only 1/10 of parameters reduces model flops by 10× but Muon still touches all parameters across the batch — so Muon/model flop ratio is ~10× worse than dense. MoE can use larger batches, but not 10× larger.
All approaches above compute full Muon. Further scale may require modifying Muon to shrink the matrix entering the Newton‑Schulz iterations.
Meanwhile, layer sharding lets Muon scale to significantly larger training runs than before.