support backed_size_oblivious in guard_or_false/guard_or_true (#150231)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150231
Approved by: https://github.com/pianpwk
This commit is contained in:
Laith Sakka 2025-04-09 09:02:22 -07:00 committed by PyTorch MergeBot
parent 31fe258efc
commit 087e8587cd
2 changed files with 71 additions and 19 deletions

View File

@ -2880,6 +2880,22 @@ class TestGuardsExpressions(TestCase):
unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])
)
# Test backed_size_oblivious
with torch.fx.experimental._config.patch("backed_size_oblivious", True):
def func3(a, b):
if guard_or_true(a.size()[0] != 9):
return b * 10
else:
return b * 20
compiled = torch.compile(func3, dynamic=True, fullgraph=True)
a = torch.rand(9, 2)
b = torch.rand(3, 4)
self.assertEqual(func3(a, b), b * 20)
self.assertEqual(compiled(a, b), b * 10)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_guard_or_false(self):
@ -2929,6 +2945,22 @@ class TestGuardsExpressions(TestCase):
unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([10])
)
# Test backed_size_oblivious
with torch.fx.experimental._config.patch("backed_size_oblivious", True):
def func3(a, b):
if guard_or_false(a.size()[0] == 9):
return b * 10
else:
return b * 20
compiled = torch.compile(func3, dynamic=True, fullgraph=True)
a = torch.rand(9, 2)
b = torch.rand(3, 4)
self.assertEqual(func3(a, b), b * 10)
self.assertEqual(compiled(a, b), b * 20)
def test_guards_float_div(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 8)

View File

@ -1195,20 +1195,30 @@ def guard_or_false(a: BoolLikeType) -> bool:
"""
Try to guard a, if data dependent error encountered just return false.
"""
try:
return bool(guard_bool(a))
except GuardOnDataDependentSymNode:
return False
if torch.fx.experimental._config.backed_size_oblivious:
return statically_known_true(a)
else:
try:
return bool(guard_bool(a))
except GuardOnDataDependentSymNode:
return False
def guard_or_true(a: BoolLikeType) -> bool:
"""
Try to guard a, if data dependent error encountered just return true.
"""
try:
return bool(guard_bool(a))
except GuardOnDataDependentSymNode:
return True
if torch.fx.experimental._config.backed_size_oblivious:
result = _static_eval(a)
if result is not None:
return result
else:
return True
else:
try:
return bool(guard_bool(a))
except GuardOnDataDependentSymNode:
return True
def definitely_true(a: BoolLikeType) -> bool:
@ -1253,6 +1263,23 @@ def definitely_false(a: BoolLikeType) -> bool:
return not bool(a)
def _static_eval(x: Union[bool, SymBool]) -> Optional[bool]:
if isinstance(x, SymBool):
expr = x.node.expr
shape_env = x.node.shape_env
try:
simplified = shape_env._maybe_evaluate_static(expr)
if simplified is not None:
return bool(simplified)
else:
return None
except Exception:
log.debug("Could not simplify %s", expr)
return None
assert isinstance(x, bool)
return x
def statically_known_true(x: Union[bool, SymBool]) -> bool:
"""
Returns True if x can be simplified to a constant and is true.
@ -1264,18 +1291,11 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool:
Args:
x (bool, SymBool): The expression to try statically evaluating
"""
if isinstance(x, SymBool):
expr = x.node.expr
shape_env = x.node.shape_env
try:
simplified = shape_env._maybe_evaluate_static(expr)
if simplified is not None:
return bool(simplified)
except Exception:
log.debug("Could not simplify %s", expr)
result = _static_eval(x)
if result is None:
return False
assert isinstance(x, bool)
return x
else:
return result
def sym_eq(x: _T, y: _T) -> Union[bool, SymBool]: