Challenge
模型规模增长放大训练过程中的显存压力.所以需要在不改变模型结构的前提下,减少训练过程中的显存占用.
Memory Consumption Breakdown
我们对显存的占用进行拆解:
- Parameters
- Gradients
- Optimizer State. 例如 Adam 优化器需要为每一个参数维护 两个 tensor
- Activation
- Temp Buffer/Fragment
所以我们以 FP32 精度的 Adam 优化器训练 FP16 精度的 7B 模型为例,显存占用可以达到 7B x (2 + 2 + 4 * 3) = 112GB
Solution
ZeRO 在保持 DP (Data Parallel) 语义不变的前提下,系统性地减少了参数、梯度和优化器状态的冗余复制.具体而言,可以分为三个等级.
- ZeRO-1:每个 rank 上保存完整的 parameters 和 gradient,但是 optimizer state 按 rank 进行 slicing
- ZeRO-2:每个 rank 上保存完整的 parameters,gradients 和 optimizer state 会按 rank 进行 slicing
- ZeRO-3:parameters, gradients, optimizer state 都会进行 slicing
其实就是分布式系统语境下,节点的一个别名.AI 训练系统里可以简单地认为是 GPU Server
| Variant | Memory Consumed | |
|---|---|---|
| Baseline | 120GB | |
| ZeRO-1 | 31.4GB | |
| ZeRO-2 | 16.6GB | |
| ZeRO-3 | 1.9GB |
这里, 表示一个参数字节大小(FP16= bytes), 表示梯度的字节数, 表示 Adam 优化器需要的字节数, 表示参数量, 表示 rank 数量。
通常来说,Adam 需要为每个模型参数维护 两个参数;并且混合精度训练下,为了数值稳定,还会保存 FP32 的参数副本。
ZeRO-1
由于每个 rank 上的参数和梯度仍然是完整的,所以 forward propagation 和 backward propagation 没有任何变化(包括梯度在不同 rank 之间的 AllReduce).
不同之处在于,我们用 optimizer.step() 对参数进行更新时,我们让每一个 rank 先用自己的那部分 optimizer state 更新一部分自己的参数切片,优化完毕后,再使用 AllGather 汇总得到更新后的完整参数.
虽然 ZeRO-1 节省了显存,但是流程末尾引入了 AllGather 的通信开销. 若总参数量为 ,DP Size 为 ,参与 AllGather 的是全体参数,一共 bytes. 根据 Ring AllGather,单 Rank 发送 bytes,接收 bytes,通信量合计
ZeRO-2
每一个 Rank 上仍然有全量参数,但是会按照 DP size 切分 optimizer state 和 gradient. 前向计算流程依旧不变,但是在反向传播的过程中,不同 rank 之间的 AllReduce 会替换为 ReduceScatter.
optimizer.step() 时,每个 rank 仅使用自己持有的 gradient 和 optimizer state shard 更新部分参数. 优化器更新完毕后,通过 AllGather 汇总完整参数.
通信开销:如果使用 Ring 算法,那么 Reduce-Scatter + All-Gather 的总通信量与 All-Reduce 是相同的. ZeRO-2 的通信量与 Pure DDP without ZeRO 是一样的.