mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Hierarchical Compile] Replace tracing alias and mutation check with dynamo impl (#152570)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152570 Approved by: https://github.com/anijain2305 ghstack dependencies: #152389, #152505, #152410, #152506
This commit is contained in:
parent
57dafb90ef
commit
a415c9831f
|
|
@ -461,12 +461,6 @@ class GraphModule(torch.nn.Module):
|
|||
)
|
||||
|
||||
def test_input_mutation(self):
|
||||
def inner_fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
def inner_fn2(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 1
|
||||
|
|
@ -476,9 +470,6 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
def fn(x, y):
|
||||
x0 = torch.sin(x)
|
||||
_y0 = torch.cos(y)
|
||||
# o0 = inner_fn(x0, y0)
|
||||
# o1 = inner_fn(x0, o0)
|
||||
o2 = inner_fn2(x0, y)
|
||||
o3 = inner_fn2(x0.clone(), y.clone())
|
||||
return o2 + o3
|
||||
|
|
@ -985,10 +976,7 @@ class <lambda>(torch.nn.Module):
|
|||
)
|
||||
|
||||
def test_mutation_ordering(self):
|
||||
from torch._dynamo.graph_deduplication import (
|
||||
_populate_additional_deps,
|
||||
_stable_topological_sort,
|
||||
)
|
||||
from torch._dynamo.graph_deduplication import _stable_topological_sort
|
||||
|
||||
def inner_fn(x, y):
|
||||
x0 = x.view(x.size())
|
||||
|
|
@ -1013,74 +1001,109 @@ class <lambda>(torch.nn.Module):
|
|||
x_clone = x.clone()
|
||||
y_clone = y.clone()
|
||||
|
||||
graph, tracker = extract_graph_and_tracker(fn, x_clone, y_clone)
|
||||
graph, _ = extract_graph_and_tracker(fn, x_clone, y_clone)
|
||||
|
||||
def graph_code(graph):
|
||||
return graph.python_code("self").src
|
||||
|
||||
def get_node(name):
|
||||
return next(n for n in graph.nodes if n.name == name)
|
||||
|
||||
additional_deps = _populate_additional_deps(
|
||||
graph, tracker.node_to_mutated_arg_positions
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
additional_deps,
|
||||
"""defaultdict(<class 'torch.utils._ordered_set.OrderedSet'>, {add_: OrderedSet([x0, x0_1]), invoke_subgraph: OrderedSet([add_]), invoke_subgraph_1: OrderedSet([add_, mul_]), mul_: OrderedSet([invoke_subgraph])})""",
|
||||
graph_code(graph),
|
||||
"""\
|
||||
|
||||
|
||||
|
||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
subgraph_0 = self.subgraph_0
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
x0 = l_x_.view((10, 10))
|
||||
o0 = x0.view((10, 10)); x0 = None
|
||||
x0_1 = l_x_.view((10, 10))
|
||||
o1 = x0_1.view((10, 10)); x0_1 = None
|
||||
add_ = l_x_.add_(l_x_); add_ = None
|
||||
add_2 = o0 + o1; o0 = o1 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
|
||||
mul_ = l_y_.mul_(l_y_); mul_ = None
|
||||
getitem = invoke_subgraph[0]; invoke_subgraph = None
|
||||
sum_5 = getitem.sum(); getitem = None
|
||||
add_3 = add_2 + sum_5; add_2 = sum_5 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
|
||||
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
sum_6 = getitem_1.sum(); getitem_1 = None
|
||||
add_4 = add_3 + sum_6; add_3 = sum_6 = None
|
||||
return (add_4,)
|
||||
""",
|
||||
)
|
||||
|
||||
# Shuffle nodes in the graph
|
||||
add_ = get_node("add_")
|
||||
mul_ = get_node("mul_")
|
||||
x0 = get_node("x0")
|
||||
x0.append(mul_)
|
||||
o1 = get_node("o1")
|
||||
o1.append(add_)
|
||||
o1.append(mul_)
|
||||
add_2 = get_node("add_2")
|
||||
add_2.append(add_)
|
||||
|
||||
self.assertExpectedInline(
|
||||
graph,
|
||||
graph_code(graph),
|
||||
"""\
|
||||
graph():
|
||||
%subgraph_0 : [num_users=2] = get_attr[target=subgraph_0]
|
||||
%l_x_ : torch.Tensor [num_users=5] = placeholder[target=L_x_]
|
||||
%l_y_ : torch.Tensor [num_users=3] = placeholder[target=L_y_]
|
||||
%x0 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
|
||||
%mul_ : [num_users=0] = call_method[target=mul_](args = (%l_y_, %l_y_), kwargs = {})
|
||||
%o0 : [num_users=1] = call_method[target=view](args = (%x0, (10, 10)), kwargs = {})
|
||||
%x0_1 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
|
||||
%o1 : [num_users=1] = call_method[target=view](args = (%x0_1, (10, 10)), kwargs = {})
|
||||
%add_ : [num_users=0] = call_method[target=add_](args = (%l_x_, %l_x_), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o1), kwargs = {})
|
||||
%invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
|
||||
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
|
||||
%sum_5 : [num_users=1] = call_method[target=sum](args = (%getitem,), kwargs = {})
|
||||
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%add_2, %sum_5), kwargs = {})
|
||||
%invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
|
||||
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
|
||||
%sum_6 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {})
|
||||
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%add_3, %sum_6), kwargs = {})
|
||||
return (add_4,)""",
|
||||
|
||||
|
||||
|
||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
subgraph_0 = self.subgraph_0
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
x0 = l_x_.view((10, 10))
|
||||
o0 = x0.view((10, 10)); x0 = None
|
||||
x0_1 = l_x_.view((10, 10))
|
||||
o1 = x0_1.view((10, 10)); x0_1 = None
|
||||
mul_ = l_y_.mul_(l_y_); mul_ = None
|
||||
add_2 = o0 + o1; o0 = o1 = None
|
||||
add_ = l_x_.add_(l_x_); add_ = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
|
||||
getitem = invoke_subgraph[0]; invoke_subgraph = None
|
||||
sum_5 = getitem.sum(); getitem = None
|
||||
add_3 = add_2 + sum_5; add_2 = sum_5 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
|
||||
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
sum_6 = getitem_1.sum(); getitem_1 = None
|
||||
add_4 = add_3 + sum_6; add_3 = sum_6 = None
|
||||
return (add_4,)
|
||||
""",
|
||||
)
|
||||
_stable_topological_sort(
|
||||
graph, torch._dynamo.graph_deduplication.last_node_to_additional_deps
|
||||
)
|
||||
_stable_topological_sort(graph, additional_deps)
|
||||
self.assertExpectedInline(
|
||||
graph,
|
||||
graph_code(graph),
|
||||
"""\
|
||||
graph():
|
||||
%subgraph_0 : [num_users=2] = get_attr[target=subgraph_0]
|
||||
%l_x_ : torch.Tensor [num_users=5] = placeholder[target=L_x_]
|
||||
%l_y_ : torch.Tensor [num_users=3] = placeholder[target=L_y_]
|
||||
%x0 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
|
||||
%o0 : [num_users=1] = call_method[target=view](args = (%x0, (10, 10)), kwargs = {})
|
||||
%x0_1 : [num_users=1] = call_method[target=view](args = (%l_x_, (10, 10)), kwargs = {})
|
||||
%o1 : [num_users=1] = call_method[target=view](args = (%x0_1, (10, 10)), kwargs = {})
|
||||
%add_ : [num_users=0] = call_method[target=add_](args = (%l_x_, %l_x_), kwargs = {})
|
||||
%add_2 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o1), kwargs = {})
|
||||
%invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
|
||||
%mul_ : [num_users=0] = call_method[target=mul_](args = (%l_y_, %l_y_), kwargs = {})
|
||||
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
|
||||
%sum_5 : [num_users=1] = call_method[target=sum](args = (%getitem,), kwargs = {})
|
||||
%add_3 : [num_users=1] = call_function[target=operator.add](args = (%add_2, %sum_5), kwargs = {})
|
||||
%invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %l_y_), kwargs = {})
|
||||
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
|
||||
%sum_6 : [num_users=1] = call_method[target=sum](args = (%getitem_1,), kwargs = {})
|
||||
%add_4 : [num_users=1] = call_function[target=operator.add](args = (%add_3, %sum_6), kwargs = {})
|
||||
return (add_4,)""",
|
||||
|
||||
|
||||
|
||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
subgraph_0 = self.subgraph_0
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
x0 = l_x_.view((10, 10))
|
||||
o0 = x0.view((10, 10)); x0 = None
|
||||
x0_1 = l_x_.view((10, 10))
|
||||
o1 = x0_1.view((10, 10)); x0_1 = None
|
||||
add_2 = o0 + o1; o0 = o1 = None
|
||||
add_ = l_x_.add_(l_x_); add_ = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
|
||||
mul_ = l_y_.mul_(l_y_); mul_ = None
|
||||
getitem = invoke_subgraph[0]; invoke_subgraph = None
|
||||
sum_5 = getitem.sum(); getitem = None
|
||||
add_3 = add_2 + sum_5; add_2 = sum_5 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
|
||||
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
sum_6 = getitem_1.sum(); getitem_1 = None
|
||||
add_4 = add_3 + sum_6; add_3 = sum_6 = None
|
||||
return (add_4,)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,20 +11,28 @@ import logging
|
|||
import operator
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._dynamo import config
|
||||
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .graph_region_tracker import Node, Region
|
||||
from .graph_utils import _detect_cycles, _get_flat_args, _get_flat_args_unique
|
||||
|
||||
|
||||
# Represents an index into the region
|
||||
# to select a node and then
|
||||
# an index into that node's
|
||||
# flattened arguments
|
||||
UsageIndex = tuple[int, int]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
last_node_to_additional_deps: Optional[dict[Node, OrderedSet[Node]]] = None
|
||||
|
||||
|
||||
def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
|
|
@ -57,6 +65,9 @@ when they are created in output_graph.
|
|||
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
|
||||
output_graph.graph
|
||||
)
|
||||
node_to_mutated_arg_positions = (
|
||||
output_graph.region_tracker.node_to_mutated_arg_positions
|
||||
)
|
||||
node_to_additional_deps = _populate_additional_deps(
|
||||
output_graph.graph, output_graph.region_tracker.node_to_mutated_arg_positions
|
||||
)
|
||||
|
|
@ -68,11 +79,11 @@ when they are created in output_graph.
|
|||
region = region_group[0]
|
||||
(
|
||||
subgraph,
|
||||
node_ind_arg_inds,
|
||||
external_node_usages,
|
||||
) = _create_subgraph(region, inds_with_external_users)
|
||||
|
||||
# Ignore regions with no args for now, could they possibly be evaluated at compile time?
|
||||
if not list(node_ind_arg_inds):
|
||||
if not list(external_node_usages):
|
||||
continue
|
||||
|
||||
sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
|
||||
|
|
@ -82,21 +93,26 @@ when they are created in output_graph.
|
|||
get_subgraph_node = output_graph.graph.create_node(
|
||||
"get_attr", subgraph_name, (), {}
|
||||
)
|
||||
|
||||
for region in region_group:
|
||||
_replace_region_with_subgraph(
|
||||
output_graph.graph,
|
||||
region,
|
||||
get_subgraph_node,
|
||||
node_ind_arg_inds.keys(),
|
||||
external_node_usages,
|
||||
inds_with_external_users,
|
||||
sub_gm,
|
||||
subgraph_name,
|
||||
node_to_additional_deps,
|
||||
node_to_mutated_arg_positions,
|
||||
)
|
||||
|
||||
# This is to expose the updated node_to_additional_deps to tests
|
||||
global last_node_to_additional_deps
|
||||
last_node_to_additional_deps = node_to_additional_deps
|
||||
|
||||
_stable_topological_sort(
|
||||
output_graph.graph,
|
||||
node_to_additional_deps, # type: ignore[arg-type]
|
||||
node_to_additional_deps,
|
||||
)
|
||||
return sub_gms
|
||||
|
||||
|
|
@ -105,27 +121,34 @@ def _replace_region_with_subgraph(
|
|||
graph: torch.fx.Graph,
|
||||
region: Region,
|
||||
get_subgraph_node: Node,
|
||||
node_ind_arg_ind: Iterable[tuple[int, int]],
|
||||
external_node_usages: Iterable[OrderedSet[UsageIndex]],
|
||||
inds_with_external_users: list[int],
|
||||
sub_gm: torch.fx.GraphModule,
|
||||
subgraph_name: str,
|
||||
node_to_additional_deps: dict[torch.fx.Node, OrderedSet[torch.fx.Node]],
|
||||
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
|
||||
) -> None:
|
||||
sub_args = []
|
||||
for node_ind, arg_ind in node_ind_arg_ind:
|
||||
for usages in external_node_usages:
|
||||
node_ind, usage_ind = next(iter(usages))
|
||||
node = region[node_ind]
|
||||
flattened_args_kwargs = _get_flat_args(node, {})
|
||||
sub_args.append(flattened_args_kwargs[arg_ind])
|
||||
for user_ind, node_usage_ind in usages:
|
||||
user = region[user_ind]
|
||||
if user in node_to_mutated_arg_positions:
|
||||
if node_usage_ind in node_to_mutated_arg_positions[user]:
|
||||
log.debug(
|
||||
"NYI: Failed to substitute region %s due to mutation", region
|
||||
)
|
||||
return
|
||||
sub_args.append(flattened_args_kwargs[usage_ind])
|
||||
|
||||
# Input/Output aliasing not supported in HOPs today
|
||||
# Note: we should use the nodes in the original graph (the region here)
|
||||
# because we use the original traced example values for this check
|
||||
if _has_aliasing(region, sub_args, inds_with_external_users):
|
||||
return
|
||||
|
||||
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
|
||||
fake_inputs = [node.meta["example_value"] for node in sub_args]
|
||||
|
||||
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
|
||||
log.debug(
|
||||
"NYI: Failed to substitute region %s due to input alias or mutation",
|
||||
region,
|
||||
)
|
||||
return
|
||||
|
||||
invoke_subgraph_node = graph.create_node(
|
||||
"call_function",
|
||||
|
|
@ -143,6 +166,10 @@ def _replace_region_with_subgraph(
|
|||
# Erase in reverse topological order
|
||||
for node in reversed(region):
|
||||
graph.erase_node(node)
|
||||
# Remove any nodes with additional deps
|
||||
# This is safe; we've guaranteed that there is
|
||||
# no input mutation, so all additional deps
|
||||
# will be internal to the subgraph
|
||||
node_to_additional_deps.pop(node, None)
|
||||
for deps in node_to_additional_deps.values():
|
||||
try:
|
||||
|
|
@ -152,27 +179,27 @@ def _replace_region_with_subgraph(
|
|||
pass
|
||||
|
||||
if config.graph_deduplication_lint:
|
||||
_detect_cycles(graph, node_to_additional_deps)
|
||||
print(_detect_cycles(graph, node_to_additional_deps))
|
||||
_stable_topological_sort(graph, node_to_additional_deps)
|
||||
graph.lint()
|
||||
|
||||
|
||||
def _get_external_inputs(
|
||||
region: Region,
|
||||
) -> dict[Node, tuple[int, int]]:
|
||||
external_node_to_indices = dict()
|
||||
) -> dict[Node, OrderedSet[UsageIndex]]:
|
||||
external_node_to_usages = defaultdict[Node, OrderedSet[UsageIndex]](OrderedSet)
|
||||
region_unique = set(region)
|
||||
for node_ind, node in enumerate(region):
|
||||
flattened_args_kwargs = _get_flat_args(node, {})
|
||||
for arg_ind, in_node in enumerate(flattened_args_kwargs):
|
||||
if (
|
||||
isinstance(in_node, Node)
|
||||
and in_node not in region_unique
|
||||
and in_node not in external_node_to_indices
|
||||
):
|
||||
external_node_to_indices[in_node] = (node_ind, arg_ind)
|
||||
if isinstance(in_node, Node) and in_node not in region_unique:
|
||||
# in_node may occur in multiple nodes' flat_args
|
||||
# track this so we can check if the arg is mutated
|
||||
# Previously, we only needed to track one occurrence
|
||||
# to be able to map that node to a placeholder
|
||||
external_node_to_usages[in_node].add((node_ind, arg_ind))
|
||||
|
||||
return external_node_to_indices
|
||||
return external_node_to_usages
|
||||
|
||||
|
||||
def _get_all_output_indices(regions: list[Region]) -> list[int]:
|
||||
|
|
@ -195,17 +222,14 @@ def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None
|
|||
|
||||
def _copy_nodes_and_remap_inputs(
|
||||
subgraph: torch.fx.Graph, region: Region
|
||||
) -> dict[tuple[int, int], Any]:
|
||||
external_inputs_to_indices = _get_external_inputs(region)
|
||||
indices_to_placeholder_ind: dict[tuple[int, int], Any] = {}
|
||||
) -> list[OrderedSet[UsageIndex]]:
|
||||
external_input_to_usages = _get_external_inputs(region)
|
||||
external_node_usages = list[OrderedSet[UsageIndex]]()
|
||||
region_to_subgraph_node = {}
|
||||
for node in external_inputs_to_indices.keys():
|
||||
for node, usage_indices in external_input_to_usages.items():
|
||||
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
|
||||
region_to_subgraph_node[node] = placeholder
|
||||
arg_indices = external_inputs_to_indices[node]
|
||||
# Note: insertion order matches the order in which placeholders were created
|
||||
# for the calling convention of the subgraph
|
||||
indices_to_placeholder_ind[arg_indices] = None
|
||||
external_node_usages.append(usage_indices)
|
||||
|
||||
def map_arg(node: Node) -> Node:
|
||||
if node in region_to_subgraph_node:
|
||||
|
|
@ -217,7 +241,7 @@ def _copy_nodes_and_remap_inputs(
|
|||
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
|
||||
region_to_subgraph_node[node] = subgraph_node
|
||||
|
||||
return indices_to_placeholder_ind
|
||||
return external_node_usages
|
||||
|
||||
|
||||
def _create_subgraph_outputs(
|
||||
|
|
@ -231,11 +255,11 @@ def _create_subgraph_outputs(
|
|||
def _create_subgraph(
|
||||
region: Region,
|
||||
inds_with_external_users: list[int],
|
||||
) -> tuple[torch.fx.Graph, dict[tuple[int, int], Any]]:
|
||||
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
|
||||
subgraph: torch.fx.Graph = torch.fx.Graph()
|
||||
node_ind_input_inds = _copy_nodes_and_remap_inputs(subgraph, region)
|
||||
external_node_usages = _copy_nodes_and_remap_inputs(subgraph, region)
|
||||
_create_subgraph_outputs(subgraph, inds_with_external_users)
|
||||
return subgraph, node_ind_input_inds
|
||||
return subgraph, external_node_usages
|
||||
|
||||
|
||||
def _stable_topological_sort(
|
||||
|
|
@ -357,3 +381,59 @@ def _add_mutation_dependencies(
|
|||
node_to_additional_deps[node].add(user)
|
||||
elif user > node:
|
||||
node_to_additional_deps[user].add(node)
|
||||
|
||||
|
||||
def _has_aliasing(
|
||||
region: Region, inputs: list[Node], inds_with_external_users: list[int]
|
||||
) -> bool:
|
||||
input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
||||
|
||||
for node in inputs:
|
||||
example_value = node.meta["example_value"]
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
storage = StorageWeakRef(example_value._typed_storage())
|
||||
if storage in input_storages:
|
||||
# input-input aliasing
|
||||
log.debug(
|
||||
"NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s, %s",
|
||||
region,
|
||||
input_storages[storage],
|
||||
node,
|
||||
)
|
||||
return True
|
||||
input_storages[storage] = node
|
||||
|
||||
output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
|
||||
for i in inds_with_external_users:
|
||||
out_node = region[i]
|
||||
if out_node:
|
||||
example_value = out_node.meta["example_value"]
|
||||
assert not isinstance(example_value, list)
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
storage = StorageWeakRef(example_value._typed_storage())
|
||||
if storage in output_storages:
|
||||
# output-output aliasing
|
||||
log.debug(
|
||||
"NYI: Failed to substitute region %s due to output-output aliasing detected at nodes %s, %s",
|
||||
region,
|
||||
output_storages[storage],
|
||||
out_node,
|
||||
)
|
||||
return True
|
||||
output_storages[storage] = out_node
|
||||
|
||||
intersected_storages = input_storages.keys() & output_storages.keys()
|
||||
if len(intersected_storages) > 0:
|
||||
# input-output aliasing
|
||||
aliased = [
|
||||
(input_storages[s], output_storages[s]) for s in intersected_storages
|
||||
]
|
||||
aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
|
||||
log.debug(
|
||||
"NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s",
|
||||
region,
|
||||
aliased,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user