mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Don't record autograd state ops while torch.compile in pre-dispatch export (#121736)
Summary: Refer to OSS PR for details Test Plan: CI Differential Revision: D54812833 In pre-dispatch export, we have a special proxy torch mode where we intercept torch._C._set_grad_enabled op to correctly capture user's intention on train/eval. However, this is bit problematic when we are tracing torch.cond during export as it calls torch.compile internally. As a result, we end up capturing unwanted autograd context manager calls that are happening inside dynamo framework code because the top level tracer is still active. We fix it by turning off this proxy torch mode. We can still capture autograd ops inside cond branches because dynamo will translate them into HOP for us, so we don't have to intercept with special proxy mode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121736 Approved by: https://github.com/anijain2305, https://github.com/ydwu4
This commit is contained in:
parent
bd7beef529
commit
57b20c51b9
|
|
@ -3399,6 +3399,52 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
# this doesn't work today
|
||||
gm_unflat_strict = unflatten(ep)
|
||||
|
||||
def test_predispatch_cond(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("pred", torch.tensor(False))
|
||||
self.register_buffer("t", torch.tensor(10))
|
||||
|
||||
def forward(self, x, y):
|
||||
def true_fn(x, y):
|
||||
with torch.enable_grad():
|
||||
return x - 1 + self.t + y
|
||||
|
||||
return torch.cond(
|
||||
self.pred,
|
||||
true_fn,
|
||||
lambda x, y: x + 1 - self.t + y,
|
||||
[x, y],
|
||||
)
|
||||
|
||||
model = Model()
|
||||
exported_program = torch.export._trace._export(
|
||||
model,
|
||||
(torch.tensor(10), torch.tensor(12)),
|
||||
{},
|
||||
dynamic_shapes=None,
|
||||
pre_dispatch=True,
|
||||
strict=False
|
||||
)
|
||||
|
||||
self.assertExpectedInline(str(exported_program.graph_module.code.strip()), """\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(arg0_1, true_graph_0, false_graph_0, [arg1_1, arg2_1, arg3_1]); arg0_1 = true_graph_0 = false_graph_0 = arg1_1 = arg2_1 = arg3_1 = None
|
||||
getitem = conditional[0]; conditional = None
|
||||
return (getitem,)""") # noqa: B950
|
||||
|
||||
self.assertExpectedInline(str(exported_program.graph_module.true_graph_0.code.strip()), """\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(True)
|
||||
sub = torch.ops.aten.sub.Tensor(arg1_1, 1); arg1_1 = None
|
||||
add = torch.ops.aten.add.Tensor(sub, arg0_1); sub = arg0_1 = None
|
||||
add_1 = torch.ops.aten.add.Tensor(add, arg2_1); add = arg2_1 = None
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(False)
|
||||
return (add_1,)""")
|
||||
|
||||
def test_non_persistent_buffer(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -3040,9 +3040,6 @@ def forward(self, arg0_1, arg1_1):
|
|||
gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True)
|
||||
self.assertExpectedInline(str(gm.code).strip(), """\
|
||||
def forward(self, arg0_1):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False)
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(False)
|
||||
_set_grad_enabled_2 = torch._C._set_grad_enabled(False)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
|
|
@ -3096,9 +3093,6 @@ def forward(self, arg0_1):
|
|||
gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True)
|
||||
self.assertExpectedInline(str(gm.code).strip(), """\
|
||||
def forward(self, arg0_1):
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(False)
|
||||
_set_grad_enabled_1 = torch._C._set_grad_enabled(False)
|
||||
_set_grad_enabled_2 = torch._C._set_grad_enabled(False)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from torch._higher_order_ops.utils import (
|
|||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
_temp_remove_pre_dispatch_torch_function_mode,
|
||||
disable_proxy_modes_tracing,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
|
|
@ -133,9 +134,10 @@ def cond(pred, true_fn, false_fn, operands):
|
|||
|
||||
with _set_compilation_env():
|
||||
with torch._dynamo.utils.disable_cache_limit():
|
||||
return torch.compile(cond_op, backend="eager", fullgraph=True)(
|
||||
pred, true_fn, false_fn, operands
|
||||
)
|
||||
with _temp_remove_pre_dispatch_torch_function_mode():
|
||||
return torch.compile(cond_op, backend="eager", fullgraph=True)(
|
||||
pred, true_fn, false_fn, operands
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -532,6 +532,39 @@ class PythonKeyTracer(Tracer):
|
|||
return e
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _temp_remove_pre_dispatch_torch_function_mode():
|
||||
from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
|
||||
temp_elements = []
|
||||
pre_dispatch_mode = None
|
||||
|
||||
while _len_torch_function_stack() > 0:
|
||||
mode = _pop_mode()
|
||||
if isinstance(mode, PreDispatchTorchFunctionMode):
|
||||
pre_dispatch_mode = mode
|
||||
break
|
||||
else:
|
||||
temp_elements.append(mode)
|
||||
|
||||
for mode in reversed(temp_elements):
|
||||
_push_mode(mode)
|
||||
|
||||
try:
|
||||
yield
|
||||
|
||||
finally:
|
||||
if pre_dispatch_mode is not None:
|
||||
count = len(temp_elements)
|
||||
while count > 0:
|
||||
mode = _pop_mode()
|
||||
count -= 1
|
||||
|
||||
temp_elements.append(pre_dispatch_mode)
|
||||
|
||||
for mode in reversed(temp_elements):
|
||||
_push_mode(mode)
|
||||
|
||||
|
||||
@torch._disable_dynamo
|
||||
def dispatch_trace(
|
||||
root: Union[torch.nn.Module, Callable],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user