[dynamic shapes] add backed_size_oblivious option (#148696)

Adds option `torch.fx.experimental._config.backed_size_oblivious = True` to allocate `[0, inf]` instead of `[2, inf]` ranges for size backed symbols, and opting into size-oblivious semantics for them.

Helps in a number of cases like
- Keeps `[0, inf]` bounds for unbacked symbols, when we make a unbacked -> backed replacement
- More sound handling for 0/1 inputs at runtime when we lower from export
- Avoids ends-of-bounds, sys.maxsize constraint violations for exporting with named Dims (https://github.com/pytorch/pytorch/issues/146315, https://github.com/pytorch/pytorch/issues/146046)

May look towards turning this on globally for export.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148696
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Pian Pawakapan 2025-03-11 21:52:34 +00:00 committed by PyTorch MergeBot
parent 53a1a022a9
commit a6459afb0e
5 changed files with 93 additions and 16 deletions

View File

@ -2598,6 +2598,50 @@ def forward(self, p_linear_weight, p_linear_bias, x):
ep.graph_module.false_graph_0.code
)
def test_ends_of_bounds_oblivious(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.zeros(10))
def forward(self, x, y):
self.buf[0 : x.shape[0]] = x
return x + 2, y[:, ::1]
inps = (torch.randn(10), torch.randn(32, 36))
dynamic_shapes = {
"x": {0: Dim("dx", min=1, max=10)},
"y": {0: Dim("dy0"), 1: Dim("dy1")},
}
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
ep = export(Foo(), inps, dynamic_shapes=dynamic_shapes)
ep.module()(torch.randn(9), torch.randn(4, 4))
ep.module()(torch.randn(1), torch.randn(1, 1))
def test_colin_unbacked_backed_vr_sub(self):
class Model(torch.nn.Module):
def forward(self, a, b, c):
nz = torch.nonzero(a)
ones = a.new_ones([nz.size(0), b.size(0)])
torch._check(ones.size(0) >= 1)
equals = torch.add(ones, c)
return equals
model = Model()
example_inputs = (
torch.ones(64),
torch.randn(32),
torch.randn(64, 32),
)
dynamic_shapes = {"a": None, "b": None, "c": (Dim.DYNAMIC, Dim.STATIC)}
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
ep = export(model, example_inputs, dynamic_shapes=dynamic_shapes)
# check lower bound
for sym, vr in ep.range_constraints.items():
if str(sym) in ["u0", "s0"]:
self.assertEqual(vr.lower, 1)
def test_duplicate_modules_with_non_persistent_buffers(self):
class FooWithBuf(torch.nn.Module):
def __init__(self):

View File

@ -432,11 +432,12 @@ class TestPySymInt(TestCase):
else:
result = expand_x + expand_x
gt_op, _bt = shape_env.guards[-1]
gt_op, _bt, is_size_obv = shape_env.guards[-1]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
self.assertFalse(is_size_obv)
def test_floordiv_static(self):
shape_env = ShapeEnv()

View File

@ -230,6 +230,7 @@ class SLoc:
class ShapeGuard(NamedTuple):
expr: sympy.logic.boolalg.Boolean
sloc: SLoc
size_oblivious: bool
@dataclass_slots

View File

@ -92,6 +92,11 @@ use_duck_shape = True
# Default is False to prevent unintended registration. Set to True to enable.
meta_nonzero_assume_all_nonzero = False
# Applies size-oblivious reasoning to backed symbols. This allocates a [0, inf] range for backed size symbols,
# and relies on size-oblivious semantics to avoid 0/1 specialization guards by marking them size-like.
# Currently an experimental option for export.
backed_size_oblivious = False
from torch.utils._config_module import install_config_module

View File

@ -3880,15 +3880,21 @@ class ShapeEnv:
constraint_dims = symbolic_context.constraint_sizes # type: ignore[attr-defined]
size = []
for i, val in enumerate(tensor_size):
size.append(
self.create_symbol(
val,
TensorPropertySource(source, TensorProperty.SIZE, i),
dynamic_dims[i],
constraint_dims[i],
symbolic_context=symbolic_context,
)
sym = self.create_symbol(
val,
TensorPropertySource(source, TensorProperty.SIZE, i),
dynamic_dims[i],
constraint_dims[i],
do_not_specialize_zero_one=config.backed_size_oblivious,
symbolic_context=symbolic_context,
)
if (
config.backed_size_oblivious
and isinstance(sym, sympy.Symbol) # could be static
and symbol_is_type(sym, SymT.SIZE)
):
self.size_like.add(sym)
size.append(sym)
return size
def create_symbolic_sizes_strides_storage_offset(
@ -4534,7 +4540,9 @@ class ShapeEnv:
self._add_assertion(sympy_expr > 1)
# Apply default range, which assumes not zero-one
self.var_to_range[sympy_expr] = self._default_value_range()
self.var_to_range[sympy_expr] = self._default_value_range(
do_not_specialize_zero_one
)
self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(
self._get_sloc(
"user code shown is first use of this value--the guard itself is not "
@ -5284,7 +5292,12 @@ class ShapeEnv:
# This removes all the checks that follow from bounds
# We could simply emit those and also the bounds 2 <= size when necessary
for guard in guards if guards is not None else self.guards:
if self._maybe_evaluate_static(guard.expr, axioms=()) is not None:
if (
self._maybe_evaluate_static(
guard.expr, axioms=(), size_oblivious=guard.size_oblivious
)
is not None
):
continue
issue_guard(guard)
@ -5587,7 +5600,10 @@ class ShapeEnv:
return [
self.simplify(guard.expr)
for guard in self.guards
if self._maybe_evaluate_static(guard.expr, axioms=()) is None
if self._maybe_evaluate_static(
guard.expr, axioms=(), size_oblivious=guard.size_oblivious
)
is None
]
def format_guards(self, verbose: bool = False) -> str:
@ -6404,8 +6420,10 @@ class ShapeEnv:
return
# See: Note - On 0/1 specialization
def _default_value_range(self) -> ValueRanges:
lower = 2 if self.specialize_zero_one else 0
def _default_value_range(
self, do_not_specialize_zero_one: bool = False
) -> ValueRanges:
lower = 0 if (do_not_specialize_zero_one or not self.specialize_zero_one) else 2
return ValueRanges(lower, int_oo)
def _default_unspecified_value_range(self) -> ValueRanges:
@ -6773,7 +6791,13 @@ class ShapeEnv:
else size_oblivious,
static_expr,
)
if hint is not None:
if (
not size_oblivious
and config.backed_size_oblivious
and hint is not None
):
# TODO: maybe reconcile this with use of counterfactual hints
# in unbacked case
assert static_expr == hint, f"{static_expr} != {hint}"
return static_expr
@ -6919,7 +6943,9 @@ class ShapeEnv:
# at this point, we've evaluated the concrete expr value, and have
# flipped/negated the guard if necessary. Now we know what to guard
# or defer to runtime assert on.
guard = ShapeGuard(g, self._get_sloc())
guard = ShapeGuard(
g, self._get_sloc(), size_oblivious=size_oblivious
)
self.guards.append(guard)
self.axioms.update(dict(self.get_implications(self.simplify(g))))
else: