# Owner(s): ["oncall: export"] # flake8: noqa import unittest import torch import torch._dynamo from torch._dynamo.test_case import run_tests, TestCase from torch._export.wrappers import _mark_strict_experimental from torch._functorch.aot_autograd import aot_export_module from torch.testing import FileCheck @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported") class TestExperiment(TestCase): def test_with_buffer_as_submodule(self): @_mark_strict_experimental class B(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("buffer1", torch.ones(3)) def forward(self, x): y = x + 2 y.add_(4) # this doesnt' work today with HOO # self.buffer1.add_(6) buffer_updated = self.buffer1 + 6 return x.sum() + y.sum() + buffer_updated.sum() class M(torch.nn.Module): def __init__(self): super().__init__() self.submodule = B() def forward(self, x): x_v2 = x.sin() return (self.submodule(x_v2), x + 3) inp = torch.randn(3) ep = torch.export.export(M(), (inp,), strict=False) self.assertExpectedInline( str(ep.graph_module.code.strip()), """\ def forward(self, b_submodule_buffer1, x): sin = torch.ops.aten.sin.default(x) strict_graph_0 = self.strict_graph_0 strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None getitem_2 = strict_mode[0]; strict_mode = None add = torch.ops.aten.add.Tensor(x, 3); x = None return (getitem_2, add)""", ) self.assertExpectedInline( str(ep.graph_module.strict_graph_0.code.strip()), """\ def forward(self, arg0_1, arg1_1): add = torch.ops.aten.add.Tensor(arg0_1, 2) add_1 = torch.ops.aten.add.Tensor(add, 4); add = None add_2 = torch.ops.aten.add.Tensor(arg1_1, 6); arg1_1 = None sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None sum_2 = torch.ops.aten.sum.default(add_1); add_1 = None add_3 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None sum_3 = torch.ops.aten.sum.default(add_2); add_2 = None add_4 = torch.ops.aten.add.Tensor(add_3, sum_3); add_3 = sum_3 = None return (add_4,)""", ) eager_mod = M() ep = torch.export.export(eager_mod, (inp,), strict=True) graph_res_1, graph_res_2 = ep.module()(inp) eager_res_1, eager_res_2 = eager_mod(inp) self.assertTrue(torch.allclose(graph_res_2, eager_res_2)) self.assertTrue(torch.allclose(graph_res_1, eager_res_1)) graph_res_1, graph_res_2 = ep.module()(inp) eager_res_1, eager_res_2 = eager_mod(inp) self.assertTrue(torch.allclose(graph_res_2, eager_res_2)) self.assertTrue(torch.allclose(graph_res_1, eager_res_1)) def test_mark_strict_with_container_type(self): @_mark_strict_experimental class B(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): x0 = x[0][0] return x0.sum() class M(torch.nn.Module): def __init__(self): super().__init__() self.submodule = B() def forward(self, x): return self.submodule(x) inp = ((torch.randn(3),),) with self.assertRaisesRegex( RuntimeError, "strict_mode HOO doesn't work unless" ): ep = torch.export.export(M(), inp, strict=False) if __name__ == "__main__": run_tests()