PyTorch 关键组件
- TorchDynamo
- 将所有复杂算子简化到 PrimTorch 中的 250 个算子
- 移除未使用的算子
- 确实需要存储、写入内存的中间算子,以及可融合的算子,从而减少开销
- PrimTorch
- 定义了两个算子集合:Aten ops 和 Prim ops
- 将 PyTorch 程序的各种计算用这些算子集里的算子表示
- 简化后端需要编写的算子数量
- AOTAutograd
- 提前获取反向传播
- 基于完整的 forward/backward 根据算子的依赖关系进行算子调度,对算子和层进行融合
- TorchInductor
- 进行算子融合
- 自动生成低级 GPU 上的 Triton 代码(或者 CPU 上的 C++/OpenMP)
编译流程
我们用下面的例子介绍一下大致的编译流程,在运行时加入调试参数 TORCH_LOGS="..." python example.py
查看中间的日志输出
1 2 3 4 5 6 7 8 9 10
| import torch
@torch.compile def toy_example(x: torch.Tensor) -> torch.Tensor: y = x.sin() z = y.cos() return z
if __name__ == "__main__": x = torch.randn(1000, device="cuda", requires_grad=True)
|
Step 1. TorchDynamo
运行 TORCH_LOGS="dynamo" uv run example.py
,我们先来看第一步 TorchDynamo 的输出。
日志输出
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| [torch/_dynamo/symbolic_convert.py:2706] [0/0] Step 1: torchdynamo start tracing toy_example [很长的路径]/example.py:5 [torch/_dynamo/symbolic_convert.py:3028] [0/0] Step 1: torchdynamo done tracing toy_example
[torch/_dynamo/output_graph.py:1458] [0/0] Step 2: calling compiler function inductor [torch/_dynamo/output_graph.py:1463] [0/0] Step 2: done compiler function inductor
[torch/fx/experimental/symbolic_shapes.py:4547] [0/0] produce_guards [torch/_dynamo/pgo.py:636] [0/0] put_code_state: no cache key, skipping [torch/_dynamo/eval_frame.py:398] TorchDynamo attempted to trace the following frames [ toy_example [很长的路径]/example.py:5 ] [torch/_dynamo/utils.py:446] TorchDynamo compilation metrics: Function Runtimes (s) ------------------------------------ -------------- _compile.compile_inner 0.5482 OutputGraph.call_user_compiler 0.4845 _recursive_pre_grad_passes 0.0018 create_aot_dispatcher_function 0.4817 _recursive_joint_graph_passes 0.0684 compile_fx.<locals>.fw_compiler_base 0.3442 compile_fx_inner 0.3437 inductor_codecache_torch_key 0.0523 TritonBundler.read_and_emit 0.0002 PyCodeCache.load_by_key_path 0.0122 async_compile.precompile 0.007 async_compile.wait 0.0001
|
从日志中可以看到,TorchDynamo 的框架流程就是
- 对要编译的模型进行追踪,然后编译并生成中间表示 (FX Graph IR)
- 调用
compiler.inductor
对模型进行化简
1.1 Dynamo 图捕获
Dynamo 首先进行图捕获。这里,__graph_code
将原始代码的 Dataflow 进行捕获,并输出捕获的 DAG,即 FX Graph IR.
FX Graph IR
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| def forward(self, L_x_: "f32[1000]"): l_x_ = L_x_ y: "f32[1000]" = l_x_.sin(); l_x_ = None z: "f32[1000]" = y.cos(); y = None
return (z,)
[torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code] def forward(self, L_x_: "f32[1000][1]cuda:0"): l_x_ = L_x_ y: "f32[1000][1]cuda:0" = l_x_.sin(); l_x_ = None z: "f32[1000][1]cuda:0" = y.cos(); y = None
return (z,)
|
1.2 AOTAutograd
Dynamo 的 AOTAutograd
阶段
- 生成正向传播图和反向传播图(也是表示为 FX Graph IR 的形式)
- 会将 FX Graph IR 中的算子替换为 ATen 算子库里的算子
- 基于完整的正向、反向传播图的视角,根据依赖关系,进行算子调度、对算子和层进行融合
- 将复杂的算子根据字典进一步分解为更底层的 Core ATen IR 算子或者 Prim IR 算子
AOTAutograd IR 生成的正向图与反向图
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
|
def forward(self, primals_1: "f32[1000][1]cuda:0"): sin: "f32[1000][1]cuda:0" = torch.ops.aten.sin.default(primals_1) cos: "f32[1000][1]cuda:0" = torch.ops.aten.cos.default(sin); sin = None return (cos, primals_1)
<eval_with_key>.1 class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[1000][1]cuda:0", tangents_1: "f32[1000][1]cuda:0"): sin: "f32[1000][1]cuda:0" = torch.ops.aten.sin.default(primals_1) sin_1: "f32[1000][1]cuda:0" = torch.ops.aten.sin.default(sin); sin = None neg: "f32[1000][1]cuda:0" = torch.ops.aten.neg.default(sin_1); sin_1 = None mul: "f32[1000][1]cuda:0" = torch.ops.aten.mul.Tensor(tangents_1, neg); tangents_1 = neg = None cos_1: "f32[1000][1]cuda:0" = torch.ops.aten.cos.default(primals_1); primals_1 = None mul_1: "f32[1000][1]cuda:0" = torch.ops.aten.mul.Tensor(mul, cos_1); mul = cos_1 = None return (mul_1,)
|
2. Inductor
Triton 的核心: compile()