[package] typing.io/re edge case hack (#60666)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60666

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D29367847

Pulled By: Lilyjjo

fbshipit-source-id: 2c38140fbb3eab61ae3de60ab475243f0338c547
This commit is contained in:
Lily Johnson 2021-06-24 14:52:06 -07:00 committed by Facebook GitHub Bot
parent 375d201086
commit fe4ded01f7
3 changed files with 15 additions and 6 deletions

View File

@ -1,5 +1,8 @@
import os # noqa: F401
import os.path # noqa: F401
import typing # noqa: F401
import typing.io # noqa: F401
import typing.re # noqa: F401
import torch

View File

@ -197,14 +197,16 @@ class TestMisc(PackageTestCase):
self.assertTrue(imported_mod.is_from_package())
self.assertFalse(mod.is_from_package())
def test_os_path_edge_case(self):
def test_std_lib_sys_hackery_checks(self):
"""
Both 'os' and 'os.path' should be able to be imported into a package.
The standard library performs sys.module assignment hackery which
causes modules who do this hackery to fail on import. See
https://github.com/pytorch/pytorch/issues/57490 for more information.
"""
import package_a.os_imports
import package_a.std_sys_module_hacks
buffer = BytesIO()
mod = package_a.os_imports.Module()
mod = package_a.std_sys_module_hacks.Module()
with PackageExporter(buffer, verbose=False) as pe:
pe.intern("**")

View File

@ -402,10 +402,14 @@ class PackageImporter(Importer):
message = "import of {} halted; " "None in sys.modules".format(name)
raise ModuleNotFoundError(message, name=name)
# To handle https://github.com/pytorch/pytorch/issues/57490, where os's
# creation of os.path via the hacking of sys.modules is not import friendly
# To handle https://github.com/pytorch/pytorch/issues/57490, where std's
# creation of fake submodules via the hacking of sys.modules is not import
# friendly
if name == "os":
self.modules["os.path"] = cast(Any, module).path
elif name == "typing":
self.modules["typing.io"] = cast(Any, module).io
self.modules["typing.re"] = cast(Any, module).re
return module