From 89e16c4f184ab41c7d93cf5ac9edf738c1b67937 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 15 Feb 2023 17:57:21 -0500 Subject: [PATCH] Assume sympy is always installed (#94903) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/94903 Approved by: https://github.com/Skylion007, https://github.com/malfet --- mypy-strict.ini | 6 ++ mypy.ini | 6 ++ test/functorch/test_aotdispatch.py | 18 +---- test/fx/test_gradual_type.py | 15 +--- test/test_dynamic_shapes.py | 44 +----------- test/test_proxy_tensor.py | 13 +--- torch/_functorch/compilers.py | 2 +- torch/_guards.py | 7 +- .../experimental/graph_gradual_typechecker.py | 69 ++++++++----------- torch/fx/experimental/symbolic_shapes.py | 18 ++--- torch/utils/_sympy/value_ranges.py | 2 +- 11 files changed, 54 insertions(+), 146 deletions(-) diff --git a/mypy-strict.ini b/mypy-strict.ini index 3e5edf90dc3..e4d9d7a143e 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -63,6 +63,12 @@ follow_imports = skip [mypy-numpy] ignore_missing_imports = True +[mypy-sympy] +ignore_missing_imports = True + +[mypy-sympy.*] +ignore_missing_imports = True + [mypy-mypy.*] ignore_missing_imports = True diff --git a/mypy.ini b/mypy.ini index 1fc2e11c3e0..380f432c480 100644 --- a/mypy.ini +++ b/mypy.ini @@ -200,6 +200,12 @@ ignore_missing_imports = True [mypy-numpy.*] ignore_missing_imports = True +[mypy-sympy] +ignore_missing_imports = True + +[mypy-sympy.*] +ignore_missing_imports = True + [mypy-hypothesis.*] ignore_missing_imports = True diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 4f2529ae60b..dae4f4c12fa 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -12,9 +12,8 @@ from torch.testing._internal.common_utils import ( TestCase, run_tests, IS_ARM64, - IS_WINDOWS, compare_equal_outs_and_grads, - outs_and_grads + outs_and_grads, ) import torch import torch.nn as nn @@ -70,14 +69,6 @@ except ImportError: warnings.warn("Some tests use networkx but it was not installed", UserWarning) -try: - import sympy # noqa: F401 - # TODO(jansel): these tests fail on windows - HAS_SYMPY = not IS_WINDOWS -except ImportError: - HAS_SYMPY = False -skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy") - # NB: numpy is a testing dependency! class AOTTestCase(TestCase): @@ -1697,7 +1688,6 @@ def forward(self, primals_1, primals_2): @patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_fake_tensor", True) - @skipIfNoSympy def test_output_op_depending_on_symint(self): """ It won't be obvious from reading this test what it's testing for. We should probably make it into a more @@ -1726,7 +1716,6 @@ def forward(self, primals_1, primals_2): @patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_fake_tensor", True) - @skipIfNoSympy def test_default_partitioner_saves_symints_not_tensors_for_bw(self): """ In this test, the important thing is that primals_1 is **only** needed in the backward @@ -1919,7 +1908,6 @@ class TestPartitioning(AOTTestCase): @patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_fake_tensor", True) @unittest.skipIf(not USE_NETWORKX, "networkx not available") - @skipIfNoSympy def test_min_cut_partitioner_save_shape(self): def f(x): @@ -1960,7 +1948,6 @@ class TestPartitioning(AOTTestCase): @patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_fake_tensor", True) - @skipIfNoSympy def test_default_partitioner_output_tensor_shape_tensor(self): inp = [ @@ -2025,7 +2012,6 @@ class TestPartitioning(AOTTestCase): @patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_fake_tensor", True) @unittest.skipIf(not USE_NETWORKX, "networkx not available") - @skipIfNoSympy def test_min_cut_partitioner_output_tensor_shape_tensor(self): inp = [ @@ -2695,7 +2681,6 @@ class TestEagerFusionOpInfo(AOTTestCase): _test_aot_autograd_helper(self, device, dtype, op) @ops(op_db, allowed_dtypes=(torch.float,)) - @skipIfNoSympy @patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_fake_tensor", True) @patch("functorch.compile.config.use_functionalize", True) @@ -2742,7 +2727,6 @@ class TestEagerFusionModuleInfo(AOTTestCase): _test_aot_autograd_module_helper(self, device, dtype, training, module_info) @modules(module_db, allowed_dtypes=(torch.float,)) - @skipIfNoSympy @patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_fake_tensor", True) @patch("functorch.compile.config.use_functionalize", True) diff --git a/test/fx/test_gradual_type.py b/test/fx/test_gradual_type.py index 1e678de3a5b..23c6496b3a2 100644 --- a/test/fx/test_gradual_type.py +++ b/test/fx/test_gradual_type.py @@ -12,14 +12,7 @@ from torch.fx.experimental.rewriter import RewritingTracer from torch.fx import GraphModule from torch.fx.passes.shape_prop import ShapeProp from torch.testing._internal.common_utils import TestCase - - -try: - import sympy - HAS_SYMPY = True -except ImportError: - HAS_SYMPY = False -skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy") +import sympy try: @@ -813,7 +806,6 @@ class TypeCheckerTest(TestCase): if n.op == 'output': assert is_consistent(n.type, TensorType(b.size())) - @skipIfNoSympy @skipIfNoTorchVision def test_resnet50(self): gm_run = symbolic_trace(resnet50()) @@ -860,7 +852,6 @@ class TypeCheckerTest(TestCase): batch_sizes.add(n.type.__args__[0]) assert (len(batch_sizes) == 1) - @skipIfNoSympy def test_type_check_batch_norm_symbolic(self): class BasicBlock(torch.nn.Module): @@ -892,7 +883,6 @@ class TypeCheckerTest(TestCase): for n in graph.nodes: assert n.type == next(my_types) - @skipIfNoSympy def test_symbolic_add_with_broadcast(self): class M(torch.nn.Module): def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))): @@ -921,7 +911,6 @@ class TypeCheckerTest(TestCase): for n in symbolic_traced.graph.nodes: assert n.type == next(expected_iter) - @skipIfNoSympy def test_symbolic_add_with_broadcast_2(self): class M(torch.nn.Module): def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))): @@ -943,7 +932,6 @@ class TypeCheckerTest(TestCase): for n in symbolic_traced.graph.nodes: assert n.type == next(expected_iter) - @skipIfNoSympy def test_type_check_conv2D_types(self): class BasicBlock(torch.nn.Module): def __init__(self, inplanes, planes, stride=1): @@ -971,7 +959,6 @@ class TypeCheckerTest(TestCase): assert isinstance(n.type.__args__[2], sympy.floor) assert isinstance(n.type.__args__[3], sympy.floor) - @skipIfNoSympy def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self): class BasicBlock(torch.nn.Module): diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 480b83dc8b3..6b095ef3c30 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -5,14 +5,14 @@ from torch._C import _disabled_torch_function_impl import torch.fx import torch.nn.functional as F from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo, \ - IS_WINDOWS, parametrize, instantiate_parametrized_tests -import unittest + parametrize, instantiate_parametrized_tests import torch import operator import itertools import contextlib import math import copy +import sympy from torch.utils._pytree import tree_map from torch.fx.experimental import symbolic_shapes from torch.fx.experimental.proxy_tensor import make_fx @@ -24,15 +24,6 @@ from torch import SymBool, SymInt, SymFloat, sym_int aten = torch.ops.aten -try: - import sympy - # TODO(jansel): these tests fail on windows - HAS_SYMPY = not IS_WINDOWS -except ImportError: - HAS_SYMPY = False -skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy") - - meta_funcs = {} @@ -135,7 +126,6 @@ def create_symint(shape_env, i: int): @skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)") class TestPySymInt(TestCase): - @skipIfNoSympy def test_arith_ops(self): shape_env = ShapeEnv() symints = [] @@ -150,7 +140,6 @@ class TestPySymInt(TestCase): self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])) - @skipIfNoSympy def test_reverse_arith_ops(self): shape_env = ShapeEnv() @@ -161,7 +150,6 @@ class TestPySymInt(TestCase): self.assertTrue(5 * a == 5 * 2) - @skipIfNoSympy def test_roundtrip(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) @@ -187,7 +175,6 @@ class TestPySymInt(TestCase): self.assertTrue(isinstance(y.storage_offset(), SymInt)) self.assertTrue(y.storage_offset() == 12) - @skipIfNoSympy def test_binary(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) @@ -205,7 +192,6 @@ class TestPySymInt(TestCase): self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) - @skipIfNoSympy def test_symint_args(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) @@ -222,7 +208,6 @@ class TestPySymInt(TestCase): z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1) self.assertTrue(z.shape[2] == 2) - @skipIfNoSympy def test_symint_vargs(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) @@ -267,13 +252,11 @@ class TestPySymInt(TestCase): z = y.expand((y.shape[1],)) z = y.expand(y.shape[1]) - @skipIfNoSympy def test_stride(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env) self.assertIsInstance(x.stride()[0], SymInt) - @skipIfNoSympy def test_size_expressions(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) @@ -289,7 +272,6 @@ class TestPySymInt(TestCase): self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) - @skipIfNoSympy def test_numel(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) @@ -300,14 +282,12 @@ class TestPySymInt(TestCase): self.assertIsInstance(x.numel(), int) self.assertIsInstance(torch.numel(x), int) - @skipIfNoSympy def test_int_to_float(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) r = sym_float(x.shape[0]) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) - @skipIfNoSympy def test_aten_ops(self): shape_env = ShapeEnv() @@ -330,21 +310,18 @@ class TestPySymInt(TestCase): # tuple of ints, not tuple torch.fx.symbolic_trace(m) - @skipIfNoSympy def test_meta_symint(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) r = torch.empty(a0, device='meta') self.assertIsInstance(r.shape[0], SymInt) - @skipIfNoSympy def test_guard_int(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) self.assertEqual(guard_int(a0), 2) self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") - @skipIfNoSympy def test_sym_int(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 5) @@ -371,7 +348,6 @@ class TestPySymInt(TestCase): self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline(str(shape_env.guards[3][0]), """Eq(2*s2, 6)""") - @skipIfNoSympy def test_sym_sqrt(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 4) @@ -380,7 +356,6 @@ class TestPySymInt(TestCase): self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(sqrt(s0), 2)""") - @skipIfNoSympy def test_sym_floor(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 5) @@ -393,7 +368,6 @@ class TestPySymInt(TestCase): self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") - @skipIfNoSympy def test_sym_ceil(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 5) @@ -407,26 +381,22 @@ class TestPySymInt(TestCase): self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") - @skipIfNoSympy def test_int_conversion(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0)) - @skipIfNoSympy def test_data_dependent_guard(self): shape_env = ShapeEnv() s0 = shape_env.create_unbacked_symint() self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0)) - @skipIfNoSympy def test_non_overlapping_and_dense(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 5) r = torch.empty_strided((a0, 7), (1, a0), device='meta') self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r)) - @skipIfNoSympy def test_symint_as_scalar(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) @@ -450,7 +420,6 @@ class TestPySymInt(TestCase): self.assertTrue(sym_int_encountered) - @skipIfNoSympy def test_deepcopy(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) @@ -458,7 +427,6 @@ class TestPySymInt(TestCase): new_shape_env = copy.deepcopy(shape_env) self.assertEqual(len(new_shape_env.guards), 1) - @skipIfNoSympy def test_print_readable_with_symints(self): def f(a, b): dim0 = a.shape[0] + b.shape[0] @@ -650,7 +618,6 @@ class TestFloorDiv(TestCase): yield (x, -y) yield (-x, -y) - @skipIfNoSympy def test_floordiv_float_int(self): values = ( (2.5, 2.1), @@ -664,7 +631,6 @@ class TestFloorDiv(TestCase): for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)) - @skipIfNoSympy def test_floordiv_bool(self): values = ( (False, True), @@ -685,7 +651,6 @@ class TestFloorDiv(TestCase): rf", expected integer or real"), lambda: TestFloorDiv.torch_floordiv(x, y)) - @skipIfNoSympy def test_floordiv_complex(self): values = ( (1.5 + 2.5j, 1.3 + 3.5j), @@ -706,7 +671,6 @@ class TestFloorDiv(TestCase): rf", expected integer or real"), lambda: TestFloorDiv.torch_floordiv(x, y)) - @skipIfNoSympy def test_floordiv_div_by_zero(self): values = ( (2.5, 0), @@ -724,7 +688,6 @@ class TestFloorDiv(TestCase): "division by zero", lambda: TestFloorDiv.torch_floordiv(x, y)) - @skipIfNoSympy def test_floordiv_zero_base(self): values = ( (0, 2.5), @@ -738,7 +701,6 @@ class TestFloorDiv(TestCase): else: self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y)) - @skipIfNoSympy def test_floordiv_div_by_one(self): values = ( (2.5, 1), @@ -750,7 +712,6 @@ class TestFloorDiv(TestCase): for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)) - @skipIfNoSympy def test_floordiv_simplify(self): # Tests how we simplify or evaluate FloorDiv without free variables shape_env = ShapeEnv() @@ -770,7 +731,6 @@ class TestFloorDiv(TestCase): self.assertEqual(shape_env.simplify(expr), result) self.assertEqual(shape_env.evaluate_expr(expr), result) - @skipIfNoSympy def test_floordiv_assumptions(self): # We define two Symbols (with different names) for each type to make # sure the behavior is consistent regardless of whether both arguments diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 743e09be5b6..e6be94a864b 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1,6 +1,6 @@ # Owner(s): ["module: ProxyTensor"] -from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, xfail_inherited_tests +from torch.testing._internal.common_utils import TestCase, run_tests, xfail_inherited_tests import torch import unittest import warnings @@ -27,13 +27,6 @@ import itertools aten = torch.ops.aten -try: - import sympy # noqa: F401 - # TODO(jansel): these tests fail on windows - HAS_SYMPY = not IS_WINDOWS -except ImportError: - HAS_SYMPY = False -skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy") HAS_CUDA = torch.cuda.is_available() @@ -735,7 +728,6 @@ class TestGenericProxyTensorFake(TestGenericProxyTensor): tracing_mode = "fake" -@skipIfNoSympy @xfail_inherited_tests([ "test_make_fx_overloads", "test_trace_subclasses", @@ -812,7 +804,6 @@ def _trace(f, *args): return make_fx(f, tracing_mode="symbolic")(*inps) # TODO: Need to test the guards themselves specifically as well -@skipIfNoSympy class TestSymbolicTracing(TestCase): def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True): """ @@ -1489,14 +1480,12 @@ class TestProxyTensorOpInfo(TestCase): def test_make_fx_fake_exhaustive(self, device, dtype, op): _test_make_fx_helper(self, device, dtype, op, "fake") - @skipIfNoSympy @ops(op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures) def test_make_fx_symbolic_exhaustive(self, device, dtype, op): _test_make_fx_helper(self, device, dtype, op, "symbolic") - @skipIfNoSympy @ops(op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace', make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures) diff --git a/torch/_functorch/compilers.py b/torch/_functorch/compilers.py index 6f944f6f483..735fcadb1c4 100644 --- a/torch/_functorch/compilers.py +++ b/torch/_functorch/compilers.py @@ -6,6 +6,7 @@ import random from contextlib import contextmanager from functools import partial from typing import Callable, Optional, Tuple, Union +import sympy import torch from torch import SymInt @@ -126,7 +127,6 @@ class DebugInterpreter(fx.Interpreter): super().run(*args) def run_node(self, n): - import sympy def subst_symint(ni): if not isinstance(ni, SymInt): diff --git a/torch/_guards.py b/torch/_guards.py index 0591d4048d9..76cfb77548e 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -8,12 +8,7 @@ from typing import Callable, Generic, List, NamedTuple, Optional, Set, TypeVar log = logging.getLogger(__name__) -# TODO(voz): Stolen pattern, not sure why this is the case, -# but mypy complains. -try: - import sympy # type: ignore[import] -except ImportError: - log.warning("No sympy found") +import sympy """ torch._guards is the definitional source of truth for general purpose guard structures. diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index 7ffabc9c699..f1c7428ce60 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -11,12 +11,7 @@ import itertools from torch.fx.experimental.unification import Var # type: ignore[attr-defined] - -try: - import sympy # type: ignore[import] - HAS_SYMPY = True -except ImportError: - HAS_SYMPY = False +import sympy _INFERENCE_RULES: Dict[Target, Callable] = {} _REFINEMENT_RULES: Dict[Target, Callable] = {} @@ -305,7 +300,7 @@ def calculate_out_dimension(d_in, module_instance, index): dilation = (module_instance.dilation, module_instance.dilation) \ if isinstance(module_instance.dilation, int) else module_instance.dilation - DIMENSION_TYPES = (int, sympy.Symbol) if HAS_SYMPY else (int,) + DIMENSION_TYPES = (int, sympy.Symbol) if d_in == Dyn: return Dyn @@ -814,18 +809,15 @@ class Refine: """ Replace all unknown types with fresh type variables. """ - if HAS_SYMPY: - if isinstance(typ, Var): - return sympy.symbols(str(typ)) - elif isinstance(typ, TensorType): - new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__] - return TensorType(tuple(new_args)) - elif isinstance(typ, list): - return [self.convert_to_sympy_symbols(t) for t in typ] - elif isinstance(typ, tuple): - return (self.convert_to_sympy_symbols(t) for t in typ) - else: - return typ + if isinstance(typ, Var): + return sympy.symbols(str(typ)) + elif isinstance(typ, TensorType): + new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__] + return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.convert_to_sympy_symbols(t) for t in typ] + elif isinstance(typ, tuple): + return (self.convert_to_sympy_symbols(t) for t in typ) else: return typ @@ -865,29 +857,26 @@ class Refine: pass def infer_symbolic_relations(self, n: Node): - if HAS_SYMPY: - n.type = self.convert_to_sympy_symbols(n.type) - if n.op == 'call_function': - if n.target in _RULES: - return _RULES[n.target](n) - else: - pass - - if n.op == 'call_module': - module_instance = self.traced.get_submodule(n.target) - if type(module_instance) in _RULES: - return _RULES[type(module_instance)](n, module_instance) - else: - pass - - if n.op == 'output': - def get_node_type(a): - return a.type - n.type = torch.fx.node.map_arg(n.args[0], get_node_type) - return n.type - + n.type = self.convert_to_sympy_symbols(n.type) + if n.op == 'call_function': + if n.target in _RULES: + return _RULES[n.target](n) else: pass + + if n.op == 'call_module': + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _RULES: + return _RULES[type(module_instance)](n, module_instance) + else: + pass + + if n.op == 'output': + def get_node_type(a): + return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + else: pass diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 0ec36829789..2bc814e7782 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -25,14 +25,9 @@ log = logging.getLogger(__name__) class GuardOnDataDependentSymNode(RuntimeError): pass -try: - import sympy # type: ignore[import] - from sympy.printing.precedence import precedence # type: ignore[import] # noqa: F401 - from sympy.printing.str import StrPrinter # type: ignore[import] - from sympy.core.logic import fuzzy_and, fuzzy_or # type: ignore[import] - HAS_SYMPY = True -except ImportError: - HAS_SYMPY = False +import sympy +from sympy.printing.str import StrPrinter +from sympy.core.logic import fuzzy_and, fuzzy_or aten = torch._ops.ops.aten # type: ignore[has-type] @@ -408,7 +403,7 @@ class SymNode: return self.guard_bool("", 0) -if HAS_SYMPY: +if True: # TODO: unindent # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 class Pow(sympy.Function): @@ -886,7 +881,7 @@ def _lru_cache(fn, maxsize=None): return wrapper -if HAS_SYMPY: +if True: # TODO: unindent # This stub exists so we can easily add metadata to sympy symbols # NB: This inherits from Dummy, not Symbol, because Symbols with the same # name get interned. This is bad for us as we want the metadata @@ -1040,9 +1035,6 @@ class ShapeEnv: def create_symbol(self, val: int, source: Source) -> "sympy.Expr": assert isinstance(source, Source), f"{type(source)} {source}" - if not HAS_SYMPY: - raise RuntimeError("Need sympy installed to create symbolic shapes") - if val < 0: from torch._dynamo.source import NegateSource return -self.create_symbol(-val, NegateSource(source)) diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 12cfaec83e2..9996dd710cd 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -1,6 +1,6 @@ import dataclasses import itertools -import sympy # type: ignore[import] +import sympy import operator import math import logging