torch.fx 的技术亮点在于三个方面:
- Symbolic Tracing
- 高度简洁的 FX Graph IR, 这一部分可以参考这篇介绍
- Python 源码转换
Symbolic Tracing
torch.fx.symbolic_trace() 内部会构建一个 Tracer 来实现 IR 的构建.
具体来说,Tracer 会将输入替换为 Proxy 对象,每一个与 Tensor 相关的调用都会转发给 Tracer.create_proxy(),每一次 Proxy 运算都会对应一个 Node 加入计算图 Graph.
对于 nn.Module 来说,Tracer 会重写 nn.Module.__getattr__ 方法,使得调用 nn.Module 的时候会生成一个 Node,其 Node.op = call_module,并记录其访问路径.
然后 Tracer 以拓扑序构建图中所有 Node,并为每一个 Node 赋予唯一名称,构建为 Graph.最后将 Graph 包装为可执行的 GraphModule,并且在其上做 IR 重写、变换、可视化、代码生成等等.
Limitations
由于 torch.fx.symbolic_trace() 仍然采取的是符号运算的路线,Proxy 不包含实际数据,所以无法支持与输入数据强相关的动态控制流.并且也仅支持对 nn.Module() 进行 tracing,因为 symbolic_trace() 的实现机制依赖于对 nn.Module.__getattr__ 的重写.
在 PyTorch 2.0 以后引入了 torch.fx.experimental.proxy_tensor.make_fx 支持了对函数的追踪,但它是基于执行的 trace