# This file is copied from Meta internal repo and is not synced with the # internal version. Once the internal version is fully mature, we should # upstream again and retire the internal version. @yifuwang import logging import operator from typing import Callable, List, Optional, Set, Tuple from functorch import make_fx import torch from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.decomposition import select_decomp_table MIN_ATEN_OPS_TO_LOWER = 10 logger: logging.Logger = logging.getLogger(__name__) def _create_subgraph_module( inputs: List[torch.fx.Node], body: List[torch.fx.Node], outputs: List[torch.fx.Node] ) -> torch.fx.GraphModule: subgraph: torch.fx.Graph = torch.fx.Graph() node_to_subgraph_node = {} for idx, inp in enumerate(inputs): subgraph_inp = subgraph.placeholder(name=f"arg_{idx}") subgraph_inp.meta = inp.meta node_to_subgraph_node[inp] = subgraph_inp for node in body: subgraph_node = subgraph.node_copy( node, arg_transform=lambda x: node_to_subgraph_node[x] ) node_to_subgraph_node[node] = subgraph_node subgraph.output(result=tuple(node_to_subgraph_node[x] for x in outputs)) subgraph.eliminate_dead_code() subgraph.lint() return torch.fx.GraphModule(root={}, graph=subgraph) def _is_container_node(node: torch.fx.Node) -> bool: if any(user.target == operator.getitem for user in node.users): assert all(user.target == operator.getitem for user in node.users), ( "Malformed graph: a container node is used as input for non-getitem nodes." "\nNode: {fmt_node}\nUsers: {fmt_users}".format( fmt_node=node.format_node(), fmt_users="\n".join(u.format_node() for u in node.users), ) ) return True return False def _lower_subgraph_nodes( gm: torch.fx.GraphModule, subgraph_name: str, subgraph_nodes: List[torch.fx.Node], dumper: Callable[[str], str], ) -> None: prologue: List[torch.fx.Node] = [] inputs: List[torch.fx.Node] = [] body: List[torch.fx.Node] = [] visible: Set[torch.fx.Node] = set() # Inductor requires all graph input to be tensors. When adding a container # node as subgraph input, add its descendant getitem nodes to the subgraph # prologue and add its leaf getitem nodes to the subgraph input. def add_input(arg: torch.fx.Node) -> None: stack = [arg] while len(stack) != 0: node = stack.pop() if _is_container_node(node): # We should only prepone nodes within subgraph_nodes prologue.extend(user for user in node.users if user in subgraph_nodes) stack.extend(node.users) else: if node not in visible: inputs.append(node) visible.add(node) for node in subgraph_nodes: if node.op == "get_attr": # Prepone get_attr to avoid having to copy # the attribute to the subgraph module. inputs.append(node) visible.add(node) continue for arg in node.all_input_nodes: if arg not in visible: add_input(arg) if node not in prologue: body.append(node) visible.add(node) outputs: List[torch.fx.Node] = [] # Inductor requires all graph output to be tensors. When adding a container # node as subgraph output, add its descendant getitem nodes to the subgraph # body and add its leaf getitem nodes to the subgraph output. def add_output(output: torch.fx.Node) -> None: stack = [output] while len(stack) != 0: node = stack.pop() if _is_container_node(node): body.extend(node.users) stack.extend(node.users) elif not all(user in visible for user in node.users): if node not in outputs: outputs.append(node) for node in body: if not all(user in visible for user in node.users): add_output(node) assert len(inputs) == len(set(inputs)) assert len(outputs) == len(set(outputs)) subgraph_module = _create_subgraph_module(inputs, body, outputs) readable_tag = dumper(str(subgraph_module.graph)) setattr(gm, subgraph_name, _InductorModule(subgraph_module)) insertion_point = subgraph_nodes[-1].next for node in prologue: insertion_point.prepend(node) with gm.graph.inserting_before(insertion_point): # Insert subgraph call subgraph_call = gm.graph.create_node( op="call_module", target=subgraph_name, args=tuple(inputs), kwargs={"tag": readable_tag}, ) # Replace parent graph nodes with their corresponding subgraph outputs for idx, output in enumerate(outputs): new_output = gm.graph.create_node( op="call_function", target=operator.getitem, args=(subgraph_call, idx), ) new_output.meta = output.meta output.replace_all_uses_with(new_output) # Erase lowered nodes from the parent graph for node in reversed(body + outputs): if len(node.users) == 0: gm.graph.erase_node(node) class _InductorModule(torch.nn.Module): def __init__(self, gm: torch.fx.GraphModule) -> None: super().__init__() self.gm = gm self.compiled: Optional[ Callable[[List[torch.Tensor]], List[torch.Tensor]] ] = None def forward(self, *args: torch.Tensor, tag: str) -> List[torch.Tensor]: if self.compiled is None: inductor_decompositions = select_decomp_table() # TODO: figure out why turning on cudagraphs cause exceptions. decomp_gm = make_fx(self.gm, decomposition_table=inductor_decompositions)( *args ) logger.info("Lowering subgraph (%s) to Inductor...", tag) self.compiled = compile_fx_inner( decomp_gm, list(args), cudagraphs=False, ) logger.info("Completed lowering subgraph (%s) to Inductor", tag) with torch.profiler.record_function(tag): assert self.compiled is not None return self.compiled(list(args)) def _is_inductor_compatible(node: torch.fx.Node) -> Tuple[bool, str]: # `has_tag` is not supported yet # if has_tag(node, "non_lowerable"): if node.target in ( torch.ops.aten._fused_adam_.default, torch.ops.aten._fused_adam.default, torch.ops.aten._foreach_add_.Scalar, torch.ops.aten._foreach_add.Scalar, ): return False, "fused adam is not supported yet" # TODO(yifu): apparently having a meta kernel is not a necessary # condition for Inductor compatiblity. We should refine the check. # Sneaking this one in for now to support comm_fusion_with_cat. if node.target == torch.ops.aten.flatten.using_ints: return True, "" if isinstance(node.target, torch._ops.OpOverload): if not node.target.has_kernel_for_dispatch_key(torch._C.DispatchKey.Meta): return False, f"{node.target} doesn't have a meta kernel registered" return True, "" def _subgraph_predicate(nodes: List[torch.fx.Node]) -> bool: num_aten_ops = len([n for n in nodes if str(n.target).startswith("aten.")]) return num_aten_ops >= MIN_ATEN_OPS_TO_LOWER def partial_lower( gm: torch.fx.GraphModule, node_predicate: Callable[[torch.fx.Node], bool] = lambda x: True, subgraph_predicate: Callable[[List[torch.fx.Node]], bool] = lambda x: True, dumper: Callable[[str], str] = lambda x: "subgraph", ) -> torch.fx.GraphModule: """ Lower Inductor compatible portions of the graph module to Inductor. Args: node_predicate: user predicate for determining whether to consider a node for lowering. subgraph_predicate: user predicate for determining whether to consider a list of candidate nodes for lowering. dumper: a callback for dumping subgraphs for human digestion. For exmaple, it can be a function that writes to disk/blob storage and returns the path/handle. The returned path/handle for each subgraph will be made available in the subgraph call node in the parent graph, as well as the label of the profiler block for the subgraph. """ nodes_per_subgraph: List[List[torch.fx.Node]] = [[]] ptr = next(iter(gm.graph.nodes)) def _node_predicate(node: torch.fx.Node) -> Tuple[bool, str]: should_lower, reason = _is_inductor_compatible(node) if not should_lower: return should_lower, reason if not node_predicate(node): return False, "user predicate" return True, "" while ptr.op != "output": if ptr.op == "placeholder": ptr = ptr.next continue should_lower, reason = _node_predicate(ptr) if should_lower: nodes_per_subgraph[-1].append(ptr) else: if len(nodes_per_subgraph[-1]) > 0: logger.warning( "partial_lower: graph break at %s. Reason: %s", str(ptr), reason ) nodes_per_subgraph.append([]) ptr = ptr.next nodes_per_subgraph = [ nodes for nodes in nodes_per_subgraph if subgraph_predicate(nodes) and _subgraph_predicate(nodes) ] for idx, subgraph_nodes in enumerate(nodes_per_subgraph): subgraph_name = f"subgraph_{idx}" _lower_subgraph_nodes(gm, subgraph_name, subgraph_nodes, dumper) gm.graph.lint() gm.recompile() return gm