在 PyTorch 代码执行计算的时候,AutoGrad 会构建一张由 Function 对象组成的 DAG 计算图,用于反向传播.每一个 Function 对象表示操作,通过其 .apply() 计算前向传播结果,并记录其反向传播的逻辑

.grad_fn 保存 Function 对象,在反向传播阶段分析计算图并基于链式法则自动求导.

前向传播会记录中间结果,便于反向传播的计算.所以通常来说,y.grad_fn._saved_self == x 但并非总是成立.例如 y=exy=e^x,此时的 y.grad_fn._saved_result != y 来避免循环引用导致内存泄露.

在 PyTorch 里,我们通过 pack/unpack 避免循环引用:

  • 在保存张量时进行 pack,存储必要的数据到磁盘
  • 读取时进行 unpack 返回新的张量对象

那么不可导(如 ReLU(0)ReLU(0))的情况怎么办呢?

y=x2y=x^2 为例,根据链式法则,假设上游传过来的梯度为 dwdy\frac{dw}{dy},那么

dwdx=dwdydydx=dwdy(2x)\frac{dw}{dx}=\frac{dw}{dy}\frac{dy}{dx}=\frac{dw}{dy}\cdot(2x)

requires_grad 属性

可以使用 requires_grad 属性,来细粒度地控制计算操作是否纳入 AutoGrad 计算图中.(只对叶子张量有效)

原地操作

不允许原地操作

给每一个 Tensor 维护

混合精度训练 torch.amp

核心包括 autocast, GradScaler

torch.amp.autocast 自动将某些算子用 FP16,用 with torch.amp.autocast()

torch.amp.GradScaler 解决 FP16 训练时梯度过小的问题

  • 缩放 loss 来放大反向传播的梯度
  • 恢复梯度
  • 动态调整 scale 大小

混合精度训练策略