[RELAND][export] Exempt autograd ops for predispatch export (#117448)

Summary: Reland of https://github.com/pytorch/pytorch/pull/116527/files

Test Plan: CI

Differential Revision: D52675324

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117448
Approved by: https://github.com/ydwu4
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar 2024-01-16 19:32:15 +00:00 committed by PyTorch MergeBot
parent 99e54744f7
commit 28be47c267
5 changed files with 66 additions and 1 deletions

View File

@ -2,6 +2,7 @@
# flake8: noqa # flake8: noqa
import copy import copy
import dataclasses import dataclasses
import io
import unittest import unittest
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass

View File

@ -82,6 +82,43 @@ class TestSafeguard(TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
export(f3, (torch.randn(10, requires_grad=False),)) export(f3, (torch.randn(10, requires_grad=False),))
def test_global_autograd_exempt_predispatch(self):
def f1(a):
with torch.no_grad():
b = a + a
return b
def f2(a):
with torch.enable_grad():
b = a + a
return b
def f3(a):
with torch.set_grad_enabled(False):
b = a + a
return b
def f4(a):
with torch.set_grad_enabled(True):
b = a + a
return b
a = torch.randn(10)
from torch.export._trace import _export
with torch.no_grad():
_export(f1, (a,), pre_dispatch=True)
_export(f2, (a,), pre_dispatch=True)
_export(f3, (a,), pre_dispatch=True)
_export(f4, (a,), pre_dispatch=True)
with torch.enable_grad():
_export(f1, (a,), pre_dispatch=True)
_export(f2, (a,), pre_dispatch=True)
_export(f3, (a,), pre_dispatch=True)
_export(f4, (a,), pre_dispatch=True)
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -49,6 +49,22 @@ def get_filtered_export_db_tests():
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestSerialize(TestCase): class TestSerialize(TestCase):
def test_predispatch_export_with_autograd_op(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
with torch.enable_grad():
return x + x
with torch.no_grad():
from torch.export._trace import _export
ep = _export(Foo(), (torch.ones(10),), pre_dispatch=True)
with self.assertRaisesRegex(SerializeError, "Failed serializing node _set_grad_enabled"):
torch.export.save(ep, io.BytesIO())
def test_serialize_multiple_returns_from_node(self) -> None: def test_serialize_multiple_returns_from_node(self) -> None:
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -159,6 +159,11 @@ class Verifier(metaclass=_VerifierMeta):
torch.sym_min, torch.sym_min,
torch.sym_not, torch.sym_not,
torch.sym_sqrt, torch.sym_sqrt,
# TODO (tmanlaibaatar)
# Predispatch export is able to contain autograd ops.
# These will be modeled as HOO later
torch._C._set_grad_enabled
) )
if not isinstance(op, _allowed_op_types()): if not isinstance(op, _allowed_op_types()):

View File

@ -683,7 +683,13 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None): def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {} kwargs = kwargs or {}
if func in _side_effectful_need_to_be_preserved_pre_dispatch: if func in _side_effectful_need_to_be_preserved_pre_dispatch:
return self.tracer.create_node("call_function", func, args, {}) # It's for passing the export verifier which needs to verify the meta['val']
# TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
# instead of hardcoding it here.
node = self.tracer.create_node("call_function", func, args, {})
if func is torch._C._set_grad_enabled:
node.meta['val'] = None
return node
# Don't actually run the function! We just want to trace the calls # Don't actually run the function! We just want to trace the calls
# into a graph. We don't actualy want to change global autograd state. # into a graph. We don't actualy want to change global autograd state.
return func(*args, **kwargs) return func(*args, **kwargs)