mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: X-link: https://github.com/pytorch/executorch/pull/1817 Basic support for non-persistent buffers, which are buffers that do not show up in the state dict. One weird twist is that most of our other systems (FX, aot_export, dynamo) have completely buggy handling of non-persistent buffers. I tried to go on a wild goose chase to fix them all, but it got to be too much. So I introduced some sad rewrite passes in `_export` make the final state dict correctly align with the original module's state dict. This exposed some bugs/ambiguous handling of parameters/buffers in existing test code. For example, `TestSaveLoad.test_save_buffer` traced over a module that was not in the root module hierarchy and caused some weird behavior. I think we should error explicitly on use cases like this: https://github.com/pytorch/pytorch/issues/118410. For now I just rewrote the tests or skipped them. As a side effect, this diff tightened up quite a few sloppy behaviors around state dict handling: - Tensor attributes were getting promoted to be buffers—bad! - Tracing through a module not in the children of the root module would add its parameters/buffers to the state dict—bad! This behavior is unlikely to show up in user code since the model would be totally broken, but did show up in a bunch of tests. #buildmore Test Plan: unit tests sandcastle Differential Revision: D53340041 Pull Request resolved: https://github.com/pytorch/pytorch/pull/118969 Approved by: https://github.com/guangy10, https://github.com/huydhn, https://github.com/titaiwangms
158 lines
4.4 KiB
Python
158 lines
4.4 KiB
Python
# Owner(s): ["oncall: export"]
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
from torch.export import export
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
|
class TestSafeguard(TestCase):
|
|
# If the autograd state doesn't change, dynamo eliminates autograd state manager op and later export can succeed.
|
|
# Otherwise, autograd can be preserved in the produced gragh, and export will fail.
|
|
def test_global_autograd(self):
|
|
class F1(torch.nn.Module):
|
|
def forward(self, a):
|
|
with torch.no_grad():
|
|
b = a + a
|
|
return b
|
|
|
|
f1 = F1()
|
|
|
|
class F2(torch.nn.Module):
|
|
def forward(self, a):
|
|
with torch.enable_grad():
|
|
b = a + a
|
|
return b
|
|
|
|
f2 = F2()
|
|
|
|
class F3(torch.nn.Module):
|
|
def forward(self, a):
|
|
with torch.set_grad_enabled(False):
|
|
b = a + a
|
|
return b
|
|
|
|
f3 = F3()
|
|
|
|
class F4(torch.nn.Module):
|
|
def forward(self, a):
|
|
with torch.set_grad_enabled(True):
|
|
b = a + a
|
|
return b
|
|
|
|
f4 = F4()
|
|
|
|
a = torch.randn(10)
|
|
with torch.no_grad():
|
|
export(f1, (a,))
|
|
export(f2, (a,))
|
|
export(f3, (a,))
|
|
export(f4, (a,))
|
|
|
|
with torch.enable_grad():
|
|
export(f2, (a,))
|
|
export(f4, (a,))
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Encountered autograd state manager op.*"
|
|
):
|
|
export(f1, (a,))
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Encountered autograd state manager op.*"
|
|
):
|
|
export(f3, (a,))
|
|
|
|
def test_tensor_autograd(self):
|
|
# dynamo errors when Tensor.requires_grad_ change the autograd state
|
|
class F1(torch.nn.Module):
|
|
def forward(self, a):
|
|
a.requires_grad_(True)
|
|
b = a + a
|
|
return b
|
|
|
|
f1 = F1()
|
|
|
|
# dynamo errors when Tensor.requires_grad_ change the autograd state
|
|
class F2(torch.nn.Module):
|
|
def forward(self, a):
|
|
a.requires_grad_(False)
|
|
b = a + a
|
|
return b
|
|
|
|
f2 = F2()
|
|
|
|
# dynamo always errors on Tensor.requires_grad
|
|
class F3(torch.nn.Module):
|
|
def forward(self, a):
|
|
a.requires_grad = False
|
|
b = a + a
|
|
return b
|
|
|
|
f3 = F3()
|
|
|
|
export(f1, (torch.randn(10, requires_grad=True),))
|
|
export(f2, (torch.randn(10, requires_grad=False),))
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
export(f1, (torch.randn(10, requires_grad=False),))
|
|
with self.assertRaises(RuntimeError):
|
|
export(f2, (torch.randn(10, requires_grad=True),))
|
|
with self.assertRaises(RuntimeError):
|
|
export(f3, (torch.randn(10, requires_grad=False),))
|
|
|
|
def test_global_autograd_exempt_predispatch(self):
|
|
class F1(torch.nn.Module):
|
|
def forward(self, a):
|
|
with torch.no_grad():
|
|
b = a + a
|
|
return b
|
|
|
|
f1 = F1()
|
|
|
|
class F2(torch.nn.Module):
|
|
def forward(self, a):
|
|
with torch.enable_grad():
|
|
b = a + a
|
|
return b
|
|
|
|
f2 = F2()
|
|
|
|
class F3(torch.nn.Module):
|
|
def forward(self, a):
|
|
with torch.set_grad_enabled(False):
|
|
b = a + a
|
|
return b
|
|
|
|
f3 = F3()
|
|
|
|
class F4(torch.nn.Module):
|
|
def forward(self, a):
|
|
with torch.set_grad_enabled(True):
|
|
b = a + a
|
|
return b
|
|
|
|
f4 = F4()
|
|
|
|
a = torch.randn(10)
|
|
|
|
from torch.export._trace import _export
|
|
|
|
with torch.no_grad():
|
|
_export(f1, (a,), pre_dispatch=True)
|
|
_export(f2, (a,), pre_dispatch=True)
|
|
_export(f3, (a,), pre_dispatch=True)
|
|
_export(f4, (a,), pre_dispatch=True)
|
|
|
|
with torch.enable_grad():
|
|
_export(f1, (a,), pre_dispatch=True)
|
|
_export(f2, (a,), pre_dispatch=True)
|
|
_export(f3, (a,), pre_dispatch=True)
|
|
_export(f4, (a,), pre_dispatch=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|