pytorch/test/test_dynamic_shapes.py
Avik Chaudhuri ebc7039bcb New export API with dynamic shape specifications instead of constraints (#108448)
Our experience using `constraints` / `dynamic_dim` with the existing export API has found it to be (subjectively) clunky and (objectively) verbose in common cases.

This PR implements a new design for the export API that replaces the use of `constraints` / `dynamic_dim` with a new way of specifying dynamic shapes, involving the following concepts:
* a constructor `Dim` for first-class named dynamic dimensions with ranges (similar to `functorch.dim`, and analogous to internal symbolic sizes)
* a mechanism that uses the above in `export` calls to associate inputs to their dynamic shape specifications (`dynamic_shapes`)

Design doc: https://docs.google.com/presentation/d/168U7XK72C_WSsZpGESP6Cho9udh193fi0gfjxCNcJ4E/edit#slide=id.p (Meta-only). Note that we only implement Option 1 in that doc. An older version of this PR also implemented Option 3, which is an alternative way of specifying dynamic shapes using tensor type annotations on the exported callable; but we have moved that to future work for now.

See docs for these new features in `torch.export`. The existing `torch.export.export` is modified to use the new API, `torch._export.export__RC__`, whenever `constraints=None`. We have not deprecated the existing API yet, but will do in a follow-up.

Constraint violation errors arising through use of the new API will now contain suggested fixes using the new API. No longer do we need to report all specializations for static dimensions and suggest all constraints over dynamic dimensions to fix such errors. Instead, due to the redesign, the suggested fixes are much more concise, only involving modifying the definitions of relevant `Dim`s.

Differential Revision: [D48919204](https://our.internmc.facebook.com/intern/diff/D48919204/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108448
Approved by: https://github.com/suo, https://github.com/gmagogsfm
2023-09-22 06:58:26 +00:00

2014 lines
70 KiB
Python

# Owner(s): ["oncall: jit"]
import contextlib
import copy
import itertools
import inspect
import math
import operator
import re
import sympy
import torch
import torch.fx
import torch.nn.functional as F
from torch import sym_int, SymBool, SymFloat, SymInt
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental import symbolic_shapes
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import (
DimConstraints,
DimDynamic,
expect_true,
guard_bool,
guard_float,
guard_int,
GuardOnDataDependentSymNode,
ShapeEnv,
sym_float,
sym_sqrt,
SymNode,
to_node,
is_symbolic,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
from torch.utils._sympy.functions import FloorDiv, Mod
aten = torch.ops.aten
meta_funcs = {}
def register_meta(op):
def decorator(f):
def add_func(op):
meta_funcs[op] = f
tree_map(add_func, op)
return f
return decorator
@register_meta([aten.add.Tensor, aten.sub.Tensor])
def binary_meta(a, b):
return a.new_empty(a.shape)
@register_meta(aten.cat.default)
def cat_meta(tensors, dim=0):
concat_length = 0
shape = tensors[0].shape
for tensor in tensors:
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
if idx == dim:
concat_length = concat_length + length
else:
assert length == common_length
new_shape = list(shape)
new_shape[dim] = concat_length
return tensors[0].new_empty(new_shape)
@register_meta([aten.narrow_copy.default])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.shape):
if i == dim:
shape.append(length)
else:
shape.append(x)
return a.new_empty(tuple(shape))
@register_meta([aten.expand.default])
def expand_symint_meta(a, size, implicit=False):
return a.new_empty(size)
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
class FakeSymbolicTensor(torch.Tensor):
@staticmethod
def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0):
# TODO: this is wrong in general
sym_stride = create_contiguous(sym_shape)
r = torch.Tensor._make_wrapper_subclass(
cls, sym_shape,
sym_stride, storage_offset,
dtype=dtype, layout=layout, requires_grad=requires_grad,
device=device,
)
return r
__torch_function__ = _disabled_torch_function_impl
def new_empty(self, shape):
return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device)
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
if func_overload in meta_funcs:
return meta_funcs[func_overload](*args, **kwargs)
if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device)
raise RuntimeError(f"operator {func_overload} not supported")
def create_symbolic_tensor(name, arg, shape_env):
from torch._dynamo.source import ConstantSource
constraint_dims = [None] * arg.dim()
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
sym_shapes, sym_strides, sym_storage_offset = \
shape_env.create_symbolic_sizes_strides_storage_offset(
arg,
source=ConstantSource(name),
dynamic_dims=dynamic_dims,
constraint_dims=constraint_dims
)
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset)
def create_symtype(cls, pytype, shape_env, val):
from torch._dynamo.source import ConstantSource
symbol = shape_env.create_symbol(
val,
source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"),
dynamic_dim=DimDynamic.DUCK,
constraint_dim=None,
)
return cls(SymNode(
symbol,
shape_env,
pytype,
hint=val,
))
def create_symint(shape_env, i: int):
return create_symtype(SymInt, int, shape_env, i)
def create_symbool(shape_env, b: bool):
return create_symtype(SymBool, bool, shape_env, b)
def create_symfloat(shape_env, f: float):
return create_symtype(SymFloat, float, shape_env, f)
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestPySymInt(TestCase):
def test_arith_ops(self):
shape_env = ShapeEnv()
symints = []
for i in range(2, 5):
symints.append((i, create_symint(shape_env, i)))
ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]
for op in ops:
for args in itertools.permutations(symints, 2):
if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0):
self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
def test_reverse_arith_ops(self):
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
self.assertTrue(5 // a == 5 // 2)
a = create_symint(shape_env, 2)
self.assertTrue(5 * a == 5 * 2)
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.shape[0], SymNode))
self.assertTrue(isinstance(x.shape[0], SymInt))
self.assertTrue(x.shape[0] == 5)
self.assertTrue(x.shape[1] == 4)
self.assertTrue(x.shape[2], 3)
self.assertTrue(x.size()[0], 5)
self.assertTrue(x.size()[1], 4)
# Should be simplifiable to an integer.
# Ref: https://github.com/pytorch/pytorch/pull/107492
self.assertTrue(isinstance(x.size()[1], SymInt))
self.assertTrue(isinstance(x.size()[1].node.maybe_as_int(), int)) # due to guard above
self.assertTrue(x.size()[2] == 3)
self.assertTrue(x.size(0) == 5)
self.assertTrue(x.size(1) == 4)
self.assertTrue(x.size(2) == 3)
self.assertTrue(isinstance(x.size(2), SymInt))
self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int))
y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env)
self.assertTrue(isinstance(y.storage_offset(), SymInt))
self.assertTrue(y.storage_offset() == 12)
def test_binary(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# broadcasting
y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
def test_symint_args(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == y.shape[2])
# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == 2)
# arithmetic expr with a symint and python int
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
self.assertTrue(z.shape[2] == 2)
def test_symint_vargs(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
# varargs
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# shape list
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(x.shape[0], y.shape[1], 3)
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints in a list
z = y.expand((x.shape[0], y.shape[1], 3))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(5, y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python ints and symints in a list
z = y.expand((5, y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
z = y.expand((y.shape[1],))
z = y.expand(y.shape[1])
def test_stride(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
self.assertIsInstance(x.stride()[0], SymInt)
def test_size_expressions(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
expand_x = x.expand(x.shape[0], x.shape[0])
if expand_x.shape[0] > 3:
result = expand_x + expand_x
else:
result = expand_x + expand_x
gt_op, _bt = shape_env.guards[-1]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
def test_numel(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
self.assertIsInstance(x.numel(), torch.SymInt)
self.assertIsInstance(torch.numel(x), torch.SymInt)
x = torch.rand(3, 3)
self.assertIsInstance(x.numel(), int)
self.assertIsInstance(torch.numel(x), int)
def test_int_to_float(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
r = sym_float(x.shape[0])
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
def test_aten_ops(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
shape_env = ShapeEnv()
x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env)
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):
def forward(self, x):
bs, c, h, w = x.shape
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
m = CustomModule()
x = torch.rand(1, 3, 4, 4)
# should not TypeError: pad(): argument 'pad' (position 2) must be
# tuple of ints, not tuple
torch.fx.symbolic_trace(m)
def test_meta_symint(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device='meta')
self.assertIsInstance(r.shape[0], SymInt)
def test_guard_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(a0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
def test_sym_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = sym_int(a0)
self.assertEqual(r, 5)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""")
a1 = create_symint(shape_env, 7)
r = sym_int(a1 / 2)
self.assertEqual(guard_int(r), 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(floor(s1/2), 3)""")
a3 = create_symint(shape_env, 3)
r = sym_int(2.0 * sym_float(a3))
self.assertEqual(guard_int(r), 6)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[2][0]), """Eq(2*s2, 6)""")
def test_sym_sqrt(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 4)
r = sym_sqrt(a0)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(sqrt(s0), 2)""")
def test_sym_floor(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.floor(a0 / 2)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""")
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
def test_sym_ceil(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.ceil(a0 / 2)
self.assertEqual(r, 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""")
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
def test_int_conversion(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
int(a0)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
def test_data_dependent_guard(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
def test_expect_true_basic(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
# This doesn't error
self.assertTrue(expect_true(i0 == 0))
# This generates a deferred runtime assert
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
"""[Eq(i0, 0)]"""
)
self.assertIn("test_dynamic_shapes.py", shape_env.deferred_runtime_asserts[i0.node.expr][0].msg)
# After expecting true, guards now resolve given the runtime assert
bool(i0 == 0)
def test_expect_true_with_s0(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 5)
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(i0 <= s0))
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
"""[i0 <= s0]"""
)
self.assertTrue(i0 <= s0)
self.assertFalse(i0 > s0)
def test_expect_true_prefer_later(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i1 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(i0 + i1 == 10))
# Importantly, this is put in i1, not i0!
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1.node.expr]]),
"""[Eq(i0 + i1, 10)]"""
)
self.assertTrue(i0 + i1 == 10)
# NB: We currently don't support deriving that we can substitute
# i0 + i1 with 10; maybe we should, but this means our rewriting
# system is no longer confluent (it's probably OK though, because
# you're unlikely to get other equalities like this on the
# unbacked SymInts.)
def test_expect_true_double_digits(self):
shape_env = ShapeEnv()
ia = [shape_env.create_unbacked_symint() for _ in range(11)] # allocate 10
self.assertEqual(str(ia[-1]), "i10")
self.assertTrue(expect_true(sum(ia) == 20))
self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1)
def test_non_overlapping_and_dense(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = torch.empty_strided((a0, 7), (1, a0), device='meta')
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
def test_specialize_zero_one(self):
shape_env = ShapeEnv(specialize_zero_one=True)
a0 = create_symint(shape_env, 5)
assert a0 != 1
self.assertEqual(len(shape_env.guards), 0)
shape_env = ShapeEnv(specialize_zero_one=False)
a0 = create_symint(shape_env, 5)
assert a0 != 1
self.assertEqual(len(shape_env.guards), 1)
def test_duck_shape(self):
shape_env = ShapeEnv(duck_shape=True)
a0 = create_symint(shape_env, 5)
a1 = create_symint(shape_env, 5)
assert a0 == a1
self.assertEqual(len(shape_env.guards), 0)
shape_env = ShapeEnv(duck_shape=False)
a0 = create_symint(shape_env, 5)
a1 = create_symint(shape_env, 5)
assert a0 == a1
self.assertEqual(len(shape_env.guards), 1)
def test_int_bool(self):
# See https://github.com/pytorch/pytorch/issues/95981
shape_env = ShapeEnv(duck_shape=True)
a0 = create_symint(shape_env, 5)
assert a0
self.assertEqual(len(shape_env.guards), 0)
def test_symint_as_scalar(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
sym_int_encountered = False
class TestSymInt(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
assert func == torch.ops.aten.add.Tensor
nonlocal sym_int_encountered
# WARNING: do not do identity tests on the outer
# SymInt/SymFloat, they are NOT STABLE
sym_int_encountered = kwargs["alpha"].node is a0.node
kwargs["alpha"] = 0
return func(*args)
x = torch.rand([4, 4])
with TestSymInt():
y = torch.add(x, x, alpha=a0)
self.assertTrue(sym_int_encountered)
def test_deepcopy(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
assert a0 < 4
new_shape_env = copy.deepcopy(shape_env)
self.assertEqual(len(new_shape_env.guards), 1)
def test_print_readable_with_symints(self):
def f(a, b):
dim0 = a.shape[0] + b.shape[0]
dim1 = a.shape[1] + b.shape[1]
d = a.new_empty(dim0, dim1)
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
return d
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
out = fx_g.print_readable(print_output=False)
self.assertExpectedInline(out.strip(), """\
class f(torch.nn.Module):
def forward(self, a_1: f32[s0, s1], b_1: f32[s2, s1]):
# No stacktrace found for following nodes
sym_size: Sym(s0) = torch.ops.aten.sym_size(a_1, 0)
sym_size_1: Sym(s2) = torch.ops.aten.sym_size(b_1, 0)
add: Sym(s0 + s2) = sym_size + sym_size_1; sym_size = sym_size_1 = None
sym_size_2: Sym(s1) = torch.ops.aten.sym_size(a_1, 1)
sym_size_3: Sym(s1) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
add_1: Sym(2*s1) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
new_empty: f32[s0 + s2, 2*s1] = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: f32[s0 + s2, 2*s1] = native_dropout[0]
getitem_1: b8[s0 + s2, 2*s1] = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""") # noqa: B950
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestSymNumberMagicMethods(TestCase):
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
# Helper function
# NB: don't use one as that will get specialized
seed_node = (create_symint(shape_env, 2) / 2.).node
bool_seed_node = (create_symint(shape_env, 2) == 2).node
def get_sym_inp(inp):
# NB: this must come before int
if isinstance(inp, bool):
return torch.SymBool(to_node(bool_seed_node, inp))
elif isinstance(inp, int):
return torch.SymInt(to_node(seed_node, inp))
else:
return torch.SymFloat(to_node(seed_node, inp))
def maybe_xfail(inp1, inp2):
if fn == "sym_sqrt" and inp1 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
elif fn in ("truediv", "floordiv", "mod") and inp2 == 0:
# ZeroDivisionError: division by zero
return self.assertRaises((ZeroDivisionError,))
elif fn == "pow" and inp1 == 0 and inp2 < 0:
# ZeroDivisionError: 0.0 cannot be raised to a negative power
return self.assertRaises((ZeroDivisionError,))
elif fn == "pow" and inp1 < 0 and inp2 in (2.5, -2.5) and (
type(inp1) in (SymFloat, SymInt) or
type(inp2) in (SymFloat, SymInt)
):
# Complex result, which we do not support:
# TypeError: Cannot convert complex to float
return self.assertRaises((TypeError,))
elif fn in ("lshift", "rshift") and not (
isinstance(inp1, (SymInt, int)) and
isinstance(inp2, (SymInt, int))
):
# TypeError: unsupported operand type(s)
return self.assertRaises((TypeError,))
elif fn in ("lshift", "rshift") and inp2 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
else:
return contextlib.nullcontext()
if fn in symbolic_shapes.magic_methods_on_math:
lambda_apply = getattr(math, fn)
elif fn in symbolic_shapes.magic_methods_on_submodule:
lambda_apply = getattr(symbolic_shapes, fn)
elif fn in symbolic_shapes.magic_methods_on_operator_with_trailing_underscore:
lambda_apply = getattr(operator, f"{fn}_")
else:
lambda_apply = getattr(operator, fn)
def guard_fn(v):
if type(v) in (SymBool, bool):
return guard_bool(v)
elif type(v) in (SymFloat, float):
return guard_float(v)
else: # SymInt, int
return guard_int(v)
# Get reference result
with maybe_xfail(inp1, inp2):
if is_unary_fn:
ref_out = lambda_apply(inp1)
else:
ref_out = lambda_apply(inp1, inp2)
# Symified first arg
sym_inp1 = get_sym_inp(inp1)
with maybe_xfail(sym_inp1, inp2):
if is_unary_fn:
out = lambda_apply(sym_inp1)
else:
out = lambda_apply(sym_inp1, inp2)
if fn not in symbolic_shapes.alternate_impl_if_hinted_methods:
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
if is_unary_fn:
return
# Symified second arg
sym_inp2 = get_sym_inp(inp2)
with maybe_xfail(inp1, sym_inp2):
out = lambda_apply(inp1, sym_inp2)
if fn not in symbolic_shapes.alternate_impl_if_hinted_methods:
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
# Symified both args
with maybe_xfail(sym_inp1, sym_inp2):
out = lambda_apply(sym_inp1, sym_inp2)
if fn not in symbolic_shapes.alternate_impl_if_hinted_methods:
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
@parametrize("fn", list(symbolic_shapes.magic_methods.keys()))
def test_bool_method(self, fn):
if fn not in symbolic_shapes.bool_magic_methods:
self.skipTest(f"{fn} is non-bool")
is_unary_fn = fn in symbolic_shapes.unary_magic_methods
shape_env = ShapeEnv()
self._do_test(fn, True, False, shape_env, is_unary_fn)
@parametrize("fn", list(symbolic_shapes.magic_methods.keys()))
@parametrize("first_type", ["int", "float"])
@parametrize("second_type", ["int", "float"])
def test_method(self, fn, first_type, second_type):
if first_type == "float":
# TODO: Hmm, this looks like we skip all floats
self.skipTest(f"{fn} is not a float magic method")
is_unary_fn = fn in symbolic_shapes.unary_magic_methods
# Second argument is ignored for unary function. So only run for one type
if is_unary_fn and second_type == "float":
self.skipTest(f"{fn} is unary and already tested")
if fn in symbolic_shapes.bool_magic_methods:
self.skipTest(f"{fn} is bool")
# Only floats here since these will be converted to int if necessary.
# We also ignore complex and bool.
values = (
0.0,
1.0,
2.5,
)
neg_values = tuple(-x for x in values)
for inp1, inp2 in itertools.chain(
itertools.product(values, values),
itertools.product(values, neg_values),
itertools.product(neg_values, values),
itertools.product(neg_values, neg_values),
):
if first_type == "int":
inp1 = int(inp1)
if second_type == "int":
inp2 = int(inp2)
shape_env = ShapeEnv()
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
def get_constant_bool(self, val):
return SymBool(torch._C._get_constant_bool_symnode(val))
def test_symnode_hashing(self):
shape_env = ShapeEnv()
# SymInt, SymBool, SymFloat are unhashable
unhashable = (
create_symint(shape_env, 3),
create_symbool(shape_env, True),
create_symfloat(shape_env, 4.2),
)
for x in unhashable:
with self.assertRaisesRegex(TypeError, "unhashable"):
hash(x)
# Singleton SymInt, constant SymBool, SymNode are hashable
j1 = torch._C._get_singleton_int(1)
j1_copy = torch._C._get_singleton_int(1)
j2 = torch._C._get_singleton_int(2)
t = self.get_constant_bool(True)
t_copy = self.get_constant_bool(True)
f = self.get_constant_bool(False)
n = create_symint(shape_env, 3).node
m = self.get_constant_bool(True).node
self.assertIs(j1 == j1_copy, True)
self.assertEqual(hash(j1), hash(j1_copy))
self.assertIs(j1 == j2, False)
self.assertNotEqual(hash(j1), hash(j2))
self.assertIs(t == t_copy, True)
self.assertEqual(hash(t), hash(t_copy))
self.assertIs(t == f, False)
self.assertNotEqual(hash(t), hash(f))
hash(n)
hash(m)
def test_non_symbolic_symnode(self):
j1 = torch._C._get_singleton_int(1)
j2 = torch._C._get_singleton_int(1)
j3 = torch._C._get_singleton_int(3)
self.assertIsInstance(j1, torch.SymInt)
self.assertNotIsInstance(j1, int)
with self.assertRaisesRegex(RuntimeError, "add not supported by SingletonSymNode"):
j1 + 3
self.assertFalse(j1 == 3)
self.assertFalse(3 >= j2)
self.assertIs(j1 == j1, True)
self.assertIs(j1 == j2, True)
self.assertIs(j1 == j3, False)
self.assertIs(j1 != j3, True)
self.assertIs(j1 != j2, False)
x = self.get_constant_bool(True)
#
# Unary
#
# op(constant SymBool)
self.assertIs(x.__sym_not__(), False)
#
# Binary
#
# op(constant SymBool, bool)
# op(constant SymBool, constant SymBool)
# op(bool, constant SymBool)
self.assertIs(operator.and_(x, True), True)
self.assertIs(operator.and_(x, x), True)
self.assertIs(operator.and_(True, x), True)
# op(symbolic SymBool, constant Symbool)
# op(constant SymBool, symbolic Symbool)
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
b = create_symint(shape_env, 2)
c = a == b # symbolic SymBool
d = self.get_constant_bool(True)
e = operator.and_(c, d)
f = operator.and_(d, c)
self.assertTrue(is_symbolic(e))
self.assertTrue(is_symbolic(f))
self.assertIs(e.node.guard_bool("", 0), True)
self.assertIs(f.node.guard_bool("", 0), True)
# Comparing sizes
sz1 = torch.Size([j1, j1, j1])
sz2 = torch.Size([j1, j1, j1])
self.assertIs(sz1 == sz2, True)
sz1 = torch.Size([3, j1, 4])
sz2 = torch.Size([3, j2, 4])
self.assertIs(sz1 == sz2, True)
self.assertIs(sz1 != sz2, False)
instantiate_parametrized_tests(TestSymNumberMagicMethods)
class TestFloorDiv(TestCase):
@staticmethod
def python_floordiv(x, y):
return x // y
@staticmethod
def torch_floordiv(x, y):
# Note: we fully evaluate here since FloorDiv might not always do
# that.
shape_env = ShapeEnv()
return shape_env.evaluate_expr(FloorDiv(x, y))
@staticmethod
def yield_test_cases(values, negate=True):
for x, y in values:
yield (x, y)
if negate:
yield (-x, y)
yield (x, -y)
yield (-x, -y)
def test_floordiv_float_int(self):
values = (
(2.5, 2.1),
(2.1, 2.5),
(2.0, 2.1),
(7, 2.5),
(2.1, 7),
(7, 2),
)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_bool(self):
values = (
(False, True),
(True, 2.5),
(2.5, True),
(False, 7),
(7, True),
)
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
# Compares to int since our FloorDiv has no bool support
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(int(x), int(y)))
# Tests that our impl throws
self.assertRaisesRegex(
TypeError,
(rf"unsupported operand type\(s\) for //: "
rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'"
rf", expected integer or real"),
lambda: TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_complex(self):
values = (
(1.5 + 2.5j, 1.3 + 3.5j),
(1.5 + 2.5j, 2.5),
(2.5, 1.5 + 2.5j),
(1.5 + 2.5j, 7),
(7, 1.5 + 2.5j),
)
for x, y in TestFloorDiv.yield_test_cases(values):
# We don't test error messages to avoid depending on Python
# interpreter version
self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y))
self.assertRaisesRegex(
TypeError,
(rf"unsupported operand type\(s\) for //: "
rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'"
rf", expected integer or real"),
lambda: TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_div_by_zero(self):
values = (
(2.5, 0),
(2.1, 0.0),
(2.3, sympy.Symbol("s", zero=True)),
)
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
# We don't test error messages to avoid depending on Python
# interpreter version
if type(y) is not sympy.Symbol:
self.assertRaises(ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y))
self.assertRaisesRegex(
ZeroDivisionError,
"division by zero",
lambda: TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_zero_base(self):
values = (
(0, 2.5),
(0.0, 2.1),
(sympy.Symbol("s", zero=True), 2.3),
)
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
if type(x) is not sympy.Symbol:
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
else:
self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_div_by_one(self):
values = (
(2.5, 1),
(2.1, 1.0),
(2, 1.0),
(2, 1),
)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_simplify(self):
# Tests how we simplify or evaluate FloorDiv without free variables
shape_env = ShapeEnv()
result = 21
exprs = (
7 * FloorDiv(6, 2),
7 * FloorDiv(6.28, 2),
7 * FloorDiv(6.28, 2.0),
7 * FloorDiv(6.28, (FloorDiv(6.28, 3.14))),
)
for expr in exprs:
self.assertEqual(expr, result)
self.assertEqual(expr.doit(deep=False), result)
self.assertEqual(expr.doit(deep=True), result)
self.assertEqual(sympy.simplify(expr), result)
self.assertEqual(shape_env.simplify(expr), result)
self.assertEqual(shape_env.evaluate_expr(expr), result)
def test_floordiv_simplify_rational(self):
result = 21
a = sympy.Symbol("a", integer=True)
b = sympy.Symbol("b")
cases = [
(FloorDiv(a, sympy.Rational(1, 8)), 8 * a),
(FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)),
]
for expr, expected in cases:
self.assertEqual(expr, expected)
def test_floordiv_assumptions(self):
# We define two Symbols (with different names) for each type to make
# sure the behavior is consistent regardless of whether both arguments
# are the same object or not.
cases = (
sympy.Symbol("i1", integer=True),
sympy.Symbol("i2", integer=True),
sympy.Symbol("r1", real=True),
sympy.Symbol("r2", real=True),
sympy.Symbol("c1", complex=True, real=False, integer=False),
sympy.Symbol("c2", complex=True, real=False, integer=False),
sympy.Symbol("s1"),
sympy.Symbol("s2"),
)
for base, divisor in itertools.product(cases, repeat=2):
def op():
return FloorDiv(base, divisor)
def is_complex(x):
return x.is_integer is False and x.is_real is False and x.is_complex
if is_complex(base) or is_complex(divisor):
self.assertRaisesRegex(
TypeError,
(r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
r" expected integer or real"),
op)
continue
op = op()
# In regular Python, x//x == 1.0 if x is a float, but FloorDiv
# always returns an integer 1 when both args are the same object.
# This even works for Symbols with no assumptions specified.
if base is divisor:
self.assertTrue(op.is_integer)
self.assertTrue(op.is_real)
elif base.is_integer and divisor.is_integer:
self.assertTrue(op.is_integer)
self.assertTrue(op.is_real)
else:
self.assertEqual(op.is_integer, None)
self.assertTrue(op.is_real)
class TestDimConstraints(TestCase):
def test_dim_constraints_reduce_congruences_simple(self):
from sympy import Symbol
from torch.fx.experimental.symbolic_shapes import DimConstraints
s = Symbol("s", positive=True, integer=True)
dim_constraints = DimConstraints({}, {}, set(), {})
dim_constraints._congruences[s] = {
(s / 2) % 2,
(s / 2) % 8,
(s / 2) % 4,
s % 2,
((s / 16) + 2) % 4,
}
congruences = dim_constraints.reduce_congruences()
self.assertEqual(congruences[s], {(s + 32) % 64})
def test_dim_constraints_reduce_inequalities_simple(self):
from sympy import Eq, Interval, Ne, Symbol
from sympy.solvers.inequalities import reduce_inequalities
s = Symbol("s", positive=True, integer=True)
exprs = {
s >= 2,
Ne(8 * s, 16),
Ne(s / 2, 1),
Ne(16 * s, 32),
s < 16,
Ne(s, 2),
s / 2 < 16,
s / 2 > 1,
s / 2 >= 2,
Ne(3 * s / 2, 3),
}
solution = reduce_inequalities(exprs, s).as_set()
self.assertEqual(solution, Interval.Ropen(4, 16))
exprs.add(Eq(s / 2, 4))
solution = reduce_inequalities(exprs, s).as_set()
self.assertEqual(solution, {8})
def test_dim_constraints_solve_full(self):
from sympy import Eq, Integer, Ne, Symbol
from torch._dynamo.source import LocalSource, TensorProperty, TensorPropertySource
src0 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
)
src2 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
)
src3 = TensorPropertySource(
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
)
src4 = TensorPropertySource(
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
)
src1 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
)
src7 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
)
src5 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
)
src8 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
)
src6 = TensorPropertySource(
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
)
src9 = TensorPropertySource(
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
)
src10 = TensorPropertySource(
base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
)
src11 = TensorPropertySource(
base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
)
src12 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
)
s0 = Symbol("s0", positive=True, integer=True)
s1 = Symbol("s1", positive=True, integer=True)
s5 = Symbol("s5", positive=True, integer=True)
s6 = Symbol("s6", positive=True, integer=True)
symbol_to_source = {
s0: [src0, src2, src3, src4],
s1: [src1, src7],
s5: [src5, src8],
s6: [src6, src9, src10],
}
var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
marked_dynamic = {s0, s1, s5, s6}
dim_constraints = DimConstraints(symbol_to_source, var_to_val, marked_dynamic, {})
dim_constraints.add_equality(src2, s0)
dim_constraints.add_equality(src3, s0)
dim_constraints.add_equality(src4, s0)
dim_constraints.add_equality(src7, s1)
dim_constraints.add_equality(src8, s5)
dim_constraints.add_equality(src9, s6)
dim_constraints.add_equality(src10, s6)
dim_constraints.add_equality(src11, Integer(1))
dim_constraints.add_equality(src12, Integer(3))
dim_constraints.add(s1**2 <= 2147483647)
dim_constraints.add(32 * s1**2 <= 2147483647)
dim_constraints.add(s0 < 16)
dim_constraints.add(Eq(Mod(s1, 2), 0))
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1))
dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647)
dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1)
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 64
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 1
> 1
)
dim_constraints.add(
128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 128
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 1
> 1
)
dim_constraints.add(
256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 256
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 1
> 1
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
<= 2147483647
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1)
dim_constraints.add(
Ne(
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * s0,
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
1,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
0,
)
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
>= 0
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0))
dim_constraints.add(
1
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1))
dim_constraints.add(
Ne(
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * s0,
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 0
)
dim_constraints.add(
Eq(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2))
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2)
+ 60 * (Mod(s0, 2)),
0,
)
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
Ne(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
0,
)
)
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
dim_constraints.add(
Ne(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60,
0,
)
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
>= 0
)
dim_constraints.add(
1
< 60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
)
dim_constraints.add(Ne(16 * s0, 32))
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
dim_constraints.add(Ne(16 * s0, 32))
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
dim_constraints.add(FloorDiv(s0, 2) >= 2)
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
dim_constraints.add(1 < FloorDiv(s0, 2))
dim_constraints.add(Ne(s0, 2))
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
>= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2))))
> 0
)
dim_constraints.add(
Ne(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
)
)
dim_constraints.add(
Ne(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20,
0,
)
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 0
)
dim_constraints.add(
Ne(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20,
20,
)
)
dim_constraints.add(
Ne(
20
* (
Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
)
),
0,
)
)
dim_constraints.add(
Ne(
20
* (FloorDiv((FloorDiv(s1, 2) - 1), 8))
* (
Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
)
)
- 20
* Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
),
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
>= 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 0
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 2
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
> 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(
Ne(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60,
60,
)
)
dim_constraints.add(
Ne(
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
)
)
dim_constraints.add(
Eq(
(FloorDiv((FloorDiv(s1, 2) - 1), 8))
* (
Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1,
)
)
- Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1,
),
0,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
)
)
dim_constraints.add(Ne(8 * s0, 16))
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
>= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
)
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
<= 2147483647
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
<= 2147483647
)
dim_constraints.add(FloorDiv(s0, 2) < 16)
dim_constraints.add(FloorDiv(s0, 2) > 1)
dim_constraints.add(
Ne(
90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
> 1
)
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
> 1
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
> 1
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
> 1
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)),
3 * (FloorDiv(s0, 2)),
)
)
dim_constraints.add(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2))
> 0
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
> 0
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
1
< 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
6,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 0
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
<= 2147483647
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
<= 20480
)
dim_constraints.add(
Ne(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90,
0,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 1
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
<= 20480
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
<= 20480
)
dim_constraints.add(
Ne(
240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 240,
0,
)
)
dim_constraints.add(Eq(6 * s5, 132))
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4))
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
<= 2147483647
)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
> 1
)
dim_constraints.add(
62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62
<= 2147483647
)
dim_constraints.add(
Ne(
62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62
)
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576
<= 2147483647
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1)
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9,
1,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9,
0,
)
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
>= 0
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0))
dim_constraints.add(
1
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2)),
256,
)
)
dim_constraints.add(
Eq(
64
* (
Mod(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2)),
4,
)
),
0,
)
)
dim_constraints.add(
Eq(
FloorDiv(s0, 2),
FloorDiv(
(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2))
),
4,
),
)
)
dim_constraints.add(
Eq(
FloorDiv(
(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2))
),
4,
),
FloorDiv(s0, 2),
)
)
dim_constraints.add(Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0))
dim_constraints.add(
Eq(
64
* (
Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
4,
)
),
0,
)
)
dim_constraints.add(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2))
> 0
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
>= 1
)
dim_constraints.add(
Eq(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576,
256,
)
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
<= 2147483647
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
<= 2147483647
)
dim_constraints.add(
Ne(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
> 1
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
> 1
)
dim_constraints.add(s0 >= 2)
dim_constraints.add(s1 >= 2)
dim_constraints.add(s6 >= 2)
dim_constraints.add(s5 >= 2)
dim_constraints.solve()
dim_constraints.remove_redundant_dynamic_results()
self.assertEqual(dim_constraints._static_results, {
"L['c'].size()[0] == 8",
"L['d'].size()[0] == 8",
"L['a'].size()[2] == 96",
"L['f'].size()[1] == 1",
"L['a'].size()[3] == 96",
"L['b'].size()[2] == 3",
"L['b'].size()[1] == 22",
"L['b'].size()[0] == 8",
"L['a'].size()[1] == 22",
"L['a'].size()[0] == 8",
})
self.assertEqual(dim_constraints._dynamic_results, {
"dynamic_dim(L['e'], 1) == dynamic_dim(L['c'], 1)",
"dynamic_dim(L['d'], 1) == dynamic_dim(L['c'], 1)",
})
def dummy_fn(a, b, c, d, e, f):
pass
action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn))
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
expected_static = '''
def specializations(a, b, c, d, e, f):
# a:
assert a.size()[0] == 8
assert a.size()[1] == 22
assert a.size()[2] == 96
assert a.size()[3] == 96
# b:
assert b.size()[0] == 8
assert b.size()[1] == 22
assert b.size()[2] == 3
# c:
assert c.size()[0] == 8
# d:
assert d.size()[0] == 8
# f:
assert f.size()[1] == 1
'''
expected_dynamic = '''
def specify_constraints(a, b, c, d, e, f):
return [
# d:
dynamic_dim(d, 1) == dynamic_dim(c, 1),
# e:
dynamic_dim(e, 1) == dynamic_dim(c, 1),
]
'''
self.assertEqual(static_code, expected_static)
self.assertEqual(dynamic_code, expected_dynamic)
if __name__ == '__main__':
run_tests()