# Owner(s): ["oncall: export"] import torch from torch._dynamo.test_case import TestCase from torch._export.tools import report_exportability from torch.testing._internal.common_utils import run_tests torch.library.define( "testlib::op_missing_meta", "(Tensor(a!) x, Tensor(b!) z) -> Tensor", tags=torch.Tag.pt2_compliant_tag, ) @torch.library.impl("testlib::op_missing_meta", "cpu") @torch._dynamo.disable def op_missing_meta(x, z): x.add_(5) z.add_(5) return x + z class TestExportTools(TestCase): def test_report_exportability_basic(self): class Module(torch.nn.Module): def forward(self, x, y): return x[0] + y f = Module() inp = ([torch.ones(1, 3)], torch.ones(1, 3)) report = report_exportability(f, inp) self.assertTrue(len(report) == 1) self.assertTrue(report[""] is None) def test_report_exportability_with_issues(self): class Unsupported(torch.nn.Module): def forward(self, x): return torch.ops.testlib.op_missing_meta(x, x.cos()) class Supported(torch.nn.Module): def forward(self, x): return x.sin() class Module(torch.nn.Module): def __init__(self) -> None: super().__init__() self.unsupported = Unsupported() self.supported = Supported() def forward(self, x): y = torch.nonzero(x) return self.unsupported(y) + self.supported(y) f = Module() inp = (torch.ones(4, 4),) report = report_exportability(f, inp, strict=False, pre_dispatch=True) self.assertTrue(report[""] is not None) self.assertTrue(report["unsupported"] is not None) self.assertTrue(report["supported"] is None) if __name__ == "__main__": run_tests()