""" This module implements graph deduplication functionality for TorchDynamo's optimization pipeline. Graph deduplication identifies identical subgraphs in the computational graph and merges them to reduce redundancy and improve performance. The process involves analyzing regions of the graph, identifying structurally equivalent regions, and replacing them with a single shared implementation. This optimization is particularly effective for models with repeated patterns or similar computational structures across different parts of the network. """ import logging import operator from collections.abc import Iterable from typing import Any import torch import torch.fx from torch._dynamo import config from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation from .graph_region_tracker import Node, Region from .graph_utils import _detect_cycles, _flatten_args_kwargs log = logging.getLogger(__name__) def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def] """ This is the main entry point for applying the graph deduplication pass. \ Deduplication occurs in two phases: 1. Subgraph creation: Subgraph creation works by taking one representative region from each region \ group and creating a subgraph from it, which will then be used to replace all regions \ in the group. This is implemented by first copying all nodes of the region to the new \ subgraph and then finding all inputs which are not within the region and creating placeholders \ for them. For the outputs, all regions in a region group need to be scanned to ensure the \ largest set of outputs is found, and then an output node is created which returns \ a tuple of all outputs. 2. Graph replacement: To replace each region with the extracted subgraph, the node index in the region \ and argument index within the node's flattened args and kwargs are recorded once during \ subgraph creation. This allows us to determine which (external to the region) nodes and \ in which order these nodes are passed as inputs. For the outputs, getitem nodes are created \ for each output, and all nodes in the region with external outputs are replaced by the proper \ getitem node. Finally, all original nodes are erased (there should be no uses of these \ left in the graph). The deduplication mutates the output_graph argument in place. Returns a mapping of nodes to their subgraph output replacement node to remap outputs when they are created in output_graph. """ from torch._inductor.pattern_matcher import stable_topological_sort duplicated_region_groups = output_graph.region_tracker.get_identical_regions( output_graph.graph ) sub_gms: dict[str, torch.fx.GraphModule] = {} for region_group in duplicated_region_groups: inds_with_external_users = _get_all_output_indices(region_group) region = region_group[0] ( subgraph, node_ind_arg_inds, ) = _create_subgraph(region, inds_with_external_users) # Ignore regions with no args for now, could they possibly be evaluated at compile time? if not list(node_ind_arg_inds): continue sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph) subgraph_name = output_graph.install_subgraph("subgraph", sub_gm) sub_gms[subgraph_name] = sub_gm with output_graph.graph.inserting_before(): get_subgraph_node = output_graph.graph.create_node( "get_attr", subgraph_name, (), {} ) for region in region_group: _replace_region_with_subgraph( output_graph.graph, region, get_subgraph_node, node_ind_arg_inds.keys(), inds_with_external_users, sub_gm, subgraph_name, ) stable_topological_sort(output_graph.graph) return sub_gms def _replace_region_with_subgraph( graph: torch.fx.Graph, region: Region, get_subgraph_node: Node, node_ind_arg_ind: Iterable[tuple[int, int]], inds_with_external_users: list[int], sub_gm: torch.fx.GraphModule, subgraph_name: str, ) -> None: sub_args = [] for node_ind, arg_ind in node_ind_arg_ind: node = region[node_ind] flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs)) sub_args.append(flattened_args_kwargs[arg_ind]) invoke_args = (get_subgraph_node, subgraph_name, tuple(sub_args)) fake_inputs = [node.meta["example_value"] for node in sub_args] if has_potential_input_alias_or_mutation(sub_gm, fake_inputs): log.debug( "NYI: Failed to substitute region %s due to input alias or mutation", region, ) return from torch._inductor.pattern_matcher import stable_topological_sort invoke_subgraph_node = graph.create_node( "call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {} ) for ind, external_user_ind in enumerate(inds_with_external_users): node = region[external_user_ind] subgraph_output = graph.create_node( "call_function", operator.getitem, (invoke_subgraph_node, ind), {} ) node.replace_all_uses_with(subgraph_output, propagate_meta=True) # Erase in reverse topological order for node in reversed(region): graph.erase_node(node) if config.graph_deduplication_lint: _detect_cycles(graph) stable_topological_sort(graph) graph.lint() if config.graph_deduplication_lint: graph.lint() def _get_external_inputs( region: Region, ) -> dict[Node, tuple[int, int]]: external_node_to_indices = dict() region_unique = set(region) for node_ind, node in enumerate(region): flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs)) for arg_ind, in_node in enumerate(flattened_args_kwargs): if ( isinstance(in_node, Node) and in_node not in region_unique and in_node not in external_node_to_indices ): external_node_to_indices[in_node] = (node_ind, arg_ind) return external_node_to_indices def _get_all_output_indices(regions: list[Region]) -> list[int]: # Scan all regions to get the set of all possible output nodes indices in the region # perhaps we can record this information during region creation for more efficiency? inds_with_external_users: set[int] = set() for region in regions: _get_inds_with_external_users(region, inds_with_external_users) return sorted(inds_with_external_users) def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None: for ind, node in enumerate(region): for user in node.users: if user not in region: if ind not in inds_unique: inds_unique.add(ind) def _copy_nodes_and_remap_inputs( subgraph: torch.fx.Graph, region: Region ) -> dict[tuple[int, int], Any]: external_inputs_to_indices = _get_external_inputs(region) indices_to_placeholder_ind: dict[tuple[int, int], Any] = {} region_to_subgraph_node = {} for node in external_inputs_to_indices.keys(): placeholder = subgraph.placeholder(f"subgraph_input_{node.name}") region_to_subgraph_node[node] = placeholder arg_indices = external_inputs_to_indices[node] # Note: insertion order matches the order in which placeholders were created # for the calling convention of the subgraph indices_to_placeholder_ind[arg_indices] = None def map_arg(node: Node) -> Node: if node in region_to_subgraph_node: return region_to_subgraph_node[node] else: return node for node in region: subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old)) region_to_subgraph_node[node] = subgraph_node return indices_to_placeholder_ind def _create_subgraph_outputs( subgraph: torch.fx.Graph, inds_to_output: list[int] ) -> None: node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")] out_tup = tuple(node_list[ind] for ind in inds_to_output) subgraph.output(out_tup) def _create_subgraph( region: Region, inds_with_external_users: list[int], ) -> tuple[torch.fx.Graph, dict[tuple[int, int], Any]]: subgraph: torch.fx.Graph = torch.fx.Graph() node_ind_input_inds = _copy_nodes_and_remap_inputs(subgraph, region) _create_subgraph_outputs(subgraph, inds_with_external_users) return subgraph, node_ind_input_inds