mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
support backed_size_oblivious in guard_or_false/guard_or_true (#150231)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150231 Approved by: https://github.com/pianpwk
This commit is contained in:
parent
31fe258efc
commit
087e8587cd
|
|
@ -2880,6 +2880,22 @@ class TestGuardsExpressions(TestCase):
|
|||
unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])
|
||||
)
|
||||
|
||||
# Test backed_size_oblivious
|
||||
with torch.fx.experimental._config.patch("backed_size_oblivious", True):
|
||||
|
||||
def func3(a, b):
|
||||
if guard_or_true(a.size()[0] != 9):
|
||||
return b * 10
|
||||
else:
|
||||
return b * 20
|
||||
|
||||
compiled = torch.compile(func3, dynamic=True, fullgraph=True)
|
||||
a = torch.rand(9, 2)
|
||||
b = torch.rand(3, 4)
|
||||
|
||||
self.assertEqual(func3(a, b), b * 20)
|
||||
self.assertEqual(compiled(a, b), b * 10)
|
||||
|
||||
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_guard_or_false(self):
|
||||
|
|
@ -2929,6 +2945,22 @@ class TestGuardsExpressions(TestCase):
|
|||
unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([10])
|
||||
)
|
||||
|
||||
# Test backed_size_oblivious
|
||||
with torch.fx.experimental._config.patch("backed_size_oblivious", True):
|
||||
|
||||
def func3(a, b):
|
||||
if guard_or_false(a.size()[0] == 9):
|
||||
return b * 10
|
||||
else:
|
||||
return b * 20
|
||||
|
||||
compiled = torch.compile(func3, dynamic=True, fullgraph=True)
|
||||
a = torch.rand(9, 2)
|
||||
b = torch.rand(3, 4)
|
||||
|
||||
self.assertEqual(func3(a, b), b * 10)
|
||||
self.assertEqual(compiled(a, b), b * 20)
|
||||
|
||||
def test_guards_float_div(self):
|
||||
shape_env = ShapeEnv()
|
||||
s0 = create_symint(shape_env, 8)
|
||||
|
|
|
|||
|
|
@ -1195,20 +1195,30 @@ def guard_or_false(a: BoolLikeType) -> bool:
|
|||
"""
|
||||
Try to guard a, if data dependent error encountered just return false.
|
||||
"""
|
||||
try:
|
||||
return bool(guard_bool(a))
|
||||
except GuardOnDataDependentSymNode:
|
||||
return False
|
||||
if torch.fx.experimental._config.backed_size_oblivious:
|
||||
return statically_known_true(a)
|
||||
else:
|
||||
try:
|
||||
return bool(guard_bool(a))
|
||||
except GuardOnDataDependentSymNode:
|
||||
return False
|
||||
|
||||
|
||||
def guard_or_true(a: BoolLikeType) -> bool:
|
||||
"""
|
||||
Try to guard a, if data dependent error encountered just return true.
|
||||
"""
|
||||
try:
|
||||
return bool(guard_bool(a))
|
||||
except GuardOnDataDependentSymNode:
|
||||
return True
|
||||
if torch.fx.experimental._config.backed_size_oblivious:
|
||||
result = _static_eval(a)
|
||||
if result is not None:
|
||||
return result
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
try:
|
||||
return bool(guard_bool(a))
|
||||
except GuardOnDataDependentSymNode:
|
||||
return True
|
||||
|
||||
|
||||
def definitely_true(a: BoolLikeType) -> bool:
|
||||
|
|
@ -1253,6 +1263,23 @@ def definitely_false(a: BoolLikeType) -> bool:
|
|||
return not bool(a)
|
||||
|
||||
|
||||
def _static_eval(x: Union[bool, SymBool]) -> Optional[bool]:
|
||||
if isinstance(x, SymBool):
|
||||
expr = x.node.expr
|
||||
shape_env = x.node.shape_env
|
||||
try:
|
||||
simplified = shape_env._maybe_evaluate_static(expr)
|
||||
if simplified is not None:
|
||||
return bool(simplified)
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
log.debug("Could not simplify %s", expr)
|
||||
return None
|
||||
assert isinstance(x, bool)
|
||||
return x
|
||||
|
||||
|
||||
def statically_known_true(x: Union[bool, SymBool]) -> bool:
|
||||
"""
|
||||
Returns True if x can be simplified to a constant and is true.
|
||||
|
|
@ -1264,18 +1291,11 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool:
|
|||
Args:
|
||||
x (bool, SymBool): The expression to try statically evaluating
|
||||
"""
|
||||
if isinstance(x, SymBool):
|
||||
expr = x.node.expr
|
||||
shape_env = x.node.shape_env
|
||||
try:
|
||||
simplified = shape_env._maybe_evaluate_static(expr)
|
||||
if simplified is not None:
|
||||
return bool(simplified)
|
||||
except Exception:
|
||||
log.debug("Could not simplify %s", expr)
|
||||
result = _static_eval(x)
|
||||
if result is None:
|
||||
return False
|
||||
assert isinstance(x, bool)
|
||||
return x
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
def sym_eq(x: _T, y: _T) -> Union[bool, SymBool]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user