mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[fx] Optimize TracerBase.create_arg and Graph._gen_python_code (#148292)
Before: 19502951 function calls (18702776 primitive calls) in 8.533 seconds After: 16402551 function calls (15602452 primitive calls) in 7.701 seconds Pull Request resolved: https://github.com/pytorch/pytorch/pull/148292 Approved by: https://github.com/oulgen ghstack dependencies: #148243, #148260, #148261, #148303, #148288
This commit is contained in:
parent
5eb0337cfd
commit
8531d247ba
|
|
@ -1,20 +1,20 @@
|
||||||
add_loop_eager,compile_time_instruction_count,2859000000,0.015
|
add_loop_eager,compile_time_instruction_count,2799000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_eager_dynamic,compile_time_instruction_count,6250000000,0.025
|
add_loop_eager_dynamic,compile_time_instruction_count,6131000000,0.025
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor,compile_time_instruction_count,27960000000,0.015
|
add_loop_inductor,compile_time_instruction_count,27570000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43820000000,0.025
|
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43120000000,0.025
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
add_loop_inductor_gpu,compile_time_instruction_count,24280000000,0.015
|
add_loop_inductor_gpu,compile_time_instruction_count,23910000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -22,11 +22,11 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,954100000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17380000000,0.015
|
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17150000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15580000000,0.015
|
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15410000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -34,15 +34,15 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,98740000
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
update_hint_regression,compile_time_instruction_count,1629000000,0.02
|
update_hint_regression,compile_time_instruction_count,1615000000,0.02
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
sum_floordiv_regression,compile_time_instruction_count,1036000000,0.015
|
sum_floordiv_regression,compile_time_instruction_count,1029000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
symint_sum,compile_time_instruction_count,3094000000,0.015
|
symint_sum,compile_time_instruction_count,3038000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,12 +54,12 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5734000000,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7980000000,0.015
|
aotdispatcher_partitioner_cpu,compile_time_instruction_count,7814000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3606000000,0.015
|
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3573000000,0.015
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9862000000,0.015
|
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9747000000,0.015
|
||||||
|
|
|
||||||
|
|
|
@ -479,38 +479,24 @@ class CodeGen:
|
||||||
# Common case: this is a regular module name like 'foo.bar.baz'
|
# Common case: this is a regular module name like 'foo.bar.baz'
|
||||||
return add_global(typename, o)
|
return add_global(typename, o)
|
||||||
|
|
||||||
codes = {
|
if colored:
|
||||||
"yellow": "\033[33m",
|
red = _color_fns["red"]
|
||||||
"cyan": "\033[36m",
|
dim_green = _color_fns["dim_green"]
|
||||||
"green": "\033[32m",
|
dim = _color_fns["dim"]
|
||||||
"blue": "\033[34m",
|
dim_blue = _color_fns["dim_blue"]
|
||||||
"red": "\033[31m",
|
blue = _color_fns["blue"]
|
||||||
"dim": "\033[2m",
|
else:
|
||||||
"dim_blue": "\033[2m\033[34m",
|
red = _identity
|
||||||
"dim_green": "\033[2m\033[32m",
|
dim_green = _identity
|
||||||
"reset": "\033[0m",
|
dim = _identity
|
||||||
}
|
dim_blue = _identity
|
||||||
|
blue = _identity
|
||||||
def make_wrapper_func(name):
|
|
||||||
def f(s):
|
|
||||||
if colored:
|
|
||||||
return f"{codes[name]}{s}{codes['reset']}"
|
|
||||||
return s
|
|
||||||
|
|
||||||
return f
|
|
||||||
|
|
||||||
yellow = make_wrapper_func("yellow") # noqa: F841
|
|
||||||
cyan = make_wrapper_func("cyan") # noqa: F841
|
|
||||||
red = make_wrapper_func("red")
|
|
||||||
green = make_wrapper_func("green") # noqa: F841
|
|
||||||
dim_green = make_wrapper_func("dim_green")
|
|
||||||
dim = make_wrapper_func("dim")
|
|
||||||
dim_blue = make_wrapper_func("dim_blue")
|
|
||||||
blue = make_wrapper_func("blue")
|
|
||||||
|
|
||||||
def _get_repr(arg: Any) -> str:
|
def _get_repr(arg: Any) -> str:
|
||||||
# Handle NamedTuples (if it has `_fields`) via add_global.
|
if isinstance(arg, Node): # first because common
|
||||||
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
|
return repr(arg)
|
||||||
|
elif isinstance(arg, tuple) and hasattr(arg, "_fields"):
|
||||||
|
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||||
qualified_name = _get_qualified_name(type(arg))
|
qualified_name = _get_qualified_name(type(arg))
|
||||||
global_name = add_global(qualified_name, type(arg))
|
global_name = add_global(qualified_name, type(arg))
|
||||||
return f"{global_name}{repr(tuple(arg))}"
|
return f"{global_name}{repr(tuple(arg))}"
|
||||||
|
|
@ -524,8 +510,6 @@ class CodeGen:
|
||||||
cls = arg.__class__
|
cls = arg.__class__
|
||||||
clsname = add_global(cls.__name__, cls)
|
clsname = add_global(cls.__name__, cls)
|
||||||
return f"{clsname}.{arg.name}"
|
return f"{clsname}.{arg.name}"
|
||||||
elif isinstance(arg, Node):
|
|
||||||
return repr(arg)
|
|
||||||
elif isinstance(arg, torch.Tensor):
|
elif isinstance(arg, torch.Tensor):
|
||||||
size = list(arg.size())
|
size = list(arg.size())
|
||||||
dtype = str(arg.dtype).split(".")[-1]
|
dtype = str(arg.dtype).split(".")[-1]
|
||||||
|
|
@ -545,11 +529,9 @@ class CodeGen:
|
||||||
def _format_args(
|
def _format_args(
|
||||||
args: tuple[Argument, ...], kwargs: dict[str, Argument]
|
args: tuple[Argument, ...], kwargs: dict[str, Argument]
|
||||||
) -> str:
|
) -> str:
|
||||||
args_s = ", ".join(_get_repr(a) for a in args)
|
res = [_get_repr(a) for a in args]
|
||||||
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
|
res.extend([f"{k} = {_get_repr(v)}" for k, v in kwargs.items()])
|
||||||
if args_s and kwargs_s:
|
return ", ".join(res)
|
||||||
return f"{args_s}, {kwargs_s}"
|
|
||||||
return args_s or kwargs_s
|
|
||||||
|
|
||||||
# Run through reverse nodes and record the first instance of a use
|
# Run through reverse nodes and record the first instance of a use
|
||||||
# of a given node. This represents the *last* use of the node in the
|
# of a given node. This represents the *last* use of the node in the
|
||||||
|
|
@ -564,8 +546,8 @@ class CodeGen:
|
||||||
user_to_last_uses.setdefault(user, []).append(n)
|
user_to_last_uses.setdefault(user, []).append(n)
|
||||||
|
|
||||||
for node in reversed(nodes):
|
for node in reversed(nodes):
|
||||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
for input_node in node._input_nodes:
|
||||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
register_last_uses(input_node, node)
|
||||||
|
|
||||||
def delete_unused_values(user: Node):
|
def delete_unused_values(user: Node):
|
||||||
"""
|
"""
|
||||||
|
|
@ -603,23 +585,23 @@ class CodeGen:
|
||||||
"""
|
"""
|
||||||
nonlocal prev_stacktrace
|
nonlocal prev_stacktrace
|
||||||
|
|
||||||
if node.op not in {"placeholder", "output"}:
|
if node.op not in ("placeholder", "output"):
|
||||||
if node.stack_trace:
|
stack_trace = node.stack_trace
|
||||||
if node.stack_trace != prev_stacktrace:
|
if stack_trace:
|
||||||
prev_stacktrace = node.stack_trace
|
if stack_trace != prev_stacktrace:
|
||||||
summary_str = ""
|
prev_stacktrace = stack_trace
|
||||||
|
if parsed_stack_trace := _parse_stack_trace(stack_trace):
|
||||||
if parsed_stack_trace := _parse_stack_trace(node.stack_trace):
|
|
||||||
summary_str = parsed_stack_trace.get_summary_str()
|
summary_str = parsed_stack_trace.get_summary_str()
|
||||||
|
else:
|
||||||
body.append(f'\n {dim("# " + summary_str)}\n')
|
summary_str = ""
|
||||||
|
body.append(f'\n {dim(f"# {summary_str}")}\n')
|
||||||
elif prev_stacktrace != "":
|
elif prev_stacktrace != "":
|
||||||
prev_stacktrace = ""
|
prev_stacktrace = ""
|
||||||
no_stacktrace_msg = "# No stacktrace found for following nodes"
|
no_stacktrace_msg = "# No stacktrace found for following nodes"
|
||||||
body.append(f"\n{dim(no_stacktrace_msg)}\n")
|
body.append(f"\n{dim(no_stacktrace_msg)}\n")
|
||||||
|
|
||||||
def stringify_shape(shape: Iterable) -> str:
|
def stringify_shape(shape: Iterable) -> str:
|
||||||
return f"[{', '.join(str(x) for x in shape)}]"
|
return f"[{', '.join([str(x) for x in shape])}]"
|
||||||
|
|
||||||
def emit_node(node: Node):
|
def emit_node(node: Node):
|
||||||
maybe_type_annotation = (
|
maybe_type_annotation = (
|
||||||
|
|
@ -777,8 +759,8 @@ class CodeGen:
|
||||||
new_lines: list[str] = []
|
new_lines: list[str] = []
|
||||||
cur_idx = None
|
cur_idx = None
|
||||||
for line in "".join(body).split("\n"):
|
for line in "".join(body).split("\n"):
|
||||||
counter = re.search(r"# COUNTER: (\d+)", line)
|
counter = _counter_regexp.search(line)
|
||||||
if counter and counter.group(1) is not None:
|
if counter is not None:
|
||||||
cur_idx = int(counter.group(1))
|
cur_idx = int(counter.group(1))
|
||||||
else:
|
else:
|
||||||
lineno_map[len(new_lines) + prologue_len] = cur_idx
|
lineno_map[len(new_lines) + prologue_len] = cur_idx
|
||||||
|
|
@ -1207,12 +1189,10 @@ class Graph:
|
||||||
|
|
||||||
# Null out this Node's argument nodes so that the Nodes referred to
|
# Null out this Node's argument nodes so that the Nodes referred to
|
||||||
# can update their ``users`` accordingly
|
# can update their ``users`` accordingly
|
||||||
new_args = map_arg(to_erase.args, lambda n: None)
|
to_erase._update_args_kwargs(
|
||||||
assert isinstance(new_args, tuple)
|
map_arg(to_erase._args, lambda n: None),
|
||||||
to_erase.args = new_args
|
map_arg(to_erase._kwargs, lambda n: None),
|
||||||
new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
|
)
|
||||||
assert isinstance(new_kwargs, dict)
|
|
||||||
to_erase.kwargs = new_kwargs
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def inserting_before(self, n: Optional[Node] = None):
|
def inserting_before(self, n: Optional[Node] = None):
|
||||||
|
|
@ -1723,21 +1703,14 @@ class Graph:
|
||||||
seen_names: set[str] = set()
|
seen_names: set[str] = set()
|
||||||
seen_values: set[Node] = set()
|
seen_values: set[Node] = set()
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
if node.op not in [
|
if node.op not in _legal_ops:
|
||||||
"placeholder",
|
|
||||||
"call_method",
|
|
||||||
"call_module",
|
|
||||||
"call_function",
|
|
||||||
"get_attr",
|
|
||||||
"output",
|
|
||||||
]:
|
|
||||||
raise RuntimeError(f"Node {node} had unknown opcode {node.op}!")
|
raise RuntimeError(f"Node {node} had unknown opcode {node.op}!")
|
||||||
if node.graph is not self:
|
if node.graph is not self:
|
||||||
raise RuntimeError(f"Node '{node}' does not belong to this Graph!")
|
raise RuntimeError(f"Node '{node}' does not belong to this Graph!")
|
||||||
if node not in self._find_nodes_lookup_table:
|
if node not in self._find_nodes_lookup_table:
|
||||||
raise RuntimeError(f"Node '{node}' is not added to the side table")
|
raise RuntimeError(f"Node '{node}' is not added to the side table")
|
||||||
map_arg(node.args, lambda arg: check_arg(arg, node))
|
for arg in node._input_nodes:
|
||||||
map_arg(node.kwargs, lambda arg: check_arg(arg, node))
|
check_arg(arg, node)
|
||||||
seen_values.add(node)
|
seen_values.add(node)
|
||||||
|
|
||||||
if node.name in seen_names:
|
if node.name in seen_names:
|
||||||
|
|
@ -1956,6 +1929,32 @@ class Graph:
|
||||||
return on_generate_code_context_manager()
|
return on_generate_code_context_manager()
|
||||||
|
|
||||||
|
|
||||||
|
def _identity(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _make_color_fn(code):
|
||||||
|
def f(s):
|
||||||
|
reset = "\033[0m"
|
||||||
|
return f"{code}{s}{reset}"
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
_color_codes = {
|
||||||
|
"yellow": "\033[33m",
|
||||||
|
"cyan": "\033[36m",
|
||||||
|
"green": "\033[32m",
|
||||||
|
"blue": "\033[34m",
|
||||||
|
"red": "\033[31m",
|
||||||
|
"dim": "\033[2m",
|
||||||
|
"dim_blue": "\033[2m\033[34m",
|
||||||
|
"dim_green": "\033[2m\033[32m",
|
||||||
|
}
|
||||||
|
_color_fns = {k: _make_color_fn(v) for k, v in _color_codes.items()}
|
||||||
|
_counter_regexp = re.compile(r"# COUNTER: (\d+)")
|
||||||
|
|
||||||
|
|
||||||
reflectable_magic_methods = {
|
reflectable_magic_methods = {
|
||||||
"add": "{} + {}",
|
"add": "{} + {}",
|
||||||
"sub": "{} - {}",
|
"sub": "{} - {}",
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,12 @@ from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx.traceback as fx_traceback
|
import torch.fx.traceback as fx_traceback
|
||||||
from torch._C import _fx_map_aggregate as map_aggregate
|
from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg
|
||||||
from torch.utils._traceback import CapturedTraceback
|
from torch.utils._traceback import CapturedTraceback
|
||||||
|
|
||||||
from ._compatibility import compatibility
|
from ._compatibility import compatibility
|
||||||
from .graph import Graph, magic_methods, reflectable_magic_methods
|
from .graph import Graph, magic_methods, reflectable_magic_methods
|
||||||
|
from .immutable_collections import immutable_dict, immutable_list
|
||||||
from .node import Argument, base_types, Node, Target
|
from .node import Argument, base_types, Node, Target
|
||||||
from .operator_schemas import check_for_mutable_operation
|
from .operator_schemas import check_for_mutable_operation
|
||||||
|
|
||||||
|
|
@ -302,6 +303,13 @@ class TracerBase:
|
||||||
# into the graph. In particular, Tensor operations should go into the graph,
|
# into the graph. In particular, Tensor operations should go into the graph,
|
||||||
# but non-Tensor operations shouldn't. What that means is that constructors
|
# but non-Tensor operations shouldn't. What that means is that constructors
|
||||||
# for new types *SHOULD NOT* become nodes in the FX graph.
|
# for new types *SHOULD NOT* become nodes in the FX graph.
|
||||||
|
handler = _create_arg_bypass.get(type(a))
|
||||||
|
if handler is not None:
|
||||||
|
# this is just a performance optimization and can be removed if needed
|
||||||
|
# for common types, we have a fast path to avoid isinstance() overhead
|
||||||
|
# this doesn't remove the checks below since we need to handle subclasses
|
||||||
|
return handler(self, a)
|
||||||
|
|
||||||
if isinstance(a, Proxy):
|
if isinstance(a, Proxy):
|
||||||
return a.node # most common arg type goes first
|
return a.node # most common arg type goes first
|
||||||
elif hasattr(a, "__fx_create_arg__"):
|
elif hasattr(a, "__fx_create_arg__"):
|
||||||
|
|
@ -318,24 +326,7 @@ class TracerBase:
|
||||||
elif isinstance(a, list):
|
elif isinstance(a, list):
|
||||||
return [self.create_arg(elem) for elem in a]
|
return [self.create_arg(elem) for elem in a]
|
||||||
elif isinstance(a, dict):
|
elif isinstance(a, dict):
|
||||||
|
return _create_arg_dict(self, a)
|
||||||
def no_node(arg):
|
|
||||||
if isinstance(arg, Node):
|
|
||||||
raise RuntimeError(
|
|
||||||
"Keys for dictionaries used as an argument cannot contain a "
|
|
||||||
f"Node. Got key: {k}"
|
|
||||||
)
|
|
||||||
|
|
||||||
r = {}
|
|
||||||
for k, v in a.items():
|
|
||||||
# Check for invalid dict keys. We do not want a Proxy to appear
|
|
||||||
# anywhere within the key. Since keys can be collection types,
|
|
||||||
# we iterate through the key with map_aggregate
|
|
||||||
k = self.create_arg(k)
|
|
||||||
map_aggregate(k, no_node)
|
|
||||||
|
|
||||||
r[k] = self.create_arg(v)
|
|
||||||
return r
|
|
||||||
elif isinstance(a, slice):
|
elif isinstance(a, slice):
|
||||||
return slice(
|
return slice(
|
||||||
self.create_arg(a.start),
|
self.create_arg(a.start),
|
||||||
|
|
@ -746,3 +737,41 @@ def _define_reflectable(orig_method_name):
|
||||||
|
|
||||||
for orig_method_name in reflectable_magic_methods:
|
for orig_method_name in reflectable_magic_methods:
|
||||||
_define_reflectable(orig_method_name)
|
_define_reflectable(orig_method_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _no_nodes_error(arg):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Keys for dictionaries used as an argument cannot contain a "
|
||||||
|
f"Node. Got key: {arg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_arg_dict(self, a):
|
||||||
|
r = {}
|
||||||
|
for k, v in a.items():
|
||||||
|
if not isinstance(k, str):
|
||||||
|
# Check for invalid dict keys. We do not want a Proxy to appear
|
||||||
|
# anywhere within the key. Since keys can be collection types,
|
||||||
|
# we iterate through the key with map_arg
|
||||||
|
k = self.create_arg(k)
|
||||||
|
map_arg(k, _no_nodes_error)
|
||||||
|
r[k] = self.create_arg(v)
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
_create_arg_bypass = {
|
||||||
|
t: lambda self, a: a
|
||||||
|
for t in [
|
||||||
|
*base_types,
|
||||||
|
type(None),
|
||||||
|
type(...),
|
||||||
|
torch._ops.OpOverload,
|
||||||
|
torch._ops.HigherOrderOperator,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
_create_arg_bypass[Proxy] = lambda self, a: a.node
|
||||||
|
_create_arg_bypass[tuple] = lambda self, a: tuple([self.create_arg(elem) for elem in a])
|
||||||
|
_create_arg_bypass[list] = lambda self, a: [self.create_arg(elem) for elem in a]
|
||||||
|
_create_arg_bypass[dict] = _create_arg_dict
|
||||||
|
_create_arg_bypass[immutable_list] = _create_arg_bypass[list]
|
||||||
|
_create_arg_bypass[immutable_dict] = _create_arg_bypass[dict]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user