[ca] clean up aot node deduping (#149064)

rename the AOT nodes as we copy paste them into the CA graph

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149064
Approved by: https://github.com/jansel
This commit is contained in:
Simon Fan 2025-03-14 14:57:20 -07:00 committed by PyTorch MergeBot
parent 96795e9533
commit f4368d8872
2 changed files with 63 additions and 119 deletions

View File

@ -2,6 +2,8 @@
# flake8: noqa
import functools
import itertools
from unittest import mock
import torch
import torch._dynamo.test_case
@ -91,7 +93,10 @@ class _multiply_invoke(torch.nn.Module):
""",
)
def test_invoke_in_pt2_compiled_autograd(self):
@mock.patch(
"torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count
)
def test_invoke_in_pt2_compiled_autograd(self, _):
graph = None
def compiler_fn(gm):
@ -121,7 +126,7 @@ class _multiply_invoke(torch.nn.Module):
out.backward(grad_out)
actual = normalize_gm(graph.print_readable(False))
self.assertEqual(x.grad, grad_out * grad_out)
if backend in ["aot_eager", "inductor"]:
if backend == "aot_eager":
self.assertExpectedInline(
actual,
"""\
@ -136,11 +141,36 @@ class GraphModule(torch.nn.Module):
getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
getitem_11: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
aot1_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
new_grad: "f32[s0]" = torch.clone(getitem_11)
new_grad: "f32[s0]" = torch.clone(aot1_tangents_1)
result: "f32[s0]" = getitem_11 * getitem_11; getitem_11 = None
result: "f32[s0]" = aot1_tangents_1 * aot1_tangents_1; aot1_tangents_1 = None
new_grad_1: "f32[s0]" = torch.clone(result); result = None
return (new_grad, new_grad_1)
""",
)
elif backend == "inductor":
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_ : list, L_sizes_0_: "Sym(s0)"):
l_inputs_ = L_inputs_
l_sizes_0_ = L_sizes_0_
getitem: "f32[s0]" = l_inputs_[0]; l_inputs_ = None
validate_outputs = torch__dynamo_compiled_autograd_ops_validate_outputs([getitem], [((None, None, device(type='cpu'), 6, 0, None), [l_sizes_0_], False)]); getitem = l_sizes_0_ = None
getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
aot3_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
new_grad: "f32[s0]" = torch.clone(aot3_tangents_1)
result: "f32[s0]" = aot3_tangents_1 * aot3_tangents_1; aot3_tangents_1 = None
new_grad_1: "f32[s0]" = torch.clone(result); result = None
return (new_grad, new_grad_1)
@ -149,7 +179,10 @@ class GraphModule(torch.nn.Module):
graph = None
def test_invoke_in_pt2_compiled_autograd_side_effect(self):
@mock.patch(
"torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count
)
def test_invoke_in_pt2_compiled_autograd_side_effect(self, _):
def _side_effect_stateful_fn2(x, obj):
obj.counter = obj.counter + 1
return _multiply(x)
@ -211,13 +244,13 @@ class GraphModule(torch.nn.Module):
getitem_9: "f32[s0]" = validate_outputs[0]; validate_outputs = None
call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((), [], getitem_9); getitem_9 = None
getitem_11: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
aot0_tangents_1: "f32[s0]" = call_aot_bwd_prologue[0]; call_aot_bwd_prologue = None
new_grad: "f32[s0]" = torch.clone(getitem_11)
new_grad: "f32[s0]" = torch.clone(aot0_tangents_1)
add: "Sym(s7 + 1)" = l_hooks_1_keywords_fn_keywords_obj_counter + 1; l_hooks_1_keywords_fn_keywords_obj_counter = None
result: "f32[s0]" = getitem_11 * getitem_11; getitem_11 = None
result: "f32[s0]" = aot0_tangents_1 * aot0_tangents_1; aot0_tangents_1 = None
new_grad_1: "f32[s0]" = torch.clone(result); result = None
return (new_grad, new_grad_1, add)

View File

@ -22,7 +22,7 @@ import itertools
import operator
import time
from collections import Counter, defaultdict
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Optional, TYPE_CHECKING, Union
import torch
import torch.utils._pytree as pytree
@ -191,6 +191,7 @@ class AutogradCompilerInstance:
):
counters["compiled_autograd"]["captures"] += 1
self.id = next(COMPILE_COUNTER)
self.aot_id_counter: dict[int, int] = defaultdict(int)
self.compile_context = make_compile_context(self.id)
self.compile_context.__enter__()
self.start_time_ns = time.time_ns()
@ -200,8 +201,6 @@ class AutogradCompilerInstance:
{"graph_id": self.id},
log_pt2_compile_event=True,
)
self.aot_graph_cls_name: Optional[str] = None
self.aot_graph_infos: dict[int, dict[str, Any]] = {}
self.fx_tracer.root = torch.nn.Module()
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
self.fx_tracer.tensor_attrs = {}
@ -321,6 +320,7 @@ class AutogradCompilerInstance:
CompiledFunction = ctx._forward_cls
metadata = CompiledFunction.metadata
maybe_subclass_metadata = CompiledFunction.maybe_subclass_metadata
aot_id = CompiledFunction._aot_id
del CompiledFunction
@torch._dynamo.allow_in_graph # type: ignore[misc]
@ -381,9 +381,23 @@ class AutogradCompilerInstance:
args_idx = 0
value_remap = {}
poutputs: Optional[list[torch.fx.Proxy]] = None
# names of nodes must appear only once in the fx.Graph
# dedup AOT backwards that appear multiple times
deduped_aot_id = str(aot_id)
if self.aot_id_counter[aot_id]:
deduped_aot_id += f"_{self.aot_id_counter[aot_id]}"
self.aot_id_counter[aot_id] += 1
def make_unique(node_name):
# make it both informative and unique
return f"aot{deduped_aot_id}_{node_name}"
for node in ctx._bw_module.graph.nodes:
if node.op == "placeholder":
value_remap[node] = pall_args[args_idx].node
ph = pall_args[args_idx].node
ph.name = make_unique(node.name)
value_remap[node] = ph
args_idx += 1
elif node.op == "output":
assert len(node.args) == 1
@ -400,6 +414,7 @@ class AutogradCompilerInstance:
self.fx_tracer.root, qualname, getattr(ctx._bw_module, name)
)
result = self.fx_tracer.create_node("get_attr", qualname, (), {})
result.name = make_unique(node.name)
value_remap[node] = result
elif node.op == "call_function":
if node.target == torch.ops.aten.view.default:
@ -410,9 +425,11 @@ class AutogradCompilerInstance:
result = self.fx_tracer.graph.node_copy(
node, lambda n: value_remap[n]
)
result.name = make_unique(node.name)
value_remap[node] = result
else:
raise AssertionError("shouldn't get here")
assert poutputs is not None
# In general we don't know what the shapes of the outputs are, so allocate
@ -785,7 +802,6 @@ class AutogradCompilerInstance:
f"CompiledAutograd{self.id}PreReordering",
).print_readable(print_output=False),
)
self.rename_aot_dispatcher_nodes()
self.delay_unpack_hook_nodes()
self.reorder_tensor_pre_hook_nodes()
self.reorder_pre_hook_nodes_to_schedule_asap()
@ -843,104 +859,6 @@ class AutogradCompilerInstance:
self.compile_context.__exit__(None, None, None)
return runtime_wrapper, self.compiler_fn(graph)
def rename_aot_dispatcher_nodes(self):
"""
Renames nodes as they appear in the AOTDispatcher backward graphs, prefixed by AOT id
e.g. AOTDispatcher backward graph X's `sin_Y` -> `aotX_sin_Y`
"""
if self.aot_graph_cls_name is None:
return
def is_similar(ca: torch.fx.node.Node, aot: torch.fx.node.Node):
# 1. comparing using target (for aten ops)
target_match = ca.target == aot.target
if not target_match:
# 2. comparing using name (for HOPs)
target_match = (
hasattr(ca.target, "__name__")
and hasattr(aot.target, "__name__")
and ca.target.__name__ == aot.target.__name__
)
if (
not target_match
and hasattr(ca.target, "name")
and hasattr(aot.target, "name")
and aot.target.name() == "aten::reshape"
and hasattr(aot.meta.get("original_aten"), "name")
):
# 3. undo view_to_reshape post grad pass
target_match = ca.target.name() == aot.meta["original_aten"].name()
return (
target_match
and ca.op == aot.op
and ca.type == aot.type
and len(ca.all_input_nodes) == len(aot.all_input_nodes)
)
# number of times we saw this AOT backward graph, used to dedup reused graphs
aot_id_counter: dict[int, int] = defaultdict(int)
for nodecall_index, info in self.aot_graph_infos.items():
ca_node_start_idx = info["ca_node_start_idx"]
aot_id = info["aot_id"]
aot_id_postfix = ""
aot_graph = info["aot_gm"].graph
if aot_id_counter[aot_id]:
aot_id_postfix = f"_{aot_id_counter[aot_id]}"
aot_id_counter[aot_id] += 1
# 1. Find the first op from user code in the AOT graph
aot_it = iter(aot_graph.nodes)
aot_node = next(aot_it)
assert aot_node is not None
try:
while aot_node.op != "call_function":
aot_node = next(aot_it)
except StopIteration:
continue
try:
# 2. Find the first op in the compiled autograd graph segment
ca_it = iter(self.fx_tracer.graph.nodes)
for _ in range(ca_node_start_idx):
next(ca_it)
ca_node = next(ca_it)
# Graphs should all end with output node
while ca_node.op != "output" and not is_similar(ca_node, aot_node):
# The compiled autograd graph may contain lazily inserted ops
# We skip those when aligning nodes
ca_node = next(ca_it)
# 3. Keep alligned and rename nodes
while aot_node.op != "output" and ca_node.op != "output":
if not ca_node.users:
# TODO: DCE for compiled autograd graph
ca_node = next(ca_it)
continue
if not is_similar(ca_node, aot_node):
# There should be no lazily inserted ops in the middle of a match
# So any deviation is an error
raise StopIteration
ca_node.name = f"aot{aot_id}{aot_id_postfix}_{aot_node.name}"
for i, inp in enumerate(aot_node.all_input_nodes):
ca_node.all_input_nodes[
i
].name = f"aot{aot_id}{aot_id_postfix}_{inp.name}"
aot_node = next(aot_it)
ca_node = next(ca_it)
except StopIteration:
verbose_log.debug(
"Failed to match %s%s (NodeCall %s) nodes with AOT backward graph %s nodes",
self.aot_graph_cls_name,
aot_id,
nodecall_index,
aot_id,
)
@staticmethod
def get_all_nodes(args):
# filter out non-Node args, like None
@ -1239,14 +1157,7 @@ class AutogradCompilerInstance:
"""This compiled backward function was saved by AOTAutogradCache, which does not support
compiled autograd. Please turn off AOTAutogradCache using `TORCHINDUCTOR_AUTOGRAD_CACHE=0`."""
)
self.aot_graph_cls_name = node_name
maybe_aot_id = forward_cls._aot_id
self.aot_graph_infos[nodecall_index] = {
"ca_node_start_idx": len(self.fx_tracer.graph.nodes),
"aot_id": maybe_aot_id,
"aot_gm": forward_cls._lazy_backward_info.bw_module,
}
new_code = f"{node_name}{maybe_aot_id} (NodeCall {nodecall_index})"
raw_stack_trace = CapturedTraceback.extract().format()[-1]
new_stack_trace = raw_stack_trace.replace(