mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Unconditionally exclude upper bound in all size oblivious tests (#144867)
I was thinking about https://github.com/pytorch/pytorch/pull/144471 some more and I thought, "Hmm, why not just always exclude the constant upper bound." So here it is. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/144867 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
df67ac4c86
commit
323fb4dad0
|
|
@ -1664,8 +1664,9 @@ def _check_is_size(i, message=None, *, max=None):
|
||||||
|
|
||||||
When max is not None, this specifies an upper bound equivalent to
|
When max is not None, this specifies an upper bound equivalent to
|
||||||
``_check(i <= max)``. This bound is also subject to alternate semantics:
|
``_check(i <= max)``. This bound is also subject to alternate semantics:
|
||||||
in ``guard_size_oblivious`` tests, we assume that the max bound is treated
|
in ``guard_size_oblivious`` tests, we assume that a constant max bound is
|
||||||
equivalently to all other values.
|
treated equivalently to all other values. Symbolic max bounds are not yet
|
||||||
|
supported.
|
||||||
|
|
||||||
NB: Do NOT use this in contexts where a -1 size would be valid (indicating
|
NB: Do NOT use this in contexts where a -1 size would be valid (indicating
|
||||||
to infer the size from context, or if you should wrap-around or truncate).
|
to infer the size from context, or if you should wrap-around or truncate).
|
||||||
|
|
|
||||||
|
|
@ -1945,7 +1945,6 @@ class _SymbolInfo(NamedTuple):
|
||||||
vr: Optional[ValueRanges]
|
vr: Optional[ValueRanges]
|
||||||
val: Optional[sympy.Integer]
|
val: Optional[sympy.Integer]
|
||||||
is_size_like: bool
|
is_size_like: bool
|
||||||
oblivious_upper_bound_exclusive: sympy.Integer
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(None)
|
@lru_cache(None)
|
||||||
|
|
@ -1967,7 +1966,7 @@ def _maybe_evaluate_static_worker(
|
||||||
new_shape_env = {}
|
new_shape_env = {}
|
||||||
new_range_env = {}
|
new_range_env = {}
|
||||||
for idx, sinfo in enumerate(symbol_info):
|
for idx, sinfo in enumerate(symbol_info):
|
||||||
k, vr, val, is_size_like, oblivious_upper_bound_exclusive = sinfo
|
k, vr, val, is_size_like = sinfo
|
||||||
if isinstance(val, SingletonInt):
|
if isinstance(val, SingletonInt):
|
||||||
# Skip var_ranges logic for SingletonInt which is only used
|
# Skip var_ranges logic for SingletonInt which is only used
|
||||||
# for jagged layout NestedTensors today
|
# for jagged layout NestedTensors today
|
||||||
|
|
@ -1982,8 +1981,9 @@ def _maybe_evaluate_static_worker(
|
||||||
# This is similar to the flavor where size oblivious omits
|
# This is similar to the flavor where size oblivious omits
|
||||||
# 0/1, it changes semantics but in a benign way.
|
# 0/1, it changes semantics but in a benign way.
|
||||||
upper = min(2**48, vr.upper)
|
upper = min(2**48, vr.upper)
|
||||||
if oblivious_upper_bound_exclusive is not None:
|
# Excluding the very upper bound can be helpful
|
||||||
upper = min(upper, oblivious_upper_bound_exclusive - 1)
|
if upper > lower:
|
||||||
|
upper = upper - 1
|
||||||
# This is a bit dodgy: what this means is that there was a
|
# This is a bit dodgy: what this means is that there was a
|
||||||
# size-like unbacked symbol whose upper bound < 2. This
|
# size-like unbacked symbol whose upper bound < 2. This
|
||||||
# causes... problems.
|
# causes... problems.
|
||||||
|
|
@ -3175,13 +3175,6 @@ class ShapeEnv:
|
||||||
# practice
|
# practice
|
||||||
self.var_to_range: dict[sympy.Symbol, ValueRanges] = {}
|
self.var_to_range: dict[sympy.Symbol, ValueRanges] = {}
|
||||||
self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {}
|
self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {}
|
||||||
# When doing a size-oblivious test, exclude this integer and
|
|
||||||
# everything higher than it from the acceptable range. This solves
|
|
||||||
# https://github.com/pytorch/pytorch/issues/120288 for constant range
|
|
||||||
# case
|
|
||||||
# TODO: generalize this to work with expressions (in that case, we
|
|
||||||
# need to maintain a SET and we need extra symbolic reasoning on top)
|
|
||||||
self.oblivious_upper_bound_exclusive: dict[sympy.Symbol, sympy.Integer] = {}
|
|
||||||
self.source_name_to_debug_name: dict[str, str] = {}
|
self.source_name_to_debug_name: dict[str, str] = {}
|
||||||
self.var_to_sources: dict[sympy.Symbol, list[Source]] = {}
|
self.var_to_sources: dict[sympy.Symbol, list[Source]] = {}
|
||||||
self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {}
|
self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {}
|
||||||
|
|
@ -3490,10 +3483,8 @@ class ShapeEnv:
|
||||||
|
|
||||||
@record_shapeenv_event()
|
@record_shapeenv_event()
|
||||||
def _constrain_is_bounded(self, a: sympy.Symbol, upper_bound: int) -> None:
|
def _constrain_is_bounded(self, a: sympy.Symbol, upper_bound: int) -> None:
|
||||||
self.oblivious_upper_bound_exclusive[a] = min(
|
# TODO: Do something nontrivial when upper_bound is expression
|
||||||
self.oblivious_upper_bound_exclusive.get(a, int_oo),
|
pass
|
||||||
sympy.Integer(upper_bound),
|
|
||||||
)
|
|
||||||
|
|
||||||
@record_shapeenv_event()
|
@record_shapeenv_event()
|
||||||
def _constrain_range_for_size(
|
def _constrain_range_for_size(
|
||||||
|
|
@ -5627,7 +5618,6 @@ class ShapeEnv:
|
||||||
var_ranges.get(s),
|
var_ranges.get(s),
|
||||||
self.var_to_val.get(s),
|
self.var_to_val.get(s),
|
||||||
s in self.size_like,
|
s in self.size_like,
|
||||||
self.oblivious_upper_bound_exclusive.get(s),
|
|
||||||
)
|
)
|
||||||
for s in sorted(fs, key=str) # TODO: speed up sort?
|
for s in sorted(fs, key=str) # TODO: speed up sort?
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user