mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ca] clean up aot node deduping (#149064)
rename the AOT nodes as we copy paste them into the CA graph Pull Request resolved: https://github.com/pytorch/pytorch/pull/149064 Approved by: https://github.com/jansel
This commit is contained in:
parent
96795e9533
commit
f4368d8872
|
|
@ -2,6 +2,8 @@
|
|||
# flake8: noqa
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
|
|
@ -91,7 +93,10 @@ class _multiply_invoke(torch.nn.Module):
|
|||
""",
|
||||
)
|
||||
|
||||
def test_invoke_in_pt2_compiled_autograd(self):
|
||||
@mock.patch(
|
||||
"torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count
|
||||
)
|
||||
def test_invoke_in_pt2_compiled_autograd(self, _):
|
||||
graph = None
|
||||
|
||||
def compiler_fn(gm):
|
||||
|
|
@ -121,7 +126,7 @@ class _multiply_invoke(torch.nn.Module):
|
|||
out.backward(grad_out)
|
||||
actual = normalize_gm(graph.print_readable(False))
|
||||
self.assertEqual(x.grad, grad_out * grad_out)
|
||||
if backend in ["aot_eager", "inductor"]:
|
||||
if backend == "aot_eager":
|
||||
self.assertExpectedInline(
|
||||
actual,
|
||||
"""\
|
||||
|
|
@ -136,11 +141,36 @@ class GraphModule(torch.nn.Module):
|
|||
getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
||||
getitem_11: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
||||
aot1_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
||||
|
||||
new_grad: "f32[s0]" = torch.clone(getitem_11)
|
||||
new_grad: "f32[s0]" = torch.clone(aot1_tangents_1)
|
||||
|
||||
result: "f32[s0]" = getitem_11 * getitem_11; getitem_11 = None
|
||||
result: "f32[s0]" = aot1_tangents_1 * aot1_tangents_1; aot1_tangents_1 = None
|
||||
|
||||
new_grad_1: "f32[s0]" = torch.clone(result); result = None
|
||||
return (new_grad, new_grad_1)
|
||||
""",
|
||||
)
|
||||
elif backend == "inductor":
|
||||
self.assertExpectedInline(
|
||||
actual,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(s0)"):
|
||||
l_inputs_ = L_inputs_
|
||||
l_sizes_0_ = L_sizes_0_
|
||||
|
||||
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
|
||||
|
||||
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None
|
||||
getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
||||
aot3_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
||||
|
||||
new_grad: "f32[s0]" = torch.clone(aot3_tangents_1)
|
||||
|
||||
result: "f32[s0]" = aot3_tangents_1 * aot3_tangents_1; aot3_tangents_1 = None
|
||||
|
||||
new_grad_1: "f32[s0]" = torch.clone(result); result = None
|
||||
return (new_grad, new_grad_1)
|
||||
|
|
@ -149,7 +179,10 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
graph = None
|
||||
|
||||
def test_invoke_in_pt2_compiled_autograd_side_effect(self):
|
||||
@mock.patch(
|
||||
"torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count
|
||||
)
|
||||
def test_invoke_in_pt2_compiled_autograd_side_effect(self, _):
|
||||
def _side_effect_stateful_fn2(x, obj):
|
||||
obj.counter = obj.counter + 1
|
||||
return _multiply(x)
|
||||
|
|
@ -211,13 +244,13 @@ class GraphModule(torch.nn.Module):
|
|||
getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None
|
||||
|
||||
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
|
||||
getitem_11: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
||||
aot0_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
|
||||
|
||||
new_grad: "f32[s0]" = torch.clone(getitem_11)
|
||||
new_grad: "f32[s0]" = torch.clone(aot0_tangents_1)
|
||||
|
||||
add: "Sym(s7 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
|
||||
|
||||
result: "f32[s0]" = getitem_11 * getitem_11; getitem_11 = None
|
||||
result: "f32[s0]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None
|
||||
|
||||
new_grad_1: "f32[s0]" = torch.clone(result); result = None
|
||||
return (new_grad, new_grad_1, add)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ import itertools
|
|||
import operator
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -191,6 +191,7 @@ class AutogradCompilerInstance:
|
|||
):
|
||||
counters["compiled_autograd"]["captures"] += 1
|
||||
self.id = next(COMPILE_COUNTER)
|
||||
self.aot_id_counter: dict[int, int] = defaultdict(int)
|
||||
self.compile_context = make_compile_context(self.id)
|
||||
self.compile_context.__enter__()
|
||||
self.start_time_ns = time.time_ns()
|
||||
|
|
@ -200,8 +201,6 @@ class AutogradCompilerInstance:
|
|||
{"graph_id": self.id},
|
||||
log_pt2_compile_event=True,
|
||||
)
|
||||
self.aot_graph_cls_name: Optional[str] = None
|
||||
self.aot_graph_infos: dict[int, dict[str, Any]] = {}
|
||||
self.fx_tracer.root = torch.nn.Module()
|
||||
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
|
||||
self.fx_tracer.tensor_attrs = {}
|
||||
|
|
@ -321,6 +320,7 @@ class AutogradCompilerInstance:
|
|||
CompiledFunction = ctx._forward_cls
|
||||
metadata = CompiledFunction.metadata
|
||||
maybe_subclass_metadata = CompiledFunction.maybe_subclass_metadata
|
||||
aot_id = CompiledFunction._aot_id
|
||||
del CompiledFunction
|
||||
|
||||
@torch._dynamo.allow_in_graph # type: ignore[misc]
|
||||
|
|
@ -381,9 +381,23 @@ class AutogradCompilerInstance:
|
|||
args_idx = 0
|
||||
value_remap = {}
|
||||
poutputs: Optional[list[torch.fx.Proxy]] = None
|
||||
|
||||
# names of nodes must appear only once in the fx.Graph
|
||||
# dedup AOT backwards that appear multiple times
|
||||
deduped_aot_id = str(aot_id)
|
||||
if self.aot_id_counter[aot_id]:
|
||||
deduped_aot_id += f"_{self.aot_id_counter[aot_id]}"
|
||||
self.aot_id_counter[aot_id] += 1
|
||||
|
||||
def make_unique(node_name):
|
||||
# make it both informative and unique
|
||||
return f"aot{deduped_aot_id}_{node_name}"
|
||||
|
||||
for node in ctx._bw_module.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
value_remap[node] = pall_args[args_idx].node
|
||||
ph = pall_args[args_idx].node
|
||||
ph.name = make_unique(node.name)
|
||||
value_remap[node] = ph
|
||||
args_idx += 1
|
||||
elif node.op == "output":
|
||||
assert len(node.args) == 1
|
||||
|
|
@ -400,6 +414,7 @@ class AutogradCompilerInstance:
|
|||
self.fx_tracer.root, qualname, getattr(ctx._bw_module, name)
|
||||
)
|
||||
result = self.fx_tracer.create_node("get_attr", qualname, (), {})
|
||||
result.name = make_unique(node.name)
|
||||
value_remap[node] = result
|
||||
elif node.op == "call_function":
|
||||
if node.target == torch.ops.aten.view.default:
|
||||
|
|
@ -410,9 +425,11 @@ class AutogradCompilerInstance:
|
|||
result = self.fx_tracer.graph.node_copy(
|
||||
node, lambda n: value_remap[n]
|
||||
)
|
||||
result.name = make_unique(node.name)
|
||||
value_remap[node] = result
|
||||
else:
|
||||
raise AssertionError("shouldn't get here")
|
||||
|
||||
assert poutputs is not None
|
||||
|
||||
# In general we don't know what the shapes of the outputs are, so allocate
|
||||
|
|
@ -785,7 +802,6 @@ class AutogradCompilerInstance:
|
|||
f"CompiledAutograd{self.id}PreReordering",
|
||||
).print_readable(print_output=False),
|
||||
)
|
||||
self.rename_aot_dispatcher_nodes()
|
||||
self.delay_unpack_hook_nodes()
|
||||
self.reorder_tensor_pre_hook_nodes()
|
||||
self.reorder_pre_hook_nodes_to_schedule_asap()
|
||||
|
|
@ -843,104 +859,6 @@ class AutogradCompilerInstance:
|
|||
self.compile_context.__exit__(None, None, None)
|
||||
return runtime_wrapper, self.compiler_fn(graph)
|
||||
|
||||
def rename_aot_dispatcher_nodes(self):
|
||||
"""
|
||||
Renames nodes as they appear in the AOTDispatcher backward graphs, prefixed by AOT id
|
||||
e.g. AOTDispatcher backward graph X's `sin_Y` -> `aotX_sin_Y`
|
||||
"""
|
||||
if self.aot_graph_cls_name is None:
|
||||
return
|
||||
|
||||
def is_similar(ca: torch.fx.node.Node, aot: torch.fx.node.Node):
|
||||
# 1. comparing using target (for aten ops)
|
||||
target_match = ca.target == aot.target
|
||||
if not target_match:
|
||||
# 2. comparing using name (for HOPs)
|
||||
target_match = (
|
||||
hasattr(ca.target, "__name__")
|
||||
and hasattr(aot.target, "__name__")
|
||||
and ca.target.__name__ == aot.target.__name__
|
||||
)
|
||||
if (
|
||||
not target_match
|
||||
and hasattr(ca.target, "name")
|
||||
and hasattr(aot.target, "name")
|
||||
and aot.target.name() == "aten::reshape"
|
||||
and hasattr(aot.meta.get("original_aten"), "name")
|
||||
):
|
||||
# 3. undo view_to_reshape post grad pass
|
||||
target_match = ca.target.name() == aot.meta["original_aten"].name()
|
||||
|
||||
return (
|
||||
target_match
|
||||
and ca.op == aot.op
|
||||
and ca.type == aot.type
|
||||
and len(ca.all_input_nodes) == len(aot.all_input_nodes)
|
||||
)
|
||||
|
||||
# number of times we saw this AOT backward graph, used to dedup reused graphs
|
||||
aot_id_counter: dict[int, int] = defaultdict(int)
|
||||
for nodecall_index, info in self.aot_graph_infos.items():
|
||||
ca_node_start_idx = info["ca_node_start_idx"]
|
||||
aot_id = info["aot_id"]
|
||||
aot_id_postfix = ""
|
||||
aot_graph = info["aot_gm"].graph
|
||||
if aot_id_counter[aot_id]:
|
||||
aot_id_postfix = f"_{aot_id_counter[aot_id]}"
|
||||
aot_id_counter[aot_id] += 1
|
||||
|
||||
# 1. Find the first op from user code in the AOT graph
|
||||
aot_it = iter(aot_graph.nodes)
|
||||
aot_node = next(aot_it)
|
||||
assert aot_node is not None
|
||||
try:
|
||||
while aot_node.op != "call_function":
|
||||
aot_node = next(aot_it)
|
||||
except StopIteration:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 2. Find the first op in the compiled autograd graph segment
|
||||
ca_it = iter(self.fx_tracer.graph.nodes)
|
||||
for _ in range(ca_node_start_idx):
|
||||
next(ca_it)
|
||||
ca_node = next(ca_it)
|
||||
|
||||
# Graphs should all end with output node
|
||||
while ca_node.op != "output" and not is_similar(ca_node, aot_node):
|
||||
# The compiled autograd graph may contain lazily inserted ops
|
||||
# We skip those when aligning nodes
|
||||
ca_node = next(ca_it)
|
||||
|
||||
# 3. Keep alligned and rename nodes
|
||||
while aot_node.op != "output" and ca_node.op != "output":
|
||||
if not ca_node.users:
|
||||
# TODO: DCE for compiled autograd graph
|
||||
ca_node = next(ca_it)
|
||||
continue
|
||||
|
||||
if not is_similar(ca_node, aot_node):
|
||||
# There should be no lazily inserted ops in the middle of a match
|
||||
# So any deviation is an error
|
||||
raise StopIteration
|
||||
|
||||
ca_node.name = f"aot{aot_id}{aot_id_postfix}_{aot_node.name}"
|
||||
for i, inp in enumerate(aot_node.all_input_nodes):
|
||||
ca_node.all_input_nodes[
|
||||
i
|
||||
].name = f"aot{aot_id}{aot_id_postfix}_{inp.name}"
|
||||
|
||||
aot_node = next(aot_it)
|
||||
ca_node = next(ca_it)
|
||||
except StopIteration:
|
||||
verbose_log.debug(
|
||||
"Failed to match %s%s (NodeCall %s) nodes with AOT backward graph %s nodes",
|
||||
self.aot_graph_cls_name,
|
||||
aot_id,
|
||||
nodecall_index,
|
||||
aot_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all_nodes(args):
|
||||
# filter out non-Node args, like None
|
||||
|
|
@ -1239,14 +1157,7 @@ class AutogradCompilerInstance:
|
|||
"""This compiled backward function was saved by AOTAutogradCache, which does not support
|
||||
compiled autograd. Please turn off AOTAutogradCache using `TORCHINDUCTOR_AUTOGRAD_CACHE=0`."""
|
||||
)
|
||||
self.aot_graph_cls_name = node_name
|
||||
maybe_aot_id = forward_cls._aot_id
|
||||
self.aot_graph_infos[nodecall_index] = {
|
||||
"ca_node_start_idx": len(self.fx_tracer.graph.nodes),
|
||||
"aot_id": maybe_aot_id,
|
||||
"aot_gm": forward_cls._lazy_backward_info.bw_module,
|
||||
}
|
||||
|
||||
new_code = f"{node_name}{maybe_aot_id} (NodeCall {nodecall_index})"
|
||||
raw_stack_trace = CapturedTraceback.extract().format()[-1]
|
||||
new_stack_trace = raw_stack_trace.replace(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user