Unconditionally exclude upper bound in all size oblivious tests (#144867)

I was thinking about https://github.com/pytorch/pytorch/pull/144471 some more and I thought, "Hmm, why not just always exclude the constant upper bound." So here it is.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144867
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Edward Z. Yang 2025-01-21 07:39:46 -08:00 committed by PyTorch MergeBot
parent df67ac4c86
commit 323fb4dad0
2 changed files with 9 additions and 18 deletions

View File

@ -1664,8 +1664,9 @@ def _check_is_size(i, message=None, *, max=None):
When max is not None, this specifies an upper bound equivalent to When max is not None, this specifies an upper bound equivalent to
``_check(i <= max)``. This bound is also subject to alternate semantics: ``_check(i <= max)``. This bound is also subject to alternate semantics:
in ``guard_size_oblivious`` tests, we assume that the max bound is treated in ``guard_size_oblivious`` tests, we assume that a constant max bound is
equivalently to all other values. treated equivalently to all other values. Symbolic max bounds are not yet
supported.
NB: Do NOT use this in contexts where a -1 size would be valid (indicating NB: Do NOT use this in contexts where a -1 size would be valid (indicating
to infer the size from context, or if you should wrap-around or truncate). to infer the size from context, or if you should wrap-around or truncate).

View File

@ -1945,7 +1945,6 @@ class _SymbolInfo(NamedTuple):
vr: Optional[ValueRanges] vr: Optional[ValueRanges]
val: Optional[sympy.Integer] val: Optional[sympy.Integer]
is_size_like: bool is_size_like: bool
oblivious_upper_bound_exclusive: sympy.Integer
@lru_cache(None) @lru_cache(None)
@ -1967,7 +1966,7 @@ def _maybe_evaluate_static_worker(
new_shape_env = {} new_shape_env = {}
new_range_env = {} new_range_env = {}
for idx, sinfo in enumerate(symbol_info): for idx, sinfo in enumerate(symbol_info):
k, vr, val, is_size_like, oblivious_upper_bound_exclusive = sinfo k, vr, val, is_size_like = sinfo
if isinstance(val, SingletonInt): if isinstance(val, SingletonInt):
# Skip var_ranges logic for SingletonInt which is only used # Skip var_ranges logic for SingletonInt which is only used
# for jagged layout NestedTensors today # for jagged layout NestedTensors today
@ -1982,8 +1981,9 @@ def _maybe_evaluate_static_worker(
# This is similar to the flavor where size oblivious omits # This is similar to the flavor where size oblivious omits
# 0/1, it changes semantics but in a benign way. # 0/1, it changes semantics but in a benign way.
upper = min(2**48, vr.upper) upper = min(2**48, vr.upper)
if oblivious_upper_bound_exclusive is not None: # Excluding the very upper bound can be helpful
upper = min(upper, oblivious_upper_bound_exclusive - 1) if upper > lower:
upper = upper - 1
# This is a bit dodgy: what this means is that there was a # This is a bit dodgy: what this means is that there was a
# size-like unbacked symbol whose upper bound < 2. This # size-like unbacked symbol whose upper bound < 2. This
# causes... problems. # causes... problems.
@ -3175,13 +3175,6 @@ class ShapeEnv:
# practice # practice
self.var_to_range: dict[sympy.Symbol, ValueRanges] = {} self.var_to_range: dict[sympy.Symbol, ValueRanges] = {}
self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {} self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {}
# When doing a size-oblivious test, exclude this integer and
# everything higher than it from the acceptable range. This solves
# https://github.com/pytorch/pytorch/issues/120288 for constant range
# case
# TODO: generalize this to work with expressions (in that case, we
# need to maintain a SET and we need extra symbolic reasoning on top)
self.oblivious_upper_bound_exclusive: dict[sympy.Symbol, sympy.Integer] = {}
self.source_name_to_debug_name: dict[str, str] = {} self.source_name_to_debug_name: dict[str, str] = {}
self.var_to_sources: dict[sympy.Symbol, list[Source]] = {} self.var_to_sources: dict[sympy.Symbol, list[Source]] = {}
self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {} self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {}
@ -3490,10 +3483,8 @@ class ShapeEnv:
@record_shapeenv_event() @record_shapeenv_event()
def _constrain_is_bounded(self, a: sympy.Symbol, upper_bound: int) -> None: def _constrain_is_bounded(self, a: sympy.Symbol, upper_bound: int) -> None:
self.oblivious_upper_bound_exclusive[a] = min( # TODO: Do something nontrivial when upper_bound is expression
self.oblivious_upper_bound_exclusive.get(a, int_oo), pass
sympy.Integer(upper_bound),
)
@record_shapeenv_event() @record_shapeenv_event()
def _constrain_range_for_size( def _constrain_range_for_size(
@ -5627,7 +5618,6 @@ class ShapeEnv:
var_ranges.get(s), var_ranges.get(s),
self.var_to_val.get(s), self.var_to_val.get(s),
s in self.size_like, s in self.size_like,
self.oblivious_upper_bound_exclusive.get(s),
) )
for s in sorted(fs, key=str) # TODO: speed up sort? for s in sorted(fs, key=str) # TODO: speed up sort?
) )