mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[RELAND][export] Exempt autograd ops for predispatch export (#117448)
Summary: Reland of https://github.com/pytorch/pytorch/pull/116527/files Test Plan: CI Differential Revision: D52675324 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117448 Approved by: https://github.com/ydwu4
This commit is contained in:
parent
99e54744f7
commit
28be47c267
|
|
@ -2,6 +2,7 @@
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import io
|
||||||
import unittest
|
import unittest
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,43 @@ class TestSafeguard(TestCase):
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
export(f3, (torch.randn(10, requires_grad=False),))
|
export(f3, (torch.randn(10, requires_grad=False),))
|
||||||
|
|
||||||
|
def test_global_autograd_exempt_predispatch(self):
|
||||||
|
def f1(a):
|
||||||
|
with torch.no_grad():
|
||||||
|
b = a + a
|
||||||
|
return b
|
||||||
|
|
||||||
|
def f2(a):
|
||||||
|
with torch.enable_grad():
|
||||||
|
b = a + a
|
||||||
|
return b
|
||||||
|
|
||||||
|
def f3(a):
|
||||||
|
with torch.set_grad_enabled(False):
|
||||||
|
b = a + a
|
||||||
|
return b
|
||||||
|
|
||||||
|
def f4(a):
|
||||||
|
with torch.set_grad_enabled(True):
|
||||||
|
b = a + a
|
||||||
|
return b
|
||||||
|
|
||||||
|
a = torch.randn(10)
|
||||||
|
|
||||||
|
from torch.export._trace import _export
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
_export(f1, (a,), pre_dispatch=True)
|
||||||
|
_export(f2, (a,), pre_dispatch=True)
|
||||||
|
_export(f3, (a,), pre_dispatch=True)
|
||||||
|
_export(f4, (a,), pre_dispatch=True)
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
_export(f1, (a,), pre_dispatch=True)
|
||||||
|
_export(f2, (a,), pre_dispatch=True)
|
||||||
|
_export(f3, (a,), pre_dispatch=True)
|
||||||
|
_export(f4, (a,), pre_dispatch=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,22 @@ def get_filtered_export_db_tests():
|
||||||
|
|
||||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
||||||
class TestSerialize(TestCase):
|
class TestSerialize(TestCase):
|
||||||
|
def test_predispatch_export_with_autograd_op(self):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
with torch.enable_grad():
|
||||||
|
return x + x
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
from torch.export._trace import _export
|
||||||
|
ep = _export(Foo(), (torch.ones(10),), pre_dispatch=True)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(SerializeError, "Failed serializing node _set_grad_enabled"):
|
||||||
|
torch.export.save(ep, io.BytesIO())
|
||||||
|
|
||||||
def test_serialize_multiple_returns_from_node(self) -> None:
|
def test_serialize_multiple_returns_from_node(self) -> None:
|
||||||
class MyModule(torch.nn.Module):
|
class MyModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
||||||
|
|
@ -159,6 +159,11 @@ class Verifier(metaclass=_VerifierMeta):
|
||||||
torch.sym_min,
|
torch.sym_min,
|
||||||
torch.sym_not,
|
torch.sym_not,
|
||||||
torch.sym_sqrt,
|
torch.sym_sqrt,
|
||||||
|
# TODO (tmanlaibaatar)
|
||||||
|
# Predispatch export is able to contain autograd ops.
|
||||||
|
# These will be modeled as HOO later
|
||||||
|
torch._C._set_grad_enabled
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(op, _allowed_op_types()):
|
if not isinstance(op, _allowed_op_types()):
|
||||||
|
|
|
||||||
|
|
@ -683,7 +683,13 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
||||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
if func in _side_effectful_need_to_be_preserved_pre_dispatch:
|
if func in _side_effectful_need_to_be_preserved_pre_dispatch:
|
||||||
return self.tracer.create_node("call_function", func, args, {})
|
# It's for passing the export verifier which needs to verify the meta['val']
|
||||||
|
# TODO(tmanlaibaatar): we should systematically couple it with expoert verifier,
|
||||||
|
# instead of hardcoding it here.
|
||||||
|
node = self.tracer.create_node("call_function", func, args, {})
|
||||||
|
if func is torch._C._set_grad_enabled:
|
||||||
|
node.meta['val'] = None
|
||||||
|
return node
|
||||||
# Don't actually run the function! We just want to trace the calls
|
# Don't actually run the function! We just want to trace the calls
|
||||||
# into a graph. We don't actualy want to change global autograd state.
|
# into a graph. We don't actualy want to change global autograd state.
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user