[dynamic shapes] unbacked-safe slicing (#157944)

Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157944
Approved by: https://github.com/laithsakka
This commit is contained in:
Pian Pawakapan 2025-08-18 22:38:12 +00:00 committed by PyTorch MergeBot
parent 0254646654
commit 56218d85e2
10 changed files with 488 additions and 35 deletions

View File

@ -3028,6 +3028,32 @@ def forward(self, causal_mask, fill_value):
},
)
def test_unbacked_slice_forward(self):
class Foo(torch.nn.Module):
def forward(self, x, xs):
u0, u1 = xs.tolist()
out = x[u0:u1]
return out
x = torch.randn(10)
idxs = torch.tensor([3, 6])
mod = Foo()
ep = export(mod, (x, idxs))
for xs in [
idxs,
torch.tensor([-9, -1]),
torch.tensor([-10000, 10000]),
torch.tensor([0, -10]),
]:
self.assertTrue(torch.allclose(ep.module()(x, xs), mod(x, xs)))
# check unbacked bindings
# should be 4 symbols: u0, u1, output size, output storage offset
bound_unbacked = set()
for node in ep.graph.nodes:
bound_unbacked |= node.meta.get("unbacked_bindings", {}).keys()
self.assertEqual(len(bound_unbacked), 4)
def test_dim_hint_ranges(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
@ -5704,7 +5730,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
}
self._test_export_same_as_eager(kw_func, args, kwargs)
def test_unbacked_slice(self):
def test_unbacked_slice_simple(self):
class M(torch.nn.Module):
def forward(self, scores, score_thr, topk: torch.Tensor, results=None):
valid_mask = scores > score_thr

View File

@ -3391,6 +3391,119 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 2)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_slice(self):
from torch.fx.experimental.symbolic_shapes import statically_known_true
# standard slice
def f1(x, xs):
u0, u1 = xs.tolist()
torch._check_is_size(u0, max=x.size(0))
torch._check_is_size(u1, max=x.size(0))
torch._check(u0 <= u1)
out = x[u0:u1]
assert statically_known_true(out.size(0) == (u1 - u0))
return out
x, xs = torch.randn(10), torch.tensor([3, 6])
fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
self.assertEqual(fn1(x, xs).size(0), 3)
self.assertTrue(torch.allclose(fn1(x, xs), f1(x, xs)))
with self.assertRaises(RuntimeError):
fn1(x, torch.tensor([-1, 5]))
# known negative slice
def f2(x, n):
u0 = n.item()
torch._check(u0 > 1)
torch._check(u0 <= x.size(0))
out = x[-u0:]
assert statically_known_true(out.size(0) == u0)
return out
x, n = torch.randn(10), torch.tensor([5])
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
self.assertEqual(fn2(x, n).size(0), 5)
self.assertTrue(torch.allclose(fn2(x, n), f2(x, n)))
with self.assertRaises(RuntimeError):
fn2(x, torch.tensor([-5]))
# general case: no known info
def f3(x, xs):
u0, u1 = xs.tolist()
return x[u0:u1]
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
cnts = CompileCounterWithBackend("inductor")
x, xs = torch.randn(10), torch.tensor([3, 6])
with ctx():
fn3 = torch.compile(f3, fullgraph=True, backend=cnts)
xs = torch.tensor([-9, -1]) # negative case
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
xs = torch.tensor([-1000, 1000]) # out of bounds
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
xs = torch.tensor([2, -2]) # mixed
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
self.assertEqual(cnts.frame_count, 1)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""\
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
slice_1: "f32[u2][1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, _local_scalar_dense, _local_scalar_dense_1); arg1_1 = _local_scalar_dense = _local_scalar_dense_1 = None
sym_size_int: "Sym(u2)" = torch.ops.aten.sym_size.int(slice_1, 0)
ge_2: "Sym(u2 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_2 = _assert_scalar = None
le: "Sym(u2 <= 10)" = sym_size_int <= 10; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u2 <= 10 on node 'le'"); le = _assert_scalar_1 = None
sym_storage_offset_default: "Sym(u3)" = torch.ops.aten.sym_storage_offset.default(slice_1)
ge_3: "Sym(u3 >= 0)" = sym_storage_offset_default >= 0; sym_storage_offset_default = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_2 = None
return (slice_1,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_unbacked_slice_cpp_wrapper(self):
self.test_unbacked_slice()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_tensor_split(self):
def f1(x, xs):
xs = torch.tensor(xs.tolist())
return torch.tensor_split(x, xs)
x = torch.randn(20)
xs = torch.tensor([5, 10, 15])
fn = torch.compile(f1, fullgraph=True, backend="inductor")
def compare(x, xs):
for i, j in zip(f1(x, xs), fn(x, xs)):
self.assertTrue(torch.allclose(i, j))
compare(x, xs)
xs = torch.tensor([-15, 9, 10, 11])
compare(x, xs)
xs = torch.tensor([-15, -10, -5, -2])
compare(x, xs)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_tensor_split_cpp_wrapper(self):
self.test_tensor_split()
@unittest.skip("this test fails due to inductor/autograd issue #153041")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_non_contigious_reshape_failing(self):

View File

@ -1973,7 +1973,6 @@ make_fx_failures = {
skip('item'),
xfail('cov'),
xfail('nn.functional.gaussian_nll_loss'),
xfail('tensor_split'),
xfail('corrcoef'),
xfail('quantile'),
xfail('nanquantile'),
@ -1993,10 +1992,12 @@ make_fx_failures = {
only_real_tensor_failures = {
xfail('narrow'),
xfail('tensor_split'),
}
only_fake_tensor_failures = {
xfail('narrow'),
xfail('tensor_split'),
}
fake_tensor_failures = set()

View File

@ -6,6 +6,7 @@ import numbers
import operator
import sys
from collections.abc import Iterable
from contextlib import nullcontext
from enum import Enum
from functools import partial, reduce
from itertools import chain, product
@ -721,10 +722,7 @@ def slice_forward(
end: Optional[int] = None,
step: int = 1,
):
from torch.fx.experimental.symbolic_shapes import (
guard_size_oblivious,
statically_known_true,
)
from torch.fx.experimental.symbolic_shapes import statically_known_true
ndim = self.dim()
if ndim == 0:
@ -739,22 +737,22 @@ def slice_forward(
start_val = start if start is not None else 0
end_val = end if end is not None else sys.maxsize # 2^63 - 1
if guard_size_oblivious(start_val < 0):
if start_val < 0:
start_val += sizes[dim]
if guard_size_oblivious(end_val < 0):
if end_val < 0:
end_val += sizes[dim]
if guard_size_oblivious(start_val < 0):
if start_val < 0:
start_val = 0
elif guard_size_oblivious(start_val > sizes[dim]):
elif start_val > sizes[dim]:
start_val = sizes[dim]
if statically_known_true(end_val == sys.maxsize):
end_val = sizes[dim]
elif guard_size_oblivious(end_val < start_val):
elif end_val < start_val:
end_val = start_val
elif guard_size_oblivious(end_val > sizes[dim]):
elif end_val > sizes[dim]:
end_val = sizes[dim]
storage_offset = self.storage_offset() + start_val * strides[dim]
@ -1438,7 +1436,17 @@ def tensor_split_tensor_indices_or_sections_py_impl(
assert isinstance(sections, IntLike)
return self.tensor_split(sections, dim)
else:
indices = [i.item() for i in tensor_indices_or_sections]
ctx = nullcontext
if (fake_mode := torch._guards.detect_fake_mode()) and (
shape_env := fake_mode.shape_env
):
ctx = shape_env.ignore_fresh_unbacked_symbols # type: ignore[assignment]
# In fake tensor prop, we end up calling slice() with these unbacked indices.
# Because slice has flexible semantics, the unbacked handling generates new output sizes
# for each slice, effectively clobbering over these index symbols.
# To avoid PendingUnbackedSymbolNotFound errors, we tell the compiler it's fine to not bind these.
with ctx():
indices = [i.item() for i in tensor_indices_or_sections]
# WARNING: Tempted to torch._check_is_size on the indices here? You
# can't: tensor_split works with negative values in indices:
#

View File

@ -1456,19 +1456,51 @@ class CppWrapperCpu(PythonWrapperCodegen):
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
self.unbacked_symbol_decls.add(str(node.sym))
def codegen_dynamic_select_index(self, node):
def codegen_dynamic_select_index(self, node, clamp):
index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int)
size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int)
index_compute_str = (
# codegen index
sym = node.unbacked_offset_symbol
index_str = (
f"{index_cpp_str} < 0 ? {index_cpp_str} + "
f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}"
f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}"
)
self.writeline(f"auto {sym}_index = {index_str};")
index_str_clamped = (
f"{sym}_index < 0 ? 0 : ({sym}_index > {size_cpp_str} ? {size_cpp_str} : {sym}_index)"
if clamp
else f"{sym}_index"
)
self.writeline(f"auto {sym}_index_clamped = {index_str_clamped};")
self.writeline(
f"auto {node.unbacked_offset_symbol} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + "
f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * ({index_compute_str});"
f"auto {sym} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + "
f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * {sym}_index_clamped;"
)
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
self.unbacked_symbol_decls.add(str(sym))
def codegen_dynamic_slice_size(self, node):
start_cpp_str = self.val_to_arg_str_for_prim_type(node.start, int)
end_cpp_str = self.val_to_arg_str_for_prim_type(node.end, int)
size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int)
sym = node.unbacked_size_symbol
def codegen_clamp(index_str, start=True):
suf = "start" if start else "end"
index_ = f"{sym}_{suf}_index"
self.writeline(
f"auto {index_} = {index_str} < 0 ? {index_str} + {size_cpp_str} : {index_str};"
)
self.writeline(
f"auto {sym}_{suf}_clamped = {index_} < 0 ? 0 : ({index_} > {size_cpp_str} ? {size_cpp_str} : {index_});"
)
codegen_clamp(start_cpp_str, start=True)
codegen_clamp(end_cpp_str, start=False)
self.writeline(f"auto {sym}_raw = {sym}_end_clamped - {sym}_start_clamped;")
self.writeline(f"auto {sym} = {sym}_raw < 0 ? 0 : {sym}_raw;")
self.unbacked_symbol_decls.add(str(sym))
def make_buffer_free(self, buffer):
return (

View File

@ -1817,14 +1817,33 @@ class PythonWrapperCodegen(CodeGen):
arg_name = node.input_name(0)
self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices))
def codegen_dynamic_select_index(self, node):
def codegen_dynamic_select_index(self, node, clamp):
index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}"
if clamp:
index_str = f"max(0, min({node.size}, {index_str}))"
self.writeline(
f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})"
)
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
def codegen_dynamic_slice_size(self, node):
def clamp_index(x):
pos = self.codegen_sizevar(sympy.Max(0, sympy.Min(x, node.size)))
neg = self.codegen_sizevar(
sympy.Max(0, sympy.Min(x + node.size, node.size))
)
return f"{pos} if {x} >= 0 else {neg}"
# codegen start, end
sym = node.unbacked_size_symbol
start = clamp_index(node.start)
end = clamp_index(node.end)
self.writeline(f"{sym}_start = {start}")
self.writeline(f"{sym}_end = {end}")
self.writeline(f"{sym} = max(0, {sym}_end - {sym}_start)")
self.unbacked_symbol_decls.add(str(node.unbacked_size_symbol))
def codegen_dynamic_scalar(self, node):
(data,) = (t.codegen_reference() for t in node.inputs)
if len(node.keypath) == 0:

View File

@ -3437,7 +3437,6 @@ class SliceView(View):
if val is None:
# TODO(rec): can this really happen?
return default
val = cls.handle_negative_index(val, dim_size)
return clamp(val, lower, upper)
start = clamp_wrap(start, 0, dim_size, 0)
@ -3454,14 +3453,6 @@ class SliceView(View):
step: int = 1,
clamp: bool = True,
) -> IRNode:
step = sympy.expand(step)
assert isinstance(step, Expr) or step > 0, step
try:
if start == 0 and end >= 2**63 - 1 and step == 1:
return x
except TypeError:
pass
new_size = list(x.get_size())
# NB: Ordinarily we default to clamping.
@ -7221,6 +7212,7 @@ class DynamicSelectStorageOffset(ExternKernel):
base_offset: Union[sympy.Symbol, int],
base_dim_stride: Union[sympy.Symbol, int],
size: Union[sympy.Symbol, int],
clamp: bool,
) -> None:
super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
# This node codegen the following:
@ -7230,6 +7222,7 @@ class DynamicSelectStorageOffset(ExternKernel):
self.base_offset = base_offset
self.base_dim_stride = base_dim_stride
self.size = size
self.clamp = clamp
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet([self.unbacked_offset_symbol])
@ -7240,7 +7233,57 @@ class DynamicSelectStorageOffset(ExternKernel):
return get_free_symbols(self.index, unbacked_only)
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
wrapper.codegen_dynamic_select_index(self)
wrapper.codegen_dynamic_select_index(self, clamp=self.clamp)
class DynamicSliceSize(ExternKernel):
"""
Computes the output size of a slice call, handling the correct semantics in codegen.
We do this for flexible handling for unbacked indices (to not data-dependent error).
Slicing has 4 semantics for indices, i.e. x[start:] could be:
1) start < -x.size(0) -> x[0:] # negative out-of-bounds
2) start in [-x.size(0), 0) -> x[x.size(0) + start:] # negative slicing
3) start in [0, x.size(0)) -> x[start:] # standard slicing
4) start >= x.size(0) -> empty slice # positive out-of-bounds
If the appropriate semantics are known beforehand, the output size is computed based on
the start & end indices. If not (with unbacked indices), a new unbacked symbol is created
to represent the output size, and codegen handles computing the correct case.
"""
def get_reads(self) -> OrderedSet[Dep]:
return OrderedSet()
def should_allocate(self) -> bool:
return False
def __init__(
self,
unbacked_size_symbol: sympy.Symbol,
start: sympy.Symbol,
end: Union[sympy.Symbol, int],
size: Union[sympy.Symbol, int],
):
super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
# This node codegen
self.unbacked_size_symbol = unbacked_size_symbol
self.start = start
self.end = end
self.size = size
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet([self.unbacked_size_symbol])
def get_free_symbol_uses(
self, unbacked_only: bool = False
) -> OrderedSet[sympy.Symbol]:
return get_free_symbols(self.start, unbacked_only).union(
get_free_symbols(self.end, unbacked_only)
)
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
wrapper.codegen_dynamic_slice_size(self)
class DynamicScalar(ExternKernel):

View File

@ -1172,9 +1172,130 @@ def permute(x, dims):
@register_lowering(aten.slice, type_promotion_kind=None)
def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True):
"""
Lowers a slice call, creating ExternKernels for the output size & storage offset symbols,
if the indices are unbacked and appropriate semantics aren't known.
If they are known (indices are static/backed/unbacked with info), a SliceView is created.
"""
from torch.fx.experimental.symbolic_shapes import (
CallMethodKey,
resolve_unbacked_bindings,
)
assert isinstance(x, TensorBox)
dim = _validate_dim(x, dim, 0)
return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp))
size = x.get_size()[dim]
step = sympy.expand(step)
assert isinstance(step, sympy.Expr) or step > 0, step
# maybe apply slice optimization
try:
if (
start == 0
and V.graph.sizevars.statically_known_leq(size, end)
and step == 1
):
return x
except TypeError:
pass
# try to avoid dynamic slice
def handle_negative_index(idx, size, default):
if idx is None:
return default
idx = sympy.expand(idx)
size = sympy.expand(size)
if V.graph.sizevars.guard_or_false(idx >= 0):
return idx
elif V.graph.sizevars.guard_or_false(idx < 0):
return size + idx
return None
ambiguous_slice = clamp
if ambiguous_slice:
start_index = handle_negative_index(start, size, 0)
end_index = handle_negative_index(end, size, size)
if start_index is not None and end_index is not None:
start, end = start_index, end_index
ambiguous_slice = False
# ambiguous_slice=False means we know what semantics this slice call follows,
# and don't need to generate an extern kernel to represent the output size.
# This is assumed True for clamp=False
# (meant to follow standard indexing semantics: 0 <= index < size)
if not ambiguous_slice:
return TensorBox(
ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp)
) # go to SliceView/ReinterpretView
# unbacked territory: create DynamicSlice ExternKernel
# clamp is True, unbacked start / end
assert clamp
unbacked_bindings = resolve_unbacked_bindings(
V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
)
assert unbacked_bindings is not None
assert len(unbacked_bindings) <= 2, unbacked_bindings
sym_size, sym_storage = None, None
for sym, keypath in unbacked_bindings.items():
if keypath == (CallMethodKey("size"), pytree.SequenceKey(dim)):
sym_size = sym
elif keypath == (CallMethodKey("storage_offset"),):
sym_storage = sym
def compute_slice_index(index, size):
fn = lambda x: V.graph.sizevars.guard_or_false(x) # noqa: E731
if fn(sympy.Ge(index, 0)) and fn(sympy.Le(index, size)):
return index
elif fn(sympy.Lt(index, 0)) and fn(sympy.Ge(index, -size)):
return -index
elif fn(sympy.Gt(index, size)):
return size
elif fn(sympy.Lt(index, -size)):
return 0
return None
start_index = compute_slice_index(start, size)
end_index = compute_slice_index(end, size)
if start_index is not None and end_index is not None:
# we shouldn't have allocated size symbol, if output size was determinable from input indices
assert sym_size is None
new_size = sympy.Max(0, end_index - start_index)
else:
b_size = ir.DynamicSliceSize(
sym_size,
start,
end,
x.get_size()[dim],
)
b_size.name = V.graph.register_buffer(b_size)
V.graph.register_operation(b_size)
new_size = sym_size
if start_index is not None:
# we shouldn't have allocated storage offset symbol if start index was determinable
assert sym_storage is None
new_storage_offset = x.get_layout().offset + start_index * x.get_stride()[dim]
else:
b_storage = ir.DynamicSelectStorageOffset(
sym_storage,
start,
x.get_layout().offset,
x.get_stride()[dim],
x.get_size()[dim],
clamp=True,
)
b_storage.name = V.graph.register_buffer(b_storage)
V.graph.register_operation(b_storage)
new_storage_offset = sym_storage
new_sizes = list(x.get_size())
new_strides = list(x.get_stride())
new_sizes[dim] = new_size
new_strides[dim] *= step
return as_strided(x, new_sizes, new_strides, new_storage_offset)
@register_lowering(aten.as_strided, type_promotion_kind=None)
@ -1800,6 +1921,7 @@ def select(x, dim, idx):
x.get_layout().offset,
new_stride[dim],
x.get_size()[dim],
clamp=False,
)
buffer.name = V.graph.register_buffer(buffer)
V.graph.register_operation(buffer)
@ -2991,6 +3113,8 @@ def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
dim = _validate_dim(x, dim, 0)
dim_size = x.get_size()[dim]
start = ir.SliceView.handle_negative_index(start, dim_size)
end = ir.SliceView.handle_negative_index(end, dim_size)
start, end = ir.SliceView.normalize_start_end(x, dim, start, end)
src_size = list(x.get_size())

View File

@ -6,7 +6,7 @@ import math
import operator
import sys
from functools import reduce
from typing import Callable, Union
from typing import Callable, Optional, Union
import torch
import torch._custom_op
@ -15,6 +15,7 @@ import torch._prims_common as utils
from torch._dispatch.python import no_python_dispatcher
from torch._ops import OpOverload
from torch._prims_common import (
canonicalize_dim,
contiguous_for_memory_format_or_false,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
@ -746,6 +747,88 @@ def _padded_dense_to_jagged_forward(fake_mode, func, padded, offsets, total_L=No
return padded.new_empty(output_shape)
def _compute_slice_index(size, index):
from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_and
if guard_or_false(sym_and(index >= 0, index <= size)):
return index
elif guard_or_false(sym_and(index < 0, index >= -size)):
return index + size
elif guard_or_false(index < -size):
return 0
elif guard_or_false(index > size):
return size
return None
@register_op_impl(torch.ops.aten.slice.Tensor)
def slice_forward(
fake_mode,
func,
self,
dim: int = 0,
start: Optional[int] = None,
end: Optional[int] = None,
step: int = 1,
):
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
statically_known_true,
)
shape_env = fake_mode.shape_env
ndim = self.dim()
if ndim == 0:
raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
dim = canonicalize_dim(self.dim(), dim)
sizes = list(self.size())
strides = list(self.stride())
if step <= 0:
raise RuntimeError("slice step must be positive")
# start, end
start_index = 0 if start is None else _compute_slice_index(sizes[dim], start)
end_index = (
sizes[dim]
if statically_known_true(end == sys.maxsize) or end is None
else _compute_slice_index(sizes[dim], end)
)
# size
new_size = None
if start_index is not None and end_index is not None:
if guard_or_false(end_index >= start_index):
new_size = (end_index - start_index + step - 1) // step
elif guard_or_false(start_index >= end_index):
new_size = 0
# create unbacked if case unknown
if new_size is None:
new_size = shape_env.create_unbacked_symint()
torch._check_is_size(new_size, max=sizes[dim])
# stride
new_stride = strides[dim] * step
# storage offset
if start_index is not None:
storage_offset = self.storage_offset() + start_index * strides[dim]
else:
storage_offset = shape_env.create_unbacked_symint()
torch._check(storage_offset >= 0)
sizes[dim] = new_size
strides[dim] = new_stride
if self.is_quantized:
raise NotImplementedError(
"Slice decomposition for quantized tensors aren't implemented"
)
else:
return self.as_strided(sizes, strides, storage_offset)
@register_op_impl(torch.ops.aten.masked_select.default)
def masked_select(fake_mode, func, self, mask):
if (

View File

@ -2616,7 +2616,9 @@ class FakeTensorMode(TorchDispatchMode):
if (
func not in meta_table
and not self.cpp_meta_supports_symint(func)
and not (has_symbolic_sizes and func in self._view_fake_tensor_impl_ops)
and not (
has_symbolic_sizes and func in self._unbacked_special_fake_handling_ops
)
):
from torch._decomp import decomposition_table
@ -2925,8 +2927,10 @@ class FakeTensorMode(TorchDispatchMode):
aten._sparse_coo_tensor_with_dims_and_tensors.default,
)
_view_fake_tensor_impl_ops = ordered_set(
aten.view.default, aten._unsafe_view.default
_unbacked_special_fake_handling_ops = ordered_set(
aten.view.default,
aten._unsafe_view.default,
aten.slice.Tensor,
)
def cpp_meta_supports_symint(self, func: OpOverload) -> bool: