[Hierarchical Compile] Replace tracing alias and mutation check with dynamo impl (#152570)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152570
Approved by: https://github.com/anijain2305
ghstack dependencies: #152389, #152505, #152410, #152506
This commit is contained in:
Michael Lazos 2025-05-12 22:19:59 -07:00 committed by PyTorch MergeBot
parent 57dafb90ef
commit a415c9831f
2 changed files with 210 additions and 107 deletions

View File

@ -461,12 +461,6 @@ class GraphModule(torch.nn.Module):
)
def test_input_mutation(self):
def inner_fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
def inner_fn2(x, y):
x0 = x + 1
y0 = y + 1
@ -476,9 +470,6 @@ class GraphModule(torch.nn.Module):
def fn(x, y):
x0 = torch.sin(x)
_y0 = torch.cos(y)
# o0 = inner_fn(x0, y0)
# o1 = inner_fn(x0, o0)
o2 = inner_fn2(x0, y)
o3 = inner_fn2(x0.clone(), y.clone())
return o2 + o3
@ -985,10 +976,7 @@ class <lambda>(torch.nn.Module):
)
def test_mutation_ordering(self):
from torch._dynamo.graph_deduplication import (
_populate_additional_deps,
_stable_topological_sort,
)
from torch._dynamo.graph_deduplication import _stable_topological_sort
def inner_fn(x, y):
x0 = x.view(x.size())
@ -1013,74 +1001,109 @@ class <lambda>(torch.nn.Module):
x_clone = x.clone()
y_clone = y.clone()
graph, tracker = extract_graph_and_tracker(fn, x_clone, y_clone)
graph, _ = extract_graph_and_tracker(fn, x_clone, y_clone)
def graph_code(graph):
return graph.python_code("self").src
def get_node(name):
return next(n for n in graph.nodes if n.name == name)
additional_deps = _populate_additional_deps(
graph, tracker.node_to_mutated_arg_positions
)
self.assertExpectedInline(
additional_deps,
"""defaultdict(<class 'torch.utils._ordered_set.OrderedSet'>, {add_: OrderedSet([x0, x0_1]), invoke_subgraph: OrderedSet([add_]), invoke_subgraph_1: OrderedSet([add_, mul_]), mul_: OrderedSet([invoke_subgraph])})""",
graph_code(graph),
"""\
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
subgraph_0 = self.subgraph_0
l_x_ = L_x_
l_y_ = L_y_
x0 = l_x_.view((10, 10))
o0 = x0.view((10, 10)); x0 = None
x0_1 = l_x_.view((10, 10))
o1 = x0_1.view((10, 10)); x0_1 = None
add_ = l_x_.add_(l_x_); add_ = None
add_2 = o0 + o1; o0 = o1 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
mul_ = l_y_.mul_(l_y_); mul_ = None
getitem = invoke_subgraph[0]; invoke_subgraph = None
sum_5 = getitem.sum(); getitem = None
add_3 = add_2 + sum_5; add_2 = sum_5 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_6 = getitem_1.sum(); getitem_1 = None
add_4 = add_3 + sum_6; add_3 = sum_6 = None
return (add_4,)
""",
)
# Shuffle nodes in the graph
add_ = get_node("add_")
mul_ = get_node("mul_")
x0 = get_node("x0")
x0.append(mul_)
o1 = get_node("o1")
o1.append(add_)
o1.append(mul_)
add_2 = get_node("add_2")
add_2.append(add_)
self.assertExpectedInline(
graph,
graph_code(graph),
"""\
graph():
%subgraph_0 : [num_users=2] = get_attr[target=subgraph_0]
%l_x_ : torch.Tensor [num_users=5] = placeholder[target=L_x_]
%l_y_ : torch.Tensor [num_users=3] = placeholder[target=L_y_]
%x0 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
%mul_ : [num_users=0] = call_method[target=mul_](args = (%l_y_, %l_y_), kwargs = {})
%o0 : [num_users=1] = call_method[target=view](args = (%x0, (10, 10)), kwargs = {})
%x0_1 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
%o1 : [num_users=1] = call_method[target=view](args = (%x0_1, (10, 10)), kwargs = {})
%add_ : [num_users=0] = call_method[target=add_](args = (%l_x_, %l_x_), kwargs = {})
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o1), kwargs = {})
%invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
%sum_5 : [num_users=1] = call_method[target=sum](args = (%getitem,), kwargs = {})
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%add_2, %sum_5), kwargs = {})
%invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
%sum_6 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {})
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%add_3, %sum_6), kwargs = {})
return (add_4,)""",
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
subgraph_0 = self.subgraph_0
l_x_ = L_x_
l_y_ = L_y_
x0 = l_x_.view((10, 10))
o0 = x0.view((10, 10)); x0 = None
x0_1 = l_x_.view((10, 10))
o1 = x0_1.view((10, 10)); x0_1 = None
mul_ = l_y_.mul_(l_y_); mul_ = None
add_2 = o0 + o1; o0 = o1 = None
add_ = l_x_.add_(l_x_); add_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
getitem = invoke_subgraph[0]; invoke_subgraph = None
sum_5 = getitem.sum(); getitem = None
add_3 = add_2 + sum_5; add_2 = sum_5 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_6 = getitem_1.sum(); getitem_1 = None
add_4 = add_3 + sum_6; add_3 = sum_6 = None
return (add_4,)
""",
)
_stable_topological_sort(
graph, torch._dynamo.graph_deduplication.last_node_to_additional_deps
)
_stable_topological_sort(graph, additional_deps)
self.assertExpectedInline(
graph,
graph_code(graph),
"""\
graph():
%subgraph_0 : [num_users=2] = get_attr[target=subgraph_0]
%l_x_ : torch.Tensor [num_users=5] = placeholder[target=L_x_]
%l_y_ : torch.Tensor [num_users=3] = placeholder[target=L_y_]
%x0 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
%o0 : [num_users=1] = call_method[target=view](args = (%x0, (10, 10)), kwargs = {})
%x0_1 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
%o1 : [num_users=1] = call_method[target=view](args = (%x0_1, (10, 10)), kwargs = {})
%add_ : [num_users=0] = call_method[target=add_](args = (%l_x_, %l_x_), kwargs = {})
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o1), kwargs = {})
%invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
%mul_ : [num_users=0] = call_method[target=mul_](args = (%l_y_, %l_y_), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
%sum_5 : [num_users=1] = call_method[target=sum](args = (%getitem,), kwargs = {})
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%add_2, %sum_5), kwargs = {})
%invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
%sum_6 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {})
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%add_3, %sum_6), kwargs = {})
return (add_4,)""",
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
subgraph_0 = self.subgraph_0
l_x_ = L_x_
l_y_ = L_y_
x0 = l_x_.view((10, 10))
o0 = x0.view((10, 10)); x0 = None
x0_1 = l_x_.view((10, 10))
o1 = x0_1.view((10, 10)); x0_1 = None
add_2 = o0 + o1; o0 = o1 = None
add_ = l_x_.add_(l_x_); add_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
mul_ = l_y_.mul_(l_y_); mul_ = None
getitem = invoke_subgraph[0]; invoke_subgraph = None
sum_5 = getitem.sum(); getitem = None
add_3 = add_2 + sum_5; add_2 = sum_5 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_6 = getitem_1.sum(); getitem_1 = None
add_4 = add_3 + sum_6; add_3 = sum_6 = None
return (add_4,)
""",
)

View File

@ -11,20 +11,28 @@ import logging
import operator
from collections import defaultdict
from collections.abc import Generator, Iterable
from typing import Any
from typing import Optional
import torch
import torch.fx
from torch._dynamo import config
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._ordered_set import OrderedSet
from .graph_region_tracker import Node, Region
from .graph_utils import _detect_cycles, _get_flat_args, _get_flat_args_unique
# Represents an index into the region
# to select a node and then
# an index into that node's
# flattened arguments
UsageIndex = tuple[int, int]
log = logging.getLogger(__name__)
last_node_to_additional_deps: Optional[dict[Node, OrderedSet[Node]]] = None
def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def]
"""
@ -57,6 +65,9 @@ when they are created in output_graph.
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
output_graph.graph
)
node_to_mutated_arg_positions = (
output_graph.region_tracker.node_to_mutated_arg_positions
)
node_to_additional_deps = _populate_additional_deps(
output_graph.graph, output_graph.region_tracker.node_to_mutated_arg_positions
)
@ -68,11 +79,11 @@ when they are created in output_graph.
region = region_group[0]
(
subgraph,
node_ind_arg_inds,
external_node_usages,
) = _create_subgraph(region, inds_with_external_users)
# Ignore regions with no args for now, could they possibly be evaluated at compile time?
if not list(node_ind_arg_inds):
if not list(external_node_usages):
continue
sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
@ -82,21 +93,26 @@ when they are created in output_graph.
get_subgraph_node = output_graph.graph.create_node(
"get_attr", subgraph_name, (), {}
)
for region in region_group:
_replace_region_with_subgraph(
output_graph.graph,
region,
get_subgraph_node,
node_ind_arg_inds.keys(),
external_node_usages,
inds_with_external_users,
sub_gm,
subgraph_name,
node_to_additional_deps,
node_to_mutated_arg_positions,
)
# This is to expose the updated node_to_additional_deps to tests
global last_node_to_additional_deps
last_node_to_additional_deps = node_to_additional_deps
_stable_topological_sort(
output_graph.graph,
node_to_additional_deps, # type: ignore[arg-type]
node_to_additional_deps,
)
return sub_gms
@ -105,27 +121,34 @@ def _replace_region_with_subgraph(
graph: torch.fx.Graph,
region: Region,
get_subgraph_node: Node,
node_ind_arg_ind: Iterable[tuple[int, int]],
external_node_usages: Iterable[OrderedSet[UsageIndex]],
inds_with_external_users: list[int],
sub_gm: torch.fx.GraphModule,
subgraph_name: str,
node_to_additional_deps: dict[torch.fx.Node, OrderedSet[torch.fx.Node]],
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
) -> None:
sub_args = []
for node_ind, arg_ind in node_ind_arg_ind:
for usages in external_node_usages:
node_ind, usage_ind = next(iter(usages))
node = region[node_ind]
flattened_args_kwargs = _get_flat_args(node, {})
sub_args.append(flattened_args_kwargs[arg_ind])
for user_ind, node_usage_ind in usages:
user = region[user_ind]
if user in node_to_mutated_arg_positions:
if node_usage_ind in node_to_mutated_arg_positions[user]:
log.debug(
"NYI: Failed to substitute region %s due to mutation", region
)
return
sub_args.append(flattened_args_kwargs[usage_ind])
# Input/Output aliasing not supported in HOPs today
# Note: we should use the nodes in the original graph (the region here)
# because we use the original traced example values for this check
if _has_aliasing(region, sub_args, inds_with_external_users):
return
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
fake_inputs = [node.meta["example_value"] for node in sub_args]
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
log.debug(
"NYI: Failed to substitute region %s due to input alias or mutation",
region,
)
return
invoke_subgraph_node = graph.create_node(
"call_function",
@ -143,6 +166,10 @@ def _replace_region_with_subgraph(
# Erase in reverse topological order
for node in reversed(region):
graph.erase_node(node)
# Remove any nodes with additional deps
# This is safe; we've guaranteed that there is
# no input mutation, so all additional deps
# will be internal to the subgraph
node_to_additional_deps.pop(node, None)
for deps in node_to_additional_deps.values():
try:
@ -152,27 +179,27 @@ def _replace_region_with_subgraph(
pass
if config.graph_deduplication_lint:
_detect_cycles(graph, node_to_additional_deps)
print(_detect_cycles(graph, node_to_additional_deps))
_stable_topological_sort(graph, node_to_additional_deps)
graph.lint()
def _get_external_inputs(
region: Region,
) -> dict[Node, tuple[int, int]]:
external_node_to_indices = dict()
) -> dict[Node, OrderedSet[UsageIndex]]:
external_node_to_usages = defaultdict[Node, OrderedSet[UsageIndex]](OrderedSet)
region_unique = set(region)
for node_ind, node in enumerate(region):
flattened_args_kwargs = _get_flat_args(node, {})
for arg_ind, in_node in enumerate(flattened_args_kwargs):
if (
isinstance(in_node, Node)
and in_node not in region_unique
and in_node not in external_node_to_indices
):
external_node_to_indices[in_node] = (node_ind, arg_ind)
if isinstance(in_node, Node) and in_node not in region_unique:
# in_node may occur in multiple nodes' flat_args
# track this so we can check if the arg is mutated
# Previously, we only needed to track one occurrence
# to be able to map that node to a placeholder
external_node_to_usages[in_node].add((node_ind, arg_ind))
return external_node_to_indices
return external_node_to_usages
def _get_all_output_indices(regions: list[Region]) -> list[int]:
@ -195,17 +222,14 @@ def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None
def _copy_nodes_and_remap_inputs(
subgraph: torch.fx.Graph, region: Region
) -> dict[tuple[int, int], Any]:
external_inputs_to_indices = _get_external_inputs(region)
indices_to_placeholder_ind: dict[tuple[int, int], Any] = {}
) -> list[OrderedSet[UsageIndex]]:
external_input_to_usages = _get_external_inputs(region)
external_node_usages = list[OrderedSet[UsageIndex]]()
region_to_subgraph_node = {}
for node in external_inputs_to_indices.keys():
for node, usage_indices in external_input_to_usages.items():
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
region_to_subgraph_node[node] = placeholder
arg_indices = external_inputs_to_indices[node]
# Note: insertion order matches the order in which placeholders were created
# for the calling convention of the subgraph
indices_to_placeholder_ind[arg_indices] = None
external_node_usages.append(usage_indices)
def map_arg(node: Node) -> Node:
if node in region_to_subgraph_node:
@ -217,7 +241,7 @@ def _copy_nodes_and_remap_inputs(
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
region_to_subgraph_node[node] = subgraph_node
return indices_to_placeholder_ind
return external_node_usages
def _create_subgraph_outputs(
@ -231,11 +255,11 @@ def _create_subgraph_outputs(
def _create_subgraph(
region: Region,
inds_with_external_users: list[int],
) -> tuple[torch.fx.Graph, dict[tuple[int, int], Any]]:
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
subgraph: torch.fx.Graph = torch.fx.Graph()
node_ind_input_inds = _copy_nodes_and_remap_inputs(subgraph, region)
external_node_usages = _copy_nodes_and_remap_inputs(subgraph, region)
_create_subgraph_outputs(subgraph, inds_with_external_users)
return subgraph, node_ind_input_inds
return subgraph, external_node_usages
def _stable_topological_sort(
@ -357,3 +381,59 @@ def _add_mutation_dependencies(
node_to_additional_deps[node].add(user)
elif user > node:
node_to_additional_deps[user].add(node)
def _has_aliasing(
region: Region, inputs: list[Node], inds_with_external_users: list[int]
) -> bool:
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
for node in inputs:
example_value = node.meta["example_value"]
if isinstance(example_value, torch.Tensor):
storage = StorageWeakRef(example_value._typed_storage())
if storage in input_storages:
# input-input aliasing
log.debug(
"NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s, %s",
region,
input_storages[storage],
node,
)
return True
input_storages[storage] = node
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
for i in inds_with_external_users:
out_node = region[i]
if out_node:
example_value = out_node.meta["example_value"]
assert not isinstance(example_value, list)
if isinstance(example_value, torch.Tensor):
storage = StorageWeakRef(example_value._typed_storage())
if storage in output_storages:
# output-output aliasing
log.debug(
"NYI: Failed to substitute region %s due to output-output aliasing detected at nodes %s, %s",
region,
output_storages[storage],
out_node,
)
return True
output_storages[storage] = out_node
intersected_storages = input_storages.keys() & output_storages.keys()
if len(intersected_storages) > 0:
# input-output aliasing
aliased = [
(input_storages[s], output_storages[s]) for s in intersected_storages
]
aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
log.debug(
"NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s",
region,
aliased,
)
return True
return False