mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
AC should work with pre-dispatch IR (#164505)
Previously we had to rely on turning off export verifier because the AC body was torch IR instead of aten IR. This PR makes it so that we create an IR that is export compatible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164505 Approved by: https://github.com/ydwu4, https://github.com/xmfan
This commit is contained in:
parent
660e369a68
commit
91c211fb8c
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["oncall: export"]
|
||||
# ruff: noqa: F841
|
||||
# flake8: noqa
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
|
|
@ -39,6 +40,7 @@ from torch._export.utils import (
|
|||
is_param,
|
||||
register_dataclass_as_pytree_node,
|
||||
)
|
||||
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||
from torch._higher_order_ops.associative_scan import associative_scan
|
||||
from torch._higher_order_ops.hints_wrap import hints_wrapper
|
||||
from torch._higher_order_ops.scan import scan
|
||||
|
|
@ -1141,6 +1143,118 @@ graph():
|
|||
# instead of the scripted function, so we get x.sin()
|
||||
self.assertEqual(res, x.sin())
|
||||
|
||||
def test_tag_ac_export(self):
|
||||
ops_to_save = [torch.ops.aten.addmm.default]
|
||||
|
||||
def policy_fn(ctx, op, *args, **wargs):
|
||||
if op in ops_to_save:
|
||||
return torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE
|
||||
else:
|
||||
return torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE
|
||||
|
||||
context_fn = functools.partial(
|
||||
torch.utils.checkpoint.create_selective_checkpoint_contexts, policy_fn
|
||||
)
|
||||
|
||||
class Block(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(128, 128)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.linear2 = torch.nn.Linear(128, 128)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear2(self.relu(self.linear1(x)))
|
||||
|
||||
# Wrap the block with checkpointing
|
||||
class CheckpointedBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.block = Block()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
self.block, x, context_fn=context_fn
|
||||
)
|
||||
|
||||
model = CheckpointedBlock()
|
||||
x = torch.randn(16, 128, requires_grad=True)
|
||||
|
||||
ep = torch.export.export(model, (x,), strict=True)
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%p_block_linear1_weight : [num_users=1] = placeholder[target=p_block_linear1_weight]
|
||||
%p_block_linear1_bias : [num_users=1] = placeholder[target=p_block_linear1_bias]
|
||||
%p_block_linear2_weight : [num_users=1] = placeholder[target=p_block_linear2_weight]
|
||||
%p_block_linear2_bias : [num_users=1] = placeholder[target=p_block_linear2_bias]
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%wrap_body0 : [num_users=1] = get_attr[target=wrap_body0]
|
||||
%tag_activation_checkpoint : [num_users=1] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body0, %x, %p_block_linear1_weight, %p_block_linear1_bias, %p_block_linear2_weight, %p_block_linear2_bias), kwargs = {})
|
||||
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 0), kwargs = {})
|
||||
return (getitem,)""",
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.wrap_body0.graph).strip(),
|
||||
"""\
|
||||
graph():
|
||||
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
|
||||
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
|
||||
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
|
||||
%arg3_1 : [num_users=1] = placeholder[target=arg3_1]
|
||||
%arg4_1 : [num_users=1] = placeholder[target=arg4_1]
|
||||
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg0_1, %arg1_1, %arg2_1), kwargs = {})
|
||||
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear,), kwargs = {})
|
||||
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %arg3_1, %arg4_1), kwargs = {})
|
||||
return (linear_1,)""",
|
||||
)
|
||||
|
||||
stack = contextlib.ExitStack()
|
||||
|
||||
with stack:
|
||||
jwd = aot_export_joint_with_descriptors(stack, ep.module(), (x,))
|
||||
for node in jwd.graph_module.graph.nodes:
|
||||
if "recompute" in node.meta:
|
||||
actual = node.meta["recompute"]
|
||||
expected = policy_fn(None, node.target, None, None)
|
||||
self.assertEqual(actual, expected)
|
||||
self.assertExpectedInline(
|
||||
str(jwd.graph_module.code).strip(),
|
||||
"""\
|
||||
def forward(self, primals, tangents):
|
||||
primals_1, primals_2, primals_3, primals_4, primals_5, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
|
||||
t = torch.ops.aten.t.default(primals_1); primals_1 = None
|
||||
addmm = torch.ops.aten.addmm.default(primals_2, primals_5, t); primals_2 = None
|
||||
relu = torch.ops.aten.relu.default(addmm); addmm = None
|
||||
detach_9 = torch.ops.aten.detach.default(relu)
|
||||
detach_10 = torch.ops.aten.detach.default(detach_9); detach_9 = None
|
||||
detach_11 = torch.ops.aten.detach.default(detach_10); detach_10 = None
|
||||
t_1 = torch.ops.aten.t.default(primals_3); primals_3 = None
|
||||
addmm_1 = torch.ops.aten.addmm.default(primals_4, relu, t_1); primals_4 = None
|
||||
t_2 = torch.ops.aten.t.default(t_1); t_1 = None
|
||||
mm = torch.ops.aten.mm.default(tangents_1, t_2); t_2 = None
|
||||
t_3 = torch.ops.aten.t.default(tangents_1)
|
||||
mm_1 = torch.ops.aten.mm.default(t_3, relu); t_3 = relu = None
|
||||
t_4 = torch.ops.aten.t.default(mm_1); mm_1 = None
|
||||
sum_1 = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None
|
||||
view = torch.ops.aten.view.default(sum_1, [128]); sum_1 = None
|
||||
t_5 = torch.ops.aten.t.default(t_4); t_4 = None
|
||||
detach_18 = torch.ops.aten.detach.default(detach_11); detach_11 = None
|
||||
detach_19 = torch.ops.aten.detach.default(detach_18); detach_18 = None
|
||||
threshold_backward = torch.ops.aten.threshold_backward.default(mm, detach_19, 0); mm = detach_19 = None
|
||||
t_6 = torch.ops.aten.t.default(t); t = None
|
||||
mm_2 = torch.ops.aten.mm.default(threshold_backward, t_6); t_6 = None
|
||||
t_7 = torch.ops.aten.t.default(threshold_backward)
|
||||
mm_3 = torch.ops.aten.mm.default(t_7, primals_5); t_7 = primals_5 = None
|
||||
t_8 = torch.ops.aten.t.default(mm_3); mm_3 = None
|
||||
sum_2 = torch.ops.aten.sum.dim_IntList(threshold_backward, [0], True); threshold_backward = None
|
||||
view_1 = torch.ops.aten.view.default(sum_2, [128]); sum_2 = None
|
||||
t_9 = torch.ops.aten.t.default(t_8); t_8 = None
|
||||
return pytree.tree_unflatten([addmm_1, t_9, view_1, t_5, view, mm_2], self._out_spec)""",
|
||||
)
|
||||
|
||||
def test_inline_script_class_method_recursive(self):
|
||||
f = 0.4
|
||||
i = 2
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any, Optional
|
|||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._higher_order_ops.utils import reenter_make_fx
|
||||
from torch._logging import warning_once
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch.fx import GraphModule
|
||||
|
|
@ -312,6 +313,9 @@ def proxy_mode_key(
|
|||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> tuple[torch.Tensor]:
|
||||
import torch.fx.traceback as fx_traceback
|
||||
from torch.fx import Interpreter
|
||||
|
||||
assert proxy_mode.pre_dispatch, (
|
||||
"post-dispatch mode should have inlined in the Autograd key"
|
||||
)
|
||||
|
|
@ -319,11 +323,14 @@ def proxy_mode_key(
|
|||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) # type: ignore[union-attr]
|
||||
proxy_kwargs = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, kwargs) # type: ignore[union-attr]
|
||||
qualname = proxy_mode.tracer.get_fresh_qualname("wrap_body") # type: ignore[union-attr]
|
||||
proxy_mode.tracer.root.register_module(qualname, gmod) # type: ignore[union-attr]
|
||||
proxy_gmod = proxy_mode.tracer.unwrap_proxy(gmod) # type: ignore[union-attr, call-overload]
|
||||
for node in proxy_gmod.graph.nodes:
|
||||
if "example_value" in node.meta:
|
||||
node.meta["val"] = node.meta["example_value"]
|
||||
|
||||
# TODO (tmanlaibaatar) don't we need flat_apply here??
|
||||
flat_args, _ = pytree.tree_flatten((args, kwargs))
|
||||
with fx_traceback.preserve_node_meta():
|
||||
gmod_aten = reenter_make_fx(Interpreter(gmod).run)(*flat_args)
|
||||
gmod_aten.meta["_checkpoint_context_fn"] = gmod.meta["_checkpoint_context_fn"]
|
||||
proxy_mode.tracer.root.register_module(qualname, gmod_aten) # type: ignore[union-attr]
|
||||
proxy_gmod = proxy_mode.tracer.unwrap_proxy(gmod_aten) # type: ignore[union-attr, call-overload]
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function",
|
||||
tag_activation_checkpoint,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user