diff --git a/test/dynamo/test_compile.py b/test/dynamo/test_compile.py index d9b0643e03c..e3847cbb2ae 100644 --- a/test/dynamo/test_compile.py +++ b/test/dynamo/test_compile.py @@ -6,7 +6,6 @@ import tempfile import unittest import torch -import torch._dynamo from torch._dynamo.testing import CompileCounter diff --git a/test/inductor/test_cpp_wrapper.py b/test/inductor/test_cpp_wrapper.py index 7c6c2042f6d..90241998e0a 100644 --- a/test/inductor/test_cpp_wrapper.py +++ b/test/inductor/test_cpp_wrapper.py @@ -3,7 +3,7 @@ import sys import unittest from typing import NamedTuple -import torch._dynamo +import torch from torch._inductor import config from torch.testing._internal.common_utils import ( IS_MACOS, diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 3392b95a2de..b070124f28a 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -9,7 +9,6 @@ from unittest.mock import patch import numpy as np import sympy import torch -import torch._dynamo from torch._C import FileCheck from torch._dynamo.testing import rand_strided from torch._dynamo.utils import same diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index a6ab143c86e..f377018bc56 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -4,7 +4,6 @@ import sys import unittest import torch -import torch._dynamo import torch._dynamo.config as dynamo_config from torch import nn from torch._dynamo.debug_utils import same_two_models diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index bc9f663970e..ba587e9f4e2 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -9,7 +9,6 @@ import warnings import torch -import torch._dynamo import torch._dynamo.config as dynamo_config import torch.nn as nn from torch._inductor import config diff --git a/test/inductor/test_minifier.py b/test/inductor/test_minifier.py index 1f57fd9f407..d307e48e69e 100644 --- a/test/inductor/test_minifier.py +++ b/test/inductor/test_minifier.py @@ -4,7 +4,6 @@ import unittest from unittest.mock import patch import torch -import torch._dynamo import torch._dynamo.config as dynamo_config import torch._inductor.config as inductor_config import torch._inductor.utils diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index b60c52fb1c9..612a80ec1de 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -4,7 +4,7 @@ from unittest.mock import patch import functorch -import torch._dynamo +import torch import torch._inductor.config as config from torch._dynamo.backends.registry import register_backend from torch._inductor import metrics diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 4988a1aa69d..346db3ce1b0 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -21,7 +21,6 @@ import numpy as np import torch -import torch._dynamo import torch._dynamo.config as dynamo_config import torch.nn as nn from torch._dispatch.python import enable_python_dispatcher diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 322b1ef93a3..c5329214455 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -10,7 +10,6 @@ from unittest.mock import patch import torch -import torch._dynamo from torch._dynamo.test_case import run_tests from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, diff --git a/test/test_testing.py b/test/test_testing.py index e1a1606a74c..bc6aca1d515 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -2139,6 +2139,15 @@ instantiate_device_type_tests(TestTestParametrizationDeviceType, globals()) class TestImports(TestCase): + @classmethod + def _check_python_output(cls, program) -> str: + return subprocess.check_output( + [sys.executable, "-W", "all", "-c", program], + stderr=subprocess.STDOUT, + # On Windows, opening the subprocess with the default CWD makes `import torch` + # fail, so just set CWD to this script's directory + cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8") + def test_circular_dependencies(self) -> None: """ Checks that all modules inside torch can be imported Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """ @@ -2187,14 +2196,14 @@ class TestImports(TestCase): raise RuntimeError(f"Failed to import {mod_name}: {e}") from e self.assertTrue(inspect.ismodule(mod)) + @unittest.skipIf(IS_WINDOWS, "TODO enable on Windows") + def test_lazy_imports_are_lazy(self) -> None: + out = self._check_python_output("import sys;import torch;print(all(x not in sys.modules for x in torch._lazy_modules))") + self.assertEqual(out.strip(), "True") + @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning") def test_no_warning_on_import(self) -> None: - out = subprocess.check_output( - [sys.executable, "-W", "all", "-c", "import torch"], - stderr=subprocess.STDOUT, - # On Windows, opening the subprocess with the default CWD makes `import torch` - # fail, so just set CWD to this script's directory - cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8") + out = self._check_python_output("import torch") self.assertEqual(out, "") @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning") @@ -2212,10 +2221,7 @@ class TestImports(TestCase): 'logging.root.setLevel(logging.INFO)', f'_logger.info("{expected}")' ] - out = subprocess.check_output( - [sys.executable, "-W", "all", "-c", "; ".join(commands)], - stderr=subprocess.STDOUT, - ).decode("utf-8") + out = self._check_python_output("; ".join(commands)) self.assertEqual(out.strip(), expected) class TestOpInfos(TestCase): diff --git a/torch/__init__.py b/torch/__init__.py index 5800d25f6ba..3fc8910f059 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1652,7 +1652,6 @@ def compile(model: Optional[Callable] = None, *, disable=disable) return fn - import torch._dynamo if mode is not None and options is not None: raise RuntimeError("Either mode or options can be specified, but both can't be specified at the same time.") if mode is None and options is None: @@ -1743,12 +1742,26 @@ _deprecated_attrs = { "has_cudnn": torch.backends.cudnn.is_available, "has_mkldnn": torch.backends.mkldnn.is_available, } + +_lazy_modules = { + "_dynamo", + "_inductor", +} + def __getattr__(name): + # Deprecated attrs replacement = _deprecated_attrs.get(name) if replacement is not None: import warnings warnings.warn(f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'", stacklevel=2) return replacement() + + # Lazy modules + if name in _lazy_modules: + import importlib + return importlib.import_module(f".{name}", __name__) + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + from . import _logging _logging._init_logs()