[dynamic shapes] guard individual terms in sym_and; user-code-friendly sym_and/sym_or (#154737)

Previously when processing `sym_and(a, b, c)`, symbolic shapes wouldn't individually process a, b, and c and store their implications. This would lead us to data-dependent error on individual checks, e.g. we stored `u0 >= 0 & u0 <= 10`, but then couldn't figure out `u0 <= 10`.

This handles that, and also makes `sym_and/or` user-code friendly, for testing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154737
Approved by: https://github.com/laithsakka
This commit is contained in:
Pian Pawakapan 2025-06-11 18:08:06 +00:00 committed by PyTorch MergeBot
parent c1cbaca7fd
commit 247f83e0a4
5 changed files with 67 additions and 20 deletions

View File

@ -7813,6 +7813,29 @@ utils_device.CURRENT_DEVICE == None""".split(
self.assertEqual(fn(torch.tensor([4])).size(0), 1)
self.assertEqual(fn(torch.tensor([1])).size(0), 0)
def test_sym_and_terms(self):
from torch.fx.experimental.symbolic_shapes import sym_and
@torch.compile(fullgraph=True, dynamic=True, backend="eager")
def fn(xs):
u0, u1 = xs.tolist()
torch._check(sym_and(u0 >= 3, u0 <= 10, u1 >= 2))
# test individual checks
n = 0
if u0 >= 3:
n += 1
if u0 <= 11:
n += 1
if u1 >= 1:
n += 1
return u0 + u1 + n
fn(torch.tensor([5, 6]))
fn(torch.tensor([8, 7]))
with self.assertRaises(RuntimeError):
fn(torch.tensor([9, 0]))
def test_unbacked_2d_expand(self):
@torch.compile(fullgraph=True, dynamic=True, backend="inductor")
def func(a, b):

View File

@ -7602,9 +7602,7 @@ def forward(self, x):
RuntimeError, r".* expression Eq\(u0, 2\) \| Eq\(u0, 4\) \| Eq\(u0, 6\) .*"
):
ep.module()(torch.tensor([3, 6, 5]))
with self.assertRaisesRegex(
RuntimeError, r".* expression Eq\(u2, 5\) & \(4 <= u1\) & \(u1 <= 8\) .*"
):
with self.assertRaisesRegex(RuntimeError, r".* expression u[\d]+ <= 5 .*"):
ep.module()(torch.tensor([6, 6, 6]))
def test_redundant_assert_max_upper_bound(self):

View File

@ -338,6 +338,8 @@ manual_torch_name_rule_map: dict[str, Any] = {
"torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable,
"torch.fx.experimental.symbolic_shapes.statically_known_true": TorchInGraphFunctionVariable,
"torch.fx.experimental.symbolic_shapes.statically_known_false": TorchInGraphFunctionVariable,
"torch.fx.experimental.symbolic_shapes.sym_and": TorchInGraphFunctionVariable,
"torch.fx.experimental.symbolic_shapes.sym_or": TorchInGraphFunctionVariable,
"torch.fx.experimental.symbolic_shapes.has_static_value": TorchInGraphFunctionVariable,
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,

View File

@ -946,6 +946,28 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
elif isinstance(expr, ConstantVariable):
return expr
@register(torch.fx.experimental.symbolic_shapes.sym_and)
def handle_sym_and(self, tx: "InstructionTranslator", *terms):
if all(isinstance(x, SymNodeVariable) for x in terms):
return SymNodeVariable.create(
tx,
torch.fx.experimental.symbolic_shapes.sym_and(
*(x.as_proxy() for x in terms)
),
sym_num=None,
)
@register(torch.fx.experimental.symbolic_shapes.sym_or)
def handle_sym_or(self, tx: "InstructionTranslator", *terms):
if all(isinstance(x, SymNodeVariable) for x in terms):
return SymNodeVariable.create(
tx,
torch.fx.experimental.symbolic_shapes.sym_or(
*(x.as_proxy() for x in terms)
),
sym_num=None,
)
@register(torch.fx.experimental.symbolic_shapes.has_static_value)
def handle_has_static_value(self, tx: "InstructionTranslator", expr):
if isinstance(expr, SymNodeVariable):

View File

@ -1462,11 +1462,9 @@ def sym_and(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType:
"""
and, but for symbolic expressions, without bool casting.
"""
assert isinstance(x, (bool, SymBool))
if len(others) == 0:
return x
for y in others:
assert isinstance(y, (bool, SymBool))
x = operator.and_(x, y)
return x
@ -1490,11 +1488,9 @@ def sym_or(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType:
"""
or, but for symbolic expressions, without bool casting.
"""
assert isinstance(x, (bool, SymBool))
if len(others) == 0:
return x
for y in others:
assert isinstance(y, (bool, SymBool))
x = operator.or_(x, y)
return x
@ -6799,12 +6795,20 @@ class ShapeEnv:
return self.replacements[a]
@lru_cache(256)
def _maybe_guard_rel(self, expr: sympy.Rel) -> None:
def _maybe_guard_rel(self, expr: sympy.Expr) -> None:
"""
The relational guard is guarded to be true. Use this information to
simplify shapes (i.e. a == b or a % 5 == 0)
"""
assert isinstance(expr, sympy.Rel)
if isinstance(expr, sympy.And):
for arg in expr.args:
self._maybe_guard_rel(arg)
return
elif not isinstance(expr, sympy.Rel):
log.warning(
"_maybe_guard_rel() was called on non-relation expression %s", expr
)
return
# A good example of what goes wrong if you don't do this is
# python test/functorch/test_aotdispatch.py -k
@ -7586,15 +7590,14 @@ class ShapeEnv:
if not self._suppress_guards_tls():
self._log_guard("eval", g, forcing_spec=forcing_spec)
if isinstance(g, sympy.Rel):
# TODO: If we successfully eliminate a symbol via equality, it
# is not actually necessary to save a guard for the equality,
# as we will implicitly generate a guard when we match that
# input against the symbol. Probably the easiest way to
# implement this is to have maybe_guard_rel return a bool
# saying if it "subsumed" the guard (and therefore the guard
# is no longer necessary)
self._maybe_guard_rel(g)
# TODO: If we successfully eliminate a symbol via equality, it
# is not actually necessary to save a guard for the equality,
# as we will implicitly generate a guard when we match that
# input against the symbol. Probably the easiest way to
# implement this is to have maybe_guard_rel return a bool
# saying if it "subsumed" the guard (and therefore the guard
# is no longer necessary)
self._maybe_guard_rel(g)
if not self.allow_complex_guards_as_runtime_asserts:
# at this point, we've evaluated the concrete expr value, and have
@ -7708,8 +7711,7 @@ class ShapeEnv:
log.debug("runtime_asserts_frozen but then got %s", expr)
self._check_frozen(expr, sympy.true)
# eliminate symbols on equality tests / refine ranges
if isinstance(expr, sympy.Rel):
self._maybe_guard_rel(expr)
self._maybe_guard_rel(expr)
# canonicalise to remove equations that are trivially equal
orig_expr = expr