PyTorch 关键组件

  • TorchDynamo
    • 将所有复杂算子简化到 PrimTorch 中的 250 个算子
    • 移除未使用的算子
    • 确实需要存储、写入内存的中间算子,以及可融合的算子,从而减少开销
  • PrimTorch
    • 定义了两个算子集合:Aten ops 和 Prim ops
    • 将 PyTorch 程序的各种计算用这些算子集里的算子表示
    • 简化后端需要编写的算子数量
  • AOTAutograd
    • 提前获取反向传播
    • 基于完整的 forward/backward 根据算子的依赖关系进行算子调度,对算子和层进行融合
  • TorchInductor
    • 进行算子融合
    • 自动生成低级 GPU 上的 Triton 代码(或者 CPU 上的 C++/OpenMP)

编译流程

我们用下面的例子介绍一下大致的编译流程,在运行时加入调试参数 TORCH_LOGS="..." python example.py 查看中间的日志输出

1
2
3
4
5
6
7
8
9
10
import torch

@torch.compile
def toy_example(x: torch.Tensor) -> torch.Tensor:
y = x.sin()
z = y.cos()
return z

if __name__ == "__main__":
x = torch.randn(1000, device="cuda", requires_grad=True) # 开启反向传播

Step 1. TorchDynamo

运行 TORCH_LOGS="dynamo" uv run example.py,我们先来看第一步 TorchDynamo 的输出。

日志输出
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
[torch/_dynamo/symbolic_convert.py:2706]
[0/0] Step 1: torchdynamo start tracing toy_example [很长的路径]/example.py:5
[torch/_dynamo/symbolic_convert.py:3028]
[0/0] Step 1: torchdynamo done tracing toy_example

[torch/_dynamo/output_graph.py:1458]
[0/0] Step 2: calling compiler function inductor
[torch/_dynamo/output_graph.py:1463]
[0/0] Step 2: done compiler function inductor

[torch/fx/experimental/symbolic_shapes.py:4547]
[0/0] produce_guards
[torch/_dynamo/pgo.py:636]
[0/0] put_code_state: no cache key, skipping
[torch/_dynamo/eval_frame.py:398]
TorchDynamo attempted to trace the following frames [
toy_example [很长的路径]/example.py:5
]
[torch/_dynamo/utils.py:446]
TorchDynamo compilation metrics:
Function Runtimes (s)
------------------------------------ --------------
_compile.compile_inner 0.5482
OutputGraph.call_user_compiler 0.4845
_recursive_pre_grad_passes 0.0018
create_aot_dispatcher_function 0.4817
_recursive_joint_graph_passes 0.0684
compile_fx.<locals>.fw_compiler_base 0.3442
compile_fx_inner 0.3437
inductor_codecache_torch_key 0.0523
TritonBundler.read_and_emit 0.0002
PyCodeCache.load_by_key_path 0.0122
async_compile.precompile 0.007
async_compile.wait 0.0001

从日志中可以看到,TorchDynamo 的框架流程就是

  1. 对要编译的模型进行追踪,然后编译并生成中间表示 (FX Graph IR)
  2. 调用 compiler.inductor 对模型进行化简

1.1 Dynamo 图捕获

Dynamo 首先进行图捕获。这里,__graph_code 将原始代码的 Dataflow 进行捕获,并输出捕获的 DAG,即 FX Graph IR.

FX Graph IR
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# [torch/fx/passes/runtime_assert.py:118] [0/0] [__graph_code]
def forward(self, L_x_: "f32[1000]"):
l_x_ = L_x_

# File: example.py:7 in toy_example, code: y = x.sin()
y: "f32[1000]" = l_x_.sin(); l_x_ = None

# File: example.py:8 in toy_example, code: z = y.cos()
z: "f32[1000]" = y.cos(); y = None

return (z,)

[torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code]
def forward(self, L_x_: "f32[1000][1]cuda:0"):
l_x_ = L_x_

# File: example.py:7 in toy_example, code: y = x.sin()
y: "f32[1000][1]cuda:0" = l_x_.sin(); l_x_ = None

# File: example.py:8 in toy_example, code: z = y.cos()
z: "f32[1000][1]cuda:0" = y.cos(); y = None

return (z,)

1.2 AOTAutograd

Dynamo 的 AOTAutograd 阶段

  1. 生成正向传播图和反向传播图(也是表示为 FX Graph IR 的形式)
  2. 会将 FX Graph IR 中的算子替换为 ATen 算子库里的算子
  3. 基于完整的正向、反向传播图的视角,根据依赖关系,进行算子调度、对算子和层进行融合
  4. 将复杂的算子根据字典进一步分解为更底层的 Core ATen IR 算子或者 Prim IR 算子
AOTAutograd IR 生成的正向图与反向图
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 这个是正向图
# ===== Forward graph 0 =====
# torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[1000][1]cuda:0"):
## File: example.py:7 in toy_example, code: y = x.sin()
sin: "f32[1000][1]cuda:0" = torch.ops.aten.sin.default(primals_1)

## File: example.py:8 in toy_example, code: z = y.cos()
cos: "f32[1000][1]cuda:0" = torch.ops.aten.cos.default(sin); sin = None
return (cos, primals_1)

# 这个是反向图
# [torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:603]
# [0/0] [__aot_graphs]
#
# TRACED GRAPH
# ===== Backward graph 0 =====
<eval_with_key>.1 class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[1000][1]cuda:0", tangents_1: "f32[1000][1]cuda:0"):
# File: example.py:7 in toy_example, code: y = x.sin()
sin: "f32[1000][1]cuda:0" = torch.ops.aten.sin.default(primals_1)

# File: example.py:8 in toy_example, code: z = y.cos()
sin_1: "f32[1000][1]cuda:0" = torch.ops.aten.sin.default(sin); sin = None
neg: "f32[1000][1]cuda:0" = torch.ops.aten.neg.default(sin_1); sin_1 = None
mul: "f32[1000][1]cuda:0" = torch.ops.aten.mul.Tensor(tangents_1, neg); tangents_1 = neg = None

# File: example.py:7 in toy_example, code: y = x.sin()
cos_1: "f32[1000][1]cuda:0" = torch.ops.aten.cos.default(primals_1); primals_1 = None
mul_1: "f32[1000][1]cuda:0" = torch.ops.aten.mul.Tensor(mul, cos_1); mul = cos_1 = None
return (mul_1,)

2. Inductor


Triton 的核心: compile()

  • model
  • fullgraph
  • dynamic