Don't record autograd state ops while torch.compile in pre-dispatch export (#121736)

Summary: Refer to OSS PR for details

Test Plan: CI

Differential Revision: D54812833

In pre-dispatch export, we have a special proxy torch mode where we intercept torch._C._set_grad_enabled op to correctly capture user's intention on train/eval. However, this is bit problematic when we are tracing torch.cond during export as it calls torch.compile internally. As a result, we end up capturing unwanted autograd context manager  calls that are happening inside dynamo framework code because the top level tracer is still active. We fix it by turning off this proxy torch mode. We can still capture autograd ops inside cond branches because dynamo will translate them into HOP for us, so we don't have to intercept with special proxy mode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121736
Approved by: https://github.com/anijain2305, https://github.com/ydwu4
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar 2024-03-14 23:06:10 +00:00 committed by PyTorch MergeBot
parent bd7beef529
commit 57b20c51b9
4 changed files with 84 additions and 9 deletions

View File

@ -3399,6 +3399,52 @@ def forward(self, arg0_1, arg1_1, arg2_1):
# this doesn't work today
gm_unflat_strict = unflatten(ep)
def test_predispatch_cond(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("pred", torch.tensor(False))
self.register_buffer("t", torch.tensor(10))
def forward(self, x, y):
def true_fn(x, y):
with torch.enable_grad():
return x - 1 + self.t + y
return torch.cond(
self.pred,
true_fn,
lambda x, y: x + 1 - self.t + y,
[x, y],
)
model = Model()
exported_program = torch.export._trace._export(
model,
(torch.tensor(10), torch.tensor(12)),
{},
dynamic_shapes=None,
pre_dispatch=True,
strict=False
)
self.assertExpectedInline(str(exported_program.graph_module.code.strip()), """\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(arg0_1, true_graph_0, false_graph_0, [arg1_1, arg2_1, arg3_1]); arg0_1 = true_graph_0 = false_graph_0 = arg1_1 = arg2_1 = arg3_1 = None
getitem = conditional[0]; conditional = None
return (getitem,)""") # noqa: B950
self.assertExpectedInline(str(exported_program.graph_module.true_graph_0.code.strip()), """\
def forward(self, arg0_1, arg1_1, arg2_1):
_set_grad_enabled = torch._C._set_grad_enabled(True)
sub = torch.ops.aten.sub.Tensor(arg1_1, 1); arg1_1 = None
add = torch.ops.aten.add.Tensor(sub, arg0_1); sub = arg0_1 = None
add_1 = torch.ops.aten.add.Tensor(add, arg2_1); add = arg2_1 = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(False)
return (add_1,)""")
def test_non_persistent_buffer(self):
class MyModule(torch.nn.Module):
def __init__(self):

View File

@ -3040,9 +3040,6 @@ def forward(self, arg0_1, arg1_1):
gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True)
self.assertExpectedInline(str(gm.code).strip(), """\
def forward(self, arg0_1):
_set_grad_enabled = torch._C._set_grad_enabled(False)
_set_grad_enabled_1 = torch._C._set_grad_enabled(False)
_set_grad_enabled_2 = torch._C._set_grad_enabled(False)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None
@ -3096,9 +3093,6 @@ def forward(self, arg0_1):
gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True)
self.assertExpectedInline(str(gm.code).strip(), """\
def forward(self, arg0_1):
_set_grad_enabled = torch._C._set_grad_enabled(False)
_set_grad_enabled_1 = torch._C._set_grad_enabled(False)
_set_grad_enabled_2 = torch._C._set_grad_enabled(False)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None

View File

@ -24,6 +24,7 @@ from torch._higher_order_ops.utils import (
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
_temp_remove_pre_dispatch_torch_function_mode,
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
@ -133,9 +134,10 @@ def cond(pred, true_fn, false_fn, operands):
with _set_compilation_env():
with torch._dynamo.utils.disable_cache_limit():
return torch.compile(cond_op, backend="eager", fullgraph=True)(
pred, true_fn, false_fn, operands
)
with _temp_remove_pre_dispatch_torch_function_mode():
return torch.compile(cond_op, backend="eager", fullgraph=True)(
pred, true_fn, false_fn, operands
)
"""

View File

@ -532,6 +532,39 @@ class PythonKeyTracer(Tracer):
return e
@contextmanager
def _temp_remove_pre_dispatch_torch_function_mode():
from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
temp_elements = []
pre_dispatch_mode = None
while _len_torch_function_stack() > 0:
mode = _pop_mode()
if isinstance(mode, PreDispatchTorchFunctionMode):
pre_dispatch_mode = mode
break
else:
temp_elements.append(mode)
for mode in reversed(temp_elements):
_push_mode(mode)
try:
yield
finally:
if pre_dispatch_mode is not None:
count = len(temp_elements)
while count > 0:
mode = _pop_mode()
count -= 1
temp_elements.append(pre_dispatch_mode)
for mode in reversed(temp_elements):
_push_mode(mode)
@torch._disable_dynamo
def dispatch_trace(
root: Union[torch.nn.Module, Callable],