[fx] Optimize TracerBase.create_arg and Graph._gen_python_code (#148292)

Before: 19502951 function calls (18702776 primitive calls) in 8.533 seconds
After: 16402551 function calls (15602452 primitive calls) in 7.701 seconds

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148292
Approved by: https://github.com/oulgen
ghstack dependencies: #148243, #148260, #148261, #148303, #148288
This commit is contained in:
Jason Ansel 2025-03-03 09:32:45 -08:00 committed by PyTorch MergeBot
parent 5eb0337cfd
commit 8531d247ba
3 changed files with 127 additions and 99 deletions

View File

@ -1,20 +1,20 @@
add_loop_eager,compile_time_instruction_count,2859000000,0.015
add_loop_eager,compile_time_instruction_count,2799000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,6250000000,0.025
add_loop_eager_dynamic,compile_time_instruction_count,6131000000,0.025
add_loop_inductor,compile_time_instruction_count,27960000000,0.015
add_loop_inductor,compile_time_instruction_count,27570000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43820000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43120000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,24280000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,23910000000,0.015
@ -22,11 +22,11 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,954100000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17380000000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17150000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15580000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15410000000,0.015
@ -34,15 +34,15 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,98740000
update_hint_regression,compile_time_instruction_count,1629000000,0.02
update_hint_regression,compile_time_instruction_count,1615000000,0.02
sum_floordiv_regression,compile_time_instruction_count,1036000000,0.015
sum_floordiv_regression,compile_time_instruction_count,1029000000,0.015
symint_sum,compile_time_instruction_count,3094000000,0.015
symint_sum,compile_time_instruction_count,3038000000,0.015
@ -54,12 +54,12 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5734000000,0
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7980000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7814000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3606000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3573000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9862000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9747000000,0.015

1 add_loop_eager compile_time_instruction_count 2859000000 2799000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 6250000000 6131000000 0.025
3 add_loop_inductor compile_time_instruction_count 27960000000 27570000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 43820000000 43120000000 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 24280000000 23910000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 954100000 954100000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 17380000000 17150000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 15580000000 15410000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 9874000000 9874000000 0.2
10 update_hint_regression compile_time_instruction_count 1629000000 1615000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 1036000000 1029000000 0.015
12 symint_sum compile_time_instruction_count 3094000000 3038000000 0.015
13 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 1992000000 1992000000 0.015
14 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5734000000 5734000000 0.015
15 aotdispatcher_partitioner_cpu compile_time_instruction_count 7980000000 7814000000 0.015
16 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3606000000 3573000000 0.015
17 aotdispatcher_training_subclass_cpu compile_time_instruction_count 9862000000 9747000000 0.015
18
19
20
22
23
24
25
26
27
28
29
30
31
32
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
54
55
56
57
58
59
60
61
62
63
64
65

View File

@ -479,38 +479,24 @@ class CodeGen:
# Common case: this is a regular module name like 'foo.bar.baz'
return add_global(typename, o)
codes = {
"yellow": "\033[33m",
"cyan": "\033[36m",
"green": "\033[32m",
"blue": "\033[34m",
"red": "\033[31m",
"dim": "\033[2m",
"dim_blue": "\033[2m\033[34m",
"dim_green": "\033[2m\033[32m",
"reset": "\033[0m",
}
def make_wrapper_func(name):
def f(s):
if colored:
return f"{codes[name]}{s}{codes['reset']}"
return s
return f
yellow = make_wrapper_func("yellow") # noqa: F841
cyan = make_wrapper_func("cyan") # noqa: F841
red = make_wrapper_func("red")
green = make_wrapper_func("green") # noqa: F841
dim_green = make_wrapper_func("dim_green")
dim = make_wrapper_func("dim")
dim_blue = make_wrapper_func("dim_blue")
blue = make_wrapper_func("blue")
if colored:
red = _color_fns["red"]
dim_green = _color_fns["dim_green"]
dim = _color_fns["dim"]
dim_blue = _color_fns["dim_blue"]
blue = _color_fns["blue"]
else:
red = _identity
dim_green = _identity
dim = _identity
dim_blue = _identity
blue = _identity
def _get_repr(arg: Any) -> str:
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
if isinstance(arg, Node): # first because common
return repr(arg)
elif isinstance(arg, tuple) and hasattr(arg, "_fields"):
# Handle NamedTuples (if it has `_fields`) via add_global.
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
@ -524,8 +510,6 @@ class CodeGen:
cls = arg.__class__
clsname = add_global(cls.__name__, cls)
return f"{clsname}.{arg.name}"
elif isinstance(arg, Node):
return repr(arg)
elif isinstance(arg, torch.Tensor):
size = list(arg.size())
dtype = str(arg.dtype).split(".")[-1]
@ -545,11 +529,9 @@ class CodeGen:
def _format_args(
args: tuple[Argument, ...], kwargs: dict[str, Argument]
) -> str:
args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
res = [_get_repr(a) for a in args]
res.extend([f"{k} = {_get_repr(v)}" for k, v in kwargs.items()])
return ", ".join(res)
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
@ -564,8 +546,8 @@ class CodeGen:
user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
for input_node in node._input_nodes:
register_last_uses(input_node, node)
def delete_unused_values(user: Node):
"""
@ -603,23 +585,23 @@ class CodeGen:
"""
nonlocal prev_stacktrace
if node.op not in {"placeholder", "output"}:
if node.stack_trace:
if node.stack_trace != prev_stacktrace:
prev_stacktrace = node.stack_trace
summary_str = ""
if parsed_stack_trace := _parse_stack_trace(node.stack_trace):
if node.op not in ("placeholder", "output"):
stack_trace = node.stack_trace
if stack_trace:
if stack_trace != prev_stacktrace:
prev_stacktrace = stack_trace
if parsed_stack_trace := _parse_stack_trace(stack_trace):
summary_str = parsed_stack_trace.get_summary_str()
body.append(f'\n {dim("# " + summary_str)}\n')
else:
summary_str = ""
body.append(f'\n {dim(f"# {summary_str}")}\n')
elif prev_stacktrace != "":
prev_stacktrace = ""
no_stacktrace_msg = "# No stacktrace found for following nodes"
body.append(f"\n{dim(no_stacktrace_msg)}\n")
def stringify_shape(shape: Iterable) -> str:
return f"[{', '.join(str(x) for x in shape)}]"
return f"[{', '.join([str(x) for x in shape])}]"
def emit_node(node: Node):
maybe_type_annotation = (
@ -777,8 +759,8 @@ class CodeGen:
new_lines: list[str] = []
cur_idx = None
for line in "".join(body).split("\n"):
counter = re.search(r"# COUNTER: (\d+)", line)
if counter and counter.group(1) is not None:
counter = _counter_regexp.search(line)
if counter is not None:
cur_idx = int(counter.group(1))
else:
lineno_map[len(new_lines) + prologue_len] = cur_idx
@ -1207,12 +1189,10 @@ class Graph:
# Null out this Node's argument nodes so that the Nodes referred to
# can update their ``users`` accordingly
new_args = map_arg(to_erase.args, lambda n: None)
assert isinstance(new_args, tuple)
to_erase.args = new_args
new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
assert isinstance(new_kwargs, dict)
to_erase.kwargs = new_kwargs
to_erase._update_args_kwargs(
map_arg(to_erase._args, lambda n: None),
map_arg(to_erase._kwargs, lambda n: None),
)
@compatibility(is_backward_compatible=True)
def inserting_before(self, n: Optional[Node] = None):
@ -1723,21 +1703,14 @@ class Graph:
seen_names: set[str] = set()
seen_values: set[Node] = set()
for node in self.nodes:
if node.op not in [
"placeholder",
"call_method",
"call_module",
"call_function",
"get_attr",
"output",
]:
if node.op not in _legal_ops:
raise RuntimeError(f"Node {node} had unknown opcode {node.op}!")
if node.graph is not self:
raise RuntimeError(f"Node '{node}' does not belong to this Graph!")
if node not in self._find_nodes_lookup_table:
raise RuntimeError(f"Node '{node}' is not added to the side table")
map_arg(node.args, lambda arg: check_arg(arg, node))
map_arg(node.kwargs, lambda arg: check_arg(arg, node))
for arg in node._input_nodes:
check_arg(arg, node)
seen_values.add(node)
if node.name in seen_names:
@ -1956,6 +1929,32 @@ class Graph:
return on_generate_code_context_manager()
def _identity(x):
return x
def _make_color_fn(code):
def f(s):
reset = "\033[0m"
return f"{code}{s}{reset}"
return f
_color_codes = {
"yellow": "\033[33m",
"cyan": "\033[36m",
"green": "\033[32m",
"blue": "\033[34m",
"red": "\033[31m",
"dim": "\033[2m",
"dim_blue": "\033[2m\033[34m",
"dim_green": "\033[2m\033[32m",
}
_color_fns = {k: _make_color_fn(v) for k, v in _color_codes.items()}
_counter_regexp = re.compile(r"# COUNTER: (\d+)")
reflectable_magic_methods = {
"add": "{} + {}",
"sub": "{} - {}",

View File

@ -15,11 +15,12 @@ from typing import Any, Callable, Optional
import torch
import torch.fx.traceback as fx_traceback
from torch._C import _fx_map_aggregate as map_aggregate
from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg
from torch.utils._traceback import CapturedTraceback
from ._compatibility import compatibility
from .graph import Graph, magic_methods, reflectable_magic_methods
from .immutable_collections import immutable_dict, immutable_list
from .node import Argument, base_types, Node, Target
from .operator_schemas import check_for_mutable_operation
@ -302,6 +303,13 @@ class TracerBase:
# into the graph. In particular, Tensor operations should go into the graph,
# but non-Tensor operations shouldn't. What that means is that constructors
# for new types *SHOULD NOT* become nodes in the FX graph.
handler = _create_arg_bypass.get(type(a))
if handler is not None:
# this is just a performance optimization and can be removed if needed
# for common types, we have a fast path to avoid isinstance() overhead
# this doesn't remove the checks below since we need to handle subclasses
return handler(self, a)
if isinstance(a, Proxy):
return a.node # most common arg type goes first
elif hasattr(a, "__fx_create_arg__"):
@ -318,24 +326,7 @@ class TracerBase:
elif isinstance(a, list):
return [self.create_arg(elem) for elem in a]
elif isinstance(a, dict):
def no_node(arg):
if isinstance(arg, Node):
raise RuntimeError(
"Keys for dictionaries used as an argument cannot contain a "
f"Node. Got key: {k}"
)
r = {}
for k, v in a.items():
# Check for invalid dict keys. We do not want a Proxy to appear
# anywhere within the key. Since keys can be collection types,
# we iterate through the key with map_aggregate
k = self.create_arg(k)
map_aggregate(k, no_node)
r[k] = self.create_arg(v)
return r
return _create_arg_dict(self, a)
elif isinstance(a, slice):
return slice(
self.create_arg(a.start),
@ -746,3 +737,41 @@ def _define_reflectable(orig_method_name):
for orig_method_name in reflectable_magic_methods:
_define_reflectable(orig_method_name)
def _no_nodes_error(arg):
raise RuntimeError(
"Keys for dictionaries used as an argument cannot contain a "
f"Node. Got key: {arg}"
)
def _create_arg_dict(self, a):
r = {}
for k, v in a.items():
if not isinstance(k, str):
# Check for invalid dict keys. We do not want a Proxy to appear
# anywhere within the key. Since keys can be collection types,
# we iterate through the key with map_arg
k = self.create_arg(k)
map_arg(k, _no_nodes_error)
r[k] = self.create_arg(v)
return r
_create_arg_bypass = {
t: lambda self, a: a
for t in [
*base_types,
type(None),
type(...),
torch._ops.OpOverload,
torch._ops.HigherOrderOperator,
]
}
_create_arg_bypass[Proxy] = lambda self, a: a.node
_create_arg_bypass[tuple] = lambda self, a: tuple([self.create_arg(elem) for elem in a])
_create_arg_bypass[list] = lambda self, a: [self.create_arg(elem) for elem in a]
_create_arg_bypass[dict] = _create_arg_dict
_create_arg_bypass[immutable_list] = _create_arg_bypass[list]
_create_arg_bypass[immutable_dict] = _create_arg_bypass[dict]