Challenge
模型规模增长放大训练过程中的显存压力.所以需要在不改变模型结构的前提下,减少训练过程中的显存占用.
Memory Consumption Breakdown
我们对显存的占用进行拆解:
- Parameters
- Gradients
- Optimizer State. 例如 Adam 优化器需要为每一个参数维护 两个 tensor
- Activation
- Temp Buffer/Fragment
所以我们以 FP32 精度的 Adam 优化器训练 FP16 精度的 7B 模型为例,显存占用可以达到 7B x (2 + 2 + 4 * 3) = 112GB
Solution
ZeRO 在保持 DP (Data Parallel) 语义不变的前提下,系统性地减少了参数、梯度和优化器状态的冗余复制.具体而言,可以分为三个等级.
- ZeRO-1:每个 rank 上保存完整的 parameters 和 gradient,但是 optimizer state 按 rank 进行 slicing
- ZeRO-2:每个 rank 上保存完整的 parameters,gradients 和 optimizer state 会按 rank 进行 slicing
- ZeRO-3:parameters, gradients, optimizer state 都会进行 slicing
What is “rank”?
其实就是分布式系统语境下,节点的一个别名.AI 训练系统里可以简单地认为是 GPU Server
ZeRO-1
由于每个 rank 上的参数和梯度仍然是完整的,所以 forward propagation 和 backward propagation 没有任何变化(包括梯度在不同 rank 之间的 AllReduce).
不同之处在于,我们用 optimizer.step() 对参数进行更新时,我们让每一个 rank 先用自己的那部分 optimizer state 更新一部分自己的参数切片,优化完毕后,再使用 AllGather 汇总得到更新后的完整参数.