mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Assume sympy is always installed (#94903)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/94903 Approved by: https://github.com/Skylion007, https://github.com/malfet
This commit is contained in:
parent
23b1af0399
commit
89e16c4f18
|
|
@ -63,6 +63,12 @@ follow_imports = skip
|
|||
[mypy-numpy]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-sympy]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-sympy.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-mypy.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
|
|
|||
6
mypy.ini
6
mypy.ini
|
|
@ -200,6 +200,12 @@ ignore_missing_imports = True
|
|||
[mypy-numpy.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-sympy]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-sympy.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-hypothesis.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,8 @@ from torch.testing._internal.common_utils import (
|
|||
TestCase,
|
||||
run_tests,
|
||||
IS_ARM64,
|
||||
IS_WINDOWS,
|
||||
compare_equal_outs_and_grads,
|
||||
outs_and_grads
|
||||
outs_and_grads,
|
||||
)
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -70,14 +69,6 @@ except ImportError:
|
|||
warnings.warn("Some tests use networkx but it was not installed",
|
||||
UserWarning)
|
||||
|
||||
try:
|
||||
import sympy # noqa: F401
|
||||
# TODO(jansel): these tests fail on windows
|
||||
HAS_SYMPY = not IS_WINDOWS
|
||||
except ImportError:
|
||||
HAS_SYMPY = False
|
||||
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
|
||||
|
||||
# NB: numpy is a testing dependency!
|
||||
|
||||
class AOTTestCase(TestCase):
|
||||
|
|
@ -1697,7 +1688,6 @@ def forward(self, primals_1, primals_2):
|
|||
|
||||
@patch("functorch.compile.config.use_dynamic_shapes", True)
|
||||
@patch("functorch.compile.config.use_fake_tensor", True)
|
||||
@skipIfNoSympy
|
||||
def test_output_op_depending_on_symint(self):
|
||||
"""
|
||||
It won't be obvious from reading this test what it's testing for. We should probably make it into a more
|
||||
|
|
@ -1726,7 +1716,6 @@ def forward(self, primals_1, primals_2):
|
|||
|
||||
@patch("functorch.compile.config.use_dynamic_shapes", True)
|
||||
@patch("functorch.compile.config.use_fake_tensor", True)
|
||||
@skipIfNoSympy
|
||||
def test_default_partitioner_saves_symints_not_tensors_for_bw(self):
|
||||
"""
|
||||
In this test, the important thing is that primals_1 is **only** needed in the backward
|
||||
|
|
@ -1919,7 +1908,6 @@ class TestPartitioning(AOTTestCase):
|
|||
@patch("functorch.compile.config.use_dynamic_shapes", True)
|
||||
@patch("functorch.compile.config.use_fake_tensor", True)
|
||||
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
|
||||
@skipIfNoSympy
|
||||
def test_min_cut_partitioner_save_shape(self):
|
||||
|
||||
def f(x):
|
||||
|
|
@ -1960,7 +1948,6 @@ class TestPartitioning(AOTTestCase):
|
|||
|
||||
@patch("functorch.compile.config.use_dynamic_shapes", True)
|
||||
@patch("functorch.compile.config.use_fake_tensor", True)
|
||||
@skipIfNoSympy
|
||||
def test_default_partitioner_output_tensor_shape_tensor(self):
|
||||
|
||||
inp = [
|
||||
|
|
@ -2025,7 +2012,6 @@ class TestPartitioning(AOTTestCase):
|
|||
@patch("functorch.compile.config.use_dynamic_shapes", True)
|
||||
@patch("functorch.compile.config.use_fake_tensor", True)
|
||||
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
|
||||
@skipIfNoSympy
|
||||
def test_min_cut_partitioner_output_tensor_shape_tensor(self):
|
||||
|
||||
inp = [
|
||||
|
|
@ -2695,7 +2681,6 @@ class TestEagerFusionOpInfo(AOTTestCase):
|
|||
_test_aot_autograd_helper(self, device, dtype, op)
|
||||
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
@skipIfNoSympy
|
||||
@patch("functorch.compile.config.use_dynamic_shapes", True)
|
||||
@patch("functorch.compile.config.use_fake_tensor", True)
|
||||
@patch("functorch.compile.config.use_functionalize", True)
|
||||
|
|
@ -2742,7 +2727,6 @@ class TestEagerFusionModuleInfo(AOTTestCase):
|
|||
_test_aot_autograd_module_helper(self, device, dtype, training, module_info)
|
||||
|
||||
@modules(module_db, allowed_dtypes=(torch.float,))
|
||||
@skipIfNoSympy
|
||||
@patch("functorch.compile.config.use_dynamic_shapes", True)
|
||||
@patch("functorch.compile.config.use_fake_tensor", True)
|
||||
@patch("functorch.compile.config.use_functionalize", True)
|
||||
|
|
|
|||
|
|
@ -12,14 +12,7 @@ from torch.fx.experimental.rewriter import RewritingTracer
|
|||
from torch.fx import GraphModule
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
try:
|
||||
import sympy
|
||||
HAS_SYMPY = True
|
||||
except ImportError:
|
||||
HAS_SYMPY = False
|
||||
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
|
||||
import sympy
|
||||
|
||||
|
||||
try:
|
||||
|
|
@ -813,7 +806,6 @@ class TypeCheckerTest(TestCase):
|
|||
if n.op == 'output':
|
||||
assert is_consistent(n.type, TensorType(b.size()))
|
||||
|
||||
@skipIfNoSympy
|
||||
@skipIfNoTorchVision
|
||||
def test_resnet50(self):
|
||||
gm_run = symbolic_trace(resnet50())
|
||||
|
|
@ -860,7 +852,6 @@ class TypeCheckerTest(TestCase):
|
|||
batch_sizes.add(n.type.__args__[0])
|
||||
assert (len(batch_sizes) == 1)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_type_check_batch_norm_symbolic(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
|
||||
|
|
@ -892,7 +883,6 @@ class TypeCheckerTest(TestCase):
|
|||
for n in graph.nodes:
|
||||
assert n.type == next(my_types)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symbolic_add_with_broadcast(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))):
|
||||
|
|
@ -921,7 +911,6 @@ class TypeCheckerTest(TestCase):
|
|||
for n in symbolic_traced.graph.nodes:
|
||||
assert n.type == next(expected_iter)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symbolic_add_with_broadcast_2(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))):
|
||||
|
|
@ -943,7 +932,6 @@ class TypeCheckerTest(TestCase):
|
|||
for n in symbolic_traced.graph.nodes:
|
||||
assert n.type == next(expected_iter)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_type_check_conv2D_types(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self, inplanes, planes, stride=1):
|
||||
|
|
@ -971,7 +959,6 @@ class TypeCheckerTest(TestCase):
|
|||
assert isinstance(n.type.__args__[2], sympy.floor)
|
||||
assert isinstance(n.type.__args__[3], sympy.floor)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self):
|
||||
|
||||
class BasicBlock(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -5,14 +5,14 @@ from torch._C import _disabled_torch_function_impl
|
|||
import torch.fx
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase, skipIfTorchDynamo, \
|
||||
IS_WINDOWS, parametrize, instantiate_parametrized_tests
|
||||
import unittest
|
||||
parametrize, instantiate_parametrized_tests
|
||||
import torch
|
||||
import operator
|
||||
import itertools
|
||||
import contextlib
|
||||
import math
|
||||
import copy
|
||||
import sympy
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.fx.experimental import symbolic_shapes
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
|
@ -24,15 +24,6 @@ from torch import SymBool, SymInt, SymFloat, sym_int
|
|||
|
||||
aten = torch.ops.aten
|
||||
|
||||
try:
|
||||
import sympy
|
||||
# TODO(jansel): these tests fail on windows
|
||||
HAS_SYMPY = not IS_WINDOWS
|
||||
except ImportError:
|
||||
HAS_SYMPY = False
|
||||
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
|
||||
|
||||
|
||||
meta_funcs = {}
|
||||
|
||||
|
||||
|
|
@ -135,7 +126,6 @@ def create_symint(shape_env, i: int):
|
|||
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
|
||||
class TestPySymInt(TestCase):
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_arith_ops(self):
|
||||
shape_env = ShapeEnv()
|
||||
symints = []
|
||||
|
|
@ -150,7 +140,6 @@ class TestPySymInt(TestCase):
|
|||
self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
|
||||
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_reverse_arith_ops(self):
|
||||
shape_env = ShapeEnv()
|
||||
|
||||
|
|
@ -161,7 +150,6 @@ class TestPySymInt(TestCase):
|
|||
self.assertTrue(5 * a == 5 * 2)
|
||||
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_roundtrip(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
|
|
@ -187,7 +175,6 @@ class TestPySymInt(TestCase):
|
|||
self.assertTrue(isinstance(y.storage_offset(), SymInt))
|
||||
self.assertTrue(y.storage_offset() == 12)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_binary(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
|
|
@ -205,7 +192,6 @@ class TestPySymInt(TestCase):
|
|||
self.assertTrue(z.shape[1] == 4)
|
||||
self.assertTrue(z.shape[2] == 3)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_args(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
|
|
@ -222,7 +208,6 @@ class TestPySymInt(TestCase):
|
|||
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
|
||||
self.assertTrue(z.shape[2] == 2)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_vargs(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
||||
|
|
@ -267,13 +252,11 @@ class TestPySymInt(TestCase):
|
|||
z = y.expand((y.shape[1],))
|
||||
z = y.expand(y.shape[1])
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_stride(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
|
||||
self.assertIsInstance(x.stride()[0], SymInt)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_size_expressions(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
|
|
@ -289,7 +272,6 @@ class TestPySymInt(TestCase):
|
|||
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
|
||||
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_numel(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
|
|
@ -300,14 +282,12 @@ class TestPySymInt(TestCase):
|
|||
self.assertIsInstance(x.numel(), int)
|
||||
self.assertIsInstance(torch.numel(x), int)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_int_to_float(self):
|
||||
shape_env = ShapeEnv()
|
||||
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
||||
r = sym_float(x.shape[0])
|
||||
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_aten_ops(self):
|
||||
|
||||
shape_env = ShapeEnv()
|
||||
|
|
@ -330,21 +310,18 @@ class TestPySymInt(TestCase):
|
|||
# tuple of ints, not tuple
|
||||
torch.fx.symbolic_trace(m)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_meta_symint(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
r = torch.empty(a0, device='meta')
|
||||
self.assertIsInstance(r.shape[0], SymInt)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_guard_int(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
self.assertEqual(guard_int(a0), 2)
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_sym_int(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 5)
|
||||
|
|
@ -371,7 +348,6 @@ class TestPySymInt(TestCase):
|
|||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[3][0]), """Eq(2*s2, 6)""")
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_sym_sqrt(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 4)
|
||||
|
|
@ -380,7 +356,6 @@ class TestPySymInt(TestCase):
|
|||
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(sqrt(s0), 2)""")
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_sym_floor(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 5)
|
||||
|
|
@ -393,7 +368,6 @@ class TestPySymInt(TestCase):
|
|||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_sym_ceil(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 5)
|
||||
|
|
@ -407,26 +381,22 @@ class TestPySymInt(TestCase):
|
|||
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
|
||||
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_int_conversion(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_data_dependent_guard(self):
|
||||
shape_env = ShapeEnv()
|
||||
s0 = shape_env.create_unbacked_symint()
|
||||
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_non_overlapping_and_dense(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 5)
|
||||
r = torch.empty_strided((a0, 7), (1, a0), device='meta')
|
||||
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_symint_as_scalar(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
|
|
@ -450,7 +420,6 @@ class TestPySymInt(TestCase):
|
|||
|
||||
self.assertTrue(sym_int_encountered)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_deepcopy(self):
|
||||
shape_env = ShapeEnv()
|
||||
a0 = create_symint(shape_env, 2)
|
||||
|
|
@ -458,7 +427,6 @@ class TestPySymInt(TestCase):
|
|||
new_shape_env = copy.deepcopy(shape_env)
|
||||
self.assertEqual(len(new_shape_env.guards), 1)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_print_readable_with_symints(self):
|
||||
def f(a, b):
|
||||
dim0 = a.shape[0] + b.shape[0]
|
||||
|
|
@ -650,7 +618,6 @@ class TestFloorDiv(TestCase):
|
|||
yield (x, -y)
|
||||
yield (-x, -y)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_floordiv_float_int(self):
|
||||
values = (
|
||||
(2.5, 2.1),
|
||||
|
|
@ -664,7 +631,6 @@ class TestFloorDiv(TestCase):
|
|||
for x, y in TestFloorDiv.yield_test_cases(values):
|
||||
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_floordiv_bool(self):
|
||||
values = (
|
||||
(False, True),
|
||||
|
|
@ -685,7 +651,6 @@ class TestFloorDiv(TestCase):
|
|||
rf", expected integer or real"),
|
||||
lambda: TestFloorDiv.torch_floordiv(x, y))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_floordiv_complex(self):
|
||||
values = (
|
||||
(1.5 + 2.5j, 1.3 + 3.5j),
|
||||
|
|
@ -706,7 +671,6 @@ class TestFloorDiv(TestCase):
|
|||
rf", expected integer or real"),
|
||||
lambda: TestFloorDiv.torch_floordiv(x, y))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_floordiv_div_by_zero(self):
|
||||
values = (
|
||||
(2.5, 0),
|
||||
|
|
@ -724,7 +688,6 @@ class TestFloorDiv(TestCase):
|
|||
"division by zero",
|
||||
lambda: TestFloorDiv.torch_floordiv(x, y))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_floordiv_zero_base(self):
|
||||
values = (
|
||||
(0, 2.5),
|
||||
|
|
@ -738,7 +701,6 @@ class TestFloorDiv(TestCase):
|
|||
else:
|
||||
self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_floordiv_div_by_one(self):
|
||||
values = (
|
||||
(2.5, 1),
|
||||
|
|
@ -750,7 +712,6 @@ class TestFloorDiv(TestCase):
|
|||
for x, y in TestFloorDiv.yield_test_cases(values):
|
||||
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_floordiv_simplify(self):
|
||||
# Tests how we simplify or evaluate FloorDiv without free variables
|
||||
shape_env = ShapeEnv()
|
||||
|
|
@ -770,7 +731,6 @@ class TestFloorDiv(TestCase):
|
|||
self.assertEqual(shape_env.simplify(expr), result)
|
||||
self.assertEqual(shape_env.evaluate_expr(expr), result)
|
||||
|
||||
@skipIfNoSympy
|
||||
def test_floordiv_assumptions(self):
|
||||
# We define two Symbols (with different names) for each type to make
|
||||
# sure the behavior is consistent regardless of whether both arguments
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Owner(s): ["module: ProxyTensor"]
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, xfail_inherited_tests
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, xfail_inherited_tests
|
||||
import torch
|
||||
import unittest
|
||||
import warnings
|
||||
|
|
@ -27,13 +27,6 @@ import itertools
|
|||
|
||||
aten = torch.ops.aten
|
||||
|
||||
try:
|
||||
import sympy # noqa: F401
|
||||
# TODO(jansel): these tests fail on windows
|
||||
HAS_SYMPY = not IS_WINDOWS
|
||||
except ImportError:
|
||||
HAS_SYMPY = False
|
||||
skipIfNoSympy = unittest.skipIf(not HAS_SYMPY, "no sympy")
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
|
||||
|
||||
|
|
@ -735,7 +728,6 @@ class TestGenericProxyTensorFake(TestGenericProxyTensor):
|
|||
tracing_mode = "fake"
|
||||
|
||||
|
||||
@skipIfNoSympy
|
||||
@xfail_inherited_tests([
|
||||
"test_make_fx_overloads",
|
||||
"test_trace_subclasses",
|
||||
|
|
@ -812,7 +804,6 @@ def _trace(f, *args):
|
|||
return make_fx(f, tracing_mode="symbolic")(*inps)
|
||||
|
||||
# TODO: Need to test the guards themselves specifically as well
|
||||
@skipIfNoSympy
|
||||
class TestSymbolicTracing(TestCase):
|
||||
def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True):
|
||||
"""
|
||||
|
|
@ -1489,14 +1480,12 @@ class TestProxyTensorOpInfo(TestCase):
|
|||
def test_make_fx_fake_exhaustive(self, device, dtype, op):
|
||||
_test_make_fx_helper(self, device, dtype, op, "fake")
|
||||
|
||||
@skipIfNoSympy
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
|
||||
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures)
|
||||
def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
|
||||
_test_make_fx_helper(self, device, dtype, op, "symbolic")
|
||||
|
||||
@skipIfNoSympy
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace',
|
||||
make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import random
|
|||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch import SymInt
|
||||
|
|
@ -126,7 +127,6 @@ class DebugInterpreter(fx.Interpreter):
|
|||
super().run(*args)
|
||||
|
||||
def run_node(self, n):
|
||||
import sympy
|
||||
|
||||
def subst_symint(ni):
|
||||
if not isinstance(ni, SymInt):
|
||||
|
|
|
|||
|
|
@ -8,12 +8,7 @@ from typing import Callable, Generic, List, NamedTuple, Optional, Set, TypeVar
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# TODO(voz): Stolen pattern, not sure why this is the case,
|
||||
# but mypy complains.
|
||||
try:
|
||||
import sympy # type: ignore[import]
|
||||
except ImportError:
|
||||
log.warning("No sympy found")
|
||||
import sympy
|
||||
|
||||
"""
|
||||
torch._guards is the definitional source of truth for general purpose guard structures.
|
||||
|
|
|
|||
|
|
@ -11,12 +11,7 @@ import itertools
|
|||
|
||||
from torch.fx.experimental.unification import Var # type: ignore[attr-defined]
|
||||
|
||||
|
||||
try:
|
||||
import sympy # type: ignore[import]
|
||||
HAS_SYMPY = True
|
||||
except ImportError:
|
||||
HAS_SYMPY = False
|
||||
import sympy
|
||||
|
||||
_INFERENCE_RULES: Dict[Target, Callable] = {}
|
||||
_REFINEMENT_RULES: Dict[Target, Callable] = {}
|
||||
|
|
@ -305,7 +300,7 @@ def calculate_out_dimension(d_in, module_instance, index):
|
|||
dilation = (module_instance.dilation, module_instance.dilation) \
|
||||
if isinstance(module_instance.dilation, int) else module_instance.dilation
|
||||
|
||||
DIMENSION_TYPES = (int, sympy.Symbol) if HAS_SYMPY else (int,)
|
||||
DIMENSION_TYPES = (int, sympy.Symbol)
|
||||
|
||||
if d_in == Dyn:
|
||||
return Dyn
|
||||
|
|
@ -814,18 +809,15 @@ class Refine:
|
|||
"""
|
||||
Replace all unknown types with fresh type variables.
|
||||
"""
|
||||
if HAS_SYMPY:
|
||||
if isinstance(typ, Var):
|
||||
return sympy.symbols(str(typ))
|
||||
elif isinstance(typ, TensorType):
|
||||
new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__]
|
||||
return TensorType(tuple(new_args))
|
||||
elif isinstance(typ, list):
|
||||
return [self.convert_to_sympy_symbols(t) for t in typ]
|
||||
elif isinstance(typ, tuple):
|
||||
return (self.convert_to_sympy_symbols(t) for t in typ)
|
||||
else:
|
||||
return typ
|
||||
if isinstance(typ, Var):
|
||||
return sympy.symbols(str(typ))
|
||||
elif isinstance(typ, TensorType):
|
||||
new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__]
|
||||
return TensorType(tuple(new_args))
|
||||
elif isinstance(typ, list):
|
||||
return [self.convert_to_sympy_symbols(t) for t in typ]
|
||||
elif isinstance(typ, tuple):
|
||||
return (self.convert_to_sympy_symbols(t) for t in typ)
|
||||
else:
|
||||
return typ
|
||||
|
||||
|
|
@ -865,29 +857,26 @@ class Refine:
|
|||
pass
|
||||
|
||||
def infer_symbolic_relations(self, n: Node):
|
||||
if HAS_SYMPY:
|
||||
n.type = self.convert_to_sympy_symbols(n.type)
|
||||
if n.op == 'call_function':
|
||||
if n.target in _RULES:
|
||||
return _RULES[n.target](n)
|
||||
else:
|
||||
pass
|
||||
|
||||
if n.op == 'call_module':
|
||||
module_instance = self.traced.get_submodule(n.target)
|
||||
if type(module_instance) in _RULES:
|
||||
return _RULES[type(module_instance)](n, module_instance)
|
||||
else:
|
||||
pass
|
||||
|
||||
if n.op == 'output':
|
||||
def get_node_type(a):
|
||||
return a.type
|
||||
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
|
||||
return n.type
|
||||
|
||||
n.type = self.convert_to_sympy_symbols(n.type)
|
||||
if n.op == 'call_function':
|
||||
if n.target in _RULES:
|
||||
return _RULES[n.target](n)
|
||||
else:
|
||||
pass
|
||||
|
||||
if n.op == 'call_module':
|
||||
module_instance = self.traced.get_submodule(n.target)
|
||||
if type(module_instance) in _RULES:
|
||||
return _RULES[type(module_instance)](n, module_instance)
|
||||
else:
|
||||
pass
|
||||
|
||||
if n.op == 'output':
|
||||
def get_node_type(a):
|
||||
return a.type
|
||||
n.type = torch.fx.node.map_arg(n.args[0], get_node_type)
|
||||
return n.type
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -25,14 +25,9 @@ log = logging.getLogger(__name__)
|
|||
class GuardOnDataDependentSymNode(RuntimeError):
|
||||
pass
|
||||
|
||||
try:
|
||||
import sympy # type: ignore[import]
|
||||
from sympy.printing.precedence import precedence # type: ignore[import] # noqa: F401
|
||||
from sympy.printing.str import StrPrinter # type: ignore[import]
|
||||
from sympy.core.logic import fuzzy_and, fuzzy_or # type: ignore[import]
|
||||
HAS_SYMPY = True
|
||||
except ImportError:
|
||||
HAS_SYMPY = False
|
||||
import sympy
|
||||
from sympy.printing.str import StrPrinter
|
||||
from sympy.core.logic import fuzzy_and, fuzzy_or
|
||||
|
||||
aten = torch._ops.ops.aten # type: ignore[has-type]
|
||||
|
||||
|
|
@ -408,7 +403,7 @@ class SymNode:
|
|||
return self.guard_bool("", 0)
|
||||
|
||||
|
||||
if HAS_SYMPY:
|
||||
if True: # TODO: unindent
|
||||
# Overloaded to be compatible with regular Python.
|
||||
# https://github.com/pytorch/pytorch/issues/90900
|
||||
class Pow(sympy.Function):
|
||||
|
|
@ -886,7 +881,7 @@ def _lru_cache(fn, maxsize=None):
|
|||
return wrapper
|
||||
|
||||
|
||||
if HAS_SYMPY:
|
||||
if True: # TODO: unindent
|
||||
# This stub exists so we can easily add metadata to sympy symbols
|
||||
# NB: This inherits from Dummy, not Symbol, because Symbols with the same
|
||||
# name get interned. This is bad for us as we want the metadata
|
||||
|
|
@ -1040,9 +1035,6 @@ class ShapeEnv:
|
|||
def create_symbol(self, val: int, source: Source) -> "sympy.Expr":
|
||||
assert isinstance(source, Source), f"{type(source)} {source}"
|
||||
|
||||
if not HAS_SYMPY:
|
||||
raise RuntimeError("Need sympy installed to create symbolic shapes")
|
||||
|
||||
if val < 0:
|
||||
from torch._dynamo.source import NegateSource
|
||||
return -self.create_symbol(-val, NegateSource(source))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import dataclasses
|
||||
import itertools
|
||||
import sympy # type: ignore[import]
|
||||
import sympy
|
||||
import operator
|
||||
import math
|
||||
import logging
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user