Unconditionally exclude upper bound in all size oblivious tests (#144867)

I was thinking about https://github.com/pytorch/pytorch/pull/144471 some more and I thought, "Hmm, why not just always exclude the constant upper bound." So here it is.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144867
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Edward Z. Yang 2025-01-21 07:39:46 -08:00 committed by PyTorch MergeBot
parent df67ac4c86
commit 323fb4dad0
2 changed files with 9 additions and 18 deletions

View File

@ -1664,8 +1664,9 @@ def _check_is_size(i, message=None, *, max=None):
When max is not None, this specifies an upper bound equivalent to
``_check(i <= max)``. This bound is also subject to alternate semantics:
in ``guard_size_oblivious`` tests, we assume that the max bound is treated
equivalently to all other values.
in ``guard_size_oblivious`` tests, we assume that a constant max bound is
treated equivalently to all other values. Symbolic max bounds are not yet
supported.
NB: Do NOT use this in contexts where a -1 size would be valid (indicating
to infer the size from context, or if you should wrap-around or truncate).

View File

@ -1945,7 +1945,6 @@ class _SymbolInfo(NamedTuple):
vr: Optional[ValueRanges]
val: Optional[sympy.Integer]
is_size_like: bool
oblivious_upper_bound_exclusive: sympy.Integer
@lru_cache(None)
@ -1967,7 +1966,7 @@ def _maybe_evaluate_static_worker(
new_shape_env = {}
new_range_env = {}
for idx, sinfo in enumerate(symbol_info):
k, vr, val, is_size_like, oblivious_upper_bound_exclusive = sinfo
k, vr, val, is_size_like = sinfo
if isinstance(val, SingletonInt):
# Skip var_ranges logic for SingletonInt which is only used
# for jagged layout NestedTensors today
@ -1982,8 +1981,9 @@ def _maybe_evaluate_static_worker(
# This is similar to the flavor where size oblivious omits
# 0/1, it changes semantics but in a benign way.
upper = min(2**48, vr.upper)
if oblivious_upper_bound_exclusive is not None:
upper = min(upper, oblivious_upper_bound_exclusive - 1)
# Excluding the very upper bound can be helpful
if upper > lower:
upper = upper - 1
# This is a bit dodgy: what this means is that there was a
# size-like unbacked symbol whose upper bound < 2. This
# causes... problems.
@ -3175,13 +3175,6 @@ class ShapeEnv:
# practice
self.var_to_range: dict[sympy.Symbol, ValueRanges] = {}
self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {}
# When doing a size-oblivious test, exclude this integer and
# everything higher than it from the acceptable range. This solves
# https://github.com/pytorch/pytorch/issues/120288 for constant range
# case
# TODO: generalize this to work with expressions (in that case, we
# need to maintain a SET and we need extra symbolic reasoning on top)
self.oblivious_upper_bound_exclusive: dict[sympy.Symbol, sympy.Integer] = {}
self.source_name_to_debug_name: dict[str, str] = {}
self.var_to_sources: dict[sympy.Symbol, list[Source]] = {}
self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {}
@ -3490,10 +3483,8 @@ class ShapeEnv:
@record_shapeenv_event()
def _constrain_is_bounded(self, a: sympy.Symbol, upper_bound: int) -> None:
self.oblivious_upper_bound_exclusive[a] = min(
self.oblivious_upper_bound_exclusive.get(a, int_oo),
sympy.Integer(upper_bound),
)
# TODO: Do something nontrivial when upper_bound is expression
pass
@record_shapeenv_event()
def _constrain_range_for_size(
@ -5627,7 +5618,6 @@ class ShapeEnv:
var_ranges.get(s),
self.var_to_val.get(s),
s in self.size_like,
self.oblivious_upper_bound_exclusive.get(s),
)
for s in sorted(fs, key=str) # TODO: speed up sort?
)