Ultra-Scale Playbook vol-3 - DeepSpeed ZeRO
Zero Redundancy Optimiser
TLDR: Instead of naive data parallelism,
- partition the optimizer states, gradients, and parameters across the DP dimension,
- while still allowing computation with the full parameter set. Can require more communnication between DP ranks (think: overlapping).
3 stages of ZeRO
- ZeRO-1: only partition optimiser states
- ZeRO-2: partition optimiser states + gradients
- ZeRO-3: partition optimiser states + gradients + params
Given the model parameter count $N$, mixed precision training with Adam dictates the following memory usage:
- params
bf16= $2N$ - grads
bf16= $2N$, if we want to accumulate grad infp32then $4N$ - optimiser states
fp32= $4N + 4N$, and parametersfp32= $4N$
For efficiency let’s keep grad-accumulation in bf16 and so total memory usage becomes $2N + 2N + 12N$. Now given data parallel degree $N_d$,
- Baseline: $2N + 2N + 12N$ across each DP rank
- ZeRO-1: $2N + 2N + 12N/N_d$
- ZeRO-2: $2N + (2N + 12N)/N_d$
- ZeRO-3: $(2N + 2N + 12N)/N_d$
ZeRO-1: only shard optimiser states
- Each DP rank forward passes with the same full set of
bf16params but on a different microbatch of data. - Each DP rank back-props with the same full set of
bf16grads but on a different microbatch of data. - Since each DP rank stores only $1/N_d$ of the optimiser states, and we want to back-prop on all the data via a mean/sum reduction, we perform a
reduce_scatteron the gradients, i.e., $grad_k^j \leftarrow \sum_l grad^j_l$. This accumulates the gradient of each $N/N_d$-th parameter chunk on all the data $\equiv \nabla_w \ell(X) = \sum_{i=1}^n \nabla_w \ell(X_i)$. - Now that we have each $N/N_d$-th parameter chunk accumulated across the data, we perform an optimiser step using only $1/N_d$ states $\implies$ $1/N_d$
fp32updated params, converted to $1/N_d$bf16params. - At this point, each DP rank has $1/N_d$ updated
bf16params. It now receives the other $(N_d-1)/N_d$ fraction of params from the other DP ranks via anall_gatheroperation. - Now each DP rank has the same full set of updated
bf16params, ready for the next step of training. - Final memory footprint: $2N + 2N + 12N / N_d$.
💡 To update each chunk during reduce_scatter, only that chunk (across different microbatches) is needed per machine.
What motivates ZeRO-2: Why not accumulate a chunk on all data during back-prop, and then only store the grads required for the optimiser step? That eliminates the need to always store all the grads.
ZeRO-2: shard optimiser states + grad
Only perform reduce_scatter during back-prop. Now only $1/N_d$-th of the gradients are needed in memory, freeing up memory and giving us a much better memory footprint of $2N + (2N + 12N)/N_d$.
What motivates ZeRO-3: Distributing the params across DP ranks can make forward pass possible by doing an all_gather for each microbatch per DP rank. Think of it this way: we temporarily “stitch” all the shards of a layer together, forward pass a microbatch through it, then flush the gathered shards to keep only $1/N_d$ params in memory.
ZeRO-3: shard everything
- Forward pass: distribute the
bf16param set across DP ranks $\implies$ $2N / N_d$ memory on each rank. Given a microbatch, useall_gatherto stitch a sharded layer together, apply the layer on the microbatch, and flush the other gathered param shards from memory. - Back-prop: use
all_gatherto stitch sharded params together, compute gradient for gathered params per microbatch, and thenreduce_scatterto accumulate all microbatches into each chunk. - Final memory footprint: $(2N + 2N + 12N)/N_d$.
💡 ZeRO-3 requires $2 \cdot \text{num_layers} -1$ additional calls to all_gather w.r.t ZeRO-2, and each comes with a small base latency. Also, we gather all the shards once in the forward pass and once during back-prop, incurring a communication tax of $N$ each time. Adding another communication tax of $N$ from the reduce_scatter called during back-prop, our total communication cost is $3N$ compared to $2N$ in ZeRO-2.
💡 While this may seem like a lot of overhead due to communication, in practice we use prefetching to compensate for this: simply all_gather weights for the next layer when we forward pass through the current layer. Similarly, all_gather weights of the previous layer while back-propping through the current layer.
When will this fail: when our DP dimension exceeds 512, the communication overhead will become too large due to ring latency, and our overlap will fail. Need to think of something else at those scales.
Enjoy Reading This Article?
Here are some more articles you might like to read next: