mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Several ops cannot be lowered to the Inductor. This PR copies the internal implementation of partial_lower (credit to @yifuwang ) to torch.distributed._spmd to unblock the OSS usage. The internal version will be kept until it is mature and will replace this version. Differential Revision: [D44970278](https://our.internmc.facebook.com/intern/diff/D44970278/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99069 Approved by: https://github.com/mrshenli, https://github.com/lessw2020
269 lines
9.7 KiB
Python
269 lines
9.7 KiB
Python
# This file is copied from Meta internal repo and is not synced with the
|
|
# internal version. Once the internal version is fully mature, we should
|
|
# upstream again and retire the internal version. @yifuwang
|
|
|
|
import logging
|
|
import operator
|
|
from typing import Callable, List, Optional, Set, Tuple
|
|
|
|
from functorch import make_fx
|
|
|
|
import torch
|
|
|
|
from torch._inductor.compile_fx import compile_fx_inner
|
|
from torch._inductor.decomposition import select_decomp_table
|
|
|
|
MIN_ATEN_OPS_TO_LOWER = 10
|
|
|
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _create_subgraph_module(
|
|
inputs: List[torch.fx.Node], body: List[torch.fx.Node], outputs: List[torch.fx.Node]
|
|
) -> torch.fx.GraphModule:
|
|
subgraph: torch.fx.Graph = torch.fx.Graph()
|
|
node_to_subgraph_node = {}
|
|
for idx, inp in enumerate(inputs):
|
|
subgraph_inp = subgraph.placeholder(name=f"arg_{idx}")
|
|
subgraph_inp.meta = inp.meta
|
|
node_to_subgraph_node[inp] = subgraph_inp
|
|
|
|
for node in body:
|
|
subgraph_node = subgraph.node_copy(
|
|
node, arg_transform=lambda x: node_to_subgraph_node[x]
|
|
)
|
|
node_to_subgraph_node[node] = subgraph_node
|
|
|
|
subgraph.output(result=tuple(node_to_subgraph_node[x] for x in outputs))
|
|
subgraph.eliminate_dead_code()
|
|
subgraph.lint()
|
|
return torch.fx.GraphModule(root={}, graph=subgraph)
|
|
|
|
|
|
def _is_container_node(node: torch.fx.Node) -> bool:
|
|
if any(user.target == operator.getitem for user in node.users):
|
|
assert all(user.target == operator.getitem for user in node.users), (
|
|
"Malformed graph: a container node is used as input for non-getitem nodes."
|
|
"\nNode: {fmt_node}\nUsers: {fmt_users}".format(
|
|
fmt_node=node.format_node(),
|
|
fmt_users="\n".join(u.format_node() for u in node.users),
|
|
)
|
|
)
|
|
return True
|
|
return False
|
|
|
|
|
|
def _lower_subgraph_nodes(
|
|
gm: torch.fx.GraphModule,
|
|
subgraph_name: str,
|
|
subgraph_nodes: List[torch.fx.Node],
|
|
dumper: Callable[[str], str],
|
|
) -> None:
|
|
prologue: List[torch.fx.Node] = []
|
|
inputs: List[torch.fx.Node] = []
|
|
body: List[torch.fx.Node] = []
|
|
visible: Set[torch.fx.Node] = set()
|
|
|
|
# Inductor requires all graph input to be tensors. When adding a container
|
|
# node as subgraph input, add its descendant getitem nodes to the subgraph
|
|
# prologue and add its leaf getitem nodes to the subgraph input.
|
|
def add_input(arg: torch.fx.Node) -> None:
|
|
stack = [arg]
|
|
while len(stack) != 0:
|
|
node = stack.pop()
|
|
if _is_container_node(node):
|
|
# We should only prepone nodes within subgraph_nodes
|
|
prologue.extend(user for user in node.users if user in subgraph_nodes)
|
|
stack.extend(node.users)
|
|
else:
|
|
if node not in visible:
|
|
inputs.append(node)
|
|
visible.add(node)
|
|
|
|
for node in subgraph_nodes:
|
|
if node.op == "get_attr":
|
|
# Prepone get_attr to avoid having to copy
|
|
# the attribute to the subgraph module.
|
|
inputs.append(node)
|
|
visible.add(node)
|
|
continue
|
|
|
|
for arg in node.all_input_nodes:
|
|
if arg not in visible:
|
|
add_input(arg)
|
|
|
|
if node not in prologue:
|
|
body.append(node)
|
|
visible.add(node)
|
|
|
|
outputs: List[torch.fx.Node] = []
|
|
|
|
# Inductor requires all graph output to be tensors. When adding a container
|
|
# node as subgraph output, add its descendant getitem nodes to the subgraph
|
|
# body and add its leaf getitem nodes to the subgraph output.
|
|
def add_output(output: torch.fx.Node) -> None:
|
|
stack = [output]
|
|
while len(stack) != 0:
|
|
node = stack.pop()
|
|
if _is_container_node(node):
|
|
body.extend(node.users)
|
|
stack.extend(node.users)
|
|
elif not all(user in visible for user in node.users):
|
|
if node not in outputs:
|
|
outputs.append(node)
|
|
|
|
for node in body:
|
|
if not all(user in visible for user in node.users):
|
|
add_output(node)
|
|
|
|
assert len(inputs) == len(set(inputs))
|
|
assert len(outputs) == len(set(outputs))
|
|
|
|
subgraph_module = _create_subgraph_module(inputs, body, outputs)
|
|
readable_tag = dumper(str(subgraph_module.graph))
|
|
setattr(gm, subgraph_name, _InductorModule(subgraph_module))
|
|
|
|
insertion_point = subgraph_nodes[-1].next
|
|
for node in prologue:
|
|
insertion_point.prepend(node)
|
|
|
|
with gm.graph.inserting_before(insertion_point):
|
|
# Insert subgraph call
|
|
subgraph_call = gm.graph.create_node(
|
|
op="call_module",
|
|
target=subgraph_name,
|
|
args=tuple(inputs),
|
|
kwargs={"tag": readable_tag},
|
|
)
|
|
# Replace parent graph nodes with their corresponding subgraph outputs
|
|
for idx, output in enumerate(outputs):
|
|
new_output = gm.graph.create_node(
|
|
op="call_function",
|
|
target=operator.getitem,
|
|
args=(subgraph_call, idx),
|
|
)
|
|
new_output.meta = output.meta
|
|
output.replace_all_uses_with(new_output)
|
|
|
|
# Erase lowered nodes from the parent graph
|
|
for node in reversed(body + outputs):
|
|
if len(node.users) == 0:
|
|
gm.graph.erase_node(node)
|
|
|
|
|
|
class _InductorModule(torch.nn.Module):
|
|
def __init__(self, gm: torch.fx.GraphModule) -> None:
|
|
super().__init__()
|
|
self.gm = gm
|
|
self.compiled: Optional[
|
|
Callable[[List[torch.Tensor]], List[torch.Tensor]]
|
|
] = None
|
|
|
|
def forward(self, *args: torch.Tensor, tag: str) -> List[torch.Tensor]:
|
|
if self.compiled is None:
|
|
inductor_decompositions = select_decomp_table()
|
|
# TODO: figure out why turning on cudagraphs cause exceptions.
|
|
decomp_gm = make_fx(self.gm, decomposition_table=inductor_decompositions)(
|
|
*args
|
|
)
|
|
logger.info("Lowering subgraph (%s) to Inductor...", tag)
|
|
self.compiled = compile_fx_inner(
|
|
decomp_gm,
|
|
list(args),
|
|
cudagraphs=False,
|
|
)
|
|
logger.info("Completed lowering subgraph (%s) to Inductor", tag)
|
|
with torch.profiler.record_function(tag):
|
|
assert self.compiled is not None
|
|
return self.compiled(list(args))
|
|
|
|
|
|
def _is_inductor_compatible(node: torch.fx.Node) -> Tuple[bool, str]:
|
|
# `has_tag` is not supported yet
|
|
# if has_tag(node, "non_lowerable"):
|
|
|
|
if node.target in (
|
|
torch.ops.aten._fused_adam_.default,
|
|
torch.ops.aten._fused_adam.default,
|
|
torch.ops.aten._foreach_add_.Scalar,
|
|
torch.ops.aten._foreach_add.Scalar,
|
|
):
|
|
return False, "fused adam is not supported yet"
|
|
|
|
# TODO(yifu): apparently having a meta kernel is not a necessary
|
|
# condition for Inductor compatiblity. We should refine the check.
|
|
# Sneaking this one in for now to support comm_fusion_with_cat.
|
|
if node.target == torch.ops.aten.flatten.using_ints:
|
|
return True, ""
|
|
|
|
if isinstance(node.target, torch._ops.OpOverload):
|
|
if not node.target.has_kernel_for_dispatch_key(torch._C.DispatchKey.Meta):
|
|
return False, f"{node.target} doesn't have a meta kernel registered"
|
|
return True, ""
|
|
|
|
|
|
def _subgraph_predicate(nodes: List[torch.fx.Node]) -> bool:
|
|
num_aten_ops = len([n for n in nodes if str(n.target).startswith("aten.")])
|
|
return num_aten_ops >= MIN_ATEN_OPS_TO_LOWER
|
|
|
|
|
|
def partial_lower(
|
|
gm: torch.fx.GraphModule,
|
|
node_predicate: Callable[[torch.fx.Node], bool] = lambda x: True,
|
|
subgraph_predicate: Callable[[List[torch.fx.Node]], bool] = lambda x: True,
|
|
dumper: Callable[[str], str] = lambda x: "subgraph",
|
|
) -> torch.fx.GraphModule:
|
|
"""
|
|
Lower Inductor compatible portions of the graph module to Inductor.
|
|
|
|
Args:
|
|
node_predicate: user predicate for determining whether to consider a node for
|
|
lowering.
|
|
subgraph_predicate: user predicate for determining whether to consider a list of
|
|
candidate nodes for lowering.
|
|
dumper: a callback for dumping subgraphs for human digestion. For exmaple, it
|
|
can be a function that writes to disk/blob storage and returns the
|
|
path/handle. The returned path/handle for each subgraph will be made
|
|
available in the subgraph call node in the parent graph, as well as the
|
|
label of the profiler block for the subgraph.
|
|
"""
|
|
nodes_per_subgraph: List[List[torch.fx.Node]] = [[]]
|
|
ptr = next(iter(gm.graph.nodes))
|
|
|
|
def _node_predicate(node: torch.fx.Node) -> Tuple[bool, str]:
|
|
should_lower, reason = _is_inductor_compatible(node)
|
|
if not should_lower:
|
|
return should_lower, reason
|
|
if not node_predicate(node):
|
|
return False, "user predicate"
|
|
return True, ""
|
|
|
|
while ptr.op != "output":
|
|
if ptr.op == "placeholder":
|
|
ptr = ptr.next
|
|
continue
|
|
should_lower, reason = _node_predicate(ptr)
|
|
if should_lower:
|
|
nodes_per_subgraph[-1].append(ptr)
|
|
else:
|
|
if len(nodes_per_subgraph[-1]) > 0:
|
|
logger.warning(
|
|
"partial_lower: graph break at %s. Reason: %s", str(ptr), reason
|
|
)
|
|
nodes_per_subgraph.append([])
|
|
ptr = ptr.next
|
|
|
|
nodes_per_subgraph = [
|
|
nodes
|
|
for nodes in nodes_per_subgraph
|
|
if subgraph_predicate(nodes) and _subgraph_predicate(nodes)
|
|
]
|
|
|
|
for idx, subgraph_nodes in enumerate(nodes_per_subgraph):
|
|
subgraph_name = f"subgraph_{idx}"
|
|
_lower_subgraph_nodes(gm, subgraph_name, subgraph_nodes, dumper)
|
|
|
|
gm.graph.lint()
|
|
gm.recompile()
|
|
return gm
|