pytorch/torch/_inductor/subgraph_lowering.py
Peter Bell 24b64fc482 [HOP][inductor] Support pytrees as associative_scan input (#122137)
This allows `associative_scan` to take an arbitrary pytree of tensors,
which is flattened to their leaves before calling the `associative_scan`
higher order operator.

I also add support in inductor to generate code for scanning over sequences
of tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122137
Approved by: https://github.com/lezcano, https://github.com/Chillee
ghstack dependencies: #119430
2024-05-06 11:29:28 +00:00

143 lines
4.2 KiB
Python

"""Utilities for lowering subgraphs used by higher order operators
"""
import functools
import operator
from dataclasses import dataclass
from typing import List, Optional, TypeVar
import torch
from . import ir
from .exc import SubgraphLoweringException
from .ops_handler import SimpleCSEHandler
from .virtualized import ops, V, WrapperHandler
T = TypeVar("T")
class PointwiseSubgraphLowering(torch.fx.Interpreter):
graph_outputs: Optional[List[ir.IRNode]]
def __init__(
self,
gm: torch.fx.GraphModule,
root_graph_lowering: "torch._inductor.graph.GraphLowering",
):
super().__init__(gm)
self.graph_outputs = None
self.root_graph = root_graph_lowering
@property
def sizevars(self):
return self.root_graph.sizevars
def mark_buffer_mutated(self, name):
raise SubgraphLoweringException("Mutations are not supported in this context")
def register_buffer(self, data):
raise SubgraphLoweringException(
"Buffer creation is not supported in this context"
)
def call_function(self, target, args, kwargs):
from .lowering import lowerings
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
return super().call_function(target, args, kwargs)
assert isinstance(target, torch._ops.OpOverload)
if target not in lowerings:
raise SubgraphLoweringException(
f"{target} not supported in subgraph, (missing lowering)"
)
if torch.Tag.pointwise not in target.tags:
raise SubgraphLoweringException(
f"Only pointwise operators are supported in this context, but got {target}"
)
return lowerings[target](*args, **kwargs)
def output(self, target, args, kwargs):
assert len(args) == 1
self.graph_outputs = args[0]
@dataclass
class InputDescriptor:
dtype: torch.dtype
device: torch.device
class TracingOpsHandler(WrapperHandler[T]):
def __init__(self, tracer, num_inputs):
parent = tracer.create_proxy("placeholder", "ops", (), {})
super().__init__(parent)
self.tracer = tracer
self.placeholders = [
self.tracer.create_proxy("placeholder", f"input{i}", (), {})
for i in range(num_inputs)
]
def placeholder(self, idx):
return self.placeholders[idx]
def output(self, *args):
return self.tracer.create_node(
"output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {}
)
def lower_pointwise_subgraph(subgraph: ir.Subgraph, inputs: List[InputDescriptor]):
# Lower subgraph to ir.Pointwise nodes
def fake_inner_fn(loop_idx, input_idx):
return ops.placeholder(input_idx)
graph_inputs = [
ir.Pointwise.create(
device=desc.device,
dtype=desc.dtype,
inner_fn=functools.partial(fake_inner_fn, input_idx=i),
ranges=[],
)
for i, desc in enumerate(inputs)
]
gm = subgraph.graph_module
pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph)
with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
pw_subgraph.run(*graph_inputs)
# Combine multiple pointwise computations into a single graph module
# Do this by tracing through each individually and doing CSE
tracer = torch.fx.Tracer()
tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
trace_ops = SimpleCSEHandler(TracingOpsHandler(tracer, len(inputs)))
assert pw_subgraph.graph_outputs is not None
with V.set_ops_handler(trace_ops):
output_irs = []
for out_var in pw_subgraph.graph_outputs:
assert isinstance(out_var, ir.TensorBox), type(out_var)
assert out_var.get_size() == []
assert isinstance(out_var.data, ir.StorageBox)
assert isinstance(out_var.data.data, ir.Pointwise)
idx = ()
ir_out = out_var.data.data.inner_fn(idx)
output_irs.append(ir_out)
ops.output(*output_irs)
lowered_gm = torch.fx.GraphModule({}, tracer.graph)
def inner_fn(*args, **kwargs):
return lowered_gm(V.get_ops_handler(), *args, **kwargs)
return inner_fn