Challenge

模型规模增长放大训练过程中的显存压力.所以需要在不改变模型结构的前提下,减少训练过程中的显存占用.

Memory Consumption Breakdown

我们对显存的占用进行拆解:

  • Parameters
  • Gradients
  • Optimizer State. 例如 Adam 优化器需要为每一个参数维护 m,vm,v 两个 tensor
  • Activation
  • Temp Buffer/Fragment

所以我们以 FP32 精度的 Adam 优化器训练 FP16 精度的 7B 模型为例,显存占用可以达到 7B x (2 + 2 + 4 * 3) = 112GB


Solution

ZeRO 在保持 DP (Data Parallel) 语义不变的前提下,系统性地减少了参数、梯度和优化器状态的冗余复制.具体而言,可以分为三个等级.

  • ZeRO-1:每个 rank 上保存完整的 parameters 和 gradient,但是 optimizer state 按 rank 进行 slicing
  • ZeRO-2:每个 rank 上保存完整的 parameters,gradients 和 optimizer state 会按 rank 进行 slicing
  • ZeRO-3:parameters, gradients, optimizer state 都会进行 slicing
What is “rank”?

其实就是分布式系统语境下,节点的一个别名.AI 训练系统里可以简单地认为是 GPU Server

ZeRO-1

由于每个 rank 上的参数和梯度仍然是完整的,所以 forward propagation 和 backward propagation 没有任何变化(包括梯度在不同 rank 之间的 AllReduce).

不同之处在于,我们用 optimizer.step() 对参数进行更新时,我们让每一个 rank 先用自己的那部分 optimizer state 更新一部分自己的参数切片,优化完毕后,再使用 AllGather 汇总得到更新后的完整参数.