diff --git a/test/package/package_a/os_imports.py b/test/package/package_a/std_sys_module_hacks.py similarity index 70% rename from test/package/package_a/os_imports.py rename to test/package/package_a/std_sys_module_hacks.py index 2d88ecc95bf..fa8df64f20d 100644 --- a/test/package/package_a/os_imports.py +++ b/test/package/package_a/std_sys_module_hacks.py @@ -1,5 +1,8 @@ import os # noqa: F401 import os.path # noqa: F401 +import typing # noqa: F401 +import typing.io # noqa: F401 +import typing.re # noqa: F401 import torch diff --git a/test/package/test_misc.py b/test/package/test_misc.py index 636b97d642f..76ff2e0bd96 100644 --- a/test/package/test_misc.py +++ b/test/package/test_misc.py @@ -197,14 +197,16 @@ class TestMisc(PackageTestCase): self.assertTrue(imported_mod.is_from_package()) self.assertFalse(mod.is_from_package()) - def test_os_path_edge_case(self): + def test_std_lib_sys_hackery_checks(self): """ - Both 'os' and 'os.path' should be able to be imported into a package. + The standard library performs sys.module assignment hackery which + causes modules who do this hackery to fail on import. See + https://github.com/pytorch/pytorch/issues/57490 for more information. """ - import package_a.os_imports + import package_a.std_sys_module_hacks buffer = BytesIO() - mod = package_a.os_imports.Module() + mod = package_a.std_sys_module_hacks.Module() with PackageExporter(buffer, verbose=False) as pe: pe.intern("**") diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index f256c085d6b..d40601709bd 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -402,10 +402,14 @@ class PackageImporter(Importer): message = "import of {} halted; " "None in sys.modules".format(name) raise ModuleNotFoundError(message, name=name) - # To handle https://github.com/pytorch/pytorch/issues/57490, where os's - # creation of os.path via the hacking of sys.modules is not import friendly + # To handle https://github.com/pytorch/pytorch/issues/57490, where std's + # creation of fake submodules via the hacking of sys.modules is not import + # friendly if name == "os": self.modules["os.path"] = cast(Any, module).path + elif name == "typing": + self.modules["typing.io"] = cast(Any, module).io + self.modules["typing.re"] = cast(Any, module).re return module