pytorch/test/test_dynamic_shapes.py
Laith Sakka adedf26e21 Support python slicing with tensor inputs. (#165074)
when the slice is tensor, we decompose it to .item() call and pass the unbacked symbol to the slice to avoid DDE.
the diff also fix an existing bug in codegen_dynamic_slice_size in the cpp wrapper.  a +1 should be -1 making it match
python codegen.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165074
Approved by: https://github.com/Lucaskabela
2025-10-29 01:18:45 +00:00

4410 lines
163 KiB
Python

# Owner(s): ["oncall: jit"]
# ruff: noqa: F841
import contextlib
import copy
import itertools
import math
import operator
import unittest
import numpy as np
import pytest
import sympy
import torch
import torch.fx
import torch.nn.functional as F
from torch import sym_int, SymBool, SymFloat, SymInt
from torch._C import _disabled_torch_function_impl
from torch._dynamo.testing import CompileCounter, CompileCounterWithBackend
from torch._inductor.utils import fresh_cache
from torch.fx.experimental import sym_node
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
DimConstraints,
DimDynamic,
expect_true,
guard_bool,
guard_float,
guard_int,
GuardOnDataDependentSymNode,
has_free_symbols,
hint_int,
is_symbolic,
ShapeEnv,
StatelessSymbolicContext,
statically_known_false,
statically_known_true,
)
from torch.testing._internal.common_dtype import all_types_and
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfTorchDynamo,
TEST_XPU,
TestCase,
)
from torch.testing._internal.logging_utils import logs_to_string
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._sympy.functions import (
CleanDiv,
FloorDiv,
IsNonOverlappingAndDenseIndicator,
Mod,
)
aten = torch.ops.aten
meta_funcs = {}
def register_meta(op):
def decorator(f):
def add_func(op):
meta_funcs[op] = f
pytree.tree_map_(add_func, op)
return f
return decorator
@register_meta([aten.add.Tensor, aten.sub.Tensor])
def binary_meta(a, b):
return a.new_empty(a.shape)
@register_meta(aten.cat.default)
def cat_meta(tensors, dim=0):
concat_length = 0
shape = tensors[0].shape
for tensor in tensors:
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
if idx == dim:
concat_length = concat_length + length
else:
assert length == common_length
new_shape = list(shape)
new_shape[dim] = concat_length
return tensors[0].new_empty(new_shape)
@register_meta([aten.narrow_copy.default])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.shape):
if i == dim:
shape.append(length)
else:
shape.append(x)
return a.new_empty(tuple(shape))
@register_meta([aten.expand.default])
def expand_symint_meta(a, size, implicit=False):
return a.new_empty(size)
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
class FakeSymbolicTensor(torch.Tensor):
@staticmethod
def __new__(
cls,
sym_shape,
sym_strides,
dtype,
layout,
requires_grad,
device,
storage_offset=0,
):
# TODO: this is wrong in general
sym_stride = create_contiguous(sym_shape)
r = torch.Tensor._make_wrapper_subclass(
cls,
sym_shape,
sym_stride,
storage_offset,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device=device,
)
return r
__torch_function__ = _disabled_torch_function_impl
def new_empty(self, shape):
return FakeSymbolicTensor(
shape, None, self.dtype, self.layout, self.requires_grad, self.device
)
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
if func_overload in meta_funcs:
return meta_funcs[func_overload](*args, **kwargs)
if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
return FakeSymbolicTensor(
shape,
self.stride(),
self.dtype,
self.layout,
self.requires_grad,
self.device,
)
raise RuntimeError(f"operator {func_overload} not supported")
def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None):
from torch._dynamo.source import ConstantSource
if source is None:
source = ConstantSource(name)
constraint_dims = [None] * arg.dim()
if dynamic_dims is None:
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
(
sym_shapes,
sym_strides,
sym_storage_offset,
) = shape_env.create_symbolic_sizes_strides_storage_offset(
arg,
source=source,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims
),
)
return FakeSymbolicTensor(
sym_shapes,
sym_strides,
arg.dtype,
arg.layout,
arg.requires_grad,
arg.device,
sym_storage_offset,
)
def create_fake_tensor_with_dynamic_size(x, shape_env, dynamic_sizes, dynamic_strides):
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode(shape_env=shape_env) as fake_mode:
return fake_mode.from_tensor(
x,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=dynamic_sizes,
dynamic_strides=dynamic_strides,
),
)
def create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs):
from torch._dynamo.source import ConstantSource
symbol = shape_env.create_symbol(
val,
source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"),
dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC,
constraint_dim=None,
**kwargs,
)
return cls(SymNode(symbol, shape_env, pytype, hint=val))
# TODO: default duck to False
def create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt:
return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs)
def create_symbool(shape_env, b: bool) -> SymBool:
return create_symtype(SymBool, bool, shape_env, b)
def create_symfloat(shape_env, f: float) -> SymFloat:
return create_symtype(SymFloat, float, shape_env, f)
@skipIfTorchDynamo(
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
)
class TestPySymInt(TestCase):
def test_arith_ops(self):
shape_env = ShapeEnv()
symints = []
for i in range(2, 5):
symints.append((i, create_symint(shape_env, i)))
ops = [
operator.add,
operator.sub,
operator.floordiv,
operator.mul,
operator.mod,
]
for op in ops:
for args in itertools.permutations(symints, 2):
if not isinstance(args[0][1], int) and (
(op != operator.mod or op != operator.floordiv) and args[1][0] != 0
):
self.assertTrue(
op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])
)
def test_reverse_arith_ops(self):
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
self.assertTrue(5 // a == 5 // 2)
a = create_symint(shape_env, 2)
self.assertTrue(5 * a == 5 * 2)
def test_sympify_symint(self):
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
self.assertIs(sympy.sympify(a), a.node.expr)
b = create_symfloat(shape_env, 3.0)
self.assertIs(sympy.sympify(b), b.node.expr)
c = create_symbool(shape_env, True)
self.assertIs(sympy.sympify(c), c.node.expr)
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.shape[0], SymNode))
self.assertTrue(isinstance(x.shape[0], SymInt))
self.assertTrue(x.shape[0] == 5)
self.assertTrue(x.shape[1] == 4)
self.assertTrue(x.shape[2], 3)
self.assertTrue(x.size()[0], 5)
self.assertTrue(x.size()[1], 4)
# Should be simplifiable to an integer.
# Ref: https://github.com/pytorch/pytorch/pull/107492
self.assertTrue(isinstance(x.size()[1], SymInt))
self.assertTrue(
isinstance(x.size()[1].node.maybe_as_int(), int)
) # due to guard above
self.assertTrue(x.size()[2] == 3)
self.assertTrue(x.size(0) == 5)
self.assertTrue(x.size(1) == 4)
self.assertTrue(x.size(2) == 3)
self.assertTrue(isinstance(x.size(2), SymInt))
self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int))
y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env)
self.assertTrue(isinstance(y.storage_offset(), SymInt))
self.assertTrue(y.storage_offset() == 12)
def test_binary(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# broadcasting
y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
def test_symint_args(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == y.shape[2])
# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == 2)
# arithmetic expr with a symint and python int
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
self.assertTrue(z.shape[2] == 2)
def test_symint_vargs(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
# varargs
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# shape list
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(x.shape[0], y.shape[1], 3)
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints in a list
z = y.expand((x.shape[0], y.shape[1], 3))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(5, y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python ints and symints in a list
z = y.expand((5, y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
z = y.expand((y.shape[1],))
z = y.expand(y.shape[1])
def test_symint_bitwise_and(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 0b1100)
b0 = create_symint(shape_env, 0b1010)
res_and = a0 & b0
self.assertEqual(res_and, 0b1000)
self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s97, s26), 8)"""
)
a1 = create_symint(shape_env, 3)
b1 = create_symbool(shape_env, True)
self.assertEqual(a1 & b1, 1)
a2 = create_symint(shape_env, 0b1100)
self.assertEqual(a2 & 0b1010, 0b1000)
a3 = create_symbool(shape_env, True)
b3 = create_symbool(shape_env, True)
self.assertEqual(a3 & b3, True)
def test_symint_bitwise_or(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 0b1100)
b0 = create_symint(shape_env, 0b1010)
res_or = a0 | b0
self.assertEqual(res_or, 0b1110)
self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s97, s26), 14)"""
)
def test_stride(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
self.assertIsInstance(x.stride()[0], SymInt)
def test_size_expressions(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
expand_x = x.expand(x.shape[0], x.shape[0])
if expand_x.shape[0] > 3:
result = expand_x + expand_x
else:
result = expand_x + expand_x
gt_op, _bt, is_size_obv = shape_env.guards[-1]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
self.assertFalse(is_size_obv)
def test_floordiv_static(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 8)
# This was extracted from
# python test/inductor/test_cuda_cpp_wrapper.py -k
# DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
bool(s0 % 2 == 0)
bool(s0 % (s0 // 2) == 0)
bool(2 * (s0 // 2) == s0)
self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))
def test_numel(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
self.assertIsInstance(x.numel(), torch.SymInt)
self.assertIsInstance(torch.numel(x), torch.SymInt)
x = torch.rand(3, 3)
self.assertIsInstance(x.numel(), int)
self.assertIsInstance(torch.numel(x), int)
def test_int_to_float(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
r = torch.sym_float(x.shape[0])
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
def test_aten_ops(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
shape_env = ShapeEnv()
x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env)
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):
def forward(self, x):
bs, c, h, w = x.shape
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
m = CustomModule()
x = torch.rand(1, 3, 4, 4)
# should not TypeError: pad(): argument 'pad' (position 2) must be
# tuple of ints, not tuple
torch.fx.symbolic_trace(m)
def test_meta_symint(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device="meta")
self.assertIsInstance(r.shape[0], SymInt)
def test_guard_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(a0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
def test_sym_sum(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 2)
s1 = create_symint(shape_env, 3)
s2 = create_symint(shape_env, 4)
self.assertEqual(
(s0 + s1 + s2).node.expr, torch.sym_sum([s0, s1, s2]).node.expr
)
def test_prefer_deferred_runtime_assertions_over_guards(self):
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(s0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2)
self.assertTrue(expect_true(s0 == 2))
self.assertEqual(len(shape_env.guards), 0)
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
"""[Eq(s97, 2)]""",
)
def test_sym_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = sym_int(a0)
self.assertEqual(r, 5)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 5)""")
a1 = create_symint(shape_env, 7)
r = sym_int(a1 / 2)
self.assertEqual(guard_int(r), 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s26, 2)), 3)"""
)
a3 = create_symint(shape_env, 3)
r = sym_int(2.0 * torch.sym_float(a3))
self.assertEqual(guard_int(r), 6)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s57)), 6)"""
)
def test_sym_log2(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 4)
r = torch._sym_log2(a0)
self.assertEqual(r, 2.0)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s97)), 2.0)"""
)
def test_sym_sqrt(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 4)
r = torch._sym_sqrt(a0)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s97)), 2.0)"""
)
def test_sym_floor(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.floor(a0 / 2)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(FloorToInt(IntTrueDiv(s97, 2)), 2)""",
)
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
)
def test_sym_trunc(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.trunc(a0 / 2)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s97, 2)), 2)"""
)
r = torch.sym_int(torch.sym_sqrt(a0))
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s97))), 2)""",
)
def test_sym_ceil(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.ceil(a0 / 2)
self.assertEqual(r, 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(CeilToInt(IntTrueDiv(s97, 2)), 3)""",
)
r1 = 3.0 * a0
r = math.floor(r1)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
)
def test_sym_ite(self):
shape_env = ShapeEnv()
t = create_symint(shape_env, 5)
f = create_symint(shape_env, 4)
b1 = True
r1 = torch.sym_ite(b1, t, f)
self.assertTrue(r1 is t)
b2 = False
r2 = torch.sym_ite(b2, t, f)
self.assertTrue(r2 is f)
b3 = t == 5
r3 = torch.sym_ite(b3, t, f)
self.assertEqual(len(shape_env.guards), 0)
self.assertEqual(r3, 5)
self.assertEqual(type(t), type(r3))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(Piecewise((s97, Eq(s97, 5)), (s26, True)), 5)""",
)
b4 = f == 5
r4 = torch.sym_ite(b4, t, f)
self.assertEqual(len(shape_env.guards), 1)
self.assertEqual(r4, 4)
self.assertEqual(type(f), type(r4))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(Piecewise((s97, Eq(s26, 5)), (s26, True)), 4)""",
)
def test_tracing_sym_ite(self):
def f(x):
b = x.shape[0] == 5
ret = torch.sym_ite(b, x.shape[0], x.shape[1])
return ret
gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5))
self.assertEqual(len(gm.shape_env.guards), 0)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 5
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None
return sym_ite""",
)
r1 = gm(torch.ones(4, 5))
self.assertIsInstance(r1, int)
self.assertEqual(r1, 5)
r2 = gm(torch.ones(5, 4))
self.assertIsInstance(r2, int)
self.assertEqual(r2, 5)
def test_int_conversion(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
int(a0)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
def test_data_dependent_guard(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
def test_data_dependent_guard_propagate_real_tensors(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
shape_env.set_unbacked_var_to_val(s0.node.expr, 0)
self.assertEqual(bool(s0 == 0), True)
def test_expect_true_basic(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i0_sym = i0.node.expr
# This doesn't error
self.assertTrue(expect_true(i0 == 0))
# This generates a deferred runtime assert via replacement
self.assertEqual(shape_env.replacements[i0_sym], 0)
# After expecting true, guards now resolve given the runtime assert
bool(i0 == 0)
def test_expect_true_with_s0(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 5)
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(i0 < s0))
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
"""[u0 < s97]""",
)
self.assertTrue(i0 < s0)
self.assertTrue(i0 != s0)
self.assertFalse(i0 > s0)
self.assertFalse(i0 >= s0)
def test_expect_true_prefer_later(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i1 = shape_env.create_unbacked_symint()
i1_sym = i1.node.expr
self.assertTrue(expect_true(i0 + i1 == 10))
# Importantly, this is put in i1, not i0!
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]),
"""[Eq(u0 + u1, 10)]""",
)
self.assertTrue(i0 + i1 == 10)
# NB: We currently don't support deriving that we can substitute
# i0 + i1 with 10; maybe we should, but this means our rewriting
# system is no longer confluent (it's probably OK though, because
# you're unlikely to get other equalities like this on the
# unbacked SymInts.)
def test_unbacked_substitution(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i1 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i0)
_constrain_range_for_size(i1)
self.assertTrue(expect_true(i0 == i1 * 4))
self.assertExpectedInline(str(i0), """u0""")
i2 = shape_env.create_unbacked_symint()
i3 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i2)
_constrain_range_for_size(i3)
self.assertTrue(expect_true(i2 * 4 == i3))
self.assertExpectedInline(str(i3), """u3""")
def test_avoid_unbacked_substitution(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i0)
i1 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i1)
self.assertTrue(expect_true(i0 == 10 - i1))
self.assertExpectedInline(str(i0), """u0""")
def test_expect_true_double_digits(self):
shape_env = ShapeEnv()
ia = [shape_env.create_unbacked_symint() for _ in range(11)] # allocate 10
self.assertEqual(str(ia[-1]), "u10")
self.assertTrue(expect_true(sum(ia) == 20))
self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1)
def test_expect_true_refine_range(self):
shape_env = ShapeEnv()
for i, rel in enumerate(
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
):
with self.subTest(f"i = {i}"):
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(rel(i0)))
self.assertTrue(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertFalse(statically_known_true(i0 != 5))
self.assertFalse(statically_known_true(i0 != 6))
self.assertTrue(statically_known_true(i0 > 4))
self.assertTrue(statically_known_true(i0 >= 5))
for i, rel in enumerate(
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
):
with self.subTest(f"i = {i}"):
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(rel(i0)))
self.assertFalse(statically_known_true(i0 != 2))
self.assertFalse(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertTrue(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 < 4))
self.assertTrue(statically_known_true(i0 <= 5))
def test_guard_refine_range(self):
shape_env = ShapeEnv()
for i, rel in enumerate(
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 10, duck=False)
self.assertTrue(bool(rel(i0)))
self.assertTrue(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertFalse(statically_known_true(i0 != 5))
self.assertFalse(statically_known_true(i0 != 6))
self.assertTrue(statically_known_true(i0 > 4))
self.assertTrue(statically_known_true(i0 >= 5))
for i, rel in enumerate(
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 2, duck=False)
self.assertFalse(bool(rel(i0)))
self.assertFalse(statically_known_true(i0 != 3))
self.assertFalse(statically_known_true(i0 != 4))
self.assertTrue(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 != 6))
self.assertTrue(statically_known_true(i0 <= 4))
self.assertTrue(statically_known_true(i0 < 5))
for i, rel in enumerate(
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 2, duck=False)
self.assertTrue(bool(rel(i0)))
self.assertFalse(statically_known_true(i0 != 2))
self.assertFalse(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertTrue(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 < 4))
self.assertTrue(statically_known_true(i0 <= 3))
for i, rel in enumerate(
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 10, duck=False)
self.assertFalse(bool(rel(i0)))
self.assertTrue(statically_known_true(i0 != 2))
self.assertTrue(statically_known_true(i0 != 3))
self.assertFalse(statically_known_true(i0 != 4))
self.assertFalse(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 >= 4))
self.assertTrue(statically_known_true(i0 > 3))
def test_mul_int_oo_nan(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 5, duck=False)
s1 = create_symint(shape_env, 6, duck=False)
s2 = create_symint(shape_env, 5, duck=False)
bool(s0 * (s1 // s0) == s2)
def test_non_overlapping_and_dense_backed(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = torch.empty_strided((a0, 7), (1, a0), device="meta")
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
def test_non_overlapping_and_dense_unbacked(self):
shape_env = ShapeEnv()
u0 = shape_env.create_unbacked_symint()
cf = torch.ops.aten.is_non_overlapping_and_dense.default
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1)
self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1)
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1)
self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1)
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta")))
Max = torch.sym_max
# NB: This only works because we're able to determine this tensor is
# contiguous. transpose(0, 1) makes it stop working
self.assertTrue(
cf(
torch.empty_strided(
(2, 3, 1, u0),
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
device="meta",
)
)
)
def test_prims_non_overlapping_and_dense(self):
shape_env = ShapeEnv()
cf = torch._prims_common.is_non_overlapping_and_dense
# backed case
a0 = create_symint(shape_env, 5)
self.assertTrue(cf(torch.empty_strided((a0, 7), (1, a0), device="meta")))
# unbacked
u0 = shape_env.create_unbacked_symint()
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta")))
Max = torch.sym_max
self.assertTrue(
cf(
torch.empty_strided(
(2, 3, 1, u0),
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
device="meta",
)
)
)
self.assertFalse(
cf(
torch.empty_strided(
(2, 3, 1, u0),
(Max(1, u0), Max(1, u0), 1, 3 * Max(1, u0)),
device="meta",
)
)
)
# return False on arbitrary strides
u1 = shape_env.create_unbacked_symint()
self.assertFalse(
cf(
torch.empty_strided(
(2 * u0, u0, 1),
(u1, u0, u0 + u1),
device="meta",
)
)
)
self.assertFalse(
cf(
torch.empty_strided(
(2, 3, u0),
(u1, 3, 1),
device="meta",
)
)
)
def test_sympy_optimized_add_binary_search(self):
import sympy
from torch.fx.experimental.sym_node import _binary_search_insert_arg
a = sympy.Symbol("a")
b = sympy.Symbol("b")
c = sympy.Symbol("c")
args = []
args = _binary_search_insert_arg([], b)
self.assertEqual(args, [b])
self.assertEqual(_binary_search_insert_arg(args, b), None)
args = _binary_search_insert_arg(args, a)
self.assertEqual(args, [a, b])
self.assertEqual(_binary_search_insert_arg(args, b), None)
self.assertEqual(_binary_search_insert_arg(args, a), None)
args = _binary_search_insert_arg(args, c)
self.assertEqual(args, [a, b, c])
self.assertEqual(_binary_search_insert_arg(args, a), None)
self.assertEqual(_binary_search_insert_arg(args, b), None)
self.assertEqual(_binary_search_insert_arg(args, c), None)
a1 = sympy.Symbol("a1")
a2 = sympy.Symbol("a2")
args = _binary_search_insert_arg(args, a1)
self.assertEqual(args, [a, a1, b, c])
args = _binary_search_insert_arg(args, a2)
self.assertEqual(args, [a, a1, a2, b, c])
c1 = sympy.Symbol("c1")
args = _binary_search_insert_arg(args, c1)
self.assertEqual(args, [a, a1, a2, b, c, c1])
# insert to front
_a = sympy.Symbol("_a")
args = _binary_search_insert_arg(args, _a)
self.assertEqual(args, [_a, a, a1, a2, b, c, c1])
def test_floor_clean_div_axioms(self):
# Test that if we add an axiom that have FloorDiv, after which the
# shapeEnv changed such that it can be simplified it to CleanDiv, then
# We still correctly replace CleanDiv with the axiom value of FloorDiv.
shape_env = ShapeEnv()
a = shape_env.create_unbacked_symint()
shape_env.guard_or_defer_runtime_assert((a // 3 == 1).node.expr, " test")
from sympy import Eq
test1 = Eq(FloorDiv(a.node.expr, 3), 1)
test2 = Eq(CleanDiv(a.node.expr, 3), 1)
self.assertTrue(shape_env.evaluate_expr(test1))
self.assertEqual(shape_env._maybe_evaluate_static(test2), None)
# After this FloorDiv(a, 3) is simplified to CleanDiv(a, 3)
shape_env.guard_or_defer_runtime_assert(Eq(Mod(a, 3), 0), " test")
self.assertEqual(test2, shape_env.simplify(test1))
self.assertTrue(shape_env.evaluate_expr(test1))
self.assertTrue(shape_env.evaluate_expr(test2))
def test_sympy_optimized_add(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 2)
s1 = create_symint(shape_env, 3)
s2 = create_symint(shape_env, 4)
sum = s0 + s1
self.assertTrue(sum.node._optimized_summation)
def assert_optimized(sym):
self.assertTrue(sym.node._optimized_summation)
def assert_not_optimized(sym):
self.assertFalse(getattr(sym.node, "_optimized_summation", False))
assert_optimized(sum)
# add duplicate symbol
assert_not_optimized(sum + s0)
# add constant.
assert_not_optimized(sum + 1)
# add new unique symbol, should maintain _optimized_summation property.
assert_optimized(sum + s2)
assert_optimized(s2 + sum)
# add x + (a+b) with no _optimized_summation on the rhs sum.
a = create_symint(shape_env, 10)
b = create_symint(shape_env, 11)
two_sum = torch.sym_sum([a, b])
assert_not_optimized(two_sum)
assert_optimized(sum + two_sum)
# adding two expressions of length >2 that are _optimized_summation.
a = s0 + s1 + s2
s3 = create_symint(shape_env, 10)
s4 = create_symint(shape_env, 20)
s5 = create_symint(shape_env, 30)
b = s3 + s4 + s5
assert_optimized(a)
assert_optimized(b)
assert_not_optimized(a + b)
assert_not_optimized(b + a)
assert_not_optimized(b + a + b)
def test_max_of_unique_summation_opt(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
s1 = shape_env.create_unbacked_symint()
s2 = shape_env.create_unbacked_symint()
s3 = shape_env.create_unbacked_symint()
s4 = shape_env.create_unbacked_symint()
s5 = shape_env.create_unbacked_symint()
s7 = shape_env.create_unbacked_symint()
def assert_optimized(sym):
self.assertTrue(sym.node.expr.unique_summations_symbols is not None)
def assert_not_optimized(sym):
getattr(sym.node.expr, "unique_summations_symbols", None)
mx1 = torch.sym_max(s0, s1)
assert_not_optimized(mx1)
mx2 = torch.sym_max(s0 + s1, s2 + s3)
assert_optimized(mx2)
mx3 = torch.sym_max(mx2, s4 + s5)
assert_optimized(mx3)
assert_optimized(torch.sym_max(s4 + s5, mx2))
assert_not_optimized(torch.sym_max(mx3, s7))
assert_not_optimized(torch.sym_max(mx3, 10))
assert_not_optimized(torch.sym_max(mx3, s3 + s7))
assert_not_optimized(torch.sym_max(mx3, s7 * 2))
def test_sym_max_multi_max_simplify(self):
shape_env = ShapeEnv()
u0 = shape_env.create_unbacked_symint()
self.assertTrue(
statically_known_true(
torch.sym_max(1, torch.sym_max(257, u0)) == torch.sym_max(257, u0)
)
)
def test_numpy_sym_max(self):
self.assertEqual(torch.sym_max(np.int64(10), 12), 12)
self.assertEqual(torch.sym_max(np.int64(12), 10), 12)
self.assertEqual(torch.sym_max(np.int64(10), 12.5), 12.5)
self.assertEqual(torch.sym_max(np.int64(14), 12.5), 14.0)
self.assertEqual(torch.sym_max(np.float64(14.0), 12), 14.0)
self.assertEqual(torch.sym_max(np.float64(14.0), 16), 16.0)
def test_numpy_sym_min(self):
self.assertEqual(torch.sym_min(np.int64(10), 12), 10)
self.assertEqual(torch.sym_min(np.int64(12), 10), 10)
self.assertEqual(torch.sym_min(np.int64(10), 12.5), 10.0)
self.assertEqual(torch.sym_min(np.int64(14), 12.5), 12.5)
self.assertEqual(torch.sym_min(np.float64(14.0), 12), 12.0)
self.assertEqual(torch.sym_min(np.float64(14.0), 16), 14.0)
def test_debug_has_internal_overlap_unbacked(self):
shape_env = ShapeEnv()
u0 = shape_env.create_unbacked_symint()
cf = torch._debug_has_internal_overlap
self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 2)
Max = torch.sym_max
self.assertEqual(
cf(
torch.empty_strided(
(2, 3, 1, u0),
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
device="meta",
)
),
2,
)
# Wobbling these to zero is OK too
self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2)
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2)
def test_specialize_zero_one(self):
shape_env = ShapeEnv(specialize_zero_one=True)
a0 = create_symint(shape_env, 5)
assert a0 != 1
self.assertEqual(len(shape_env.guards), 0)
shape_env = ShapeEnv(specialize_zero_one=False)
a0 = create_symint(shape_env, 5)
assert a0 != 1
self.assertEqual(len(shape_env.guards), 1)
def test_duck_shape(self):
shape_env = ShapeEnv(duck_shape=True)
a0 = create_symint(shape_env, 5)
a1 = create_symint(shape_env, 5)
assert a0 == a1
self.assertEqual(len(shape_env.guards), 0)
shape_env = ShapeEnv(duck_shape=False)
a0 = create_symint(shape_env, 5)
a1 = create_symint(shape_env, 5)
assert a0 == a1
self.assertEqual(len(shape_env.guards), 1)
def test_int_bool(self):
# See https://github.com/pytorch/pytorch/issues/95981
shape_env = ShapeEnv(duck_shape=True)
a0 = create_symint(shape_env, 5)
assert a0
self.assertEqual(len(shape_env.guards), 0)
def test_symint_as_scalar(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
sym_int_encountered = False
class TestSymInt(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
assert func == torch.ops.aten.add.Tensor
nonlocal sym_int_encountered
# WARNING: do not do identity tests on the outer
# SymInt/SymFloat, they are NOT STABLE
sym_int_encountered = kwargs["alpha"].node is a0.node
kwargs["alpha"] = 0
return func(*args)
x = torch.rand([4, 4])
with TestSymInt():
y = torch.add(x, x, alpha=a0)
self.assertTrue(sym_int_encountered)
def test_deepcopy(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
assert a0 < 4
new_shape_env = copy.deepcopy(shape_env)
self.assertEqual(len(new_shape_env.guards), 1)
def test_print_readable_with_symints(self):
def f(a, b):
dim0 = a.shape[0] + b.shape[0]
dim1 = a.shape[1] + b.shape[1]
d = a.new_empty(dim0, dim1)
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
return d
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
out = fx_g.print_readable(print_output=False)
self.assertExpectedInline(
out.strip(),
"""\
class f(torch.nn.Module):
def forward(self, a_1: "f32[s75, s96]", b_1: "f32[s57, s96]"):
# No stacktrace found for following nodes
sym_size_int: "Sym(s75)" = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1: "Sym(s57)" = torch.ops.aten.sym_size.int(b_1, 0)
add: "Sym(s57 + s75)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
sym_size_int_2: "Sym(s96)" = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_3: "Sym(s96)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
add_1: "Sym(2*s96)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
new_empty: "f32[s57 + s75, 2*s96]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: "f32[s57 + s75, 2*s96]" = native_dropout[0]
getitem_1: "b8[s57 + s75, 2*s96]" = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""", # noqa: B950
)
def test_statically_known_true(self):
shape_env = ShapeEnv()
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
# Statically known true
self.assertTrue(statically_known_true(True))
self.assertTrue(statically_known_true(s2 == s2))
self.assertTrue(statically_known_true(s2 * s3 > s3))
self.assertTrue(statically_known_true(s3 * s4 > s4))
self.assertTrue(statically_known_true((s3 + s3) % 2 == 0))
# Statically known false
self.assertFalse(statically_known_true(False))
self.assertFalse(statically_known_true(s3 * s4 <= s4))
self.assertFalse(statically_known_true((s3 + s3) % 2 == 1))
# True for hints, but not known statically
self.assertFalse(statically_known_true(s2 + s2 == s4))
self.assertFalse(statically_known_true(s4 % s2 == 0))
self.assertFalse(statically_known_true(s2 != s3))
self.assertFalse(statically_known_true(s3 * s4 > s2))
# False for hints, but not known statically
self.assertFalse(statically_known_true(s2 == s3))
self.assertFalse(statically_known_true(s2 > s3))
self.assertFalse(statically_known_true(s3 + s3 == s4))
# No guards should be generated
self.assertEqual(len(shape_env.guards), 0)
def test_statically_known_false(self):
shape_env = ShapeEnv()
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
# Statically known true
self.assertFalse(statically_known_false(True))
self.assertFalse(statically_known_false(s2 == s2))
self.assertFalse(statically_known_false(s2 * s3 > s3))
self.assertFalse(statically_known_false(s3 * s4 > s4))
self.assertFalse(statically_known_false((s3 + s3) % 2 == 0))
# Statically known false
self.assertTrue(statically_known_false(False))
self.assertTrue(statically_known_false(s3 * s4 <= s4))
self.assertTrue(statically_known_false((s3 + s3) % 2 == 1))
# True for hints, but not known statically
self.assertFalse(statically_known_false(s2 + s2 == s4))
self.assertFalse(statically_known_false(s4 % s2 == 0))
self.assertFalse(statically_known_false(s2 != s3))
self.assertFalse(statically_known_false(s3 * s4 > s2))
# False for hints, but not known statically
self.assertFalse(statically_known_false(s2 == s3))
self.assertFalse(statically_known_false(s2 > s3))
self.assertFalse(statically_known_false(s3 + s3 == s4))
# No guards should be generated
self.assertEqual(len(shape_env.guards), 0)
def test_ephemeral_source_simplification(self):
from torch._dynamo.source import EphemeralSource
# For full robustness, ensure the ephemeral source symbols are simplified out regardless
# of construction order or check order.
for construct_ephemeral_first, x_first_in_check in itertools.product(
[False, True], [False, True]
):
shape_env = ShapeEnv()
shape = (5, 10)
dynamic_dims = [DimDynamic.DYNAMIC for _ in shape]
x = create_symbolic_tensor(
"x",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if construct_ephemeral_first else None),
dynamic_dims=dynamic_dims,
)
y = create_symbolic_tensor(
"y",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if not construct_ephemeral_first else None),
dynamic_dims=dynamic_dims,
)
t_with_ephemeral = x if construct_ephemeral_first else y
def _get_ephemeral_source_symbols(t):
return [
s.node.expr
for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),))
if isinstance(s, torch.SymInt)
and s.node.expr in shape_env.var_to_sources
and any(
source.is_ephemeral()
for source in shape_env.var_to_sources[s.node.expr]
)
]
# these checks should simplify out the ephemeral symbols, regardless of the
# ordering x == y or y == x
self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0)
if x_first_in_check:
torch._check(x.size() == y.size())
torch._check(x.stride() == y.stride())
torch._check(x.storage_offset() == y.storage_offset())
else:
torch._check(y.size() == x.size())
torch._check(y.stride() == x.stride())
torch._check(y.storage_offset() == x.storage_offset())
self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0)
def test_ephemeral_source_unified_with_non_ephemeral_source(self):
from torch._dynamo.source import EphemeralSource
for construct_ephemeral_first in (False, True):
shape_env = ShapeEnv()
shape = (5, 10)
# use duck sizing here to ensure symbol reuse across x and y
duck_dims = [DimDynamic.DUCK for _ in shape]
x = create_symbolic_tensor(
"x",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if construct_ephemeral_first else None),
dynamic_dims=duck_dims,
)
y = create_symbolic_tensor(
"y",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if not construct_ephemeral_first else None),
dynamic_dims=duck_dims,
)
# regardless of construction order, non-ephemeral sources should be preferred
# first in the var_to_sources list for potential guarding later on
for source_list in shape_env.var_to_sources.values():
self.assertFalse(source_list[0].is_ephemeral())
self.assertEqual(x.size(), y.size())
self.assertEqual(x.stride(), y.stride())
self.assertEqual(x.storage_offset(), y.storage_offset())
def test_tensor_factory_with_symint(self):
args = list(range(3))
expected = torch.tensor(args)
shape_env = ShapeEnv()
sym_args = [create_symint(shape_env, i) for i in args]
# test tensor factories
for dt in all_types_and(torch.half, torch.bfloat16):
res = torch.tensor(sym_args, dtype=dt)
self.assertEqual(res, expected, exact_dtype=False)
# test legacy tensor factories
legacy_ctors = [
torch.Tensor,
torch.LongTensor,
torch.DoubleTensor,
torch.FloatTensor,
torch.IntTensor,
torch.ShortTensor,
torch.HalfTensor,
torch.ByteTensor,
]
for Tensor in legacy_ctors:
res = Tensor(sym_args)
self.assertEqual(res, expected, exact_dtype=False)
def test_backed_size_oblivious_01_spec(self):
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
@torch.compile(dynamic=True, fullgraph=True)
def f(a, b):
if guard_size_oblivious(a.size(0) == 1):
return b * 10
else:
return b * 20
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
# always go to the >= 2 branch.
self.assertEqual(
f(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
)
@fresh_cache()
def test_slice_backed_size_oblivious(self):
@torch.compile(backend="inductor", fullgraph=True, dynamic=True)
def f(x):
return x[:5]
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
f(torch.randn(10, 10))
def test_baddbmm_symint(self):
from torch._subclasses.fake_tensor import FakeTensorMode
shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
B, M, K, N = [shape_env.create_unbacked_symint() for _ in range(4)]
with fake_mode:
A = torch.empty((B, M, K), device="meta")
Bmat = torch.empty((B, K, N), device="meta")
bias3 = torch.empty((B, M, N), device="meta")
_ = torch.baddbmm(bias3, A, Bmat)
@skipIfTorchDynamo(
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
)
class TestSymNumberMagicMethods(TestCase):
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn):
return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn)
def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn):
# Helper function
# NB: don't use one as that will get specialized
# TODO: We don't have to circuitously create the float, can just
# create a symfloat directly
seed_node = (create_symint(shape_env, 2) / 2.0).node
bool_seed_node = (create_symint(shape_env, 2) == 2).node
def get_sym_inp(inp):
# NB: this must come before int
if isinstance(inp, bool):
return torch.SymBool(to_node(bool_seed_node, inp))
elif isinstance(inp, int):
return torch.SymInt(to_node(seed_node, inp))
else:
return torch.SymFloat(to_node(seed_node, inp))
if fn == "float_pow":
if inp1 < 0:
return
if fn == "pow_by_natural":
if isinstance(inp1, float) or isinstance(inp2, float):
return
if inp2 < 0:
return
def maybe_xfail(inp1, inp2):
if fn == "sym_sqrt" and inp1 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
elif (
fn in ("float_truediv", "int_truediv", "int_floordiv", "mod")
and inp2 == 0
):
# ZeroDivisionError: division by zero
return self.assertRaises((ZeroDivisionError,))
elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0:
# ZeroDivisionError: 0.0 cannot be raised to a negative power
return self.assertRaises((ZeroDivisionError,))
elif (
# TODO: dear catastrophe waitress,
# this doesn't work
fn in ["float_pow", "pow_by_natural"]
and inp1 < 0
and (
type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat)
)
and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float))
):
# Complex result, which we do not support:
# TypeError: Cannot convert complex to float
return self.assertRaises((RuntimeError,))
elif fn in ("lshift", "rshift") and not (
isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int))
):
# TypeError: unsupported operand type(s)
return self.assertRaises((TypeError,))
elif fn in ("lshift", "rshift") and inp2 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
else:
return contextlib.nullcontext()
lambda_apply = method_to_operator(fn)
def guard_fn(v):
if type(v) in (SymBool, bool):
return guard_bool(v)
elif type(v) in (SymFloat, float):
return guard_float(v)
else: # SymInt, int
return guard_int(v)
# Get reference result
with maybe_xfail(inp1, inp2):
if is_unary_fn:
ref_out = lambda_apply(inp1)
else:
ref_out = lambda_apply(inp1, inp2)
# Symified first arg
sym_inp1 = get_sym_inp(inp1)
with maybe_xfail(sym_inp1, inp2):
if is_unary_fn:
out = lambda_apply(sym_inp1)
else:
out = lambda_apply(sym_inp1, inp2)
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
if is_unary_fn:
return
# Symified second arg
sym_inp2 = get_sym_inp(inp2)
with maybe_xfail(inp1, sym_inp2):
out = lambda_apply(inp1, sym_inp2)
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
# Symified both args
with maybe_xfail(sym_inp1, sym_inp2):
out = lambda_apply(sym_inp1, sym_inp2)
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
@parametrize("fn", list(sym_node.magic_methods.keys()))
def test_bool_method(self, fn):
# sym_ite has its own tests
if fn not in sym_node.bool_magic_methods or fn == "sym_ite":
self.skipTest(f"{fn} is non-bool")
is_unary_fn = fn in sym_node.unary_methods
shape_env = ShapeEnv()
self._do_test(fn, True, False, shape_env, is_unary_fn)
@parametrize("fn", list(sym_node.magic_methods.keys()))
@parametrize("first_type", ["int", "float"])
@parametrize("second_type", ["int", "float"])
def test_method(self, fn, first_type, second_type):
if first_type == "float":
# TODO: Hmm, this looks like we skip all floats
self.skipTest(f"{fn} is not a float magic method")
if (
first_type == "int" or second_type == "int"
) and fn in sym_node.only_float_magic_methods:
self.skipTest(f"{fn} is not an int method")
if second_type == "float" and fn in ["mod"]:
self.skipTest(f"{fn} only handles int")
if fn in sym_node.bitwise_ops and (first_type != "int" or second_type != "int"):
self.skipTest(f"{fn} is a bitwise op, only handles int")
is_unary_fn = fn in sym_node.unary_methods or fn == "round"
# Second argument is ignored for unary function. So only run for one type
if is_unary_fn and second_type == "float":
self.skipTest(f"{fn} is unary and already tested")
if fn in sym_node.bool_magic_methods:
self.skipTest(f"{fn} is bool")
# Only floats here since these will be converted to int if necessary.
# We also ignore complex and bool.
values = (
0.0,
1.0,
0.5 if fn in ("sym_acos", "sym_asin") else 2.5, # avoid math domain error
)
neg_values = tuple(-x for x in values)
for inp1, inp2 in itertools.chain(
itertools.product(values, values),
itertools.product(values, neg_values),
itertools.product(neg_values, values),
itertools.product(neg_values, neg_values),
):
if first_type == "int":
inp1 = int(inp1)
if second_type == "int":
inp2 = int(inp2)
shape_env = ShapeEnv()
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
def get_constant_bool(self, val):
return SymBool(torch._C._get_constant_bool_symnode(val))
@unittest.expectedFailure
def test_symint_hashing(self):
shape_env = ShapeEnv()
hash(create_symint(shape_env, 3))
def test_symnode_hashing(self):
shape_env = ShapeEnv()
# These all trigger specialization when hashed
hash(create_symbool(shape_env, True))
# We should be passing in float here, but create_symbol currently
# only supports int
hash(create_symfloat(shape_env, 3.0))
# NestedInt (SymInt), constant SymBool, SymNode are hashable
j1 = torch._C._get_nested_int(1, 1)
j1_copy = torch._C._get_nested_int(1, 1)
j2 = torch._C._get_nested_int(2, 1)
t = self.get_constant_bool(True)
t_copy = self.get_constant_bool(True)
f = self.get_constant_bool(False)
n = create_symint(shape_env, 3).node
m = self.get_constant_bool(True).node
self.assertIs(j1 == j1_copy, True)
self.assertEqual(hash(j1), hash(j1_copy))
self.assertIs(j1 == j2, False)
self.assertNotEqual(hash(j1), hash(j2))
self.assertIs(t == t_copy, True)
self.assertEqual(hash(t), hash(t_copy))
self.assertIs(t == f, False)
self.assertNotEqual(hash(t), hash(f))
hash(n)
hash(m)
def test_symint_deepcopy(self):
shape_env = ShapeEnv()
symnodes = (torch._C._get_nested_int(1, 1),)
deepcopied_symnodes = copy.deepcopy(symnodes)
self.assertEqual(symnodes, deepcopied_symnodes)
def test_non_symbolic_symnode(self):
j1 = torch._C._get_nested_int(1, 1)
j2 = torch._C._get_nested_int(1, 1)
j3 = torch._C._get_nested_int(3, 1)
self.assertIsInstance(j1, torch.SymInt)
self.assertNotIsInstance(j1, int)
with self.assertRaisesRegex(
RuntimeError, "add not supported by NestedIntSymNode"
):
j1 + 3
self.assertFalse(j1 == 3)
with self.assertRaisesRegex(RuntimeError, "indeterminate"):
self.assertFalse(3 >= j2)
self.assertIs(j1 == j1, True)
self.assertIs(j1 == j2, True)
self.assertIs(j1 == j3, False)
self.assertIs(j1 != j3, True)
self.assertIs(j1 != j2, False)
x = self.get_constant_bool(True)
#
# Unary
#
# op(constant SymBool)
self.assertIs(x.__sym_not__(), False)
#
# Binary
#
# op(constant SymBool, bool)
# op(constant SymBool, constant SymBool)
# op(bool, constant SymBool)
self.assertIs(operator.and_(x, True), True)
self.assertIs(operator.and_(x, x), True)
self.assertIs(operator.and_(True, x), True)
# op(symbolic SymBool, constant Symbool)
# op(constant SymBool, symbolic Symbool)
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
b = create_symint(shape_env, 2)
c = a == b # symbolic SymBool
d = self.get_constant_bool(True)
e = operator.and_(c, d)
f = operator.and_(d, c)
self.assertTrue(is_symbolic(e))
self.assertTrue(is_symbolic(f))
self.assertIs(e.node.guard_bool("", 0), True)
self.assertIs(f.node.guard_bool("", 0), True)
# Comparing sizes
sz1 = torch.Size([j1, j1, j1])
sz2 = torch.Size([j1, j1, j1])
self.assertIs(sz1 == sz2, True)
sz1 = torch.Size([3, j1, 4])
sz2 = torch.Size([3, j2, 4])
self.assertIs(sz1 == sz2, True)
self.assertIs(sz1 != sz2, False)
def test_stride_symnode(self):
shape_env = ShapeEnv()
# check everything static
t = create_fake_tensor_with_dynamic_size(
torch.ones(3, 6),
shape_env,
dynamic_sizes=[
DimDynamic.STATIC,
DimDynamic.STATIC,
],
dynamic_strides=[
DimDynamic.INFER_STRIDE,
DimDynamic.INFER_STRIDE,
],
)
self.assertTrue(all(isinstance(size, int) for size in t.size()))
self.assertTrue(all(isinstance(stride, int) for stride in t.stride()))
# check dynamic size but static dims
t = create_fake_tensor_with_dynamic_size(
torch.ones(3, 6),
shape_env,
dynamic_sizes=[
DimDynamic.DYNAMIC,
DimDynamic.DYNAMIC,
],
dynamic_strides=[
DimDynamic.INFER_STRIDE,
DimDynamic.INFER_STRIDE,
],
)
# Expect stride to be inferred
s0, s1 = t.size()
s2, s3 = t.stride()
self.assertTrue(isinstance(s0, torch.SymInt))
self.assertTrue(isinstance(s1, torch.SymInt))
self.assertTrue(isinstance(s2, torch.SymInt))
self.assertTrue(s1 == s2)
self.assertEqual(s3, 1)
# Check dynamic stride but static dims
t = create_fake_tensor_with_dynamic_size(
torch.ones(3, 6),
shape_env,
dynamic_sizes=[
DimDynamic.STATIC,
DimDynamic.STATIC,
],
dynamic_strides=[
DimDynamic.DYNAMIC,
DimDynamic.INFER_STRIDE,
],
)
s0, s1 = t.size()
s2, s3 = t.stride()
self.assertTrue(isinstance(s0, int))
self.assertTrue(isinstance(s1, int))
self.assertTrue(isinstance(s2, torch.SymInt))
self.assertTrue(isinstance(s3, int))
# Check dynamic sizes and dims, and ensure different symbol
t = create_fake_tensor_with_dynamic_size(
torch.ones(3, 6),
shape_env,
dynamic_sizes=[
DimDynamic.DYNAMIC,
DimDynamic.DYNAMIC,
],
dynamic_strides=[
DimDynamic.DYNAMIC,
DimDynamic.INFER_STRIDE,
],
)
s0, s1 = t.size()
s2, s3 = t.stride()
self.assertTrue(isinstance(s0, torch.SymInt))
self.assertTrue(isinstance(s1, torch.SymInt))
self.assertTrue(isinstance(s2, torch.SymInt))
self.assertTrue(isinstance(s3, int))
self.assertTrue(str(s1.node.expr) != str(s2.node.expr))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
def test_dynamic_int_basic_compile(self, backend):
from torch.fx.experimental.sym_node import DynamicInt
cnt = CompileCounterWithBackend(backend)
# test scalar inputs to function
def f(x, y, z):
out = torch.tensor([x + y + z])
out = out + torch.zeros(abs(x) + 2).sum() # test out tensor construction
return out
fn = torch.compile(f, fullgraph=True, backend=cnt)
x = DynamicInt(1)
z = DynamicInt(3)
self.assertEqual(fn(x, x, z), f(1, 1, 3)) # guard: x == y
self.assertEqual(fn(2, 2, 0), f(2, 2, 0))
self.assertEqual(fn(-1, -1, 2), f(-1, -1, 2))
self.assertEqual(cnt.frame_count, 1) # no recompiles
self.assertEqual(fn(3, 4, 5), f(3, 4, 5)) # now we recompile
self.assertEqual(cnt.frame_count, 2)
# test nn module property
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.i = DynamicInt(1)
def forward(self, x):
return torch.tensor([x + self.i])
cnt.clear()
m = Foo()
mc = torch.compile(m, backend=cnt, fullgraph=True)
self.assertEqual(mc(DynamicInt(0)), m(0))
mc.i = -2 # override attribute
self.assertEqual(mc(-1), m(-1))
self.assertEqual(cnt.frame_count, 1)
def test_dynamic_int_eager_usage(self):
from torch.fx.experimental.sym_node import DynamicInt
w = DynamicInt(-1)
x = DynamicInt(0)
y = DynamicInt(1)
z = DynamicInt(2)
def check(l, r):
self.assertTrue(isinstance(l, DynamicInt))
self.assertEqual(l, r)
# test arithmetic
check(2 * y + z, 4)
check((10 - z) // 2, 4)
check(1 // z, 0)
check(-w + w**2, 2)
check(x % z, 0)
check(1 << z, 4)
check(z | y, 3)
check(min(y, z), 1)
self.assertTrue(z > -2)
with self.assertRaises(ZeroDivisionError):
y % x
# math, numpy
self.assertEqual(math.cos(x), y)
self.assertEqual(math.prod([z, z], start=z), 8)
self.assertEqual(np.arange(z)[y], 1)
self.assertTrue(np.allclose(np.ones([y, z]).sum(axis=x), np.ones(z)))
# test conversions
self.assertTrue(isinstance(x + 2, int))
self.assertTrue(isinstance(x + 2, DynamicInt))
self.assertEqual(y / 2.0, 0.5) # this could return DynamicFloat in future
self.assertEqual(float(z), 2.0)
self.assertFalse(bool(x))
self.assertEqual(DynamicInt(x).real, x.real)
# torch functions, scalar inputs
self.assertEqual(torch.arange(z)[:w][x], 0)
self.assertEqual(torch.add(torch.tensor(w), torch.tensor(w), alpha=z), -3)
self.assertEqual(
list(torch.nn.Linear(z, y)(torch.randn(z * 2, z)).shape), [4, 1]
)
self.assertEqual(z * torch.ones(z).sum(dim=x), 4)
instantiate_parametrized_tests(TestSymNumberMagicMethods)
class TestFloorDiv(TestCase):
@staticmethod
def python_floordiv(x, y):
return x // y
@staticmethod
def torch_floordiv(x, y):
# Note: we fully evaluate here since FloorDiv might not always do
# that.
shape_env = ShapeEnv()
return shape_env.evaluate_expr(FloorDiv(x, y))
@staticmethod
def yield_test_cases(values, negate=True):
for x, y in values:
yield (x, y)
if negate:
yield (-x, y)
yield (x, -y)
yield (-x, -y)
def test_floordiv_float_int(self):
values = ((7, 2),)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
)
def test_floordiv_div_by_one(self):
values = ((2, 1),)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
)
def test_floordiv_div_does_not_generate_non_int_rational(self):
s14 = sympy.Symbol("s14", integer=True, positive=True)
s37 = sympy.Symbol("s37", integer=True, positive=True)
inner_expr = FloorDiv(s14, 2016)
middle_expr = (24 * s37 + 672) * inner_expr
numerator = middle_expr + 21
denominator = 22
result = FloorDiv(numerator, denominator)
rationals = result.atoms(sympy.Rational)
all_rationals_ints = all(r.q == 1 for r in rationals)
self.assertTrue(all_rationals_ints)
def test_floordiv_simplify(self):
# Tests how we simplify or evaluate FloorDiv without free variables
shape_env = ShapeEnv()
result = 21
exprs = (7 * FloorDiv(6, 2),)
for expr in exprs:
self.assertEqual(expr, result)
self.assertEqual(expr.doit(deep=False), result)
self.assertEqual(expr.doit(deep=True), result)
self.assertEqual(sympy.simplify(expr), result)
self.assertEqual(shape_env.simplify(expr), result)
self.assertEqual(shape_env.evaluate_expr(expr), result)
def test_floordiv_assumptions(self):
cases = (
sympy.Symbol("i1", integer=True),
sympy.Symbol("i2", integer=True),
)
for base, divisor in itertools.product(cases, repeat=2):
def op():
return FloorDiv(base, divisor)
def is_complex(x):
return x.is_integer is False and x.is_real is False and x.is_complex
if is_complex(base) or is_complex(divisor):
self.assertRaisesRegex(
TypeError,
(
r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
r" expected integer or real"
),
op,
)
continue
op = op()
# In regular Python, x//x == 1.0 if x is a float, but FloorDiv
# always returns an integer 1 when both args are the same object.
# This even works for Symbols with no assumptions specified.
if base is divisor:
self.assertTrue(op.is_integer)
self.assertTrue(op.is_real)
elif base.is_integer and divisor.is_integer:
self.assertTrue(op.is_integer)
self.assertTrue(op.is_real)
else:
self.assertEqual(op.is_integer, None)
self.assertTrue(op.is_real)
class TestDimConstraints(TestCase):
@skipIfTorchDynamo("mark_dynamic not supported")
def test_simplify_max_1_0(self):
x = torch.rand(10)
torch._dynamo.mark_dynamic(x, 0, max=20, min=5)
@torch.compile(fullgraph=True)
def func(x, v):
# test that statically_known_true
if (v == 0 or v == 1) and not statically_known_true(
max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2)
):
raise AssertionError("error")
if max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2):
return x * 400
else:
return (x * 10) * 100
# testing that this does not throw constraint violation error.
self.assertEqual(func(x, 1), x * 400)
self.assertEqual(func(x, 0), x * 400)
def test_dim_constraints_reduce_congruences_simple(self):
from sympy import Symbol
s = Symbol("s", positive=True, integer=True)
dim_constraints = DimConstraints({}, {}, set(), {})
dim_constraints._congruences[s] = {
(s / 2) % 2,
(s / 2) % 8,
(s / 2) % 4,
s % 2,
((s / 16) + 2) % 4,
}
congruences = dim_constraints._reduce_congruences()
self.assertEqual(congruences[s], {(s + 32) % 64})
def test_dim_constraints_reduce_inequalities_simple(self):
from sympy import Eq, Interval, Ne, Symbol
from sympy.solvers.inequalities import reduce_inequalities
s = Symbol("s", positive=True, integer=True)
exprs = {
s >= 2,
Ne(8 * s, 16),
Ne(s / 2, 1),
Ne(16 * s, 32),
s < 16,
Ne(s, 2),
s / 2 < 16,
s / 2 > 1,
s / 2 >= 2,
Ne(3 * s / 2, 3),
}
solution = reduce_inequalities(exprs, s).as_set()
self.assertEqual(solution, Interval.Ropen(4, 16))
exprs.add(Eq(s / 2, 4))
solution = reduce_inequalities(exprs, s).as_set()
self.assertEqual(solution, {8})
def test_dim_constraints_reduce_inequalities_error(self):
from collections import defaultdict
from sympy import Symbol
from sympy.solvers.inequalities import reduce_inequalities
from torch._dynamo.source import (
LocalSource,
TensorProperty,
TensorPropertySource,
)
from torch.fx.experimental.symbolic_shapes import DynamicDimConstraintPrinter
s0 = Symbol("s0", positive=True, integer=True)
exprs = {
4 * s0**3 - 4 * s0**2 + s0 <= 2147483647,
s0 >= 2,
s0**3 <= 2147483647,
s0 <= 2147483647,
}
answer = reduce_inequalities(exprs, s0)
symbol_to_source = defaultdict(list)
symbol_to_source[s0].append(
TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
)
)
dcp = DynamicDimConstraintPrinter(symbol_to_source, {})
with self.assertRaisesRegex(
AssertionError,
"Unknown symbol.*created by constraints solver",
):
dcp.doprint(answer)
def test_dim_constraints_solve_full(self):
from sympy import Eq, Integer, Ne, Symbol
from torch._dynamo.source import (
LocalSource,
TensorProperty,
TensorPropertySource,
)
src0 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
)
src2 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
)
src3 = TensorPropertySource(
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
)
src4 = TensorPropertySource(
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
)
src1 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
)
src7 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
)
src5 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
)
src8 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
)
src6 = TensorPropertySource(
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
)
src9 = TensorPropertySource(
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
)
src10 = TensorPropertySource(
base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
)
src11 = TensorPropertySource(
base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
)
src12 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
)
s0 = Symbol("s0", positive=True, integer=True)
s1 = Symbol("s1", positive=True, integer=True)
s5 = Symbol("s5", positive=True, integer=True)
s6 = Symbol("s6", positive=True, integer=True)
symbol_to_source = {
s0: [src0, src2, src3, src4],
s1: [src1, src7],
s5: [src5, src8],
s6: [src6, src9, src10],
}
var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
marked_dynamic = {s0, s1, s5, s6}
dim_constraints = DimConstraints(
symbol_to_source, var_to_val, marked_dynamic, {}
)
dim_constraints.add_equality(src2, s0)
dim_constraints.add_equality(src3, s0)
dim_constraints.add_equality(src4, s0)
dim_constraints.add_equality(src7, s1)
dim_constraints.add_equality(src8, s5)
dim_constraints.add_equality(src9, s6)
dim_constraints.add_equality(src10, s6)
dim_constraints.add_equality(src11, Integer(1))
dim_constraints.add_equality(src12, Integer(3))
dim_constraints.add(s1**2 <= 2147483647)
dim_constraints.add(32 * s1**2 <= 2147483647)
dim_constraints.add(s0 < 16)
dim_constraints.add(Eq(Mod(s1, 2), 0))
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1))
dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647)
dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1)
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 64
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 1
> 1
)
dim_constraints.add(
128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 128
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 1
> 1
)
dim_constraints.add(
256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 256
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 1
> 1
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
<= 2147483647
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1)
dim_constraints.add(
Ne(
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * s0,
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
1,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
0,
)
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
>= 0
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0))
dim_constraints.add(
1
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1))
dim_constraints.add(
Ne(
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * s0,
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 0
)
dim_constraints.add(
Eq(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2))
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2)
+ 60 * (Mod(s0, 2)),
0,
)
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
Ne(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
0,
)
)
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
dim_constraints.add(
Ne(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60,
0,
)
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
>= 0
)
dim_constraints.add(
1
< 60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
)
dim_constraints.add(Ne(16 * s0, 32))
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
dim_constraints.add(Ne(16 * s0, 32))
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
dim_constraints.add(FloorDiv(s0, 2) >= 2)
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
dim_constraints.add(1 < FloorDiv(s0, 2))
dim_constraints.add(Ne(s0, 2))
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
>= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2))))
> 0
)
dim_constraints.add(
Ne(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
)
)
dim_constraints.add(
Ne(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20,
0,
)
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 0
)
dim_constraints.add(
Ne(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20,
20,
)
)
dim_constraints.add(
Ne(
20
* (
Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
)
),
0,
)
)
dim_constraints.add(
Ne(
20
* (FloorDiv((FloorDiv(s1, 2) - 1), 8))
* (
Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
)
)
- 20
* Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
),
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
>= 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 0
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 2
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
> 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(
Ne(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60,
60,
)
)
dim_constraints.add(
Ne(
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
)
)
dim_constraints.add(
Eq(
(FloorDiv((FloorDiv(s1, 2) - 1), 8))
* (
Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1,
)
)
- Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1,
),
0,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
)
)
dim_constraints.add(Ne(8 * s0, 16))
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
>= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
)
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
<= 2147483647
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
<= 2147483647
)
dim_constraints.add(FloorDiv(s0, 2) < 16)
dim_constraints.add(FloorDiv(s0, 2) > 1)
dim_constraints.add(
Ne(
90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
> 1
)
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
> 1
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
> 1
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
> 1
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)),
3 * (FloorDiv(s0, 2)),
)
)
dim_constraints.add(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2))
> 0
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
> 0
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
1
< 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
6,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 0
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
<= 2147483647
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
<= 20480
)
dim_constraints.add(
Ne(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90,
0,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 1
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
<= 20480
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
<= 20480
)
dim_constraints.add(
Ne(
240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 240,
0,
)
)
dim_constraints.add(Eq(6 * s5, 132))
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4))
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
<= 2147483647
)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
> 1
)
dim_constraints.add(
62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62
<= 2147483647
)
dim_constraints.add(
Ne(
62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62
)
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576
<= 2147483647
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1)
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9,
1,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9,
0,
)
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
>= 0
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0))
dim_constraints.add(
1
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2)),
256,
)
)
dim_constraints.add(
Eq(
64
* (
Mod(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2)),
4,
)
),
0,
)
)
dim_constraints.add(
Eq(
FloorDiv(s0, 2),
FloorDiv(
(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2))
),
4,
),
)
)
dim_constraints.add(
Eq(
FloorDiv(
(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2))
),
4,
),
FloorDiv(s0, 2),
)
)
dim_constraints.add(
Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0)
)
dim_constraints.add(
Eq(
64
* (
Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
4,
)
),
0,
)
)
dim_constraints.add(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2))
> 0
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
>= 1
)
dim_constraints.add(
Eq(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576,
256,
)
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
<= 2147483647
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
<= 2147483647
)
dim_constraints.add(
Ne(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
> 1
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
> 1
)
dim_constraints.add(s0 >= 2)
dim_constraints.add(s1 >= 2)
dim_constraints.add(s6 >= 2)
dim_constraints.add(s5 >= 2)
dim_constraints.solve()
self.assertEqual(
dim_constraints._static_results,
{
"L['c'].size()[0] == 8",
"L['d'].size()[0] == 8",
"L['a'].size()[2] == 96",
"L['f'].size()[1] == 1",
"L['a'].size()[3] == 96",
"L['b'].size()[2] == 3",
"L['b'].size()[1] == 22",
"L['b'].size()[0] == 8",
"L['a'].size()[1] == 22",
"L['a'].size()[0] == 8",
},
)
self.assertEqual(
dim_constraints._dynamic_results,
{
"2 <= L['c'].size()[1]",
"L['d'].size()[1] == L['c'].size()[1]",
"L['e'].size()[1] == L['c'].size()[1]",
},
)
class TestGuardsExpressions(TestCase):
"""
Tests the guards-related methods used by the inductor FX graph cache.
"""
def test_guards_gt_lt(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 6)
s1 = create_symint(shape_env, 7)
s2 = create_symint(shape_env, 5)
guard_int(sym_int(s0 > 5))
guard_int(sym_int(s0 < 7))
guards = shape_env.produce_guards_expression([s0])
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s2)]))
def test_guards_float_print(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 3)
guard_bool(2 / s0 == 2 / 3)
guards = shape_env.produce_guards_expression([s0])
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_guard_or_true(self):
from torch.fx.experimental.symbolic_shapes import guard_or_true
def func(a, b):
x = a.item()
if guard_or_true(x == 1):
return b * 10
else:
return b * 20
# eager.
self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]))
self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]))
# compile with unbacked.
unbacked_func = torch.compile(func, dynamic=True, fullgraph=True)
a = torch.tensor([1])
b = torch.tensor([1])
unbacked_func(a, b)
# always return b*10
self.assertEqual(
unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])
)
self.assertEqual(
unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([10])
)
# Test that statically known true works.
def func2(a, b):
x = a.item()
if guard_or_true(x != x):
return b * 10
else:
return b * 20
unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True)
a = torch.tensor([1])
b = torch.tensor([1])
unbacked_func2(a, b)
# always return b*20
self.assertEqual(
unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
)
self.assertEqual(
unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])
)
# Test backed_size_oblivious
with torch.fx.experimental._config.patch("backed_size_oblivious", True):
def func3(a, b):
if guard_or_true(a.size()[0] != 9):
return b * 10
else:
return b * 20
compiled = torch.compile(func3, dynamic=True, fullgraph=True)
a = torch.rand(9, 2)
b = torch.rand(3, 4)
self.assertEqual(func3(a, b), b * 20)
self.assertEqual(compiled(a, b), b * 10)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_guard_or_false(self):
from torch.fx.experimental.symbolic_shapes import guard_or_false
def func(a, b):
x = a.item()
if guard_or_false(x == 1):
return b * 10
else:
return b * 20
# eager.
self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]))
self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]))
# compile with unbacked.
unbacked_func = torch.compile(func, dynamic=True, fullgraph=True)
a = torch.tensor([1])
b = torch.tensor([1])
unbacked_func(a, b)
# always return b*20
self.assertEqual(
unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
)
self.assertEqual(
unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])
)
# Test that statically known true works.
def func2(a, b):
x = a.item()
if guard_or_false(x == x):
return b * 10
else:
return b * 20
unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True)
a = torch.tensor([1])
b = torch.tensor([1])
unbacked_func2(a, b)
# always return b*10
self.assertEqual(
unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])
)
self.assertEqual(
unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([10])
)
# Test backed_size_oblivious
with torch.fx.experimental._config.patch("backed_size_oblivious", True):
def func3(a, b):
if guard_or_false(a.size()[0] == 9):
return b * 10
else:
return b * 20
compiled = torch.compile(func3, dynamic=True, fullgraph=True)
a = torch.rand(9, 2)
b = torch.rand(3, 4)
self.assertEqual(func3(a, b), b * 10)
self.assertEqual(compiled(a, b), b * 20)
def test_guards_float_div(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 8)
s1 = create_symint(shape_env, 7)
guard_int(sym_int(s0 / 2.0))
guards = shape_env.produce_guards_expression([s0])
self.assertIn("math.trunc(", guards)
self.assertIn("float(", guards)
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
@unittest.skipIf(
TEST_XPU, "Skipped on XPU"
) # https://github.com/intel/torch-xpu-ops/issues/2169"
@skipIfTorchDynamo("Attempt to trace generator")
@torch.fx.experimental._config.patch("use_duck_shape", False)
def test_size_comparison_no_recompile(self):
"""
Test that size comparisons don't cause recompilation.
When comparing x.size() == b.size() with different sizes,
the compiled function should only compile once.
We should not guard in sizes of the inner elements.
"""
cnt = CompileCounter()
@torch.compile(fullgraph=True, dynamic=True, backend=cnt)
def f(x, b):
if x.size() == b.size():
return x
return x * 2
# First call: shapes differ (1, 2) vs (2, 4, 9), so if branch is False
f(torch.rand(10, 2), torch.rand(20, 4, 9))
# Second call: shapes differ again (1, 2) vs (1, 4, 9), so if branch is False
f(torch.rand(10, 2), torch.rand(10, 4, 9))
# Should only compile once despite different input shapes
self.assertEqual(
cnt.frame_count,
1,
f"Expected 1 compilation, got {cnt.frame_count}. "
f"Size comparison should not cause recompilation.",
)
def test_remove_symbols_without_guarding(self):
from torch._functorch.partitioners import _remove_symbols_without_guarding
shape_env = ShapeEnv()
x = create_fake_tensor_with_dynamic_size(
torch.randn(5, 8),
shape_env,
dynamic_sizes=[
DimDynamic.DYNAMIC,
DimDynamic.DYNAMIC,
],
dynamic_strides=[
DimDynamic.INFER_STRIDE,
DimDynamic.INFER_STRIDE,
],
)
self.assertEqual(f"{x.stride()}", "(s49, 1)")
self.assertEqual(f"{x.shape}", "torch.Size([s26, s49])")
x_clean = _remove_symbols_without_guarding(x, 4096)
self.assertEqual(f"{x_clean.stride()}", "(8, 1)")
self.assertEqual(f"{x_clean.shape}", "torch.Size([5, 8])")
def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
for node in graph.nodes:
if node.name == "arg3_1":
assert node.meta["val"].size()[0] == 2
return graph
class TestUnbacked(TestCase):
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
def test_deferred_neq_assert(self, backend):
@torch.compile(fullgraph=True, backend=backend)
def func(a):
torch._check(a.item() != 5)
return a.item() * 10
func(torch.tensor([100]))
with self.assertRaises(RuntimeError):
func(torch.tensor([5]))
# Test a situation where we generate a runtime assert i.e: u1==s1, then we specialize s1
# later on to a constant.
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
def test_post_specialize_runtime_assert1(self, backend):
@torch.compile(dynamic=True, backend=backend)
def func(x, y):
u0 = y.item()
s0 = x.size()[0]
s1 = x.size()[1]
torch._check(u0 + s0 + s1 == 102)
assert s0 == 2
return x * 10
func(torch.rand(2, 50), torch.tensor([50]))
with self.assertRaises(RuntimeError):
func(torch.rand(2, 50), torch.tensor([51]))
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch(post_grad_custom_pre_pass=custom_pass)
@parametrize("backend", ["inductor", "eager"])
def test_post_specialize_runtime_assert2(self, backend):
@torch.compile(dynamic=True, backend=backend)
def func(x, y):
u0 = y.item()
s0 = x.size()[0]
s1 = x.size()[1]
torch._check(u0 + s0 + s1 == 102)
return x * 10
func(torch.rand(2, 50), torch.tensor([50]))
with self.assertRaises(RuntimeError):
func(torch.rand(2, 50), torch.tensor([51]))
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
def test_deferred_sym_or_assert(self, backend):
@torch.compile(fullgraph=True, backend=backend)
def func(a, b):
torch._check(operator.or_(a.item() == 5, b.item() == 5))
return a.item() * 10
func(torch.tensor([5]), torch.tensor([100]))
func(torch.tensor([100]), torch.tensor([5]))
def test_has_free_symbols(self):
self.assertFalse(has_free_symbols(sympy.S.true))
self.assertFalse(has_free_symbols(sympy.Max(1, 10, evaluate=False)))
self.assertFalse(has_free_symbols(sympy.sympify("1")))
self.assertFalse(has_free_symbols(sympy.sympify("1.1")))
self.assertTrue(has_free_symbols(sympy.sympify("a")))
self.assertTrue(has_free_symbols(sympy.sympify("a*2")))
self.assertTrue(has_free_symbols(sympy.sympify("a+b")))
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
def test_deferred_sym_eq_assert(self, backend):
@torch.compile(fullgraph=True, backend=backend)
def func(a, b):
torch._check(b.item() == 5)
return a * 10
func(torch.tensor([5]), torch.tensor([5]))
with self.assertRaises(RuntimeError):
func(torch.tensor([100]), torch.tensor([1]))
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
@skipIfTorchDynamo("mark_unbacked is not traceable")
def test_deferred_with_unbacked_input(self, backend):
@torch.compile(fullgraph=True, dynamic=True, backend=backend)
def func(a, b):
torch._check(a.size()[0] == b.size()[0])
return a * 10
a = torch.rand(1, 1)
b = torch.rand(1, 1)
torch._dynamo.decorators.mark_unbacked(a, 0)
torch._dynamo.decorators.mark_unbacked(b, 0)
func(a, b)
# inductor adds the check sometimes itself so it will be reflected
# as AssertionError.
with self.assertRaises((AssertionError, RuntimeError)):
func(a, torch.rand(2, 1))
@pytest.mark.xfail(reason="https://github.com/pytorch/pytorch/issues/163785")
@skipIfTorchDynamo("mark_unbacked is not traceable")
def test_do_not_guard_unbacked_inputs(self):
@torch.compile(fullgraph=True, dynamic=True, backend="inductor")
def func(a, b):
a.expand(b.shape)
return a * 10
a = torch.rand(1, 1)
b = torch.rand(1, 1)
torch._dynamo.decorators.mark_unbacked(a, 0)
torch._dynamo.decorators.mark_unbacked(a, 1)
torch._dynamo.decorators.mark_unbacked(b, 0)
torch._dynamo.decorators.mark_unbacked(b, 1)
log_stream, ctx = logs_to_string("torch._dynamo.guards", "guards")
with ctx():
func(a, b)
func(torch.rand(4, 5), torch.rand(4, 5))
guards = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertFalse("SYMBOLIC_SHAPE_GUARD" in guards)
@skipIfTorchDynamo("mark_unbacked is not traceable")
def test_div_unbacked_eq_input_tensors(self):
@torch.compile(fullgraph=True)
def func(a, b):
x = a.size()[0]
y = b.size()[0]
torch._check(x == y)
if x // y == 1:
a = a * 10
if 2 * x // y == 2:
a = a * 20
return a
a = torch.randn(10, 10)
b = torch.randn(10, 20)
torch._dynamo.decorators.mark_unbacked(a, 0)
torch._dynamo.decorators.mark_unbacked(b, 0)
func(a, b)
@torch.compiler.config.patch(unbacked_sources="L['x'],L['y']")
def test_div_unbacked_eq_input_ints(self):
@torch.compile(fullgraph=True)
def func(x, y):
a = torch.rand(1)
torch._check(x == y)
if x // y == 1:
a = a * 10
if 2 * x // y == 2:
a = a * 20
return a
func(10, 10)
@skipIfTorchDynamo("mark_unbacked is not traceable")
@torch.compiler.config.patch(unbacked_sources="L['y']")
def test_div_unbacked_eq_globals(self):
tensor = torch.rand(10, 44)
y = 10
@torch.compile(fullgraph=True, dynamic=True)
def func():
a = torch.rand(1)
x = tensor.size()[0]
torch._check(x == y)
if x // y == 1:
a = a * 10
if 2 * x // y == 2:
a = a * 20
return a
torch._dynamo.decorators.mark_unbacked(tensor, 0)
func()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_div_unbacked_eq_item(self):
@torch.compile(fullgraph=True)
def func(a, b):
x = a.item()
y = b.item()
torch._check(x == y)
# TODO we should not need those torch checks.
torch._check(x // y == 1)
torch._check(2 * x // y == 2)
if x // y == 1:
a = a * 10
if 2 * x // y == 2:
a = a * 20
return a
a = torch.tensor([1])
b = torch.tensor([1])
func(a, b)
class TestUbackedOps(TestCase):
@fresh_cache()
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_reshape1(self):
cnt = CompileCounterWithBackend("inductor")
# Reshape happens in place reshape (no-clone)
# reshape u1 -> (u0*u0)
def func(x, y):
f = y.item()
t1 = x.view((f, f))
t2 = x.reshape((f, f))
t3 = torch._ops.ops.aten.view_copy(x, (f, f))
return t1 * 10, t2 * 10, t3
compiled_func = torch.compile(
fullgraph=True,
backend=cnt,
dynamic=True,
)(func)
# create a non-contiguous with data being even numbers in [0:cnt-1]
# and reshape it into sqrt(cnt)*sqrt(cnt)
def make_non_contiguous_tensor_and_test(cnt):
# create a non-contiguous tensor x that is skipping odd indices.
x = torch.arange(cnt * 2)
x = x.as_strided((x.size()[0] // 2,), (2,))
torch._dynamo.decorators.mark_unbacked(x, 0)
sz = torch.tensor([int(math.sqrt(cnt))])
compiled_result = compiled_func(x, sz)
eager_result = func(x, sz)
self.assertEqual(compiled_result, eager_result)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
make_non_contiguous_tensor_and_test(4)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
view: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense])
view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense])
view_2: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None
clone: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.clone.default(view_2); view_2 = None
mul_11: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
mul_14: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
return (mul_11, mul_14, clone)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
make_non_contiguous_tensor_and_test(49)
self.assertEqual(cnt.frame_count, 1)
# Pass in a contiguous tensor, it will recompile due to stride being 1 (0/1 specialization).
# marking strides unbacked would have avoided the recompilation here.
x = torch.arange(100)
torch._dynamo.decorators.mark_unbacked(x, 0)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
compiled_result = compiled_func(x, torch.tensor([10]))
eager_result = func(x, torch.tensor([10]))
self.assertEqual(compiled_result, eager_result)
self.assertEqual(cnt.frame_count, 2)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
view: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense])
view_1: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense])
view_2: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None
clone: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.clone.default(view_2); view_2 = None
mul_6: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
return (mul_6, mul_9, clone)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
x = torch.arange(25)
compiled_result = compiled_func(x, torch.tensor([5]))
eager_result = func(x, torch.tensor([5]))
self.assertEqual(cnt.frame_count, 2)
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_reshape2(self):
cnt = CompileCounterWithBackend("inductor")
# This reshape requires a clone when the input is not contiguous and we can't compute strides.
# reshape (u2, u3) -> (u0, u1)
def func(x, y):
u0, u1 = y.tolist()
result1 = torch.reshape(x, (u0, u1))
return result1 * 10
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
x = torch.randn(10, 10)
# make x not contiguous.
x = x.t_()
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
result_eager = func(x, torch.tensor([5, 20]))
result_compiled = compiled_func(x, torch.tensor([5, 20]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 1)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""\
def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", arg3_1: "f32[u2, u3][1, u2]cpu"):
ge_1: "Sym(u2 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u3 >= 0)" = arg2_1 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
ge_4: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_4, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_4 = _assert_scalar_2 = None
sym_sum: "Sym(u0 + 1)" = torch.sym_sum((1, _local_scalar_dense))
gt: "Sym(u0 + 1 > 0)" = sym_sum > 0; sym_sum = None
_assert_scalar_3 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 0 < u0 + 1 on node 'gt'"); gt = _assert_scalar_3 = None
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
ge_5: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_5 = _assert_scalar_4 = None
sym_sum_1: "Sym(u1 + 1)" = torch.sym_sum((1, _local_scalar_dense_1))
gt_1: "Sym(u1 + 1 > 0)" = sym_sum_1 > 0; sym_sum_1 = None
_assert_scalar_5 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 + 1 on node 'gt_1'"); gt_1 = _assert_scalar_5 = None
mul: "Sym(u2*u3)" = arg1_1 * arg2_1; arg1_1 = arg2_1 = None
mul_1: "Sym(u0*u1)" = _local_scalar_dense * _local_scalar_dense_1
eq: "Sym(Eq(u2*u3, u0*u1))" = mul == mul_1; mul = mul_1 = None
_assert_scalar_6 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_6 = None
clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None
view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None
mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
return (mul_21,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
result_eager = func(x, torch.tensor([2, 50]))
result_compiled = compiled_func(x, torch.tensor([2, 50]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 1)
x = torch.randn(4, 4).t_()
result_eager = func(x, torch.tensor([2, 8]))
result_compiled = compiled_func(x, torch.tensor([2, 8]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 1)
# Pass a contiguous tensor. A recompilation will happen due to 0/1 specialization on stride.
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].
# but not anymore since we use contiguous_or_false .
# We need a way to mark strides unbacked to avoid the recompilation here.
x = torch.randn(10, 10)
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
result_compiled = compiled_func(x, torch.tensor([2, 50]))
result_eager = func(x, torch.tensor([2, 50]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 2)
x = torch.randn(4, 4)
result_eager = func(x, torch.tensor([2, 8]))
result_compiled = compiled_func(x, torch.tensor([2, 8]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 2)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_slice(self):
from torch.fx.experimental.symbolic_shapes import statically_known_true
# standard slice
def f1(x, xs):
u0, u1 = xs.tolist()
# in this test we add the torch checks not to avoid DDE but to ensure
# that we pick specific path during compilation.
torch._check(u0 >= 0)
torch._check(u0 <= x.size(0))
torch._check(u1 >= 0)
torch._check(u1 <= x.size(0))
torch._check(u0 <= u1)
out = x[u0:u1]
assert statically_known_true(out.size(0) == (u1 - u0))
return out
x, xs = torch.randn(10), torch.tensor([3, 6])
fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
self.assertEqual(fn1(x, xs).size(0), 3)
self.assertTrue(torch.allclose(fn1(x, xs), f1(x, xs)))
with self.assertRaises(RuntimeError):
fn1(x, torch.tensor([-1, 5]))
# known negative slice
def f2(x, n):
u0 = n.item()
torch._check(u0 > 1)
torch._check(u0 <= x.size(0))
out = x[-u0:]
assert statically_known_true(out.size(0) == u0)
return out
x, n = torch.randn(10), torch.tensor([5])
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
self.assertEqual(fn2(x, n).size(0), 5)
self.assertTrue(torch.allclose(fn2(x, n), f2(x, n)))
with self.assertRaises(RuntimeError):
fn2(x, torch.tensor([-5]))
# general case: no known info
def f3(x, xs):
u0, u1 = xs.tolist()
return x[u0:u1]
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
cnts = CompileCounterWithBackend("inductor")
x, xs = torch.randn(10), torch.tensor([3, 6])
with ctx():
fn3 = torch.compile(f3, fullgraph=True, backend=cnts)
xs = torch.tensor([-9, -1]) # negative case
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
xs = torch.tensor([-1000, 1000]) # out of bounds
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
xs = torch.tensor([2, -2]) # mixed
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
self.assertEqual(cnts.frame_count, 1)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""\
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
slice_1: "f32[u2][1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, _local_scalar_dense, _local_scalar_dense_1); arg1_1 = _local_scalar_dense = _local_scalar_dense_1 = None
sym_size_int: "Sym(u2)" = torch.ops.aten.sym_size.int(slice_1, 0)
ge_1: "Sym(u2 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
le: "Sym(u2 <= 10)" = sym_size_int <= 10; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u2 <= 10 on node 'le'"); le = _assert_scalar_1 = None
sym_storage_offset_default: "Sym(u3)" = torch.ops.aten.sym_storage_offset.default(slice_1)
ge_2: "Sym(u3 >= 0)" = sym_storage_offset_default >= 0; sym_storage_offset_default = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_2 = None
return (slice_1,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_unbacked_slice_cpp_wrapper(self):
self.test_unbacked_slice()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_slice_with_step(self):
def f1(x, xs):
u0, u1 = xs.tolist()
out = x[u0:u1:5]
return out
x, xs = torch.randn(10), torch.tensor([2, -2])
fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(fn1(x, xs), f1(x, xs)))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_unbacked_slice_with_step_cpp_wrapper(self):
self.test_unbacked_slice_with_step()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_slice_with_tensor_indices(self):
for d in [True, False]:
# Test slicing with tensor start/stop/step on RHS (reading)
# Test 1: Basic slice with tensor start and stop
def f1(x, start_t, stop_t):
return x[start_t:stop_t]
x = torch.randn(20)
start_t = torch.tensor(5)
stop_t = torch.tensor(15)
fn1 = torch.compile(f1, fullgraph=True, dynamic=d, backend="inductor")
self.assertTrue(
torch.allclose(fn1(x, start_t, stop_t), f1(x, start_t, stop_t))
)
# Test 2: Slice with tensor step
def f2(x, start_t, stop_t, step_t):
return x[start_t:stop_t:step_t]
step_t = torch.tensor(2)
fn2 = torch.compile(f2, fullgraph=True, dynamic=d, backend="inductor")
self.assertTrue(
torch.allclose(
fn2(x, start_t, stop_t, step_t), f2(x, start_t, stop_t, step_t)
)
)
# Test 3: Slice with only tensor start
def f3(x, start_t):
return x[start_t:]
fn3 = torch.compile(f3, fullgraph=True, dynamic=d, backend="inductor")
self.assertTrue(torch.allclose(fn3(x, start_t), f3(x, start_t)))
# Test 4: Slice with only tensor stop
def f4(x, stop_t):
return x[:stop_t]
fn4 = torch.compile(f4, fullgraph=True, dynamic=d, backend="inductor")
self.assertTrue(torch.allclose(fn4(x, stop_t), f4(x, stop_t)))
# Test 5: Negative indices with tensors
def f5(x, start_t):
return x[start_t:-1]
start_t_neg = torch.tensor(-10)
fn5 = torch.compile(f5, fullgraph=True, dynamic=d, backend="inductor")
self.assertTrue(torch.allclose(fn5(x, start_t_neg), f5(x, start_t_neg)))
# Test 6: Multidimensional slice with tensor indices
def f6(x, start_t, stop_t):
return x[:, start_t:stop_t]
x_2d = torch.randn(10, 20)
fn6 = torch.compile(f6, fullgraph=True, dynamic=d, backend="inductor")
self.assertTrue(
torch.allclose(fn6(x_2d, start_t, stop_t), f6(x_2d, start_t, stop_t))
)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_slice_with_tensor_indices_cpp_wrapper(self):
self.test_slice_with_tensor_indices()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_tensor_split(self):
def f1(x, xs):
xs = torch.tensor(xs.tolist())
return torch.tensor_split(x, xs)
x = torch.randn(20)
xs = torch.tensor([5, 10, 15])
fn = torch.compile(f1, fullgraph=True, backend="inductor")
def compare(x, xs):
for i, j in zip(f1(x, xs), fn(x, xs)):
self.assertTrue(torch.allclose(i, j))
compare(x, xs)
xs = torch.tensor([-15, 9, 10, 11])
compare(x, xs)
xs = torch.tensor([-15, -10, -5, -2])
compare(x, xs)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_tensor_split_cpp_wrapper(self):
self.test_tensor_split()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
def test_nonzero_slice(self):
def f(x):
nz = x.nonzero()
return nz[:-1]
x = torch.randn(3, 4)
fn = torch.compile(f, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(f(x), fn(x)))
y = torch.zeros(3, 4)
self.assertTrue(torch.allclose(f(y), fn(y)))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_nonzero_slice_cpp_wrapper(self):
self.test_nonzero_slice()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
def test_nonzero_select(self):
def f(x):
nz = x.nonzero()
return nz[-1] + nz[0]
x = torch.randn(3, 4)
fn = torch.compile(f, fullgraph=True, backend="inductor")
self.assertTrue(torch.allclose(f(x), fn(x)))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_nonzero_select_cpp_wrapper(self):
self.test_nonzero_select()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True)
def test_padnd(self):
import torch.nn.functional as F
def f(x, xs, y):
u0, u1 = xs.tolist()
for u in [u0, u1]:
torch._check(u >= 0)
z = F.pad(x, (u0, u1, u0, u1))
return z @ y
x = torch.randn(8, 8)
xs = torch.tensor([2, 2])
y = torch.randn(12, 4)
fn = torch.compile(f, fullgraph=True, backend="inductor")
fn(x, xs, y)
@unittest.skip("this test fails due to inductor/autograd issue #153041")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_non_contigious_reshape_failing(self):
# reshape u1 -> (u0*u0)
# this result in the tensor "i64[u0, u0][s7*u0, s7].
# reshape happens in place reshape (no-clone)
def func(x, y):
f = y.item()
t1 = x.view((f, f))
t2 = x.reshape((f, f))
return t1, t2
# create a non-contiguous with data being even numbers in [0:cnt-1]
def make_non_contiguous_tensor(cnt):
# create a non-contiguous tensor x that is skipping odd indices.
x = torch.arange(cnt * 2)
x = x.as_strided((x.size()[0] // 2,), (2,))
return x
x = make_non_contiguous_tensor(4)
torch._dynamo.decorators.mark_unbacked(x, 0)
compiled_func = torch.compile(
fullgraph=True,
backend="inductor",
)(func)
compiled_result = compiled_func(x, torch.tensor([2]))
eager_result = func(x, torch.tensor([2]))
self.assertEqual(compiled_result, eager_result)
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_invalid_view_unbacked_view(self):
cnt = CompileCounterWithBackend("inductor")
# This view (u2, u3) -> (u0, u1) can't happen in general unless we know that input is contiguous or we have
# hints to to compute strides.
def func(x, y):
u0, u1 = y.tolist()
result2 = x.view(u0, u1) * 10
return result2
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
x = torch.randn(10, 10)
# make x not contiguous.
x = x.t_()
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
with self.assertRaises(torch._dynamo.exc.UserError):
# throws a data dependent error.
compiled_func(x, torch.tensor([5, 20]))
@skipIfTorchDynamo()
def test_unbind_not_dynamic(self):
cnt = CompileCounter()
@torch.compile(fullgraph=True, dynamic=True, backend=cnt)
def func(y):
return y.unbind(dim=2), y * 10
func(torch.ones(5, 6, 7, 8))
self.assertEqual(cnt.frame_count, 1)
# it can be dynamic in all dimensions except dim=2
func(torch.ones(4, 9, 7, 10))
self.assertEqual(cnt.frame_count, 1)
func(torch.ones(5, 6, 8, 8))
func(torch.ones(5, 6, 9, 8))
self.assertEqual(cnt.frame_count, 3)
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
@fresh_cache()
def test_unbacked_contiguous(self):
cnt = CompileCounterWithBackend("inductor")
def func(x):
contig = x.contiguous()
return (contig + 1) * 100
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
x = torch.randn(10, 10)
# make x not contiguous.
x = x.t_()
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
compiled_func(x)
self.assertEqual(compiled_func(x), func(x))
y = torch.rand(20, 20).t()
self.assertEqual(compiled_func(y), func(y))
self.assertEqual(cnt.frame_count, 1)
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
output,
"""\
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None
add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None
mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None
return (mul_6,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
# recompilation will happen due to stride specialization.
y = torch.rand(20, 20)
torch._dynamo.decorators.mark_unbacked(y, 0)
torch._dynamo.decorators.mark_unbacked(y, 1)
self.assertEqual(compiled_func(y), func(y))
self.assertEqual(cnt.frame_count, 2)
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
# No clone this time since input is contiguous.
self.assertExpectedInline(
output,
"""\
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None
mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None
return (mul_5,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_select_index(self):
cnt = CompileCounterWithBackend("inductor")
def func(x, y):
u0 = y.item()
return (
torch.select(x, 0, u0),
torch.select(x, 1, u0),
torch.select(x, 2, u0),
)
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
x = torch.rand(3, 3, 3)
zero = torch.tensor([0])
pos = torch.tensor([1])
# code can handle both negative and positive indices.
neg = torch.tensor([-1])
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
self.assertEqual(compiled_func(x, zero), func(x, zero))
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
output,
"""\
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
select: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense)
select_1: "f32[s77, s77][s77**2, 1]cpu" = torch.ops.aten.select.int(arg2_1, 1, _local_scalar_dense)
select_2: "f32[s77, s77][s77**2, s77]cpu" = torch.ops.aten.select.int(arg2_1, 2, _local_scalar_dense); arg2_1 = _local_scalar_dense = None
return (select, select_1, select_2)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
self.assertEqual(compiled_func(x, pos), func(x, pos))
self.assertEqual(compiled_func(x, neg), func(x, neg))
self.assertEqual(cnt.frame_count, 1)
def func2(x, y):
u0, u1 = y.tolist()
return torch.select(x, 0, u0 + u1)
compiled_func2 = torch.compile(fullgraph=True, backend=cnt, dynamic=False)(
func2
)
zero = torch.tensor([0, 0])
pos = torch.tensor([1, 1])
neg = torch.tensor([-1, -1])
self.assertEqual(compiled_func2(x, pos), func2(x, pos))
self.assertEqual(compiled_func2(x, neg), func2(x, neg))
self.assertEqual(compiled_func2(x, zero), func2(x, zero))
self.assertEqual(cnt.frame_count, 2)
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_select_2(self):
class M(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
return nz[-1]
mod = M()
x = torch.randn(4)
self.assertEqual(torch.compile(mod)(x), mod(x))
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_select_index_with_check(self):
def func3(x, y):
u0 = y.item()
# Test that taking the non-unbacked path works fine also.
torch._check(u0 >= 0)
return (torch.select(x, 1, u0),)
compiled_func3 = torch.compile(
fullgraph=True, backend="inductor", dynamic=True
)(func3)
x = torch.rand(3, 3, 3)
zero = torch.tensor([0])
pos = torch.tensor([1])
print(compiled_func3(x, pos))
self.assertEqual(compiled_func3(x, pos), func3(x, pos))
self.assertEqual(compiled_func3(x, zero), func3(x, zero))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_unbacked_select_index_cpp_wrapper(self):
self.test_unbacked_select_index()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_select2(self):
def f(idx, x):
x = x.select(0, idx.item())
return x @ x
x = torch.randn(3, 3, 3)
idx = torch.tensor(1, dtype=torch.int64)
out = torch.compile(f)(idx, x)
self.assertEqual(out, f(idx, x))
def test_trunc_int_div_true(self):
@torch.compile(backend="inductor", dynamic=True, fullgraph=True)
def f(x, s13, s57, s77):
torch._check(s13 >= 0)
torch._check(s57 >= 0)
torch._check(s77 >= 0)
if int(s13 * ((s57 // s13) + (s77 // s13)) / s13) >= 1:
return x * 2
else:
return x * 100
# ensure we compile this with no errors.
x = torch.rand(10)
f(x, 4, 4096, 3920)
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_reshape3(self):
def func(x):
x = x.as_strided([x.size()[0], 1536], [2048, 1])
result1 = x.view(x.size()[0], -1, 128)
return result1 * 10
compiled = torch.compile(fullgraph=True, backend="inductor")(func)
x = torch.randn(10, 2048)
torch._dynamo.decorators.mark_unbacked(x, 0)
self.assertEqual(func(x), compiled(x))
@fresh_cache()
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_reshape_copy(self):
cnt = CompileCounterWithBackend("inductor")
# Reshape happens in place reshape (no-clone)
# reshape u1 -> (u0*u0)
def func(x, y):
f = y.item()
t3 = torch._ops.ops.aten._reshape_copy(x, (f, f))
return t3
compiled_func = torch.compile(
fullgraph=True,
backend=cnt,
dynamic=True,
)(func)
# create a non-contiguous with data being even numbers in [0:cnt-1]
# and reshape it into sqrt(cnt)*sqrt(cnt)
def make_non_contiguous_tensor_and_test(cnt):
# create a non-contiguous tensor x that is skipping odd indices.
x = torch.arange(cnt * 2)
x = x.as_strided((x.size()[0] // 2,), (2,))
torch._dynamo.decorators.mark_unbacked(x, 0)
sz = torch.tensor([int(math.sqrt(cnt))])
compiled_result = compiled_func(x, sz)
eager_result = func(x, sz)
self.assertEqual(compiled_result, eager_result)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
make_non_contiguous_tensor_and_test(4)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
_reshape_copy: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten._reshape_copy.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None
return (_reshape_copy,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
make_non_contiguous_tensor_and_test(49)
self.assertEqual(cnt.frame_count, 1)
# Pass in a contiguous tensor, it will recompile due to stride being 1 (0/1 specialization).
# marking strides unbacked would have avoided the recompilation here.
x = torch.arange(100)
torch._dynamo.decorators.mark_unbacked(x, 0)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.graph_capture", "aot_graphs"
)
with ctx():
compiled_result = compiled_func(x, torch.tensor([10]))
eager_result = func(x, torch.tensor([10]))
self.assertEqual(compiled_result, eager_result)
self.assertEqual(cnt.frame_count, 2)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
_reshape_copy: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten._reshape_copy.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None
return (_reshape_copy,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
x = torch.arange(25)
compiled_result = compiled_func(x, torch.tensor([5]))
eager_result = func(x, torch.tensor([5]))
self.assertEqual(cnt.frame_count, 2)
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_item(self):
def func():
_x_ms = torch.tensor([True, False], dtype=torch.int64)
_mask_ms = torch.zeros_like(_x_ms, dtype=torch.bool)
_mask_ms[:1] = True
var_node_2 = torch.masked_select(_x_ms, _mask_ms)
var_node_0 = var_node_2.item()
return var_node_0
result_original = func()
compiled_program = torch.compile(func, fullgraph=True, dynamic=True)
result_compiled = compiled_program()
self.assertEqual(result_original, result_compiled)
def test_unbacked_item_set_item(self):
def my_arithmetic(a, b):
wrk = torch.zeros(a.size(0))
for i in range(a.size(0)):
idx = b[i].item()
wrk[idx] += 1
return wrk
compiled = torch.compile(my_arithmetic, fullgraph=True, disable=False)
a = torch.randn([9])
b = torch.ones(9, dtype=torch.int32)
compiled(a, b)
self.assertEqual(compiled(a, b), my_arithmetic(a, b))
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_item_set_item2(self):
def accumulate(X0, start):
start = start.item()
N = 3
result = X0[start]
for i in range(N):
result += X0[start + 1 + i]
return result
compiled = torch.compile(accumulate, fullgraph=True)
X0 = torch.randn(10, 10)
self.assertEqual(
accumulate(X0, torch.tensor([1])), compiled(X0, torch.tensor([1]))
)
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_item_set_item3(self):
def func(x, y):
u0 = y.item()
x[u0] = 0
return x
compiled = torch.compile(func, fullgraph=True, disable=False)
b = torch.tensor([0])
a = torch.ones(9, dtype=torch.int32)
compiled(a, b)
self.assertEqual(compiled(a, b), func(a, b))
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_select_scatter_unbacked_index(self):
def func(x, y):
u0 = y.item()
# Create a scalar tensor to scatter into the selected index
scalar_src = torch.tensor(42, dtype=x.dtype)
return x.select_scatter(scalar_src, 0, u0)
compiled = torch.compile(func, fullgraph=True, dynamic=True, backend="inductor")
b = torch.tensor([0])
a = torch.ones(9, dtype=torch.int32)
self.assertEqual(compiled(a, b), func(a, b))
instantiate_parametrized_tests(TestUnbacked)
if __name__ == "__main__":
run_tests()