mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamic shapes] guard individual terms in sym_and; user-code-friendly sym_and/sym_or (#154737)
Previously when processing `sym_and(a, b, c)`, symbolic shapes wouldn't individually process a, b, and c and store their implications. This would lead us to data-dependent error on individual checks, e.g. we stored `u0 >= 0 & u0 <= 10`, but then couldn't figure out `u0 <= 10`. This handles that, and also makes `sym_and/or` user-code friendly, for testing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154737 Approved by: https://github.com/laithsakka
This commit is contained in:
parent
c1cbaca7fd
commit
247f83e0a4
|
|
@ -7813,6 +7813,29 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
self.assertEqual(fn(torch.tensor([4])).size(0), 1)
|
||||
self.assertEqual(fn(torch.tensor([1])).size(0), 0)
|
||||
|
||||
def test_sym_and_terms(self):
|
||||
from torch.fx.experimental.symbolic_shapes import sym_and
|
||||
|
||||
@torch.compile(fullgraph=True, dynamic=True, backend="eager")
|
||||
def fn(xs):
|
||||
u0, u1 = xs.tolist()
|
||||
torch._check(sym_and(u0 >= 3, u0 <= 10, u1 >= 2))
|
||||
|
||||
# test individual checks
|
||||
n = 0
|
||||
if u0 >= 3:
|
||||
n += 1
|
||||
if u0 <= 11:
|
||||
n += 1
|
||||
if u1 >= 1:
|
||||
n += 1
|
||||
return u0 + u1 + n
|
||||
|
||||
fn(torch.tensor([5, 6]))
|
||||
fn(torch.tensor([8, 7]))
|
||||
with self.assertRaises(RuntimeError):
|
||||
fn(torch.tensor([9, 0]))
|
||||
|
||||
def test_unbacked_2d_expand(self):
|
||||
@torch.compile(fullgraph=True, dynamic=True, backend="inductor")
|
||||
def func(a, b):
|
||||
|
|
|
|||
|
|
@ -7602,9 +7602,7 @@ def forward(self, x):
|
|||
RuntimeError, r".* expression Eq\(u0, 2\) \| Eq\(u0, 4\) \| Eq\(u0, 6\) .*"
|
||||
):
|
||||
ep.module()(torch.tensor([3, 6, 5]))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r".* expression Eq\(u2, 5\) & \(4 <= u1\) & \(u1 <= 8\) .*"
|
||||
):
|
||||
with self.assertRaisesRegex(RuntimeError, r".* expression u[\d]+ <= 5 .*"):
|
||||
ep.module()(torch.tensor([6, 6, 6]))
|
||||
|
||||
def test_redundant_assert_max_upper_bound(self):
|
||||
|
|
|
|||
|
|
@ -338,6 +338,8 @@ manual_torch_name_rule_map: dict[str, Any] = {
|
|||
"torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.statically_known_true": TorchInGraphFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.statically_known_false": TorchInGraphFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.sym_and": TorchInGraphFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.sym_or": TorchInGraphFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.has_static_value": TorchInGraphFunctionVariable,
|
||||
"torch.cuda._get_device_properties": TorchInGraphFunctionVariable,
|
||||
"torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable,
|
||||
|
|
|
|||
|
|
@ -946,6 +946,28 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
elif isinstance(expr, ConstantVariable):
|
||||
return expr
|
||||
|
||||
@register(torch.fx.experimental.symbolic_shapes.sym_and)
|
||||
def handle_sym_and(self, tx: "InstructionTranslator", *terms):
|
||||
if all(isinstance(x, SymNodeVariable) for x in terms):
|
||||
return SymNodeVariable.create(
|
||||
tx,
|
||||
torch.fx.experimental.symbolic_shapes.sym_and(
|
||||
*(x.as_proxy() for x in terms)
|
||||
),
|
||||
sym_num=None,
|
||||
)
|
||||
|
||||
@register(torch.fx.experimental.symbolic_shapes.sym_or)
|
||||
def handle_sym_or(self, tx: "InstructionTranslator", *terms):
|
||||
if all(isinstance(x, SymNodeVariable) for x in terms):
|
||||
return SymNodeVariable.create(
|
||||
tx,
|
||||
torch.fx.experimental.symbolic_shapes.sym_or(
|
||||
*(x.as_proxy() for x in terms)
|
||||
),
|
||||
sym_num=None,
|
||||
)
|
||||
|
||||
@register(torch.fx.experimental.symbolic_shapes.has_static_value)
|
||||
def handle_has_static_value(self, tx: "InstructionTranslator", expr):
|
||||
if isinstance(expr, SymNodeVariable):
|
||||
|
|
|
|||
|
|
@ -1462,11 +1462,9 @@ def sym_and(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType:
|
|||
"""
|
||||
and, but for symbolic expressions, without bool casting.
|
||||
"""
|
||||
assert isinstance(x, (bool, SymBool))
|
||||
if len(others) == 0:
|
||||
return x
|
||||
for y in others:
|
||||
assert isinstance(y, (bool, SymBool))
|
||||
x = operator.and_(x, y)
|
||||
return x
|
||||
|
||||
|
|
@ -1490,11 +1488,9 @@ def sym_or(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType:
|
|||
"""
|
||||
or, but for symbolic expressions, without bool casting.
|
||||
"""
|
||||
assert isinstance(x, (bool, SymBool))
|
||||
if len(others) == 0:
|
||||
return x
|
||||
for y in others:
|
||||
assert isinstance(y, (bool, SymBool))
|
||||
x = operator.or_(x, y)
|
||||
return x
|
||||
|
||||
|
|
@ -6799,12 +6795,20 @@ class ShapeEnv:
|
|||
return self.replacements[a]
|
||||
|
||||
@lru_cache(256)
|
||||
def _maybe_guard_rel(self, expr: sympy.Rel) -> None:
|
||||
def _maybe_guard_rel(self, expr: sympy.Expr) -> None:
|
||||
"""
|
||||
The relational guard is guarded to be true. Use this information to
|
||||
simplify shapes (i.e. a == b or a % 5 == 0)
|
||||
"""
|
||||
assert isinstance(expr, sympy.Rel)
|
||||
if isinstance(expr, sympy.And):
|
||||
for arg in expr.args:
|
||||
self._maybe_guard_rel(arg)
|
||||
return
|
||||
elif not isinstance(expr, sympy.Rel):
|
||||
log.warning(
|
||||
"_maybe_guard_rel() was called on non-relation expression %s", expr
|
||||
)
|
||||
return
|
||||
|
||||
# A good example of what goes wrong if you don't do this is
|
||||
# python test/functorch/test_aotdispatch.py -k
|
||||
|
|
@ -7586,15 +7590,14 @@ class ShapeEnv:
|
|||
if not self._suppress_guards_tls():
|
||||
self._log_guard("eval", g, forcing_spec=forcing_spec)
|
||||
|
||||
if isinstance(g, sympy.Rel):
|
||||
# TODO: If we successfully eliminate a symbol via equality, it
|
||||
# is not actually necessary to save a guard for the equality,
|
||||
# as we will implicitly generate a guard when we match that
|
||||
# input against the symbol. Probably the easiest way to
|
||||
# implement this is to have maybe_guard_rel return a bool
|
||||
# saying if it "subsumed" the guard (and therefore the guard
|
||||
# is no longer necessary)
|
||||
self._maybe_guard_rel(g)
|
||||
# TODO: If we successfully eliminate a symbol via equality, it
|
||||
# is not actually necessary to save a guard for the equality,
|
||||
# as we will implicitly generate a guard when we match that
|
||||
# input against the symbol. Probably the easiest way to
|
||||
# implement this is to have maybe_guard_rel return a bool
|
||||
# saying if it "subsumed" the guard (and therefore the guard
|
||||
# is no longer necessary)
|
||||
self._maybe_guard_rel(g)
|
||||
|
||||
if not self.allow_complex_guards_as_runtime_asserts:
|
||||
# at this point, we've evaluated the concrete expr value, and have
|
||||
|
|
@ -7708,8 +7711,7 @@ class ShapeEnv:
|
|||
log.debug("runtime_asserts_frozen but then got %s", expr)
|
||||
self._check_frozen(expr, sympy.true)
|
||||
# eliminate symbols on equality tests / refine ranges
|
||||
if isinstance(expr, sympy.Rel):
|
||||
self._maybe_guard_rel(expr)
|
||||
self._maybe_guard_rel(expr)
|
||||
|
||||
# canonicalise to remove equations that are trivially equal
|
||||
orig_expr = expr
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user