# Owner(s): ["oncall: export"] import unittest import torch from torch._dynamo.eval_frame import is_dynamo_supported from torch.export import Dim from torch.export.experimental import _ExportPackage from torch.testing._internal.common_utils import run_tests, TestCase @unittest.skipIf(not is_dynamo_supported(), "dynamo isn't supported") class TestPackage(TestCase): def test_basic(self): def fn(x: torch.Tensor) -> torch.Tensor: return x + 1 x = torch.randn(3, 2) package = _ExportPackage() self.assertEqual( package._exporter("fn", fn)(x), fn(x), ) self.assertEqual(len(package.methods), 1) self.assertEqual(len(package.methods["fn"].fallbacks), 1) self.assertEqual(len(package.methods["fn"].overloads), 0) def test_more_than_once(self): def fn(x: torch.Tensor) -> torch.Tensor: return x + 1 x = torch.randn(3, 2) package = _ExportPackage() exporter = package._exporter("fn", fn) exporter(x) with self.assertRaisesRegex( RuntimeError, "Cannot export .* more than once", ): exporter(x) def test_error(self): def fn(x: torch.Tensor) -> torch.Tensor: return x + 1 x = torch.randn(3, 2) package = _ExportPackage() exporter = package._exporter("fn", fn, fallback="error") with self.assertRaisesRegex( RuntimeError, "Cannot export fallback .* when fallback policy is set to 'error'", ): exporter(x) def test_overloads(self): class Module(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: if x.shape[0] == 4: return x + 1 elif x.shape[0] == 3: return x - 1 else: return x + 2 fn = Module() x = torch.randn(3, 2) x2 = torch.randn(4, 2) x3 = torch.randn(5, 2) def spec(self, x): assert x.shape[0] == 3 def spec2(self, x): assert x.shape[0] == 4 def spec3(self, x): assert x.shape[0] >= 5 return {"x": (Dim("batch", min=5), Dim.STATIC)} package = _ExportPackage() exporter = ( package._exporter("fn", fn) ._define_overload("spec", spec) ._define_overload("spec2", spec2) ._define_overload("spec3", spec3) ) self.assertEqual(exporter(x), x - 1) self.assertEqual(exporter(x2), x2 + 1) self.assertEqual(exporter(x3), x3 + 2) self.assertEqual(len(package.methods), 1) self.assertEqual(len(package.methods["fn"].overloads), 3) if __name__ == "__main__": run_tests()