diff --git a/test/export/test_export.py b/test/export/test_export.py index fd7fc4a109e..dc7738e07b2 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1370,7 +1370,7 @@ graph(): self.mod.forward = hacked_up_forward.__get__(self.mod, Foo) def __call__(self, x, y): - ep = torch.export.export(self.mod, (x, y), strict=True).module() + ep = export(self.mod, (x, y), strict=True).module() out = ep(x, y) return out @@ -1379,13 +1379,31 @@ graph(): foo = Foo() ref = ReferenceControl(foo) - with self.assertWarnsRegex( - UserWarning, - "While exporting, we found certain side effects happened in the model.forward. " - "Here are the list of potential sources you can double check: " - "\[\"L\['global_list'\]\", \"L\['self'\].bank\", \"L\['self'\].bank_dict\"", - ): - ref(torch.randn(4, 4), torch.randn(4, 4)) + # TODO (tmanlaibaatar) this kinda sucks but today there is no good way to get + # good source name. We should have an util that post processes dynamo source names + # to be more readable. + if is_strict_v2_test(self._testMethodName): + with self.assertWarnsRegex( + UserWarning, + r"(L\['self']\._export_root\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank" + r"|L\['self']\._export_root\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank_dict" + r"|L\['self']\._export_root\.forward\.__func__\.__closure__\[0\]\.cell_contents)", + ): + ref(torch.randn(4, 4), torch.randn(4, 4)) + elif is_inline_and_install_strict_test(self._testMethodName): + with self.assertWarnsRegex( + UserWarning, + r"(L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank" + r"|L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank_dict" + r"|L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[0\]\.cell_contents)", + ): + ref(torch.randn(4, 4), torch.randn(4, 4)) + else: + with self.assertWarnsRegex( + UserWarning, + r"(L\['global_list'\]|L\['self'\]\.bank|L\['self'\]\.bank_dict)", + ): + ref(torch.randn(4, 4), torch.randn(4, 4)) def test_mask_nonzero_static(self): class TestModule(torch.nn.Module): diff --git a/test/export/test_export_with_inline_and_install.py b/test/export/test_export_with_inline_and_install.py index 2bc6aa3c678..0894a8e6844 100644 --- a/test/export/test_export_with_inline_and_install.py +++ b/test/export/test_export_with_inline_and_install.py @@ -89,10 +89,6 @@ unittest.expectedFailure( unittest.expectedFailure( InlineAndInstallStrictExportTestExport.test_retrace_pre_autograd_inline_and_install_strict # noqa: F821 ) -# this is because detect leak test has export root -unittest.expectedFailure( - InlineAndInstallStrictExportTestExport.test_detect_leak_strict_inline_and_install_strict # noqa: F821 -) if __name__ == "__main__": from torch._dynamo.test_case import run_tests