torch.fx 提供了一套轻量、高层、可编程的 IR,这一套 IR 一直沿用到 PyTorch v2.0.主要分为三个部分:

  • torch.fx.node.Node:IR 的基本单位,每一个 Node 表示一个操作 (operation)
  • torch.fx.graph.Graph:一个 IR 的整体容器,其内部按序保存了计算图里的所有 Node 和其拓扑关系.其本身没有执行功能,只是结构表示
  • GraphModule extends nn.Module:负责将 Graph 对象封装为 nn.Module,可以 .forward()、保存、恢复为源码.

Node.op

Node 的操作本质上可以分为以下 66 种:

Node.op Explanation
placeholder 计算图的输入
get_attr 读取模块的参数
call_function 函数调用
call_module 调用模块 .forward()
call_method 调用张量方法
output 计算图输出

torch.fx IR 示例

我们以以下的 Python 代码为例,看看 Node

main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os

os.environ["TORCH_LOGS"] = "dynamo,output_code,graph,guards"

from typing import List
import torch
from torch import _dynamo as td


def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX Graph:")
gm.graph.print_tabular()
return gm.forward


@td.optimize(my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b


for _ in range(1):
toy_example(torch.randn(10), torch.randn(10))
1
2
3
4
5
6
7
8
9
10
11
my_compiler() called with FX Graph:
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ----------- --------
placeholder l_a_ L_a_ () {}
placeholder l_b_ L_b_ () {}
call_function abs_1 <built-in method abs of type object at 0x7f3bb993e180> (l_a_,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function x <built-in function truediv> (l_a_, add) {}
call_method sum_1 sum (l_b_,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((lt, x),) {}