Challenge

模型规模增长放大训练过程中的显存压力.所以需要在不改变模型结构的前提下,减少训练过程中的显存占用.

Memory Consumption Breakdown

我们对显存的占用进行拆解:

  • Parameters
  • Gradients
  • Optimizer State. 例如 Adam 优化器需要为每一个参数维护 m,vm,v 两个 tensor
  • Activation
  • Temp Buffer/Fragment

所以我们以 FP32 精度的 Adam 优化器训练 FP16 精度的 7B 模型为例,显存占用可以达到 7B x (2 + 2 + 4 * 3) = 112GB


Solution

ZeRO 在保持 DP (Data Parallel) 语义不变的前提下,系统性地减少了参数、梯度和优化器状态的冗余复制.具体而言,可以分为三个等级.

  • ZeRO-1:每个 rank 上保存完整的 parameters 和 gradient,但是 optimizer state 按 rank 进行 slicing
  • ZeRO-2:每个 rank 上保存完整的 parameters,gradients 和 optimizer state 会按 rank 进行 slicing
  • ZeRO-3:parameters, gradients, optimizer state 都会进行 slicing
What is “rank”?

其实就是分布式系统语境下,节点的一个别名.AI 训练系统里可以简单地认为是 GPU Server

Variant Memory Consumed P=2,G=2,A=12,Φ=7.5B,N=64P=2,G=2,A=12,\Phi=7.5B,N=64
Baseline (P+G+K)Φ(P+G+K)\Phi 120GB
ZeRO-1 (P+G)Φ+KΦ/N(P+G)\Phi+K\Phi/N 31.4GB
ZeRO-2 PΦ+(G+K)Φ/NP\Phi +(G+K)\Phi/N 16.6GB
ZeRO-3 (P+G+K)Φ/N(P+G+K)\Phi/N 1.9GB

这里,PP 表示一个参数字节大小(FP16=22 bytes),GG 表示梯度的字节数,AA 表示 Adam 优化器需要的字节数,Φ\Phi 表示参数量,NN 表示 rank 数量。

通常来说,Adam 需要为每个模型参数维护 m,vm,v 两个参数;并且混合精度训练下,为了数值稳定,还会保存 FP32 的参数副本。

ZeRO-1

由于每个 rank 上的参数和梯度仍然是完整的,所以 forward propagation 和 backward propagation 没有任何变化(包括梯度在不同 rank 之间的 AllReduce).

不同之处在于,我们用 optimizer.step() 对参数进行更新时,我们让每一个 rank 先用自己的那部分 optimizer state 更新一部分自己的参数切片,优化完毕后,再使用 AllGather 汇总得到更新后的完整参数.

虽然 ZeRO-1 节省了显存,但是流程末尾引入了 AllGather 的通信开销. 若总参数量为 Φ\Phi,DP Size 为 NN,参与 AllGather 的是全体参数,一共 2Φ2\Phi bytes. 根据 Ring AllGather,单 Rank 发送 N1N2Φ\frac{N-1}{N}\cdot 2\Phi bytes,接收 N1N2Φ\frac{N-1}{N}\cdot 2\Phi bytes,通信量合计 N1N4Φ\frac{N-1}{N}\cdot 4\Phi

ZeRO-2

每一个 Rank 上仍然有全量参数,但是会按照 DP size 切分 optimizer state 和 gradient. 前向计算流程依旧不变,但是在反向传播的过程中,不同 rank 之间的 AllReduce 会替换为 ReduceScatter.

optimizer.step() 时,每个 rank 仅使用自己持有的 gradient 和 optimizer state shard 更新部分参数. 优化器更新完毕后,通过 AllGather 汇总完整参数.

通信开销:如果使用 Ring 算法,那么 Reduce-Scatter + All-Gather 的总通信量与 All-Reduce 是相同的. ZeRO-2 的通信量与 Pure DDP without ZeRO 是一样的.