AC should work with pre-dispatch IR (#164505)

Previously we had to rely on turning off export verifier because the AC body was torch IR instead of aten IR. This PR makes it so that we create an IR that is export compatible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164505
Approved by: https://github.com/ydwu4, https://github.com/xmfan
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2025-10-03 09:59:50 -07:00 committed by PyTorch MergeBot
parent 660e369a68
commit 91c211fb8c
2 changed files with 126 additions and 5 deletions

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: export"]
# ruff: noqa: F841
# flake8: noqa
import contextlib
import copy
import dataclasses
import functools
@ -39,6 +40,7 @@ from torch._export.utils import (
is_param,
register_dataclass_as_pytree_node,
)
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._higher_order_ops.associative_scan import associative_scan
from torch._higher_order_ops.hints_wrap import hints_wrapper
from torch._higher_order_ops.scan import scan
@ -1141,6 +1143,118 @@ graph():
# instead of the scripted function, so we get x.sin()
self.assertEqual(res, x.sin())
def test_tag_ac_export(self):
ops_to_save = [torch.ops.aten.addmm.default]
def policy_fn(ctx, op, *args, **wargs):
if op in ops_to_save:
return torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE
else:
return torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE
context_fn = functools.partial(
torch.utils.checkpoint.create_selective_checkpoint_contexts, policy_fn
)
class Block(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(128, 128)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(128, 128)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
# Wrap the block with checkpointing
class CheckpointedBlock(torch.nn.Module):
def __init__(self):
super().__init__()
self.block = Block()
def forward(self, x):
return torch.utils.checkpoint.checkpoint(
self.block, x, context_fn=context_fn
)
model = CheckpointedBlock()
x = torch.randn(16, 128, requires_grad=True)
ep = torch.export.export(model, (x,), strict=True)
self.assertExpectedInline(
str(ep.graph).strip(),
"""\
graph():
%p_block_linear1_weight : [num_users=1] = placeholder[target=p_block_linear1_weight]
%p_block_linear1_bias : [num_users=1] = placeholder[target=p_block_linear1_bias]
%p_block_linear2_weight : [num_users=1] = placeholder[target=p_block_linear2_weight]
%p_block_linear2_bias : [num_users=1] = placeholder[target=p_block_linear2_bias]
%x : [num_users=1] = placeholder[target=x]
%wrap_body0 : [num_users=1] = get_attr[target=wrap_body0]
%tag_activation_checkpoint : [num_users=1] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body0, %x, %p_block_linear1_weight, %p_block_linear1_bias, %p_block_linear2_weight, %p_block_linear2_bias), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 0), kwargs = {})
return (getitem,)""",
)
self.assertExpectedInline(
str(ep.graph_module.wrap_body0.graph).strip(),
"""\
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%arg3_1 : [num_users=1] = placeholder[target=arg3_1]
%arg4_1 : [num_users=1] = placeholder[target=arg4_1]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg0_1, %arg1_1, %arg2_1), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear,), kwargs = {})
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %arg3_1, %arg4_1), kwargs = {})
return (linear_1,)""",
)
stack = contextlib.ExitStack()
with stack:
jwd = aot_export_joint_with_descriptors(stack, ep.module(), (x,))
for node in jwd.graph_module.graph.nodes:
if "recompute" in node.meta:
actual = node.meta["recompute"]
expected = policy_fn(None, node.target, None, None)
self.assertEqual(actual, expected)
self.assertExpectedInline(
str(jwd.graph_module.code).strip(),
"""\
def forward(self, primals, tangents):
primals_1, primals_2, primals_3, primals_4, primals_5, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
t = torch.ops.aten.t.default(primals_1); primals_1 = None
addmm = torch.ops.aten.addmm.default(primals_2, primals_5, t); primals_2 = None
relu = torch.ops.aten.relu.default(addmm); addmm = None
detach_9 = torch.ops.aten.detach.default(relu)
detach_10 = torch.ops.aten.detach.default(detach_9); detach_9 = None
detach_11 = torch.ops.aten.detach.default(detach_10); detach_10 = None
t_1 = torch.ops.aten.t.default(primals_3); primals_3 = None
addmm_1 = torch.ops.aten.addmm.default(primals_4, relu, t_1); primals_4 = None
t_2 = torch.ops.aten.t.default(t_1); t_1 = None
mm = torch.ops.aten.mm.default(tangents_1, t_2); t_2 = None
t_3 = torch.ops.aten.t.default(tangents_1)
mm_1 = torch.ops.aten.mm.default(t_3, relu); t_3 = relu = None
t_4 = torch.ops.aten.t.default(mm_1); mm_1 = None
sum_1 = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None
view = torch.ops.aten.view.default(sum_1, [128]); sum_1 = None
t_5 = torch.ops.aten.t.default(t_4); t_4 = None
detach_18 = torch.ops.aten.detach.default(detach_11); detach_11 = None
detach_19 = torch.ops.aten.detach.default(detach_18); detach_18 = None
threshold_backward = torch.ops.aten.threshold_backward.default(mm, detach_19, 0); mm = detach_19 = None
t_6 = torch.ops.aten.t.default(t); t = None
mm_2 = torch.ops.aten.mm.default(threshold_backward, t_6); t_6 = None
t_7 = torch.ops.aten.t.default(threshold_backward)
mm_3 = torch.ops.aten.mm.default(t_7, primals_5); t_7 = primals_5 = None
t_8 = torch.ops.aten.t.default(mm_3); mm_3 = None
sum_2 = torch.ops.aten.sum.dim_IntList(threshold_backward, [0], True); threshold_backward = None
view_1 = torch.ops.aten.view.default(sum_2, [128]); sum_2 = None
t_9 = torch.ops.aten.t.default(t_8); t_8 = None
return pytree.tree_unflatten([addmm_1, t_9, view_1, t_5, view, mm_2], self._out_spec)""",
)
def test_inline_script_class_method_recursive(self):
f = 0.4
i = 2

View File

@ -6,6 +6,7 @@ from typing import Any, Optional
import torch
import torch.utils._pytree as pytree
from torch._higher_order_ops.utils import reenter_make_fx
from torch._logging import warning_once
from torch._ops import HigherOrderOperator
from torch.fx import GraphModule
@ -312,6 +313,9 @@ def proxy_mode_key(
*args: Any,
**kwargs: Any,
) -> tuple[torch.Tensor]:
import torch.fx.traceback as fx_traceback
from torch.fx import Interpreter
assert proxy_mode.pre_dispatch, (
"post-dispatch mode should have inlined in the Autograd key"
)
@ -319,11 +323,14 @@ def proxy_mode_key(
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) # type: ignore[union-attr]
proxy_kwargs = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, kwargs) # type: ignore[union-attr]
qualname = proxy_mode.tracer.get_fresh_qualname("wrap_body") # type: ignore[union-attr]
proxy_mode.tracer.root.register_module(qualname, gmod) # type: ignore[union-attr]
proxy_gmod = proxy_mode.tracer.unwrap_proxy(gmod) # type: ignore[union-attr, call-overload]
for node in proxy_gmod.graph.nodes:
if "example_value" in node.meta:
node.meta["val"] = node.meta["example_value"]
# TODO (tmanlaibaatar) don't we need flat_apply here??
flat_args, _ = pytree.tree_flatten((args, kwargs))
with fx_traceback.preserve_node_meta():
gmod_aten = reenter_make_fx(Interpreter(gmod).run)(*flat_args)
gmod_aten.meta["_checkpoint_context_fn"] = gmod.meta["_checkpoint_context_fn"]
proxy_mode.tracer.root.register_module(qualname, gmod_aten) # type: ignore[union-attr]
proxy_gmod = proxy_mode.tracer.unwrap_proxy(gmod_aten) # type: ignore[union-attr, call-overload]
out_proxy = proxy_mode.tracer.create_proxy(
"call_function",
tag_activation_checkpoint,