from typing importList import torch from torch import _dynamo as td
defmy_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): print("my_compiler() called with FX Graph:") gm.graph.print_tabular() return gm.forward
@td.optimize(my_compiler) deftoy_example(a, b): x = a / (torch.abs(a) + 1) if b.sum() < 0: b = b * -1 return x * b
for _ inrange(1): toy_example(torch.randn(10), torch.randn(10))
1 2 3 4 5 6 7 8 9 10 11
my_compiler() called with FX Graph: opcode name target args kwargs ------------- ------ ------------------------------------------------------ ----------- -------- placeholder l_a_ L_a_ () {} placeholder l_b_ L_b_ () {} call_function abs_1 <built-in method abs of type object at 0x7f3bb993e180> (l_a_,) {} call_function add <built-in function add> (abs_1, 1) {} call_function x <built-in function truediv> (l_a_, add) {} call_method sum_1 sum (l_b_,) {} call_function lt <built-in function lt> (sum_1, 0) {} output output output ((lt, x),) {}