pytorch/torch/_dynamo/graph_deduplication.py

225 lines
8.5 KiB
Python

"""
This module implements graph deduplication functionality for TorchDynamo's optimization pipeline.
Graph deduplication identifies identical subgraphs in the computational graph and merges them
to reduce redundancy and improve performance. The process involves analyzing regions of the graph,
identifying structurally equivalent regions, and replacing them with a single shared implementation.
This optimization is particularly effective for models with repeated patterns or similar computational
structures across different parts of the network.
"""
import logging
import operator
from collections.abc import Iterable
from typing import Any
import torch
import torch.fx
from torch._dynamo import config
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
from .graph_region_tracker import Node, Region
from .graph_utils import _detect_cycles, _flatten_args_kwargs
log = logging.getLogger(__name__)
def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def]
"""
This is the main entry point for applying the graph deduplication pass. \
Deduplication occurs in two phases:
1. Subgraph creation:
Subgraph creation works by taking one representative region from each region \
group and creating a subgraph from it, which will then be used to replace all regions \
in the group. This is implemented by first copying all nodes of the region to the new \
subgraph and then finding all inputs which are not within the region and creating placeholders \
for them. For the outputs, all regions in a region group need to be scanned to ensure the \
largest set of outputs is found, and then an output node is created which returns \
a tuple of all outputs.
2. Graph replacement:
To replace each region with the extracted subgraph, the node index in the region \
and argument index within the node's flattened args and kwargs are recorded once during \
subgraph creation. This allows us to determine which (external to the region) nodes and \
in which order these nodes are passed as inputs. For the outputs, getitem nodes are created \
for each output, and all nodes in the region with external outputs are replaced by the proper \
getitem node. Finally, all original nodes are erased (there should be no uses of these \
left in the graph).
The deduplication mutates the output_graph argument in place.
Returns a mapping of nodes to their subgraph output replacement node to remap outputs
when they are created in output_graph.
"""
from torch._inductor.pattern_matcher import stable_topological_sort
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
output_graph.graph
)
sub_gms: dict[str, torch.fx.GraphModule] = {}
for region_group in duplicated_region_groups:
inds_with_external_users = _get_all_output_indices(region_group)
region = region_group[0]
(
subgraph,
node_ind_arg_inds,
) = _create_subgraph(region, inds_with_external_users)
# Ignore regions with no args for now, could they possibly be evaluated at compile time?
if not list(node_ind_arg_inds):
continue
sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
subgraph_name = output_graph.install_subgraph("subgraph", sub_gm)
sub_gms[subgraph_name] = sub_gm
with output_graph.graph.inserting_before():
get_subgraph_node = output_graph.graph.create_node(
"get_attr", subgraph_name, (), {}
)
for region in region_group:
_replace_region_with_subgraph(
output_graph.graph,
region,
get_subgraph_node,
node_ind_arg_inds.keys(),
inds_with_external_users,
sub_gm,
subgraph_name,
)
stable_topological_sort(output_graph.graph)
return sub_gms
def _replace_region_with_subgraph(
graph: torch.fx.Graph,
region: Region,
get_subgraph_node: Node,
node_ind_arg_ind: Iterable[tuple[int, int]],
inds_with_external_users: list[int],
sub_gm: torch.fx.GraphModule,
subgraph_name: str,
) -> None:
sub_args = []
for node_ind, arg_ind in node_ind_arg_ind:
node = region[node_ind]
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
sub_args.append(flattened_args_kwargs[arg_ind])
invoke_args = (get_subgraph_node, subgraph_name, tuple(sub_args))
fake_inputs = [node.meta["example_value"] for node in sub_args]
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
log.debug(
"NYI: Failed to substitute region %s due to input alias or mutation",
region,
)
return
from torch._inductor.pattern_matcher import stable_topological_sort
invoke_subgraph_node = graph.create_node(
"call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {}
)
for ind, external_user_ind in enumerate(inds_with_external_users):
node = region[external_user_ind]
subgraph_output = graph.create_node(
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
)
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
# Erase in reverse topological order
for node in reversed(region):
graph.erase_node(node)
if config.graph_deduplication_lint:
_detect_cycles(graph)
stable_topological_sort(graph)
graph.lint()
if config.graph_deduplication_lint:
graph.lint()
def _get_external_inputs(
region: Region,
) -> dict[Node, tuple[int, int]]:
external_node_to_indices = dict()
region_unique = set(region)
for node_ind, node in enumerate(region):
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
for arg_ind, in_node in enumerate(flattened_args_kwargs):
if (
isinstance(in_node, Node)
and in_node not in region_unique
and in_node not in external_node_to_indices
):
external_node_to_indices[in_node] = (node_ind, arg_ind)
return external_node_to_indices
def _get_all_output_indices(regions: list[Region]) -> list[int]:
# Scan all regions to get the set of all possible output nodes indices in the region
# perhaps we can record this information during region creation for more efficiency?
inds_with_external_users: set[int] = set()
for region in regions:
_get_inds_with_external_users(region, inds_with_external_users)
return sorted(inds_with_external_users)
def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None:
for ind, node in enumerate(region):
for user in node.users:
if user not in region:
if ind not in inds_unique:
inds_unique.add(ind)
def _copy_nodes_and_remap_inputs(
subgraph: torch.fx.Graph, region: Region
) -> dict[tuple[int, int], Any]:
external_inputs_to_indices = _get_external_inputs(region)
indices_to_placeholder_ind: dict[tuple[int, int], Any] = {}
region_to_subgraph_node = {}
for node in external_inputs_to_indices.keys():
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
region_to_subgraph_node[node] = placeholder
arg_indices = external_inputs_to_indices[node]
# Note: insertion order matches the order in which placeholders were created
# for the calling convention of the subgraph
indices_to_placeholder_ind[arg_indices] = None
def map_arg(node: Node) -> Node:
if node in region_to_subgraph_node:
return region_to_subgraph_node[node]
else:
return node
for node in region:
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
region_to_subgraph_node[node] = subgraph_node
return indices_to_placeholder_ind
def _create_subgraph_outputs(
subgraph: torch.fx.Graph, inds_to_output: list[int]
) -> None:
node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")]
out_tup = tuple(node_list[ind] for ind in inds_to_output)
subgraph.output(out_tup)
def _create_subgraph(
region: Region,
inds_with_external_users: list[int],
) -> tuple[torch.fx.Graph, dict[tuple[int, int], Any]]:
subgraph: torch.fx.Graph = torch.fx.Graph()
node_ind_input_inds = _copy_nodes_and_remap_inputs(subgraph, region)
_create_subgraph_outputs(subgraph, inds_with_external_users)
return subgraph, node_ind_input_inds