Introduce guard_or_true, guard_or_false (#148430)

some context in this document:
https://docs.google.com/document/d/18nJsj-F2C_QXO7ClwzPcAUENQ-B440B43W7DdDnlDt4/edit?tab=t.0#heading=h.pgebnyi7pocj

But TLDR;
`guard_or_true`, `guard_or_false` are better than `guard_size_oblivious` due to :
- Easier to reason about what assumptions we are making while reading the code.
- Avoid size_oblivious complexity that is not needed.
- Avoid unsoundness that could make `guard_size_oblivious(a==1)` be true when its not true for some vaue `a` during runtime.
- Less data dependent errors for some cases: ex, when doing `guard_size_oblivious(a==1)` and we know `a` is a tensor size, if it's traced with `a=u1-u2` `guard_size_oblivious(a==1)` will throw a data dependent error but `guard_else_false` will just return `False`.

### How is it different from statically_known_true??
**`if(cond)`:** (normal guarding) will try to evaluate statically and guard on the condition, willing to restrict input space to evaluate cond. if it fails to evaluate due to data dependent error will throw an exception (that could be converted to graph break in some situations).

**`statically_known_true(cond)`:** would be used when you never want to add a guard (restrict your input space), but just want to do a best effort check to see if you can infer that something is true/false ONLY based on existing constraints.

**`guard_or_true(cond)`/`guard_or_false(cond)`:** Those would be used in situations you prefer to guard and know the result of the expression over not guarding, but in case you hit a data dependent error you are ok with just returning true or false.
Some reasons you might be ok with returning true/false instead could be:
1. It's an optimization I do not want to fail for not performing optimization.
2. I am willing to deviate from the normal semantics when I have unbacked for the benefit of not failing (See the doc above for more details).

**`definitely_true(cond)`**: same as `guard_or_false(cond)` except does not try to do static eval for unbacked (planning to deprecate it and replace uses with `guard_or_false` or make it alias to `guard_or_false`)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148430
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Laith Sakka 2025-03-26 22:13:06 -07:00 committed by PyTorch MergeBot
parent a9ee797e41
commit 6cbcdee944
3 changed files with 34 additions and 0 deletions

View File

@ -40,6 +40,8 @@ torch.fx.experimental.symbolic_shapes
has_free_unbacked_symbols
definitely_true
definitely_false
guard_or_true
guard_or_false
guard_size_oblivious
sym_eq
constrain_range

View File

@ -2009,6 +2009,8 @@
"constrain_unify",
"definitely_false",
"definitely_true",
"guard_or_true",
"guard_or_false",
"error",
"eval_guards",
"eval_is_non_overlapping_and_dense",

View File

@ -122,6 +122,8 @@ class PendingUnbackedSymbolNotFound(RuntimeError):
aten = torch._ops.ops.aten # type: ignore[has-type]
__all__ = [
"guard_or_false",
"guard_or_true",
"has_symbolic_sizes_strides",
"create_contiguous",
"ShapeEnv",
@ -1181,6 +1183,34 @@ def compute_unbacked_bindings(
return symbol_to_path
# The following two functions are common utilities used while defining unbacked semantics
# of various framework code. Those would be used in situations you prefer to guard and know
# the result of the expression over not guarding, but in case you hit a data dependent error
# you are ok with just returning true or false.
# Some reasons you might be ok with returning true/false instead could be:
# (1) It's an optimization/additional check I do not want to fail for not performing it.
# (2) I am willing to deviate from the normal semantics when I have unbacked for the
# benefit of not failing.
def guard_or_false(a: BoolLikeType) -> bool:
"""
Try to guard a, if data dependent error encountered just return false.
"""
try:
return bool(guard_bool(a))
except GuardOnDataDependentSymNode:
return False
def guard_or_true(a: BoolLikeType) -> bool:
"""
Try to guard a, if data dependent error encountered just return true.
"""
try:
return bool(guard_bool(a))
except GuardOnDataDependentSymNode:
return True
def definitely_true(a: BoolLikeType) -> bool:
"""
Returns True only if we can tell that a is True, possibly introducing