mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make torch._dynamo lazy-importable (#104368)
Use [PEP-562](https://peps.python.org/pep-0562) to import `_dynamo` and `_inductor` only when needed. - Remove redundant imports from tests - Add `test_lazy_imports_are_lazy` to make sure they will not get imported by accident <!-- copilot:poem --> ### <samp>🤖 Generated by Copilot at bae8e90</samp> > _Sing, O Muse, of the daring deeds of PyTorch, the swift and fiery_ > _framework of deep learning, that with skill and cunning wrought_ > _many wonders of dynamic compilation, using the hidden powers_ > _of `_dynamo` and `_inductor`, the secret modules of LLVM and MLIR._ Pull Request resolved: https://github.com/pytorch/pytorch/pull/104368 Approved by: https://github.com/msaroufim, https://github.com/albanD
This commit is contained in:
parent
d0a72ec5e4
commit
fea683491e
|
|
@ -6,7 +6,6 @@ import tempfile
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from torch._dynamo.testing import CompileCounter
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import sys
|
|||
import unittest
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch._dynamo
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_MACOS,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from unittest.mock import patch
|
|||
import numpy as np
|
||||
import sympy
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from torch._C import FileCheck
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._dynamo.utils import same
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import sys
|
|||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.config as dynamo_config
|
||||
from torch import nn
|
||||
from torch._dynamo.debug_utils import same_two_models
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import warnings
|
|||
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch._dynamo.config as dynamo_config
|
||||
import torch.nn as nn
|
||||
from torch._inductor import config
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import unittest
|
|||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.config as dynamo_config
|
||||
import torch._inductor.config as inductor_config
|
||||
import torch._inductor.utils
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from unittest.mock import patch
|
|||
|
||||
import functorch
|
||||
|
||||
import torch._dynamo
|
||||
import torch
|
||||
import torch._inductor.config as config
|
||||
from torch._dynamo.backends.registry import register_backend
|
||||
from torch._inductor import metrics
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ import numpy as np
|
|||
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
import torch._dynamo.config as dynamo_config
|
||||
import torch.nn as nn
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from unittest.mock import patch
|
|||
|
||||
import torch
|
||||
|
||||
import torch._dynamo
|
||||
from torch._dynamo.test_case import run_tests
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
|
|
|
|||
|
|
@ -2139,6 +2139,15 @@ instantiate_device_type_tests(TestTestParametrizationDeviceType, globals())
|
|||
|
||||
|
||||
class TestImports(TestCase):
|
||||
@classmethod
|
||||
def _check_python_output(cls, program) -> str:
|
||||
return subprocess.check_output(
|
||||
[sys.executable, "-W", "all", "-c", program],
|
||||
stderr=subprocess.STDOUT,
|
||||
# On Windows, opening the subprocess with the default CWD makes `import torch`
|
||||
# fail, so just set CWD to this script's directory
|
||||
cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
|
||||
|
||||
def test_circular_dependencies(self) -> None:
|
||||
""" Checks that all modules inside torch can be imported
|
||||
Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """
|
||||
|
|
@ -2187,14 +2196,14 @@ class TestImports(TestCase):
|
|||
raise RuntimeError(f"Failed to import {mod_name}: {e}") from e
|
||||
self.assertTrue(inspect.ismodule(mod))
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "TODO enable on Windows")
|
||||
def test_lazy_imports_are_lazy(self) -> None:
|
||||
out = self._check_python_output("import sys;import torch;print(all(x not in sys.modules for x in torch._lazy_modules))")
|
||||
self.assertEqual(out.strip(), "True")
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning")
|
||||
def test_no_warning_on_import(self) -> None:
|
||||
out = subprocess.check_output(
|
||||
[sys.executable, "-W", "all", "-c", "import torch"],
|
||||
stderr=subprocess.STDOUT,
|
||||
# On Windows, opening the subprocess with the default CWD makes `import torch`
|
||||
# fail, so just set CWD to this script's directory
|
||||
cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8")
|
||||
out = self._check_python_output("import torch")
|
||||
self.assertEqual(out, "")
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning")
|
||||
|
|
@ -2212,10 +2221,7 @@ class TestImports(TestCase):
|
|||
'logging.root.setLevel(logging.INFO)',
|
||||
f'_logger.info("{expected}")'
|
||||
]
|
||||
out = subprocess.check_output(
|
||||
[sys.executable, "-W", "all", "-c", "; ".join(commands)],
|
||||
stderr=subprocess.STDOUT,
|
||||
).decode("utf-8")
|
||||
out = self._check_python_output("; ".join(commands))
|
||||
self.assertEqual(out.strip(), expected)
|
||||
|
||||
class TestOpInfos(TestCase):
|
||||
|
|
|
|||
|
|
@ -1652,7 +1652,6 @@ def compile(model: Optional[Callable] = None, *,
|
|||
disable=disable)
|
||||
return fn
|
||||
|
||||
import torch._dynamo
|
||||
if mode is not None and options is not None:
|
||||
raise RuntimeError("Either mode or options can be specified, but both can't be specified at the same time.")
|
||||
if mode is None and options is None:
|
||||
|
|
@ -1743,12 +1742,26 @@ _deprecated_attrs = {
|
|||
"has_cudnn": torch.backends.cudnn.is_available,
|
||||
"has_mkldnn": torch.backends.mkldnn.is_available,
|
||||
}
|
||||
|
||||
_lazy_modules = {
|
||||
"_dynamo",
|
||||
"_inductor",
|
||||
}
|
||||
|
||||
def __getattr__(name):
|
||||
# Deprecated attrs
|
||||
replacement = _deprecated_attrs.get(name)
|
||||
if replacement is not None:
|
||||
import warnings
|
||||
warnings.warn(f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'", stacklevel=2)
|
||||
return replacement()
|
||||
|
||||
# Lazy modules
|
||||
if name in _lazy_modules:
|
||||
import importlib
|
||||
return importlib.import_module(f".{name}", __name__)
|
||||
|
||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||
|
||||
from . import _logging
|
||||
_logging._init_logs()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user