From 3f1a97a99cad4cc682b20b43c1178ed9e1b81f24 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 22 Aug 2025 20:48:46 +0000 Subject: [PATCH] Revert "[dynamic shapes] unbacked-safe slicing (#157944)" This reverts commit 44549c7146bd6c4166f97e856037babe1b7f4f49. Reverted https://github.com/pytorch/pytorch/pull/157944 on behalf of https://github.com/pianpwk due to this PR & internal diff landed out of sync, just reverted internal with D80720654, will revert this & reland as codev ([comment](https://github.com/pytorch/pytorch/pull/157944#issuecomment-3215610135)) --- test/export/test_draft_export.py | 9 +- test/export/test_export.py | 28 +---- test/test_dynamic_shapes.py | 113 ------------------ test/test_proxy_tensor.py | 3 +- torch/_decomp/decompositions.py | 30 ++--- torch/_inductor/codegen/cpp_wrapper_cpu.py | 44 +------ torch/_inductor/codegen/wrapper.py | 21 +--- torch/_inductor/ir.py | 63 ++--------- torch/_inductor/lowering.py | 126 +-------------------- torch/_subclasses/fake_impls.py | 85 +------------- torch/_subclasses/fake_tensor.py | 10 +- 11 files changed, 39 insertions(+), 493 deletions(-) diff --git a/test/export/test_draft_export.py b/test/export/test_draft_export.py index fe95d9538fe..6cf819958fc 100644 --- a/test/export/test_draft_export.py +++ b/test/export/test_draft_export.py @@ -296,8 +296,7 @@ class TestDraftExport(TestCase): res = torch.ops.mylib.foo1(a, b) c_item = c.item() - if c_item > 0: - return res[:c_item] + return res[:c_item] inp = (torch.ones(3, 3), torch.ones(3, 3), torch.tensor(3)) @@ -368,8 +367,8 @@ class TestDraftExport(TestCase): a = a + 5 z = torch.cat([y, y]) - if a > 0: - return z[:a] + + return z[:a] ep = draft_export( M(), @@ -387,7 +386,7 @@ class TestDraftExport(TestCase): for node in _ep.graph.nodes: if bindings := node.meta.get("unbacked_bindings"): unbacked_binding_symbols.update(bindings.keys()) - self.assertEqual(len(unbacked_binding_symbols), 2) + self.assertEqual(len(unbacked_binding_symbols), 1) def test_offsets(self): class M(torch.nn.Module): diff --git a/test/export/test_export.py b/test/export/test_export.py index 5695f4eaf7e..78d968ae6c7 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3089,32 +3089,6 @@ def forward(self, causal_mask, fill_value): }, ) - def test_unbacked_slice_forward(self): - class Foo(torch.nn.Module): - def forward(self, x, xs): - u0, u1 = xs.tolist() - out = x[u0:u1] - return out - - x = torch.randn(10) - idxs = torch.tensor([3, 6]) - mod = Foo() - ep = export(mod, (x, idxs)) - for xs in [ - idxs, - torch.tensor([-9, -1]), - torch.tensor([-10000, 10000]), - torch.tensor([0, -10]), - ]: - self.assertTrue(torch.allclose(ep.module()(x, xs), mod(x, xs))) - - # check unbacked bindings - # should be 4 symbols: u0, u1, output size, output storage offset - bound_unbacked = set() - for node in ep.graph.nodes: - bound_unbacked |= node.meta.get("unbacked_bindings", {}).keys() - self.assertEqual(len(bound_unbacked), 4) - def test_dim_hint_ranges(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -5791,7 +5765,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): } self._test_export_same_as_eager(kw_func, args, kwargs) - def test_unbacked_slice_simple(self): + def test_unbacked_slice(self): class M(torch.nn.Module): def forward(self, scores, score_thr, topk: torch.Tensor, results=None): valid_mask = scores > score_thr diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 6a23915c56e..7ba466119da 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3449,119 +3449,6 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", self.assertEqual(result_compiled, result_eager) self.assertEqual(cnt.frame_count, 2) - @fresh_cache() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_unbacked_slice(self): - from torch.fx.experimental.symbolic_shapes import statically_known_true - - # standard slice - def f1(x, xs): - u0, u1 = xs.tolist() - torch._check_is_size(u0, max=x.size(0)) - torch._check_is_size(u1, max=x.size(0)) - torch._check(u0 <= u1) - out = x[u0:u1] - assert statically_known_true(out.size(0) == (u1 - u0)) - return out - - x, xs = torch.randn(10), torch.tensor([3, 6]) - fn1 = torch.compile(f1, fullgraph=True, backend="inductor") - self.assertEqual(fn1(x, xs).size(0), 3) - self.assertTrue(torch.allclose(fn1(x, xs), f1(x, xs))) - with self.assertRaises(RuntimeError): - fn1(x, torch.tensor([-1, 5])) - - # known negative slice - def f2(x, n): - u0 = n.item() - torch._check(u0 > 1) - torch._check(u0 <= x.size(0)) - out = x[-u0:] - assert statically_known_true(out.size(0) == u0) - return out - - x, n = torch.randn(10), torch.tensor([5]) - fn2 = torch.compile(f2, fullgraph=True, backend="inductor") - self.assertEqual(fn2(x, n).size(0), 5) - self.assertTrue(torch.allclose(fn2(x, n), f2(x, n))) - with self.assertRaises(RuntimeError): - fn2(x, torch.tensor([-5])) - - # general case: no known info - def f3(x, xs): - u0, u1 = xs.tolist() - return x[u0:u1] - - log_stream, ctx = logs_to_string( - "torch._inductor.compile_fx", "post_grad_graphs" - ) - cnts = CompileCounterWithBackend("inductor") - x, xs = torch.randn(10), torch.tensor([3, 6]) - with ctx(): - fn3 = torch.compile(f3, fullgraph=True, backend=cnts) - xs = torch.tensor([-9, -1]) # negative case - self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs))) - xs = torch.tensor([-1000, 1000]) # out of bounds - self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs))) - xs = torch.tensor([2, -2]) # mixed - self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs))) - self.assertEqual(cnts.frame_count, 1) - - aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() - self.assertExpectedInline( - aot_graphs, - """\ - select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) - _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None - select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None - _local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None - slice_1: "f32[u2][1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, _local_scalar_dense, _local_scalar_dense_1); arg1_1 = _local_scalar_dense = _local_scalar_dense_1 = None - sym_size_int: "Sym(u2)" = torch.ops.aten.sym_size.int(slice_1, 0) - ge_2: "Sym(u2 >= 0)" = sym_size_int >= 0 - _assert_scalar = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_2 = _assert_scalar = None - le: "Sym(u2 <= 10)" = sym_size_int <= 10; sym_size_int = None - _assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u2 <= 10 on node 'le'"); le = _assert_scalar_1 = None - sym_storage_offset_default: "Sym(u3)" = torch.ops.aten.sym_storage_offset.default(slice_1) - ge_3: "Sym(u3 >= 0)" = sym_storage_offset_default >= 0; sym_storage_offset_default = None - _assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_2 = None - return (slice_1,)""", # noqa: B950 - ignore_comments=True, - ignore_empty_lines=True, - ) - - @fresh_cache() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - @torch._inductor.config.patch("cpp_wrapper", True) - def test_unbacked_slice_cpp_wrapper(self): - self.test_unbacked_slice() - - @fresh_cache() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_tensor_split(self): - def f1(x, xs): - xs = torch.tensor(xs.tolist()) - return torch.tensor_split(x, xs) - - x = torch.randn(20) - xs = torch.tensor([5, 10, 15]) - fn = torch.compile(f1, fullgraph=True, backend="inductor") - - def compare(x, xs): - for i, j in zip(f1(x, xs), fn(x, xs)): - self.assertTrue(torch.allclose(i, j)) - - compare(x, xs) - xs = torch.tensor([-15, 9, 10, 11]) - compare(x, xs) - xs = torch.tensor([-15, -10, -5, -2]) - compare(x, xs) - - @fresh_cache() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - @torch._inductor.config.patch("cpp_wrapper", True) - def test_tensor_split_cpp_wrapper(self): - self.test_tensor_split() - @unittest.skip("this test fails due to inductor/autograd issue #153041") @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_non_contigious_reshape_failing(self): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index f278eb33be1..6d36b36996c 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1973,6 +1973,7 @@ make_fx_failures = { skip('item'), xfail('cov'), xfail('nn.functional.gaussian_nll_loss'), + xfail('tensor_split'), xfail('corrcoef'), xfail('quantile'), xfail('nanquantile'), @@ -1992,12 +1993,10 @@ make_fx_failures = { only_real_tensor_failures = { xfail('narrow'), - xfail('tensor_split'), } only_fake_tensor_failures = { xfail('narrow'), - xfail('tensor_split'), } fake_tensor_failures = set() diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 954950318b6..ba09c6173c5 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -6,7 +6,6 @@ import numbers import operator import sys from collections.abc import Iterable -from contextlib import nullcontext from enum import Enum from functools import partial, reduce from itertools import chain, product @@ -722,7 +721,10 @@ def slice_forward( end: Optional[int] = None, step: int = 1, ): - from torch.fx.experimental.symbolic_shapes import statically_known_true + from torch.fx.experimental.symbolic_shapes import ( + guard_size_oblivious, + statically_known_true, + ) ndim = self.dim() if ndim == 0: @@ -737,22 +739,22 @@ def slice_forward( start_val = start if start is not None else 0 end_val = end if end is not None else sys.maxsize # 2^63 - 1 - if start_val < 0: + if guard_size_oblivious(start_val < 0): start_val += sizes[dim] - if end_val < 0: + if guard_size_oblivious(end_val < 0): end_val += sizes[dim] - if start_val < 0: + if guard_size_oblivious(start_val < 0): start_val = 0 - elif start_val > sizes[dim]: + elif guard_size_oblivious(start_val > sizes[dim]): start_val = sizes[dim] if statically_known_true(end_val == sys.maxsize): end_val = sizes[dim] - elif end_val < start_val: + elif guard_size_oblivious(end_val < start_val): end_val = start_val - elif end_val > sizes[dim]: + elif guard_size_oblivious(end_val > sizes[dim]): end_val = sizes[dim] storage_offset = self.storage_offset() + start_val * strides[dim] @@ -1436,17 +1438,7 @@ def tensor_split_tensor_indices_or_sections_py_impl( assert isinstance(sections, IntLike) return self.tensor_split(sections, dim) else: - ctx = nullcontext - if (fake_mode := torch._guards.detect_fake_mode()) and ( - shape_env := fake_mode.shape_env - ): - ctx = shape_env.ignore_fresh_unbacked_symbols # type: ignore[assignment] - # In fake tensor prop, we end up calling slice() with these unbacked indices. - # Because slice has flexible semantics, the unbacked handling generates new output sizes - # for each slice, effectively clobbering over these index symbols. - # To avoid PendingUnbackedSymbolNotFound errors, we tell the compiler it's fine to not bind these. - with ctx(): - indices = [i.item() for i in tensor_indices_or_sections] + indices = [i.item() for i in tensor_indices_or_sections] # WARNING: Tempted to torch._check_is_size on the indices here? You # can't: tensor_split works with negative values in indices: # diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index ea1cf09c1b8..6fa08465ce2 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1456,51 +1456,19 @@ class CppWrapperCpu(PythonWrapperCodegen): # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again self.unbacked_symbol_decls.add(str(node.sym)) - def codegen_dynamic_select_index(self, node, clamp): + def codegen_dynamic_select_index(self, node): index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int) - size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int) - # codegen index - sym = node.unbacked_offset_symbol - index_str = ( + index_compute_str = ( f"{index_cpp_str} < 0 ? {index_cpp_str} + " - f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}" + f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}" ) - self.writeline(f"auto {sym}_index = {index_str};") - index_str_clamped = ( - f"{sym}_index < 0 ? 0 : ({sym}_index > {size_cpp_str} ? {size_cpp_str} : {sym}_index)" - if clamp - else f"{sym}_index" - ) - self.writeline(f"auto {sym}_index_clamped = {index_str_clamped};") self.writeline( - f"auto {sym} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + " - f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * {sym}_index_clamped;" + f"auto {node.unbacked_offset_symbol} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + " + f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * ({index_compute_str});" ) # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again - self.unbacked_symbol_decls.add(str(sym)) - - def codegen_dynamic_slice_size(self, node): - start_cpp_str = self.val_to_arg_str_for_prim_type(node.start, int) - end_cpp_str = self.val_to_arg_str_for_prim_type(node.end, int) - size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int) - sym = node.unbacked_size_symbol - - def codegen_clamp(index_str, start=True): - suf = "start" if start else "end" - index_ = f"{sym}_{suf}_index" - self.writeline( - f"auto {index_} = {index_str} < 0 ? {index_str} + {size_cpp_str} : {index_str};" - ) - self.writeline( - f"auto {sym}_{suf}_clamped = {index_} < 0 ? 0 : ({index_} > {size_cpp_str} ? {size_cpp_str} : {index_});" - ) - - codegen_clamp(start_cpp_str, start=True) - codegen_clamp(end_cpp_str, start=False) - self.writeline(f"auto {sym}_raw = {sym}_end_clamped - {sym}_start_clamped;") - self.writeline(f"auto {sym} = {sym}_raw < 0 ? 0 : {sym}_raw;") - self.unbacked_symbol_decls.add(str(sym)) + self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) def make_buffer_free(self, buffer): return ( diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index b6b8075e928..27d8a28cb96 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1887,33 +1887,14 @@ class PythonWrapperCodegen(CodeGen): arg_name = node.input_name(0) self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices)) - def codegen_dynamic_select_index(self, node, clamp): + def codegen_dynamic_select_index(self, node): index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}" - if clamp: - index_str = f"max(0, min({node.size}, {index_str}))" self.writeline( f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})" ) # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol)) - def codegen_dynamic_slice_size(self, node): - def clamp_index(x): - pos = self.codegen_sizevar(sympy.Max(0, sympy.Min(x, node.size))) - neg = self.codegen_sizevar( - sympy.Max(0, sympy.Min(x + node.size, node.size)) - ) - return f"{pos} if {x} >= 0 else {neg}" - - # codegen start, end - sym = node.unbacked_size_symbol - start = clamp_index(node.start) - end = clamp_index(node.end) - self.writeline(f"{sym}_start = {start}") - self.writeline(f"{sym}_end = {end}") - self.writeline(f"{sym} = max(0, {sym}_end - {sym}_start)") - self.unbacked_symbol_decls.add(str(node.unbacked_size_symbol)) - def codegen_dynamic_scalar(self, node): (data,) = (t.codegen_reference() for t in node.inputs) if len(node.keypath) == 0: diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index e8449d30972..6255bdb6fcc 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3437,6 +3437,7 @@ class SliceView(View): if val is None: # TODO(rec): can this really happen? return default + val = cls.handle_negative_index(val, dim_size) return clamp(val, lower, upper) start = clamp_wrap(start, 0, dim_size, 0) @@ -3453,6 +3454,14 @@ class SliceView(View): step: int = 1, clamp: bool = True, ) -> IRNode: + step = sympy.expand(step) + assert isinstance(step, Expr) or step > 0, step + try: + if start == 0 and end >= 2**63 - 1 and step == 1: + return x + except TypeError: + pass + new_size = list(x.get_size()) # NB: Ordinarily we default to clamping. @@ -7212,7 +7221,6 @@ class DynamicSelectStorageOffset(ExternKernel): base_offset: Union[sympy.Symbol, int], base_dim_stride: Union[sympy.Symbol, int], size: Union[sympy.Symbol, int], - clamp: bool, ) -> None: super().__init__(None, NoneLayout(device=torch.device("cpu")), []) # This node codegen the following: @@ -7222,7 +7230,6 @@ class DynamicSelectStorageOffset(ExternKernel): self.base_offset = base_offset self.base_dim_stride = base_dim_stride self.size = size - self.clamp = clamp def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet([self.unbacked_offset_symbol]) @@ -7233,57 +7240,7 @@ class DynamicSelectStorageOffset(ExternKernel): return get_free_symbols(self.index, unbacked_only) def codegen(self, wrapper: PythonWrapperCodegen) -> None: - wrapper.codegen_dynamic_select_index(self, clamp=self.clamp) - - -class DynamicSliceSize(ExternKernel): - """ - Computes the output size of a slice call, handling the correct semantics in codegen. - We do this for flexible handling for unbacked indices (to not data-dependent error). - - Slicing has 4 semantics for indices, i.e. x[start:] could be: - 1) start < -x.size(0) -> x[0:] # negative out-of-bounds - 2) start in [-x.size(0), 0) -> x[x.size(0) + start:] # negative slicing - 3) start in [0, x.size(0)) -> x[start:] # standard slicing - 4) start >= x.size(0) -> empty slice # positive out-of-bounds - - If the appropriate semantics are known beforehand, the output size is computed based on - the start & end indices. If not (with unbacked indices), a new unbacked symbol is created - to represent the output size, and codegen handles computing the correct case. - """ - - def get_reads(self) -> OrderedSet[Dep]: - return OrderedSet() - - def should_allocate(self) -> bool: - return False - - def __init__( - self, - unbacked_size_symbol: sympy.Symbol, - start: sympy.Symbol, - end: Union[sympy.Symbol, int], - size: Union[sympy.Symbol, int], - ): - super().__init__(None, NoneLayout(device=torch.device("cpu")), []) - # This node codegen - self.unbacked_size_symbol = unbacked_size_symbol - self.start = start - self.end = end - self.size = size - - def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: - return OrderedSet([self.unbacked_size_symbol]) - - def get_free_symbol_uses( - self, unbacked_only: bool = False - ) -> OrderedSet[sympy.Symbol]: - return get_free_symbols(self.start, unbacked_only).union( - get_free_symbols(self.end, unbacked_only) - ) - - def codegen(self, wrapper: PythonWrapperCodegen) -> None: - wrapper.codegen_dynamic_slice_size(self) + wrapper.codegen_dynamic_select_index(self) class DynamicScalar(ExternKernel): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index e708355e3f6..b29732eb67e 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1172,130 +1172,9 @@ def permute(x, dims): @register_lowering(aten.slice, type_promotion_kind=None) def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True): - """ - Lowers a slice call, creating ExternKernels for the output size & storage offset symbols, - if the indices are unbacked and appropriate semantics aren't known. - If they are known (indices are static/backed/unbacked with info), a SliceView is created. - """ - - from torch.fx.experimental.symbolic_shapes import ( - CallMethodKey, - resolve_unbacked_bindings, - ) - assert isinstance(x, TensorBox) dim = _validate_dim(x, dim, 0) - size = x.get_size()[dim] - step = sympy.expand(step) - assert isinstance(step, sympy.Expr) or step > 0, step - - # maybe apply slice optimization - try: - if ( - start == 0 - and V.graph.sizevars.statically_known_leq(size, end) - and step == 1 - ): - return x - except TypeError: - pass - - # try to avoid dynamic slice - def handle_negative_index(idx, size, default): - if idx is None: - return default - idx = sympy.expand(idx) - size = sympy.expand(size) - if V.graph.sizevars.guard_or_false(idx >= 0): - return idx - elif V.graph.sizevars.guard_or_false(idx < 0): - return size + idx - return None - - ambiguous_slice = clamp - if ambiguous_slice: - start_index = handle_negative_index(start, size, 0) - end_index = handle_negative_index(end, size, size) - if start_index is not None and end_index is not None: - start, end = start_index, end_index - ambiguous_slice = False - - # ambiguous_slice=False means we know what semantics this slice call follows, - # and don't need to generate an extern kernel to represent the output size. - # This is assumed True for clamp=False - # (meant to follow standard indexing semantics: 0 <= index < size) - if not ambiguous_slice: - return TensorBox( - ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp) - ) # go to SliceView/ReinterpretView - - # unbacked territory: create DynamicSlice ExternKernel - # clamp is True, unbacked start / end - assert clamp - unbacked_bindings = resolve_unbacked_bindings( - V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] - ) - assert unbacked_bindings is not None - assert len(unbacked_bindings) <= 2, unbacked_bindings - sym_size, sym_storage = None, None - for sym, keypath in unbacked_bindings.items(): - if keypath == (CallMethodKey("size"), pytree.SequenceKey(dim)): - sym_size = sym - elif keypath == (CallMethodKey("storage_offset"),): - sym_storage = sym - - def compute_slice_index(index, size): - fn = lambda x: V.graph.sizevars.guard_or_false(x) # noqa: E731 - - if fn(sympy.Ge(index, 0)) and fn(sympy.Le(index, size)): - return index - elif fn(sympy.Lt(index, 0)) and fn(sympy.Ge(index, -size)): - return -index - elif fn(sympy.Gt(index, size)): - return size - elif fn(sympy.Lt(index, -size)): - return 0 - return None - - start_index = compute_slice_index(start, size) - end_index = compute_slice_index(end, size) - if start_index is not None and end_index is not None: - # we shouldn't have allocated size symbol, if output size was determinable from input indices - assert sym_size is None - new_size = sympy.Max(0, end_index - start_index) - else: - b_size = ir.DynamicSliceSize( - sym_size, - start, - end, - x.get_size()[dim], - ) - b_size.name = V.graph.register_buffer(b_size) - V.graph.register_operation(b_size) - new_size = sym_size - - if start_index is not None: - # we shouldn't have allocated storage offset symbol if start index was determinable - assert sym_storage is None - new_storage_offset = x.get_layout().offset + start_index * x.get_stride()[dim] - else: - b_storage = ir.DynamicSelectStorageOffset( - sym_storage, - start, - x.get_layout().offset, - x.get_stride()[dim], - x.get_size()[dim], - clamp=True, - ) - b_storage.name = V.graph.register_buffer(b_storage) - V.graph.register_operation(b_storage) - new_storage_offset = sym_storage - - new_sizes = list(x.get_size()) - new_strides = list(x.get_stride()) - new_sizes[dim] = new_size - new_strides[dim] *= step - return as_strided(x, new_sizes, new_strides, new_storage_offset) + return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp)) @register_lowering(aten.as_strided, type_promotion_kind=None) @@ -1921,7 +1800,6 @@ def select(x, dim, idx): x.get_layout().offset, new_stride[dim], x.get_size()[dim], - clamp=False, ) buffer.name = V.graph.register_buffer(buffer) V.graph.register_operation(buffer) @@ -3113,8 +2991,6 @@ def slice_scatter(x, src, dim=0, start=None, end=None, step=1): dim = _validate_dim(x, dim, 0) dim_size = x.get_size()[dim] - start = ir.SliceView.handle_negative_index(start, dim_size) - end = ir.SliceView.handle_negative_index(end, dim_size) start, end = ir.SliceView.normalize_start_end(x, dim, start, end) src_size = list(x.get_size()) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 10ba37b3611..7ebd2ec92d1 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -6,7 +6,7 @@ import math import operator import sys from functools import reduce -from typing import Callable, Optional, Union +from typing import Callable, Union import torch import torch._custom_op @@ -15,7 +15,6 @@ import torch._prims_common as utils from torch._dispatch.python import no_python_dispatcher from torch._ops import OpOverload from torch._prims_common import ( - canonicalize_dim, contiguous_for_memory_format_or_false, elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, @@ -747,88 +746,6 @@ def _padded_dense_to_jagged_forward(fake_mode, func, padded, offsets, total_L=No return padded.new_empty(output_shape) -def _compute_slice_index(size, index): - from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_and - - if guard_or_false(sym_and(index >= 0, index <= size)): - return index - elif guard_or_false(sym_and(index < 0, index >= -size)): - return index + size - elif guard_or_false(index < -size): - return 0 - elif guard_or_false(index > size): - return size - return None - - -@register_op_impl(torch.ops.aten.slice.Tensor) -def slice_forward( - fake_mode, - func, - self, - dim: int = 0, - start: Optional[int] = None, - end: Optional[int] = None, - step: int = 1, -): - from torch.fx.experimental.symbolic_shapes import ( - guard_or_false, - statically_known_true, - ) - - shape_env = fake_mode.shape_env - - ndim = self.dim() - if ndim == 0: - raise RuntimeError("slice() cannot be applied to a 0-dim tensor.") - dim = canonicalize_dim(self.dim(), dim) - sizes = list(self.size()) - strides = list(self.stride()) - - if step <= 0: - raise RuntimeError("slice step must be positive") - - # start, end - start_index = 0 if start is None else _compute_slice_index(sizes[dim], start) - end_index = ( - sizes[dim] - if statically_known_true(end == sys.maxsize) or end is None - else _compute_slice_index(sizes[dim], end) - ) - - # size - new_size = None - if start_index is not None and end_index is not None: - if guard_or_false(end_index >= start_index): - new_size = (end_index - start_index + step - 1) // step - elif guard_or_false(start_index >= end_index): - new_size = 0 - - # create unbacked if case unknown - if new_size is None: - new_size = shape_env.create_unbacked_symint() - torch._check_is_size(new_size, max=sizes[dim]) - - # stride - new_stride = strides[dim] * step - - # storage offset - if start_index is not None: - storage_offset = self.storage_offset() + start_index * strides[dim] - else: - storage_offset = shape_env.create_unbacked_symint() - torch._check(storage_offset >= 0) - - sizes[dim] = new_size - strides[dim] = new_stride - if self.is_quantized: - raise NotImplementedError( - "Slice decomposition for quantized tensors aren't implemented" - ) - else: - return self.as_strided(sizes, strides, storage_offset) - - @register_op_impl(torch.ops.aten.masked_select.default) def masked_select(fake_mode, func, self, mask): if ( diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 6da4bd98eca..52b776946b3 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -2616,9 +2616,7 @@ class FakeTensorMode(TorchDispatchMode): if ( func not in meta_table and not self.cpp_meta_supports_symint(func) - and not ( - has_symbolic_sizes and func in self._unbacked_special_fake_handling_ops - ) + and not (has_symbolic_sizes and func in self._view_fake_tensor_impl_ops) ): from torch._decomp import decomposition_table @@ -2927,10 +2925,8 @@ class FakeTensorMode(TorchDispatchMode): aten._sparse_coo_tensor_with_dims_and_tensors.default, ) - _unbacked_special_fake_handling_ops = ordered_set( - aten.view.default, - aten._unsafe_view.default, - aten.slice.Tensor, + _view_fake_tensor_impl_ops = ordered_set( + aten.view.default, aten._unsafe_view.default ) def cpp_meta_supports_symint(self, func: OpOverload) -> bool: