Assume sympy is always installed (#94903)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94903
Approved by: https://github.com/Skylion007, https://github.com/malfet
This commit is contained in:
Edward Z. Yang 2023-02-15 17:57:21 -05:00 committed by PyTorch MergeBot
parent 23b1af0399
commit 89e16c4f18
11 changed files with 54 additions and 146 deletions

View File

@ -63,6 +63,12 @@ follow_imports = skip
[mypy-numpy] [mypy-numpy]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-sympy]
ignore_missing_imports = True
[mypy-sympy.*]
ignore_missing_imports = True
[mypy-mypy.*] [mypy-mypy.*]
ignore_missing_imports = True ignore_missing_imports = True

View File

@ -200,6 +200,12 @@ ignore_missing_imports = True
[mypy-numpy.*] [mypy-numpy.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-sympy]
ignore_missing_imports = True
[mypy-sympy.*]
ignore_missing_imports = True
[mypy-hypothesis.*] [mypy-hypothesis.*]
ignore_missing_imports = True ignore_missing_imports = True

View File

@ -12,9 +12,8 @@ from torch.testing._internal.common_utils import (
TestCase, TestCase,
run_tests, run_tests,
IS_ARM64, IS_ARM64,
IS_WINDOWS,
compare_equal_outs_and_grads, compare_equal_outs_and_grads,
outs_and_grads outs_and_grads,
) )
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -70,14 +69,6 @@ except ImportError:
warnings.warn("Some tests use networkx but it was not installed", warnings.warn("Some tests use networkx but it was not installed",
UserWarning) UserWarning)
try:
import sympy # noqa: F401
# TODO(jansel): these tests fail on windows
HAS_SYMPY = not IS_WINDOWS
except ImportError:
HAS_SYMPY = False
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
# NB: numpy is a testing dependency! # NB: numpy is a testing dependency!
class AOTTestCase(TestCase): class AOTTestCase(TestCase):
@ -1697,7 +1688,6 @@ def forward(self, primals_1, primals_2):
@patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_dynamic_shapes", True)
@patch("functorch.compile.config.use_fake_tensor", True) @patch("functorch.compile.config.use_fake_tensor", True)
@skipIfNoSympy
def test_output_op_depending_on_symint(self): def test_output_op_depending_on_symint(self):
""" """
It won't be obvious from reading this test what it's testing for. We should probably make it into a more It won't be obvious from reading this test what it's testing for. We should probably make it into a more
@ -1726,7 +1716,6 @@ def forward(self, primals_1, primals_2):
@patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_dynamic_shapes", True)
@patch("functorch.compile.config.use_fake_tensor", True) @patch("functorch.compile.config.use_fake_tensor", True)
@skipIfNoSympy
def test_default_partitioner_saves_symints_not_tensors_for_bw(self): def test_default_partitioner_saves_symints_not_tensors_for_bw(self):
""" """
In this test, the important thing is that primals_1 is **only** needed in the backward In this test, the important thing is that primals_1 is **only** needed in the backward
@ -1919,7 +1908,6 @@ class TestPartitioning(AOTTestCase):
@patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_dynamic_shapes", True)
@patch("functorch.compile.config.use_fake_tensor", True) @patch("functorch.compile.config.use_fake_tensor", True)
@unittest.skipIf(not USE_NETWORKX, "networkx not available") @unittest.skipIf(not USE_NETWORKX, "networkx not available")
@skipIfNoSympy
def test_min_cut_partitioner_save_shape(self): def test_min_cut_partitioner_save_shape(self):
def f(x): def f(x):
@ -1960,7 +1948,6 @@ class TestPartitioning(AOTTestCase):
@patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_dynamic_shapes", True)
@patch("functorch.compile.config.use_fake_tensor", True) @patch("functorch.compile.config.use_fake_tensor", True)
@skipIfNoSympy
def test_default_partitioner_output_tensor_shape_tensor(self): def test_default_partitioner_output_tensor_shape_tensor(self):
inp = [ inp = [
@ -2025,7 +2012,6 @@ class TestPartitioning(AOTTestCase):
@patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_dynamic_shapes", True)
@patch("functorch.compile.config.use_fake_tensor", True) @patch("functorch.compile.config.use_fake_tensor", True)
@unittest.skipIf(not USE_NETWORKX, "networkx not available") @unittest.skipIf(not USE_NETWORKX, "networkx not available")
@skipIfNoSympy
def test_min_cut_partitioner_output_tensor_shape_tensor(self): def test_min_cut_partitioner_output_tensor_shape_tensor(self):
inp = [ inp = [
@ -2695,7 +2681,6 @@ class TestEagerFusionOpInfo(AOTTestCase):
_test_aot_autograd_helper(self, device, dtype, op) _test_aot_autograd_helper(self, device, dtype, op)
@ops(op_db, allowed_dtypes=(torch.float,)) @ops(op_db, allowed_dtypes=(torch.float,))
@skipIfNoSympy
@patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_dynamic_shapes", True)
@patch("functorch.compile.config.use_fake_tensor", True) @patch("functorch.compile.config.use_fake_tensor", True)
@patch("functorch.compile.config.use_functionalize", True) @patch("functorch.compile.config.use_functionalize", True)
@ -2742,7 +2727,6 @@ class TestEagerFusionModuleInfo(AOTTestCase):
_test_aot_autograd_module_helper(self, device, dtype, training, module_info) _test_aot_autograd_module_helper(self, device, dtype, training, module_info)
@modules(module_db, allowed_dtypes=(torch.float,)) @modules(module_db, allowed_dtypes=(torch.float,))
@skipIfNoSympy
@patch("functorch.compile.config.use_dynamic_shapes", True) @patch("functorch.compile.config.use_dynamic_shapes", True)
@patch("functorch.compile.config.use_fake_tensor", True) @patch("functorch.compile.config.use_fake_tensor", True)
@patch("functorch.compile.config.use_functionalize", True) @patch("functorch.compile.config.use_functionalize", True)

View File

@ -12,14 +12,7 @@ from torch.fx.experimental.rewriter import RewritingTracer
from torch.fx import GraphModule from torch.fx import GraphModule
from torch.fx.passes.shape_prop import ShapeProp from torch.fx.passes.shape_prop import ShapeProp
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import TestCase
import sympy
try:
import sympy
HAS_SYMPY = True
except ImportError:
HAS_SYMPY = False
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
try: try:
@ -813,7 +806,6 @@ class TypeCheckerTest(TestCase):
if n.op == 'output': if n.op == 'output':
assert is_consistent(n.type, TensorType(b.size())) assert is_consistent(n.type, TensorType(b.size()))
@skipIfNoSympy
@skipIfNoTorchVision @skipIfNoTorchVision
def test_resnet50(self): def test_resnet50(self):
gm_run = symbolic_trace(resnet50()) gm_run = symbolic_trace(resnet50())
@ -860,7 +852,6 @@ class TypeCheckerTest(TestCase):
batch_sizes.add(n.type.__args__[0]) batch_sizes.add(n.type.__args__[0])
assert (len(batch_sizes) == 1) assert (len(batch_sizes) == 1)
@skipIfNoSympy
def test_type_check_batch_norm_symbolic(self): def test_type_check_batch_norm_symbolic(self):
class BasicBlock(torch.nn.Module): class BasicBlock(torch.nn.Module):
@ -892,7 +883,6 @@ class TypeCheckerTest(TestCase):
for n in graph.nodes: for n in graph.nodes:
assert n.type == next(my_types) assert n.type == next(my_types)
@skipIfNoSympy
def test_symbolic_add_with_broadcast(self): def test_symbolic_add_with_broadcast(self):
class M(torch.nn.Module): class M(torch.nn.Module):
def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))): def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))):
@ -921,7 +911,6 @@ class TypeCheckerTest(TestCase):
for n in symbolic_traced.graph.nodes: for n in symbolic_traced.graph.nodes:
assert n.type == next(expected_iter) assert n.type == next(expected_iter)
@skipIfNoSympy
def test_symbolic_add_with_broadcast_2(self): def test_symbolic_add_with_broadcast_2(self):
class M(torch.nn.Module): class M(torch.nn.Module):
def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))): def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))):
@ -943,7 +932,6 @@ class TypeCheckerTest(TestCase):
for n in symbolic_traced.graph.nodes: for n in symbolic_traced.graph.nodes:
assert n.type == next(expected_iter) assert n.type == next(expected_iter)
@skipIfNoSympy
def test_type_check_conv2D_types(self): def test_type_check_conv2D_types(self):
class BasicBlock(torch.nn.Module): class BasicBlock(torch.nn.Module):
def __init__(self, inplanes, planes, stride=1): def __init__(self, inplanes, planes, stride=1):
@ -971,7 +959,6 @@ class TypeCheckerTest(TestCase):
assert isinstance(n.type.__args__[2], sympy.floor) assert isinstance(n.type.__args__[2], sympy.floor)
assert isinstance(n.type.__args__[3], sympy.floor) assert isinstance(n.type.__args__[3], sympy.floor)
@skipIfNoSympy
def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self): def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self):
class BasicBlock(torch.nn.Module): class BasicBlock(torch.nn.Module):

View File

@ -5,14 +5,14 @@ from torch._C import _disabled_torch_function_impl
import torch.fx import torch.fx
import torch.nn.functional as F import torch.nn.functional as F
from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo, \ from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo, \
IS_WINDOWS, parametrize, instantiate_parametrized_tests parametrize, instantiate_parametrized_tests
import unittest
import torch import torch
import operator import operator
import itertools import itertools
import contextlib import contextlib
import math import math
import copy import copy
import sympy
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torch.fx.experimental import symbolic_shapes from torch.fx.experimental import symbolic_shapes
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
@ -24,15 +24,6 @@ from torch import SymBool, SymInt, SymFloat, sym_int
aten = torch.ops.aten aten = torch.ops.aten
try:
import sympy
# TODO(jansel): these tests fail on windows
HAS_SYMPY = not IS_WINDOWS
except ImportError:
HAS_SYMPY = False
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
meta_funcs = {} meta_funcs = {}
@ -135,7 +126,6 @@ def create_symint(shape_env, i: int):
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)") @skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestPySymInt(TestCase): class TestPySymInt(TestCase):
@skipIfNoSympy
def test_arith_ops(self): def test_arith_ops(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
symints = [] symints = []
@ -150,7 +140,6 @@ class TestPySymInt(TestCase):
self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])) self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
@skipIfNoSympy
def test_reverse_arith_ops(self): def test_reverse_arith_ops(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
@ -161,7 +150,6 @@ class TestPySymInt(TestCase):
self.assertTrue(5 * a == 5 * 2) self.assertTrue(5 * a == 5 * 2)
@skipIfNoSympy
def test_roundtrip(self): def test_roundtrip(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
@ -187,7 +175,6 @@ class TestPySymInt(TestCase):
self.assertTrue(isinstance(y.storage_offset(), SymInt)) self.assertTrue(isinstance(y.storage_offset(), SymInt))
self.assertTrue(y.storage_offset() == 12) self.assertTrue(y.storage_offset() == 12)
@skipIfNoSympy
def test_binary(self): def test_binary(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
@ -205,7 +192,6 @@ class TestPySymInt(TestCase):
self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3) self.assertTrue(z.shape[2] == 3)
@skipIfNoSympy
def test_symint_args(self): def test_symint_args(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
@ -222,7 +208,6 @@ class TestPySymInt(TestCase):
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1) z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
self.assertTrue(z.shape[2] == 2) self.assertTrue(z.shape[2] == 2)
@skipIfNoSympy
def test_symint_vargs(self): def test_symint_vargs(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
@ -267,13 +252,11 @@ class TestPySymInt(TestCase):
z = y.expand((y.shape[1],)) z = y.expand((y.shape[1],))
z = y.expand(y.shape[1]) z = y.expand(y.shape[1])
@skipIfNoSympy
def test_stride(self): def test_stride(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env) x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
self.assertIsInstance(x.stride()[0], SymInt) self.assertIsInstance(x.stride()[0], SymInt)
@skipIfNoSympy
def test_size_expressions(self): def test_size_expressions(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env) x = create_symbolic_tensor("x", torch.randn(5), shape_env)
@ -289,7 +272,6 @@ class TestPySymInt(TestCase):
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
@skipIfNoSympy
def test_numel(self): def test_numel(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env) x = create_symbolic_tensor("x", torch.randn(5), shape_env)
@ -300,14 +282,12 @@ class TestPySymInt(TestCase):
self.assertIsInstance(x.numel(), int) self.assertIsInstance(x.numel(), int)
self.assertIsInstance(torch.numel(x), int) self.assertIsInstance(torch.numel(x), int)
@skipIfNoSympy
def test_int_to_float(self): def test_int_to_float(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env) x = create_symbolic_tensor("x", torch.randn(5), shape_env)
r = sym_float(x.shape[0]) r = sym_float(x.shape[0])
self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertIsInstance(r, torch.SymFloat, msg=type(r))
@skipIfNoSympy
def test_aten_ops(self): def test_aten_ops(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
@ -330,21 +310,18 @@ class TestPySymInt(TestCase):
# tuple of ints, not tuple # tuple of ints, not tuple
torch.fx.symbolic_trace(m) torch.fx.symbolic_trace(m)
@skipIfNoSympy
def test_meta_symint(self): def test_meta_symint(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2) a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device='meta') r = torch.empty(a0, device='meta')
self.assertIsInstance(r.shape[0], SymInt) self.assertIsInstance(r.shape[0], SymInt)
@skipIfNoSympy
def test_guard_int(self): def test_guard_int(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2) a0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(a0), 2) self.assertEqual(guard_int(a0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
@skipIfNoSympy
def test_sym_int(self): def test_sym_int(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5) a0 = create_symint(shape_env, 5)
@ -371,7 +348,6 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[3][0]), """Eq(2*s2, 6)""") self.assertExpectedInline(str(shape_env.guards[3][0]), """Eq(2*s2, 6)""")
@skipIfNoSympy
def test_sym_sqrt(self): def test_sym_sqrt(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 4) a0 = create_symint(shape_env, 4)
@ -380,7 +356,6 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(sqrt(s0), 2)""") self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(sqrt(s0), 2)""")
@skipIfNoSympy
def test_sym_floor(self): def test_sym_floor(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5) a0 = create_symint(shape_env, 5)
@ -393,7 +368,6 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
@skipIfNoSympy
def test_sym_ceil(self): def test_sym_ceil(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5) a0 = create_symint(shape_env, 5)
@ -407,26 +381,22 @@ class TestPySymInt(TestCase):
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
@skipIfNoSympy
def test_int_conversion(self): def test_int_conversion(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2) a0 = create_symint(shape_env, 2)
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0)) self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
@skipIfNoSympy
def test_data_dependent_guard(self): def test_data_dependent_guard(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint() s0 = shape_env.create_unbacked_symint()
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0)) self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
@skipIfNoSympy
def test_non_overlapping_and_dense(self): def test_non_overlapping_and_dense(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5) a0 = create_symint(shape_env, 5)
r = torch.empty_strided((a0, 7), (1, a0), device='meta') r = torch.empty_strided((a0, 7), (1, a0), device='meta')
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r)) self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
@skipIfNoSympy
def test_symint_as_scalar(self): def test_symint_as_scalar(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2) a0 = create_symint(shape_env, 2)
@ -450,7 +420,6 @@ class TestPySymInt(TestCase):
self.assertTrue(sym_int_encountered) self.assertTrue(sym_int_encountered)
@skipIfNoSympy
def test_deepcopy(self): def test_deepcopy(self):
shape_env = ShapeEnv() shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2) a0 = create_symint(shape_env, 2)
@ -458,7 +427,6 @@ class TestPySymInt(TestCase):
new_shape_env = copy.deepcopy(shape_env) new_shape_env = copy.deepcopy(shape_env)
self.assertEqual(len(new_shape_env.guards), 1) self.assertEqual(len(new_shape_env.guards), 1)
@skipIfNoSympy
def test_print_readable_with_symints(self): def test_print_readable_with_symints(self):
def f(a, b): def f(a, b):
dim0 = a.shape[0] + b.shape[0] dim0 = a.shape[0] + b.shape[0]
@ -650,7 +618,6 @@ class TestFloorDiv(TestCase):
yield (x, -y) yield (x, -y)
yield (-x, -y) yield (-x, -y)
@skipIfNoSympy
def test_floordiv_float_int(self): def test_floordiv_float_int(self):
values = ( values = (
(2.5, 2.1), (2.5, 2.1),
@ -664,7 +631,6 @@ class TestFloorDiv(TestCase):
for x, y in TestFloorDiv.yield_test_cases(values): for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)) self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
@skipIfNoSympy
def test_floordiv_bool(self): def test_floordiv_bool(self):
values = ( values = (
(False, True), (False, True),
@ -685,7 +651,6 @@ class TestFloorDiv(TestCase):
rf", expected integer or real"), rf", expected integer or real"),
lambda: TestFloorDiv.torch_floordiv(x, y)) lambda: TestFloorDiv.torch_floordiv(x, y))
@skipIfNoSympy
def test_floordiv_complex(self): def test_floordiv_complex(self):
values = ( values = (
(1.5 + 2.5j, 1.3 + 3.5j), (1.5 + 2.5j, 1.3 + 3.5j),
@ -706,7 +671,6 @@ class TestFloorDiv(TestCase):
rf", expected integer or real"), rf", expected integer or real"),
lambda: TestFloorDiv.torch_floordiv(x, y)) lambda: TestFloorDiv.torch_floordiv(x, y))
@skipIfNoSympy
def test_floordiv_div_by_zero(self): def test_floordiv_div_by_zero(self):
values = ( values = (
(2.5, 0), (2.5, 0),
@ -724,7 +688,6 @@ class TestFloorDiv(TestCase):
"division by zero", "division by zero",
lambda: TestFloorDiv.torch_floordiv(x, y)) lambda: TestFloorDiv.torch_floordiv(x, y))
@skipIfNoSympy
def test_floordiv_zero_base(self): def test_floordiv_zero_base(self):
values = ( values = (
(0, 2.5), (0, 2.5),
@ -738,7 +701,6 @@ class TestFloorDiv(TestCase):
else: else:
self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y)) self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y))
@skipIfNoSympy
def test_floordiv_div_by_one(self): def test_floordiv_div_by_one(self):
values = ( values = (
(2.5, 1), (2.5, 1),
@ -750,7 +712,6 @@ class TestFloorDiv(TestCase):
for x, y in TestFloorDiv.yield_test_cases(values): for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)) self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
@skipIfNoSympy
def test_floordiv_simplify(self): def test_floordiv_simplify(self):
# Tests how we simplify or evaluate FloorDiv without free variables # Tests how we simplify or evaluate FloorDiv without free variables
shape_env = ShapeEnv() shape_env = ShapeEnv()
@ -770,7 +731,6 @@ class TestFloorDiv(TestCase):
self.assertEqual(shape_env.simplify(expr), result) self.assertEqual(shape_env.simplify(expr), result)
self.assertEqual(shape_env.evaluate_expr(expr), result) self.assertEqual(shape_env.evaluate_expr(expr), result)
@skipIfNoSympy
def test_floordiv_assumptions(self): def test_floordiv_assumptions(self):
# We define two Symbols (with different names) for each type to make # We define two Symbols (with different names) for each type to make
# sure the behavior is consistent regardless of whether both arguments # sure the behavior is consistent regardless of whether both arguments

View File

@ -1,6 +1,6 @@
# Owner(s): ["module: ProxyTensor"] # Owner(s): ["module: ProxyTensor"]
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, xfail_inherited_tests from torch.testing._internal.common_utils import TestCase, run_tests, xfail_inherited_tests
import torch import torch
import unittest import unittest
import warnings import warnings
@ -27,13 +27,6 @@ import itertools
aten = torch.ops.aten aten = torch.ops.aten
try:
import sympy # noqa: F401
# TODO(jansel): these tests fail on windows
HAS_SYMPY = not IS_WINDOWS
except ImportError:
HAS_SYMPY = False
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
HAS_CUDA = torch.cuda.is_available() HAS_CUDA = torch.cuda.is_available()
@ -735,7 +728,6 @@ class TestGenericProxyTensorFake(TestGenericProxyTensor):
tracing_mode = "fake" tracing_mode = "fake"
@skipIfNoSympy
@xfail_inherited_tests([ @xfail_inherited_tests([
"test_make_fx_overloads", "test_make_fx_overloads",
"test_trace_subclasses", "test_trace_subclasses",
@ -812,7 +804,6 @@ def _trace(f, *args):
return make_fx(f, tracing_mode="symbolic")(*inps) return make_fx(f, tracing_mode="symbolic")(*inps)
# TODO: Need to test the guards themselves specifically as well # TODO: Need to test the guards themselves specifically as well
@skipIfNoSympy
class TestSymbolicTracing(TestCase): class TestSymbolicTracing(TestCase):
def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True): def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True):
""" """
@ -1489,14 +1480,12 @@ class TestProxyTensorOpInfo(TestCase):
def test_make_fx_fake_exhaustive(self, device, dtype, op): def test_make_fx_fake_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "fake") _test_make_fx_helper(self, device, dtype, op, "fake")
@skipIfNoSympy
@ops(op_db, allowed_dtypes=(torch.float,)) @ops(op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures) make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures)
def test_make_fx_symbolic_exhaustive(self, device, dtype, op): def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
_test_make_fx_helper(self, device, dtype, op, "symbolic") _test_make_fx_helper(self, device, dtype, op, "symbolic")
@skipIfNoSympy
@ops(op_db, allowed_dtypes=(torch.float,)) @ops(op_db, allowed_dtypes=(torch.float,))
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace', @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace',
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures) make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures)

View File

@ -6,6 +6,7 @@ import random
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import sympy
import torch import torch
from torch import SymInt from torch import SymInt
@ -126,7 +127,6 @@ class DebugInterpreter(fx.Interpreter):
super().run(*args) super().run(*args)
def run_node(self, n): def run_node(self, n):
import sympy
def subst_symint(ni): def subst_symint(ni):
if not isinstance(ni, SymInt): if not isinstance(ni, SymInt):

View File

@ -8,12 +8,7 @@ from typing import Callable, Generic, List, NamedTuple, Optional, Set, TypeVar
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# TODO(voz): Stolen pattern, not sure why this is the case, import sympy
# but mypy complains.
try:
import sympy # type: ignore[import]
except ImportError:
log.warning("No sympy found")
""" """
torch._guards is the definitional source of truth for general purpose guard structures. torch._guards is the definitional source of truth for general purpose guard structures.

View File

@ -11,12 +11,7 @@ import itertools
from torch.fx.experimental.unification import Var # type: ignore[attr-defined] from torch.fx.experimental.unification import Var # type: ignore[attr-defined]
import sympy
try:
import sympy # type: ignore[import]
HAS_SYMPY = True
except ImportError:
HAS_SYMPY = False
_INFERENCE_RULES: Dict[Target, Callable] = {} _INFERENCE_RULES: Dict[Target, Callable] = {}
_REFINEMENT_RULES: Dict[Target, Callable] = {} _REFINEMENT_RULES: Dict[Target, Callable] = {}
@ -305,7 +300,7 @@ def calculate_out_dimension(d_in, module_instance, index):
dilation = (module_instance.dilation, module_instance.dilation) \ dilation = (module_instance.dilation, module_instance.dilation) \
if isinstance(module_instance.dilation, int) else module_instance.dilation if isinstance(module_instance.dilation, int) else module_instance.dilation
DIMENSION_TYPES = (int, sympy.Symbol) if HAS_SYMPY else (int,) DIMENSION_TYPES = (int, sympy.Symbol)
if d_in == Dyn: if d_in == Dyn:
return Dyn return Dyn
@ -814,18 +809,15 @@ class Refine:
""" """
Replace all unknown types with fresh type variables. Replace all unknown types with fresh type variables.
""" """
if HAS_SYMPY: if isinstance(typ, Var):
if isinstance(typ, Var): return sympy.symbols(str(typ))
return sympy.symbols(str(typ)) elif isinstance(typ, TensorType):
elif isinstance(typ, TensorType): new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__]
new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__] return TensorType(tuple(new_args))
return TensorType(tuple(new_args)) elif isinstance(typ, list):
elif isinstance(typ, list): return [self.convert_to_sympy_symbols(t) for t in typ]
return [self.convert_to_sympy_symbols(t) for t in typ] elif isinstance(typ, tuple):
elif isinstance(typ, tuple): return (self.convert_to_sympy_symbols(t) for t in typ)
return (self.convert_to_sympy_symbols(t) for t in typ)
else:
return typ
else: else:
return typ return typ
@ -865,29 +857,26 @@ class Refine:
pass pass
def infer_symbolic_relations(self, n: Node): def infer_symbolic_relations(self, n: Node):
if HAS_SYMPY: n.type = self.convert_to_sympy_symbols(n.type)
n.type = self.convert_to_sympy_symbols(n.type) if n.op == 'call_function':
if n.op == 'call_function': if n.target in _RULES:
if n.target in _RULES: return _RULES[n.target](n)
return _RULES[n.target](n)
else:
pass
if n.op == 'call_module':
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _RULES:
return _RULES[type(module_instance)](n, module_instance)
else:
pass
if n.op == 'output':
def get_node_type(a):
return a.type
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
return n.type
else: else:
pass pass
if n.op == 'call_module':
module_instance = self.traced.get_submodule(n.target)
if type(module_instance) in _RULES:
return _RULES[type(module_instance)](n, module_instance)
else:
pass
if n.op == 'output':
def get_node_type(a):
return a.type
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
return n.type
else: else:
pass pass

View File

@ -25,14 +25,9 @@ log = logging.getLogger(__name__)
class GuardOnDataDependentSymNode(RuntimeError): class GuardOnDataDependentSymNode(RuntimeError):
pass pass
try: import sympy
import sympy # type: ignore[import] from sympy.printing.str import StrPrinter
from sympy.printing.precedence import precedence # type: ignore[import] # noqa: F401 from sympy.core.logic import fuzzy_and, fuzzy_or
from sympy.printing.str import StrPrinter # type: ignore[import]
from sympy.core.logic import fuzzy_and, fuzzy_or # type: ignore[import]
HAS_SYMPY = True
except ImportError:
HAS_SYMPY = False
aten = torch._ops.ops.aten # type: ignore[has-type] aten = torch._ops.ops.aten # type: ignore[has-type]
@ -408,7 +403,7 @@ class SymNode:
return self.guard_bool("", 0) return self.guard_bool("", 0)
if HAS_SYMPY: if True: # TODO: unindent
# Overloaded to be compatible with regular Python. # Overloaded to be compatible with regular Python.
# https://github.com/pytorch/pytorch/issues/90900 # https://github.com/pytorch/pytorch/issues/90900
class Pow(sympy.Function): class Pow(sympy.Function):
@ -886,7 +881,7 @@ def _lru_cache(fn, maxsize=None):
return wrapper return wrapper
if HAS_SYMPY: if True: # TODO: unindent
# This stub exists so we can easily add metadata to sympy symbols # This stub exists so we can easily add metadata to sympy symbols
# NB: This inherits from Dummy, not Symbol, because Symbols with the same # NB: This inherits from Dummy, not Symbol, because Symbols with the same
# name get interned. This is bad for us as we want the metadata # name get interned. This is bad for us as we want the metadata
@ -1040,9 +1035,6 @@ class ShapeEnv:
def create_symbol(self, val: int, source: Source) -> "sympy.Expr": def create_symbol(self, val: int, source: Source) -> "sympy.Expr":
assert isinstance(source, Source), f"{type(source)} {source}" assert isinstance(source, Source), f"{type(source)} {source}"
if not HAS_SYMPY:
raise RuntimeError("Need sympy installed to create symbolic shapes")
if val < 0: if val < 0:
from torch._dynamo.source import NegateSource from torch._dynamo.source import NegateSource
return -self.create_symbol(-val, NegateSource(source)) return -self.create_symbol(-val, NegateSource(source))

View File

@ -1,6 +1,6 @@
import dataclasses import dataclasses
import itertools import itertools
import sympy # type: ignore[import] import sympy
import operator import operator
import math import math
import logging import logging