Revert "[compile] Regional inductor compilation with fx.annotate (#164776)"

This reverts commit 1e4c7dffa3.

Reverted https://github.com/pytorch/pytorch/pull/164776 on behalf of https://github.com/malfet due to Looks like this one broke everything, not the top of the stack ([comment](https://github.com/pytorch/pytorch/pull/164776#issuecomment-3393725466))
This commit is contained in:
PyTorch MergeBot 2025-10-11 23:14:23 +00:00
parent a19123b37e
commit 8d49cd5b26
6 changed files with 0 additions and 434 deletions

View File

@ -1081,8 +1081,6 @@ coverage_ignore_functions = [
"loop_pass",
"these_before_those_pass_constraint",
"this_before_that_pass_constraint",
# torch.fx.passes.regional_inductor
"regional_inductor",
# torch.fx.passes.reinplace
"reinplace",
# torch.fx.passes.split_module

View File

@ -1169,7 +1169,6 @@ The set of leaf modules can be customized by overriding
.. py:module:: torch.fx.passes.operator_support
.. py:module:: torch.fx.passes.param_fetch
.. py:module:: torch.fx.passes.pass_manager
.. py:module:: torch.fx.passes.regional_inductor
.. py:module:: torch.fx.passes.reinplace
.. py:module:: torch.fx.passes.runtime_assert
.. py:module:: torch.fx.passes.shape_prop

View File

@ -1,295 +0,0 @@
# Owner(s): ["module: dynamo"]
import functools
import torch
import torch._inductor.test_case
import torch.fx.traceback as fx_traceback
import torch.utils.checkpoint
from torch._dynamo.backends.common import aot_autograd
from torch._inductor.test_case import run_tests
from torch._inductor.utils import run_fw_bw_and_get_code
from torch.fx.passes.regional_inductor import regional_inductor
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.triton_utils import requires_cuda_and_triton
# Open questions / follow-ups
# 1) CSE behavior with meta custom nodes
# Common subexpression elimination may not differentiate between distinct meta
# custom nodes and could remove expressions, which might confuse users.
#
# 2) SAC: recompute vs. forward size
# If the recomputed forward is smaller than the original forward, do we end up
# compiling only the smaller region?
#
# 3) fx_traceback.annotate nesting
# How does nesting behave? Are there any ordering requirements?
#
# 4) Planned uses for annotations
# a) compile flex
# b) streams
# c) nn.Module info to organize MoE runtime
# d) pipeline-parallel stages
# e) rename graph nodes for easier debugging
# f) disallow nested regional compile
def aot_eager_regional_inductor():
return aot_autograd(
fw_compiler=regional_inductor,
bw_compiler=regional_inductor,
)
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
class RegionalInductorTests(torch._inductor.test_case.TestCase):
# TODO - should not need this because we should turn this on in Dynamo but
# for some reasons, test fail.
def setUp(self):
super().setUp()
self.cm = torch.fx.traceback.preserve_node_meta()
self.cm.__enter__()
def tearDown(self):
super().tearDown()
self.cm.__exit__(None, None, None)
def test_simple(self):
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
opt_fn = torch.compile(
fn, backend=aot_eager_regional_inductor(), fullgraph=True
)
x = torch.randn(10, requires_grad=True)
y = torch.randn(10, requires_grad=True)
# Check that inductor compilation is called twice
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x, y))
self.assertEqual(len(codes), 2)
def test_repeated_blocks(self):
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
a = fn(x, y)
return fn(a, y)
mod = Mod()
opt_mod = torch.compile(
mod, backend=aot_eager_regional_inductor(), fullgraph=True
)
x = torch.randn(10, requires_grad=True)
y = torch.randn(10, requires_grad=True)
# Check that inductor compilation is called 4 times
# there will be 2 partitions in the fwd and 2 in the bwd, totalling 4
_, codes = run_fw_bw_and_get_code(lambda: opt_mod(x, y))
self.assertEqual(len(codes), 4)
def test_invoke_subgraph(self):
# Checks that get_attr nodes custom metadata is propagated
@torch.compiler.nested_compile_region
def gn(x):
return torch.sin(x)
def fn(x):
x = x + 1
with fx_traceback.annotate({"compile_with_inductor": 0}):
z = gn(x)
return torch.sigmoid(z)
opt_fn = torch.compile(
fn, backend=aot_eager_regional_inductor(), fullgraph=True
)
x = torch.randn(10, requires_grad=True)
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
self.assertEqual(len(codes), 2)
def test_invoke_subgraph_inner(self):
# Checks that the inductor regions are searched recursively.
@torch.compiler.nested_compile_region
def gn(x):
with fx_traceback.annotate({"compile_with_inductor": 0}):
return torch.sin(x)
def fn(x):
x = x + 1
x = gn(x)
x = x + 1
x = gn(x)
return torch.sigmoid(x)
opt_fn = torch.compile(
fn, backend=aot_eager_regional_inductor(), fullgraph=True
)
x = torch.randn(10, requires_grad=True)
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
# the invoke_subgraph is called twice - but the inside code is compiled
# once - so in total 2 (1 fwd + 1 bwd)
self.assertEqual(len(codes), 2)
@requires_cuda_and_triton
def test_flex_attention(self):
def _squared(score, b, h, m, n):
return score * score
def mask_mod(b, h, q, k):
return q >= 0
a = 12
b = 64
block_mask = create_block_mask(mask_mod, None, None, a * b, a * b)
def fn(x):
x = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
return torch.cos(x)
x = torch.randn(
1,
1,
a * b,
b,
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
opt_fn = torch.compile(
fn,
backend=aot_eager_regional_inductor(),
fullgraph=True,
)
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
# flex in forward and flex_backward in backward
self.assertEqual(len(codes), 2)
@requires_cuda_and_triton
def test_selective_ac_flex(self):
class FlexAttentionModule(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
# In-projections (query, key, value)
self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
self.k_proj = torch.nn.Linear(hidden_size, hidden_size)
self.v_proj = torch.nn.Linear(hidden_size, hidden_size)
# Out-projection
self.out_proj = torch.nn.Linear(hidden_size, hidden_size)
def forward(self, x):
batch_size, seq_len, _ = x.size()
# Project queries, keys, and values
q = (
self.q_proj(x)
.view(batch_size, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
k = (
self.k_proj(x)
.view(batch_size, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
v = (
self.v_proj(x)
.view(batch_size, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# Apply flex attention
with torch.fx.traceback.annotate({"compile_with_inductor": 0}):
attn_output = flex_attention(
q,
k,
v,
)
# Reshape output
attn_output = (
attn_output.transpose(1, 2)
.contiguous()
.view(batch_size, seq_len, self.hidden_size)
)
# Out projection
output = self.out_proj(attn_output)
return output
from torch.utils.checkpoint import (
checkpoint,
create_selective_checkpoint_contexts,
)
ops_to_save = [
torch.ops.aten.mm.default,
]
context_fn = functools.partial(
create_selective_checkpoint_contexts, ops_to_save
)
# Define a model that uses FlexAttention with selective activation checkpointing
class SacModule(torch.nn.Module):
def __init__(self, hidden_size, num_heads, context_fn):
super().__init__()
self.flex_attn = FlexAttentionModule(hidden_size, num_heads)
self.context_fn = context_fn
def forward(self, x):
def flex_attn_fn(x):
return self.flex_attn(x)
output = checkpoint(
flex_attn_fn,
x,
use_reentrant=False,
context_fn=self.context_fn,
)
return output
flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to(
"cuda", dtype=torch.bfloat16
)
x = torch.ones(8, 1024, 512, device="cuda", dtype=torch.bfloat16)
compiled_module = torch.compile(
flex_module, backend=aot_eager_regional_inductor(), fullgraph=True
)
_, codes = run_fw_bw_and_get_code(lambda: compiled_module(x))
# flex in forward and flex_backward in backward
self.assertEqual(len(codes), 2)
if __name__ == "__main__":
run_tests()

View File

@ -854,7 +854,6 @@ def run_joint_graph_passes_on_hops(
with joint_gm.graph.inserting_after(fw_node):
new_fw_mod_attr_name = add_new_hop_gm(new_fw_hop_gm, f"fw{identifier}")
new_fw_mod_attr = joint_gm.graph.get_attr(new_fw_mod_attr_name)
new_fw_mod_attr.meta = copy.copy(fw_node.args[0].meta)
# new_hop_fw_gm output signature is (*fw_outs, *saved_tensors)
with joint_gm.graph.inserting_after(new_fw_mod_attr):
@ -907,7 +906,6 @@ def run_joint_graph_passes_on_hops(
with joint_gm.graph.inserting_after(bw_node):
new_bw_mod_attr_name = add_new_hop_gm(new_bw_hop_gm, bw_node.args[1])
new_bw_mod_attr = joint_gm.graph.get_attr(new_bw_mod_attr_name)
new_bw_mod_attr.meta = copy.copy(bw_node.args[0].meta)
with joint_gm.graph.inserting_after(new_bw_mod_attr):
new_bw_node = joint_gm.graph.call_function(

View File

@ -4,7 +4,6 @@ from . import (
net_min_base,
operator_support,
param_fetch,
regional_inductor,
reinplace,
runtime_assert,
shape_prop,

View File

@ -1,133 +0,0 @@
# mypy: allow-untyped-defs
import functools
import logging
import torch
from torch.fx._compatibility import compatibility
logger = logging.getLogger(__name__)
__all__ = ["regional_inductor"]
# standalone_inductor returns a callable class object - this does not sit well
# with Fx graph node op call_function which expects a function. So this is just
# a wrapper function to make Fx graph codegen happy.
def _dummy_wrapper(fn):
@functools.wraps(fn)
def inner(*args, **kwargs):
return fn(*args, **kwargs)
return inner
def _partition_by_supported_nodes(gm, supported_ops, prefix):
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
partitioner = CapabilityBasedPartitioner(
gm, supported_ops, allows_single_node_partition=True
)
candidate_partitions = partitioner.propose_partitions()
partitioned_gm = fuse_by_partitions(
partitioner.graph_module,
[partition.nodes for partition in candidate_partitions],
prefix=prefix,
always_return_tuple=True,
)
return partitioned_gm
def _compile_submod(gm, prefix):
for node in gm.graph.nodes:
if node.op == "call_module" and node.target.startswith(prefix):
fake_inputs = []
for inp_node in node.all_input_nodes:
if hasattr(inp_node, "meta") and "val" in inp_node.meta:
fake_inputs.append(inp_node.meta["val"])
else:
raise RuntimeError(
f"Partition is bad because non fake tensor value is seen {inp_node}"
)
submod = getattr(gm, node.target)
# _dummy_wrapper is to make call_function happy
compiled_submod = _dummy_wrapper(
torch._inductor.standalone_compile(
submod, fake_inputs, dynamic_shapes="from_tracing_context"
)
)
with gm.graph.inserting_after(node):
new_node = gm.graph.call_function(
compiled_submod, args=node.args, kwargs=node.kwargs
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
gm.graph.erase_node(node)
del gm._modules[node.target]
gm.recompile()
return gm
def _needs_inductor_compile(node):
return (
node.op not in ("placeholder", "output")
and hasattr(node, "meta")
and node.meta.get("custom", None)
and "compile_with_inductor" in node.meta["custom"]
)
def _compile_fx_annotated_nodes_with_inductor(gm):
from torch.fx.passes.operator_support import OperatorSupport
found_marked_node = False
for node in gm.graph.nodes:
if _needs_inductor_compile(node):
found_marked_node = True
break
if not found_marked_node:
logger.info("No inductor marked nodes found")
return gm
class InductorMarkedNodes(OperatorSupport):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return _needs_inductor_compile(node)
marked_nodes = InductorMarkedNodes()
gm = _partition_by_supported_nodes(gm, marked_nodes, "__marked_inductor_submod")
gm = _compile_submod(gm, "__marked_inductor_submod")
return gm
def _recursive_compile_fx_annotated_nodes_with_inductor(gm):
for node in gm.graph.find_nodes(op="get_attr"):
if _needs_inductor_compile(node):
# If the get_attr itself is marked for compile, the outer graph will
# take care of it. If we dont do that, we end up with nested
# regional inductor compiles that do not work well.
continue
submod = getattr(gm, node.target)
if isinstance(submod, torch.fx.GraphModule):
_recursive_compile_fx_annotated_nodes_with_inductor(submod)
return _compile_fx_annotated_nodes_with_inductor(gm)
@compatibility(is_backward_compatible=False)
def regional_inductor(gm, *example_args):
"""
Scoops out inductor marked regions and compiles them with inductor.
"""
# fuser utils create new nodes using create_proxy which retains the seq_nr
# metadata and cause issues
with torch.fx.traceback.preserve_node_meta(enable=False):
return _recursive_compile_fx_annotated_nodes_with_inductor(gm)