diff --git a/test/test_testing.py b/test/test_testing.py index 7aebaa78e41..1c668f8476c 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -2240,6 +2240,17 @@ class TestImports(TestCase): out = self._check_python_output("import torch") self.assertEqual(out, "") + def test_not_import_sympy(self) -> None: + out = self._check_python_output("import torch;import sys;print('sympy' not in sys.modules)") + self.assertEqual(out.strip(), "True", + "PyTorch should not depend on SymPy at import time as importing SymPy is *very* slow.\n" + "See the beginning of the following blog post for how to profile and find which file is importing sympy:\n" + "https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589\n\n" + "If you hit this error, you may want to:\n" + " - Refactor your code to avoid depending on sympy files you may not need to depend\n" + " - Use TYPE_CHECKING if you are using sympy + strings if you are using sympy on type annotations\n" + " - Import things that depend on SymPy locally") + @unittest.skipIf(IS_WINDOWS, "importing torch+CUDA on CPU results in warning") @parametrize('path', ['torch', 'functorch']) def test_no_mutate_global_logging_on_import(self, path) -> None: