mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamic shapes] add backed_size_oblivious option (#148696)
Adds option `torch.fx.experimental._config.backed_size_oblivious = True` to allocate `[0, inf]` instead of `[2, inf]` ranges for size backed symbols, and opting into size-oblivious semantics for them. Helps in a number of cases like - Keeps `[0, inf]` bounds for unbacked symbols, when we make a unbacked -> backed replacement - More sound handling for 0/1 inputs at runtime when we lower from export - Avoids ends-of-bounds, sys.maxsize constraint violations for exporting with named Dims (https://github.com/pytorch/pytorch/issues/146315, https://github.com/pytorch/pytorch/issues/146046) May look towards turning this on globally for export. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148696 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
53a1a022a9
commit
a6459afb0e
|
|
@ -2598,6 +2598,50 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
|||
ep.graph_module.false_graph_0.code
|
||||
)
|
||||
|
||||
def test_ends_of_bounds_oblivious(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("buf", torch.zeros(10))
|
||||
|
||||
def forward(self, x, y):
|
||||
self.buf[0 : x.shape[0]] = x
|
||||
return x + 2, y[:, ::1]
|
||||
|
||||
inps = (torch.randn(10), torch.randn(32, 36))
|
||||
dynamic_shapes = {
|
||||
"x": {0: Dim("dx", min=1, max=10)},
|
||||
"y": {0: Dim("dy0"), 1: Dim("dy1")},
|
||||
}
|
||||
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
|
||||
ep = export(Foo(), inps, dynamic_shapes=dynamic_shapes)
|
||||
ep.module()(torch.randn(9), torch.randn(4, 4))
|
||||
ep.module()(torch.randn(1), torch.randn(1, 1))
|
||||
|
||||
def test_colin_unbacked_backed_vr_sub(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, a, b, c):
|
||||
nz = torch.nonzero(a)
|
||||
ones = a.new_ones([nz.size(0), b.size(0)])
|
||||
torch._check(ones.size(0) >= 1)
|
||||
equals = torch.add(ones, c)
|
||||
return equals
|
||||
|
||||
model = Model()
|
||||
example_inputs = (
|
||||
torch.ones(64),
|
||||
torch.randn(32),
|
||||
torch.randn(64, 32),
|
||||
)
|
||||
dynamic_shapes = {"a": None, "b": None, "c": (Dim.DYNAMIC, Dim.STATIC)}
|
||||
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
|
||||
ep = export(model, example_inputs, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
# check lower bound
|
||||
for sym, vr in ep.range_constraints.items():
|
||||
if str(sym) in ["u0", "s0"]:
|
||||
self.assertEqual(vr.lower, 1)
|
||||
|
||||
def test_duplicate_modules_with_non_persistent_buffers(self):
|
||||
class FooWithBuf(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -432,11 +432,12 @@ class TestPySymInt(TestCase):
|
|||
else:
|
||||
result = expand_x + expand_x
|
||||
|
||||
gt_op, _bt = shape_env.guards[-1]
|
||||
gt_op, _bt, is_size_obv = shape_env.guards[-1]
|
||||
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
|
||||
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
|
||||
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
|
||||
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
|
||||
self.assertFalse(is_size_obv)
|
||||
|
||||
def test_floordiv_static(self):
|
||||
shape_env = ShapeEnv()
|
||||
|
|
|
|||
|
|
@ -230,6 +230,7 @@ class SLoc:
|
|||
class ShapeGuard(NamedTuple):
|
||||
expr: sympy.logic.boolalg.Boolean
|
||||
sloc: SLoc
|
||||
size_oblivious: bool
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
|
|
|
|||
|
|
@ -92,6 +92,11 @@ use_duck_shape = True
|
|||
# Default is False to prevent unintended registration. Set to True to enable.
|
||||
meta_nonzero_assume_all_nonzero = False
|
||||
|
||||
# Applies size-oblivious reasoning to backed symbols. This allocates a [0, inf] range for backed size symbols,
|
||||
# and relies on size-oblivious semantics to avoid 0/1 specialization guards by marking them size-like.
|
||||
# Currently an experimental option for export.
|
||||
backed_size_oblivious = False
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3880,15 +3880,21 @@ class ShapeEnv:
|
|||
constraint_dims = symbolic_context.constraint_sizes # type: ignore[attr-defined]
|
||||
size = []
|
||||
for i, val in enumerate(tensor_size):
|
||||
size.append(
|
||||
self.create_symbol(
|
||||
val,
|
||||
TensorPropertySource(source, TensorProperty.SIZE, i),
|
||||
dynamic_dims[i],
|
||||
constraint_dims[i],
|
||||
symbolic_context=symbolic_context,
|
||||
)
|
||||
sym = self.create_symbol(
|
||||
val,
|
||||
TensorPropertySource(source, TensorProperty.SIZE, i),
|
||||
dynamic_dims[i],
|
||||
constraint_dims[i],
|
||||
do_not_specialize_zero_one=config.backed_size_oblivious,
|
||||
symbolic_context=symbolic_context,
|
||||
)
|
||||
if (
|
||||
config.backed_size_oblivious
|
||||
and isinstance(sym, sympy.Symbol) # could be static
|
||||
and symbol_is_type(sym, SymT.SIZE)
|
||||
):
|
||||
self.size_like.add(sym)
|
||||
size.append(sym)
|
||||
return size
|
||||
|
||||
def create_symbolic_sizes_strides_storage_offset(
|
||||
|
|
@ -4534,7 +4540,9 @@ class ShapeEnv:
|
|||
self._add_assertion(sympy_expr > 1)
|
||||
|
||||
# Apply default range, which assumes not zero-one
|
||||
self.var_to_range[sympy_expr] = self._default_value_range()
|
||||
self.var_to_range[sympy_expr] = self._default_value_range(
|
||||
do_not_specialize_zero_one
|
||||
)
|
||||
self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(
|
||||
self._get_sloc(
|
||||
"user code shown is first use of this value--the guard itself is not "
|
||||
|
|
@ -5284,7 +5292,12 @@ class ShapeEnv:
|
|||
# This removes all the checks that follow from bounds
|
||||
# We could simply emit those and also the bounds 2 <= size when necessary
|
||||
for guard in guards if guards is not None else self.guards:
|
||||
if self._maybe_evaluate_static(guard.expr, axioms=()) is not None:
|
||||
if (
|
||||
self._maybe_evaluate_static(
|
||||
guard.expr, axioms=(), size_oblivious=guard.size_oblivious
|
||||
)
|
||||
is not None
|
||||
):
|
||||
continue
|
||||
issue_guard(guard)
|
||||
|
||||
|
|
@ -5587,7 +5600,10 @@ class ShapeEnv:
|
|||
return [
|
||||
self.simplify(guard.expr)
|
||||
for guard in self.guards
|
||||
if self._maybe_evaluate_static(guard.expr, axioms=()) is None
|
||||
if self._maybe_evaluate_static(
|
||||
guard.expr, axioms=(), size_oblivious=guard.size_oblivious
|
||||
)
|
||||
is None
|
||||
]
|
||||
|
||||
def format_guards(self, verbose: bool = False) -> str:
|
||||
|
|
@ -6404,8 +6420,10 @@ class ShapeEnv:
|
|||
return
|
||||
|
||||
# See: Note - On 0/1 specialization
|
||||
def _default_value_range(self) -> ValueRanges:
|
||||
lower = 2 if self.specialize_zero_one else 0
|
||||
def _default_value_range(
|
||||
self, do_not_specialize_zero_one: bool = False
|
||||
) -> ValueRanges:
|
||||
lower = 0 if (do_not_specialize_zero_one or not self.specialize_zero_one) else 2
|
||||
return ValueRanges(lower, int_oo)
|
||||
|
||||
def _default_unspecified_value_range(self) -> ValueRanges:
|
||||
|
|
@ -6773,7 +6791,13 @@ class ShapeEnv:
|
|||
else size_oblivious,
|
||||
static_expr,
|
||||
)
|
||||
if hint is not None:
|
||||
if (
|
||||
not size_oblivious
|
||||
and config.backed_size_oblivious
|
||||
and hint is not None
|
||||
):
|
||||
# TODO: maybe reconcile this with use of counterfactual hints
|
||||
# in unbacked case
|
||||
assert static_expr == hint, f"{static_expr} != {hint}"
|
||||
return static_expr
|
||||
|
||||
|
|
@ -6919,7 +6943,9 @@ class ShapeEnv:
|
|||
# at this point, we've evaluated the concrete expr value, and have
|
||||
# flipped/negated the guard if necessary. Now we know what to guard
|
||||
# or defer to runtime assert on.
|
||||
guard = ShapeGuard(g, self._get_sloc())
|
||||
guard = ShapeGuard(
|
||||
g, self._get_sloc(), size_oblivious=size_oblivious
|
||||
)
|
||||
self.guards.append(guard)
|
||||
self.axioms.update(dict(self.get_implications(self.simplify(g))))
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user