mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[export] Exempt autograd ops for predispatch export (#116527)
Summary: We intend to preserve autograd ops for predispatch export. Therefore, we need to exempt the autograd ops in some places, e.g. verifier and proxy_tensor.py. Test Plan: python test/export/test_export.py -k test_predispatch_export_with_autograd_op Pull Request resolved: https://github.com/pytorch/pytorch/pull/116527 Approved by: https://github.com/tugsbayasgalan ghstack dependencies: #116339
This commit is contained in:
parent
9431798521
commit
af2ded23eb
|
|
@ -1583,6 +1583,18 @@ def forward(self, arg_0):
|
|||
):
|
||||
_ = Constraint()
|
||||
|
||||
def test_predispatch_export_with_autograd_op(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
with torch.enable_grad():
|
||||
return x + x
|
||||
|
||||
with torch.no_grad():
|
||||
ep = _export(Foo(), (torch.ones(10),), pre_dispatch=True)
|
||||
|
||||
def test_train_eval_on_exported_preautograd_module(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -61,6 +61,43 @@ class TestSafeguard(TestCase):
|
|||
):
|
||||
export(f3, (a,))
|
||||
|
||||
def test_global_autograd_exempt_predispatch(self):
|
||||
def f1(a):
|
||||
with torch.no_grad():
|
||||
b = a + a
|
||||
return b
|
||||
|
||||
def f2(a):
|
||||
with torch.enable_grad():
|
||||
b = a + a
|
||||
return b
|
||||
|
||||
def f3(a):
|
||||
with torch.set_grad_enabled(False):
|
||||
b = a + a
|
||||
return b
|
||||
|
||||
def f4(a):
|
||||
with torch.set_grad_enabled(True):
|
||||
b = a + a
|
||||
return b
|
||||
|
||||
a = torch.randn(10)
|
||||
|
||||
from torch.export._trace import _export
|
||||
|
||||
with torch.no_grad():
|
||||
_export(f1, (a,), pre_dispatch=True)
|
||||
_export(f2, (a,), pre_dispatch=True)
|
||||
_export(f3, (a,), pre_dispatch=True)
|
||||
_export(f4, (a,), pre_dispatch=True)
|
||||
|
||||
with torch.enable_grad():
|
||||
_export(f1, (a,), pre_dispatch=True)
|
||||
_export(f2, (a,), pre_dispatch=True)
|
||||
_export(f3, (a,), pre_dispatch=True)
|
||||
_export(f4, (a,), pre_dispatch=True)
|
||||
|
||||
def test_tensor_autograd(self):
|
||||
# dynamo errors when Tensor.requires_grad_ change the autograd state
|
||||
def f1(a):
|
||||
|
|
|
|||
|
|
@ -173,6 +173,8 @@ class Verifier(metaclass=_VerifierMeta):
|
|||
torch.sym_min,
|
||||
torch.sym_not,
|
||||
torch.sym_sqrt,
|
||||
# Predispatch export is able to contain autograd ops.
|
||||
torch._C._set_grad_enabled
|
||||
)
|
||||
|
||||
if not isinstance(op, _allowed_op_types()):
|
||||
|
|
|
|||
|
|
@ -569,7 +569,13 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
|||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if func in _side_effectful_need_to_be_preserved_pre_dispatch:
|
||||
return self.tracer.create_node("call_function", func, args, {})
|
||||
# It's for passing the export verifier which needs to verify the meta['val']
|
||||
# TODO(chundian): we should systematically couple it with expoert verifier,
|
||||
# instead of hardcoding it here.
|
||||
node = self.tracer.create_node("call_function", func, args, {})
|
||||
if func is torch._C._set_grad_enabled:
|
||||
node.meta['val'] = None
|
||||
return node
|
||||
# Don't actually run the function! We just want to trace the calls
|
||||
# into a graph. We don't actualy want to change global autograd state.
|
||||
return func(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user