mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[compile] Regional inductor compilation with fx.annotate (#164776)"
This reverts commit 1e4c7dffa3.
Reverted https://github.com/pytorch/pytorch/pull/164776 on behalf of https://github.com/malfet due to Looks like this one broke everything, not the top of the stack ([comment](https://github.com/pytorch/pytorch/pull/164776#issuecomment-3393725466))
This commit is contained in:
parent
a19123b37e
commit
8d49cd5b26
|
|
@ -1081,8 +1081,6 @@ coverage_ignore_functions = [
|
|||
"loop_pass",
|
||||
"these_before_those_pass_constraint",
|
||||
"this_before_that_pass_constraint",
|
||||
# torch.fx.passes.regional_inductor
|
||||
"regional_inductor",
|
||||
# torch.fx.passes.reinplace
|
||||
"reinplace",
|
||||
# torch.fx.passes.split_module
|
||||
|
|
|
|||
|
|
@ -1169,7 +1169,6 @@ The set of leaf modules can be customized by overriding
|
|||
.. py:module:: torch.fx.passes.operator_support
|
||||
.. py:module:: torch.fx.passes.param_fetch
|
||||
.. py:module:: torch.fx.passes.pass_manager
|
||||
.. py:module:: torch.fx.passes.regional_inductor
|
||||
.. py:module:: torch.fx.passes.reinplace
|
||||
.. py:module:: torch.fx.passes.runtime_assert
|
||||
.. py:module:: torch.fx.passes.shape_prop
|
||||
|
|
|
|||
|
|
@ -1,295 +0,0 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch._inductor.test_case
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._inductor.test_case import run_tests
|
||||
from torch._inductor.utils import run_fw_bw_and_get_code
|
||||
from torch.fx.passes.regional_inductor import regional_inductor
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
# Open questions / follow-ups
|
||||
# 1) CSE behavior with meta custom nodes
|
||||
# Common subexpression elimination may not differentiate between distinct meta
|
||||
# custom nodes and could remove expressions, which might confuse users.
|
||||
#
|
||||
# 2) SAC: recompute vs. forward size
|
||||
# If the recomputed forward is smaller than the original forward, do we end up
|
||||
# compiling only the smaller region?
|
||||
#
|
||||
# 3) fx_traceback.annotate nesting
|
||||
# How does nesting behave? Are there any ordering requirements?
|
||||
#
|
||||
# 4) Planned uses for annotations
|
||||
# a) compile flex
|
||||
# b) streams
|
||||
# c) nn.Module info to organize MoE runtime
|
||||
# d) pipeline-parallel stages
|
||||
# e) rename graph nodes for easier debugging
|
||||
# f) disallow nested regional compile
|
||||
|
||||
|
||||
def aot_eager_regional_inductor():
|
||||
return aot_autograd(
|
||||
fw_compiler=regional_inductor,
|
||||
bw_compiler=regional_inductor,
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable dynamo wrapped test")
|
||||
class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
||||
# TODO - should not need this because we should turn this on in Dynamo but
|
||||
# for some reasons, test fail.
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.cm = torch.fx.traceback.preserve_node_meta()
|
||||
self.cm.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
self.cm.__exit__(None, None, None)
|
||||
|
||||
def test_simple(self):
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
|
||||
return torch.sin(add)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
|
||||
# Check that inductor compilation is called twice
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x, y))
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
def test_repeated_blocks(self):
|
||||
def fn(x, y):
|
||||
sin = torch.sin(x)
|
||||
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
mul = sin * y
|
||||
add = mul + 1
|
||||
|
||||
return torch.sin(add)
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
a = fn(x, y)
|
||||
return fn(a, y)
|
||||
|
||||
mod = Mod()
|
||||
|
||||
opt_mod = torch.compile(
|
||||
mod, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
y = torch.randn(10, requires_grad=True)
|
||||
|
||||
# Check that inductor compilation is called 4 times
|
||||
# there will be 2 partitions in the fwd and 2 in the bwd, totalling 4
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_mod(x, y))
|
||||
self.assertEqual(len(codes), 4)
|
||||
|
||||
def test_invoke_subgraph(self):
|
||||
# Checks that get_attr nodes custom metadata is propagated
|
||||
@torch.compiler.nested_compile_region
|
||||
def gn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
z = gn(x)
|
||||
return torch.sigmoid(z)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
def test_invoke_subgraph_inner(self):
|
||||
# Checks that the inductor regions are searched recursively.
|
||||
@torch.compiler.nested_compile_region
|
||||
def gn(x):
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
return torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
x = gn(x)
|
||||
x = x + 1
|
||||
x = gn(x)
|
||||
return torch.sigmoid(x)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
|
||||
# the invoke_subgraph is called twice - but the inside code is compiled
|
||||
# once - so in total 2 (1 fwd + 1 bwd)
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_flex_attention(self):
|
||||
def _squared(score, b, h, m, n):
|
||||
return score * score
|
||||
|
||||
def mask_mod(b, h, q, k):
|
||||
return q >= 0
|
||||
|
||||
a = 12
|
||||
b = 64
|
||||
block_mask = create_block_mask(mask_mod, None, None, a * b, a * b)
|
||||
|
||||
def fn(x):
|
||||
x = torch.sin(x)
|
||||
with fx_traceback.annotate({"compile_with_inductor": 0}):
|
||||
x = flex_attention(x, x, x, block_mask=block_mask, score_mod=_squared)
|
||||
return torch.cos(x)
|
||||
|
||||
x = torch.randn(
|
||||
1,
|
||||
1,
|
||||
a * b,
|
||||
b,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
opt_fn = torch.compile(
|
||||
fn,
|
||||
backend=aot_eager_regional_inductor(),
|
||||
fullgraph=True,
|
||||
)
|
||||
|
||||
_, codes = run_fw_bw_and_get_code(lambda: opt_fn(x))
|
||||
# flex in forward and flex_backward in backward
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_selective_ac_flex(self):
|
||||
class FlexAttentionModule(torch.nn.Module):
|
||||
def __init__(self, hidden_size, num_heads):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
|
||||
# In-projections (query, key, value)
|
||||
self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
self.k_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
self.v_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
# Out-projection
|
||||
self.out_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
# Project queries, keys, and values
|
||||
q = (
|
||||
self.q_proj(x)
|
||||
.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
k = (
|
||||
self.k_proj(x)
|
||||
.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
v = (
|
||||
self.v_proj(x)
|
||||
.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
# Apply flex attention
|
||||
with torch.fx.traceback.annotate({"compile_with_inductor": 0}):
|
||||
attn_output = flex_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
|
||||
# Reshape output
|
||||
attn_output = (
|
||||
attn_output.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size, seq_len, self.hidden_size)
|
||||
)
|
||||
|
||||
# Out projection
|
||||
output = self.out_proj(attn_output)
|
||||
|
||||
return output
|
||||
|
||||
from torch.utils.checkpoint import (
|
||||
checkpoint,
|
||||
create_selective_checkpoint_contexts,
|
||||
)
|
||||
|
||||
ops_to_save = [
|
||||
torch.ops.aten.mm.default,
|
||||
]
|
||||
context_fn = functools.partial(
|
||||
create_selective_checkpoint_contexts, ops_to_save
|
||||
)
|
||||
|
||||
# Define a model that uses FlexAttention with selective activation checkpointing
|
||||
class SacModule(torch.nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, context_fn):
|
||||
super().__init__()
|
||||
self.flex_attn = FlexAttentionModule(hidden_size, num_heads)
|
||||
self.context_fn = context_fn
|
||||
|
||||
def forward(self, x):
|
||||
def flex_attn_fn(x):
|
||||
return self.flex_attn(x)
|
||||
|
||||
output = checkpoint(
|
||||
flex_attn_fn,
|
||||
x,
|
||||
use_reentrant=False,
|
||||
context_fn=self.context_fn,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
flex_module = SacModule(hidden_size=512, num_heads=8, context_fn=context_fn).to(
|
||||
"cuda", dtype=torch.bfloat16
|
||||
)
|
||||
x = torch.ones(8, 1024, 512, device="cuda", dtype=torch.bfloat16)
|
||||
compiled_module = torch.compile(
|
||||
flex_module, backend=aot_eager_regional_inductor(), fullgraph=True
|
||||
)
|
||||
|
||||
_, codes = run_fw_bw_and_get_code(lambda: compiled_module(x))
|
||||
# flex in forward and flex_backward in backward
|
||||
self.assertEqual(len(codes), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -854,7 +854,6 @@ def run_joint_graph_passes_on_hops(
|
|||
with joint_gm.graph.inserting_after(fw_node):
|
||||
new_fw_mod_attr_name = add_new_hop_gm(new_fw_hop_gm, f"fw{identifier}")
|
||||
new_fw_mod_attr = joint_gm.graph.get_attr(new_fw_mod_attr_name)
|
||||
new_fw_mod_attr.meta = copy.copy(fw_node.args[0].meta)
|
||||
|
||||
# new_hop_fw_gm output signature is (*fw_outs, *saved_tensors)
|
||||
with joint_gm.graph.inserting_after(new_fw_mod_attr):
|
||||
|
|
@ -907,7 +906,6 @@ def run_joint_graph_passes_on_hops(
|
|||
with joint_gm.graph.inserting_after(bw_node):
|
||||
new_bw_mod_attr_name = add_new_hop_gm(new_bw_hop_gm, bw_node.args[1])
|
||||
new_bw_mod_attr = joint_gm.graph.get_attr(new_bw_mod_attr_name)
|
||||
new_bw_mod_attr.meta = copy.copy(bw_node.args[0].meta)
|
||||
|
||||
with joint_gm.graph.inserting_after(new_bw_mod_attr):
|
||||
new_bw_node = joint_gm.graph.call_function(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from . import (
|
|||
net_min_base,
|
||||
operator_support,
|
||||
param_fetch,
|
||||
regional_inductor,
|
||||
reinplace,
|
||||
runtime_assert,
|
||||
shape_prop,
|
||||
|
|
|
|||
|
|
@ -1,133 +0,0 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
import functools
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["regional_inductor"]
|
||||
|
||||
|
||||
# standalone_inductor returns a callable class object - this does not sit well
|
||||
# with Fx graph node op call_function which expects a function. So this is just
|
||||
# a wrapper function to make Fx graph codegen happy.
|
||||
def _dummy_wrapper(fn):
|
||||
@functools.wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def _partition_by_supported_nodes(gm, supported_ops, prefix):
|
||||
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
||||
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
|
||||
|
||||
partitioner = CapabilityBasedPartitioner(
|
||||
gm, supported_ops, allows_single_node_partition=True
|
||||
)
|
||||
|
||||
candidate_partitions = partitioner.propose_partitions()
|
||||
partitioned_gm = fuse_by_partitions(
|
||||
partitioner.graph_module,
|
||||
[partition.nodes for partition in candidate_partitions],
|
||||
prefix=prefix,
|
||||
always_return_tuple=True,
|
||||
)
|
||||
|
||||
return partitioned_gm
|
||||
|
||||
|
||||
def _compile_submod(gm, prefix):
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_module" and node.target.startswith(prefix):
|
||||
fake_inputs = []
|
||||
for inp_node in node.all_input_nodes:
|
||||
if hasattr(inp_node, "meta") and "val" in inp_node.meta:
|
||||
fake_inputs.append(inp_node.meta["val"])
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Partition is bad because non fake tensor value is seen {inp_node}"
|
||||
)
|
||||
|
||||
submod = getattr(gm, node.target)
|
||||
|
||||
# _dummy_wrapper is to make call_function happy
|
||||
compiled_submod = _dummy_wrapper(
|
||||
torch._inductor.standalone_compile(
|
||||
submod, fake_inputs, dynamic_shapes="from_tracing_context"
|
||||
)
|
||||
)
|
||||
|
||||
with gm.graph.inserting_after(node):
|
||||
new_node = gm.graph.call_function(
|
||||
compiled_submod, args=node.args, kwargs=node.kwargs
|
||||
)
|
||||
new_node.meta = node.meta
|
||||
node.replace_all_uses_with(new_node)
|
||||
gm.graph.erase_node(node)
|
||||
del gm._modules[node.target]
|
||||
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def _needs_inductor_compile(node):
|
||||
return (
|
||||
node.op not in ("placeholder", "output")
|
||||
and hasattr(node, "meta")
|
||||
and node.meta.get("custom", None)
|
||||
and "compile_with_inductor" in node.meta["custom"]
|
||||
)
|
||||
|
||||
|
||||
def _compile_fx_annotated_nodes_with_inductor(gm):
|
||||
from torch.fx.passes.operator_support import OperatorSupport
|
||||
|
||||
found_marked_node = False
|
||||
for node in gm.graph.nodes:
|
||||
if _needs_inductor_compile(node):
|
||||
found_marked_node = True
|
||||
break
|
||||
|
||||
if not found_marked_node:
|
||||
logger.info("No inductor marked nodes found")
|
||||
return gm
|
||||
|
||||
class InductorMarkedNodes(OperatorSupport):
|
||||
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
|
||||
return _needs_inductor_compile(node)
|
||||
|
||||
marked_nodes = InductorMarkedNodes()
|
||||
gm = _partition_by_supported_nodes(gm, marked_nodes, "__marked_inductor_submod")
|
||||
gm = _compile_submod(gm, "__marked_inductor_submod")
|
||||
return gm
|
||||
|
||||
|
||||
def _recursive_compile_fx_annotated_nodes_with_inductor(gm):
|
||||
for node in gm.graph.find_nodes(op="get_attr"):
|
||||
if _needs_inductor_compile(node):
|
||||
# If the get_attr itself is marked for compile, the outer graph will
|
||||
# take care of it. If we dont do that, we end up with nested
|
||||
# regional inductor compiles that do not work well.
|
||||
continue
|
||||
submod = getattr(gm, node.target)
|
||||
if isinstance(submod, torch.fx.GraphModule):
|
||||
_recursive_compile_fx_annotated_nodes_with_inductor(submod)
|
||||
|
||||
return _compile_fx_annotated_nodes_with_inductor(gm)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def regional_inductor(gm, *example_args):
|
||||
"""
|
||||
Scoops out inductor marked regions and compiles them with inductor.
|
||||
"""
|
||||
# fuser utils create new nodes using create_proxy which retains the seq_nr
|
||||
# metadata and cause issues
|
||||
with torch.fx.traceback.preserve_node_meta(enable=False):
|
||||
return _recursive_compile_fx_annotated_nodes_with_inductor(gm)
|
||||
Loading…
Reference in New Issue
Block a user