mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamic shapes] unbacked-safe slicing (#157944)
Generates new unbacked symbols for slice output size & storage offset, when appropriate semantics are unclear. Teaches inductor to codegen the slice with flexible semantics. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157944 Approved by: https://github.com/laithsakka
This commit is contained in:
parent
0254646654
commit
56218d85e2
|
|
@ -3028,6 +3028,32 @@ def forward(self, causal_mask, fill_value):
|
|||
},
|
||||
)
|
||||
|
||||
def test_unbacked_slice_forward(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, xs):
|
||||
u0, u1 = xs.tolist()
|
||||
out = x[u0:u1]
|
||||
return out
|
||||
|
||||
x = torch.randn(10)
|
||||
idxs = torch.tensor([3, 6])
|
||||
mod = Foo()
|
||||
ep = export(mod, (x, idxs))
|
||||
for xs in [
|
||||
idxs,
|
||||
torch.tensor([-9, -1]),
|
||||
torch.tensor([-10000, 10000]),
|
||||
torch.tensor([0, -10]),
|
||||
]:
|
||||
self.assertTrue(torch.allclose(ep.module()(x, xs), mod(x, xs)))
|
||||
|
||||
# check unbacked bindings
|
||||
# should be 4 symbols: u0, u1, output size, output storage offset
|
||||
bound_unbacked = set()
|
||||
for node in ep.graph.nodes:
|
||||
bound_unbacked |= node.meta.get("unbacked_bindings", {}).keys()
|
||||
self.assertEqual(len(bound_unbacked), 4)
|
||||
|
||||
def test_dim_hint_ranges(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
|
|
@ -5704,7 +5730,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
}
|
||||
self._test_export_same_as_eager(kw_func, args, kwargs)
|
||||
|
||||
def test_unbacked_slice(self):
|
||||
def test_unbacked_slice_simple(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, scores, score_thr, topk: torch.Tensor, results=None):
|
||||
valid_mask = scores > score_thr
|
||||
|
|
|
|||
|
|
@ -3391,6 +3391,119 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
|||
self.assertEqual(result_compiled, result_eager)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_slice(self):
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
|
||||
# standard slice
|
||||
def f1(x, xs):
|
||||
u0, u1 = xs.tolist()
|
||||
torch._check_is_size(u0, max=x.size(0))
|
||||
torch._check_is_size(u1, max=x.size(0))
|
||||
torch._check(u0 <= u1)
|
||||
out = x[u0:u1]
|
||||
assert statically_known_true(out.size(0) == (u1 - u0))
|
||||
return out
|
||||
|
||||
x, xs = torch.randn(10), torch.tensor([3, 6])
|
||||
fn1 = torch.compile(f1, fullgraph=True, backend="inductor")
|
||||
self.assertEqual(fn1(x, xs).size(0), 3)
|
||||
self.assertTrue(torch.allclose(fn1(x, xs), f1(x, xs)))
|
||||
with self.assertRaises(RuntimeError):
|
||||
fn1(x, torch.tensor([-1, 5]))
|
||||
|
||||
# known negative slice
|
||||
def f2(x, n):
|
||||
u0 = n.item()
|
||||
torch._check(u0 > 1)
|
||||
torch._check(u0 <= x.size(0))
|
||||
out = x[-u0:]
|
||||
assert statically_known_true(out.size(0) == u0)
|
||||
return out
|
||||
|
||||
x, n = torch.randn(10), torch.tensor([5])
|
||||
fn2 = torch.compile(f2, fullgraph=True, backend="inductor")
|
||||
self.assertEqual(fn2(x, n).size(0), 5)
|
||||
self.assertTrue(torch.allclose(fn2(x, n), f2(x, n)))
|
||||
with self.assertRaises(RuntimeError):
|
||||
fn2(x, torch.tensor([-5]))
|
||||
|
||||
# general case: no known info
|
||||
def f3(x, xs):
|
||||
u0, u1 = xs.tolist()
|
||||
return x[u0:u1]
|
||||
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
cnts = CompileCounterWithBackend("inductor")
|
||||
x, xs = torch.randn(10), torch.tensor([3, 6])
|
||||
with ctx():
|
||||
fn3 = torch.compile(f3, fullgraph=True, backend=cnts)
|
||||
xs = torch.tensor([-9, -1]) # negative case
|
||||
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
|
||||
xs = torch.tensor([-1000, 1000]) # out of bounds
|
||||
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
|
||||
xs = torch.tensor([2, -2]) # mixed
|
||||
self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs)))
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
||||
self.assertExpectedInline(
|
||||
aot_graphs,
|
||||
"""\
|
||||
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
|
||||
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
|
||||
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
|
||||
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
|
||||
slice_1: "f32[u2][1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, _local_scalar_dense, _local_scalar_dense_1); arg1_1 = _local_scalar_dense = _local_scalar_dense_1 = None
|
||||
sym_size_int: "Sym(u2)" = torch.ops.aten.sym_size.int(slice_1, 0)
|
||||
ge_2: "Sym(u2 >= 0)" = sym_size_int >= 0
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_2 = _assert_scalar = None
|
||||
le: "Sym(u2 <= 10)" = sym_size_int <= 10; sym_size_int = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u2 <= 10 on node 'le'"); le = _assert_scalar_1 = None
|
||||
sym_storage_offset_default: "Sym(u3)" = torch.ops.aten.sym_storage_offset.default(slice_1)
|
||||
ge_3: "Sym(u3 >= 0)" = sym_storage_offset_default >= 0; sym_storage_offset_default = None
|
||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_2 = None
|
||||
return (slice_1,)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||
def test_unbacked_slice_cpp_wrapper(self):
|
||||
self.test_unbacked_slice()
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_tensor_split(self):
|
||||
def f1(x, xs):
|
||||
xs = torch.tensor(xs.tolist())
|
||||
return torch.tensor_split(x, xs)
|
||||
|
||||
x = torch.randn(20)
|
||||
xs = torch.tensor([5, 10, 15])
|
||||
fn = torch.compile(f1, fullgraph=True, backend="inductor")
|
||||
|
||||
def compare(x, xs):
|
||||
for i, j in zip(f1(x, xs), fn(x, xs)):
|
||||
self.assertTrue(torch.allclose(i, j))
|
||||
|
||||
compare(x, xs)
|
||||
xs = torch.tensor([-15, 9, 10, 11])
|
||||
compare(x, xs)
|
||||
xs = torch.tensor([-15, -10, -5, -2])
|
||||
compare(x, xs)
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||
def test_tensor_split_cpp_wrapper(self):
|
||||
self.test_tensor_split()
|
||||
|
||||
@unittest.skip("this test fails due to inductor/autograd issue #153041")
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_unbacked_non_contigious_reshape_failing(self):
|
||||
|
|
|
|||
|
|
@ -1973,7 +1973,6 @@ make_fx_failures = {
|
|||
skip('item'),
|
||||
xfail('cov'),
|
||||
xfail('nn.functional.gaussian_nll_loss'),
|
||||
xfail('tensor_split'),
|
||||
xfail('corrcoef'),
|
||||
xfail('quantile'),
|
||||
xfail('nanquantile'),
|
||||
|
|
@ -1993,10 +1992,12 @@ make_fx_failures = {
|
|||
|
||||
only_real_tensor_failures = {
|
||||
xfail('narrow'),
|
||||
xfail('tensor_split'),
|
||||
}
|
||||
|
||||
only_fake_tensor_failures = {
|
||||
xfail('narrow'),
|
||||
xfail('tensor_split'),
|
||||
}
|
||||
|
||||
fake_tensor_failures = set()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import numbers
|
|||
import operator
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from functools import partial, reduce
|
||||
from itertools import chain, product
|
||||
|
|
@ -721,10 +722,7 @@ def slice_forward(
|
|||
end: Optional[int] = None,
|
||||
step: int = 1,
|
||||
):
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_size_oblivious,
|
||||
statically_known_true,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
|
||||
ndim = self.dim()
|
||||
if ndim == 0:
|
||||
|
|
@ -739,22 +737,22 @@ def slice_forward(
|
|||
start_val = start if start is not None else 0
|
||||
end_val = end if end is not None else sys.maxsize # 2^63 - 1
|
||||
|
||||
if guard_size_oblivious(start_val < 0):
|
||||
if start_val < 0:
|
||||
start_val += sizes[dim]
|
||||
|
||||
if guard_size_oblivious(end_val < 0):
|
||||
if end_val < 0:
|
||||
end_val += sizes[dim]
|
||||
|
||||
if guard_size_oblivious(start_val < 0):
|
||||
if start_val < 0:
|
||||
start_val = 0
|
||||
elif guard_size_oblivious(start_val > sizes[dim]):
|
||||
elif start_val > sizes[dim]:
|
||||
start_val = sizes[dim]
|
||||
|
||||
if statically_known_true(end_val == sys.maxsize):
|
||||
end_val = sizes[dim]
|
||||
elif guard_size_oblivious(end_val < start_val):
|
||||
elif end_val < start_val:
|
||||
end_val = start_val
|
||||
elif guard_size_oblivious(end_val > sizes[dim]):
|
||||
elif end_val > sizes[dim]:
|
||||
end_val = sizes[dim]
|
||||
|
||||
storage_offset = self.storage_offset() + start_val * strides[dim]
|
||||
|
|
@ -1438,7 +1436,17 @@ def tensor_split_tensor_indices_or_sections_py_impl(
|
|||
assert isinstance(sections, IntLike)
|
||||
return self.tensor_split(sections, dim)
|
||||
else:
|
||||
indices = [i.item() for i in tensor_indices_or_sections]
|
||||
ctx = nullcontext
|
||||
if (fake_mode := torch._guards.detect_fake_mode()) and (
|
||||
shape_env := fake_mode.shape_env
|
||||
):
|
||||
ctx = shape_env.ignore_fresh_unbacked_symbols # type: ignore[assignment]
|
||||
# In fake tensor prop, we end up calling slice() with these unbacked indices.
|
||||
# Because slice has flexible semantics, the unbacked handling generates new output sizes
|
||||
# for each slice, effectively clobbering over these index symbols.
|
||||
# To avoid PendingUnbackedSymbolNotFound errors, we tell the compiler it's fine to not bind these.
|
||||
with ctx():
|
||||
indices = [i.item() for i in tensor_indices_or_sections]
|
||||
# WARNING: Tempted to torch._check_is_size on the indices here? You
|
||||
# can't: tensor_split works with negative values in indices:
|
||||
#
|
||||
|
|
|
|||
|
|
@ -1456,19 +1456,51 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
||||
self.unbacked_symbol_decls.add(str(node.sym))
|
||||
|
||||
def codegen_dynamic_select_index(self, node):
|
||||
def codegen_dynamic_select_index(self, node, clamp):
|
||||
index_cpp_str = self.val_to_arg_str_for_prim_type(node.index, int)
|
||||
size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int)
|
||||
|
||||
index_compute_str = (
|
||||
# codegen index
|
||||
sym = node.unbacked_offset_symbol
|
||||
index_str = (
|
||||
f"{index_cpp_str} < 0 ? {index_cpp_str} + "
|
||||
f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}"
|
||||
f"{self.val_to_arg_str_for_prim_type(node.size, int)}: {index_cpp_str}"
|
||||
)
|
||||
self.writeline(f"auto {sym}_index = {index_str};")
|
||||
index_str_clamped = (
|
||||
f"{sym}_index < 0 ? 0 : ({sym}_index > {size_cpp_str} ? {size_cpp_str} : {sym}_index)"
|
||||
if clamp
|
||||
else f"{sym}_index"
|
||||
)
|
||||
self.writeline(f"auto {sym}_index_clamped = {index_str_clamped};")
|
||||
self.writeline(
|
||||
f"auto {node.unbacked_offset_symbol} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + "
|
||||
f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * ({index_compute_str});"
|
||||
f"auto {sym} = {self.val_to_arg_str_for_prim_type(node.base_offset, int)} + "
|
||||
f"{self.val_to_arg_str_for_prim_type(node.base_dim_stride, int)} * {sym}_index_clamped;"
|
||||
)
|
||||
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
||||
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
|
||||
self.unbacked_symbol_decls.add(str(sym))
|
||||
|
||||
def codegen_dynamic_slice_size(self, node):
|
||||
start_cpp_str = self.val_to_arg_str_for_prim_type(node.start, int)
|
||||
end_cpp_str = self.val_to_arg_str_for_prim_type(node.end, int)
|
||||
size_cpp_str = self.val_to_arg_str_for_prim_type(node.size, int)
|
||||
sym = node.unbacked_size_symbol
|
||||
|
||||
def codegen_clamp(index_str, start=True):
|
||||
suf = "start" if start else "end"
|
||||
index_ = f"{sym}_{suf}_index"
|
||||
self.writeline(
|
||||
f"auto {index_} = {index_str} < 0 ? {index_str} + {size_cpp_str} : {index_str};"
|
||||
)
|
||||
self.writeline(
|
||||
f"auto {sym}_{suf}_clamped = {index_} < 0 ? 0 : ({index_} > {size_cpp_str} ? {size_cpp_str} : {index_});"
|
||||
)
|
||||
|
||||
codegen_clamp(start_cpp_str, start=True)
|
||||
codegen_clamp(end_cpp_str, start=False)
|
||||
self.writeline(f"auto {sym}_raw = {sym}_end_clamped - {sym}_start_clamped;")
|
||||
self.writeline(f"auto {sym} = {sym}_raw < 0 ? 0 : {sym}_raw;")
|
||||
self.unbacked_symbol_decls.add(str(sym))
|
||||
|
||||
def make_buffer_free(self, buffer):
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -1817,14 +1817,33 @@ class PythonWrapperCodegen(CodeGen):
|
|||
arg_name = node.input_name(0)
|
||||
self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices))
|
||||
|
||||
def codegen_dynamic_select_index(self, node):
|
||||
def codegen_dynamic_select_index(self, node, clamp):
|
||||
index_str = f"{node.index} + {node.size} if {node.index} < 0 else {node.index}"
|
||||
if clamp:
|
||||
index_str = f"max(0, min({node.size}, {index_str}))"
|
||||
self.writeline(
|
||||
f"{node.unbacked_offset_symbol} = {node.base_offset} + {node.base_dim_stride} * ({index_str})"
|
||||
)
|
||||
# record in unbacked_symbol_decls so we won't generate a declaration of the symbol again
|
||||
self.unbacked_symbol_decls.add(str(node.unbacked_offset_symbol))
|
||||
|
||||
def codegen_dynamic_slice_size(self, node):
|
||||
def clamp_index(x):
|
||||
pos = self.codegen_sizevar(sympy.Max(0, sympy.Min(x, node.size)))
|
||||
neg = self.codegen_sizevar(
|
||||
sympy.Max(0, sympy.Min(x + node.size, node.size))
|
||||
)
|
||||
return f"{pos} if {x} >= 0 else {neg}"
|
||||
|
||||
# codegen start, end
|
||||
sym = node.unbacked_size_symbol
|
||||
start = clamp_index(node.start)
|
||||
end = clamp_index(node.end)
|
||||
self.writeline(f"{sym}_start = {start}")
|
||||
self.writeline(f"{sym}_end = {end}")
|
||||
self.writeline(f"{sym} = max(0, {sym}_end - {sym}_start)")
|
||||
self.unbacked_symbol_decls.add(str(node.unbacked_size_symbol))
|
||||
|
||||
def codegen_dynamic_scalar(self, node):
|
||||
(data,) = (t.codegen_reference() for t in node.inputs)
|
||||
if len(node.keypath) == 0:
|
||||
|
|
|
|||
|
|
@ -3437,7 +3437,6 @@ class SliceView(View):
|
|||
if val is None:
|
||||
# TODO(rec): can this really happen?
|
||||
return default
|
||||
val = cls.handle_negative_index(val, dim_size)
|
||||
return clamp(val, lower, upper)
|
||||
|
||||
start = clamp_wrap(start, 0, dim_size, 0)
|
||||
|
|
@ -3454,14 +3453,6 @@ class SliceView(View):
|
|||
step: int = 1,
|
||||
clamp: bool = True,
|
||||
) -> IRNode:
|
||||
step = sympy.expand(step)
|
||||
assert isinstance(step, Expr) or step > 0, step
|
||||
try:
|
||||
if start == 0 and end >= 2**63 - 1 and step == 1:
|
||||
return x
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
new_size = list(x.get_size())
|
||||
|
||||
# NB: Ordinarily we default to clamping.
|
||||
|
|
@ -7221,6 +7212,7 @@ class DynamicSelectStorageOffset(ExternKernel):
|
|||
base_offset: Union[sympy.Symbol, int],
|
||||
base_dim_stride: Union[sympy.Symbol, int],
|
||||
size: Union[sympy.Symbol, int],
|
||||
clamp: bool,
|
||||
) -> None:
|
||||
super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
|
||||
# This node codegen the following:
|
||||
|
|
@ -7230,6 +7222,7 @@ class DynamicSelectStorageOffset(ExternKernel):
|
|||
self.base_offset = base_offset
|
||||
self.base_dim_stride = base_dim_stride
|
||||
self.size = size
|
||||
self.clamp = clamp
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet([self.unbacked_offset_symbol])
|
||||
|
|
@ -7240,7 +7233,57 @@ class DynamicSelectStorageOffset(ExternKernel):
|
|||
return get_free_symbols(self.index, unbacked_only)
|
||||
|
||||
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
|
||||
wrapper.codegen_dynamic_select_index(self)
|
||||
wrapper.codegen_dynamic_select_index(self, clamp=self.clamp)
|
||||
|
||||
|
||||
class DynamicSliceSize(ExternKernel):
|
||||
"""
|
||||
Computes the output size of a slice call, handling the correct semantics in codegen.
|
||||
We do this for flexible handling for unbacked indices (to not data-dependent error).
|
||||
|
||||
Slicing has 4 semantics for indices, i.e. x[start:] could be:
|
||||
1) start < -x.size(0) -> x[0:] # negative out-of-bounds
|
||||
2) start in [-x.size(0), 0) -> x[x.size(0) + start:] # negative slicing
|
||||
3) start in [0, x.size(0)) -> x[start:] # standard slicing
|
||||
4) start >= x.size(0) -> empty slice # positive out-of-bounds
|
||||
|
||||
If the appropriate semantics are known beforehand, the output size is computed based on
|
||||
the start & end indices. If not (with unbacked indices), a new unbacked symbol is created
|
||||
to represent the output size, and codegen handles computing the correct case.
|
||||
"""
|
||||
|
||||
def get_reads(self) -> OrderedSet[Dep]:
|
||||
return OrderedSet()
|
||||
|
||||
def should_allocate(self) -> bool:
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unbacked_size_symbol: sympy.Symbol,
|
||||
start: sympy.Symbol,
|
||||
end: Union[sympy.Symbol, int],
|
||||
size: Union[sympy.Symbol, int],
|
||||
):
|
||||
super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
|
||||
# This node codegen
|
||||
self.unbacked_size_symbol = unbacked_size_symbol
|
||||
self.start = start
|
||||
self.end = end
|
||||
self.size = size
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet([self.unbacked_size_symbol])
|
||||
|
||||
def get_free_symbol_uses(
|
||||
self, unbacked_only: bool = False
|
||||
) -> OrderedSet[sympy.Symbol]:
|
||||
return get_free_symbols(self.start, unbacked_only).union(
|
||||
get_free_symbols(self.end, unbacked_only)
|
||||
)
|
||||
|
||||
def codegen(self, wrapper: PythonWrapperCodegen) -> None:
|
||||
wrapper.codegen_dynamic_slice_size(self)
|
||||
|
||||
|
||||
class DynamicScalar(ExternKernel):
|
||||
|
|
|
|||
|
|
@ -1172,9 +1172,130 @@ def permute(x, dims):
|
|||
|
||||
@register_lowering(aten.slice, type_promotion_kind=None)
|
||||
def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True):
|
||||
"""
|
||||
Lowers a slice call, creating ExternKernels for the output size & storage offset symbols,
|
||||
if the indices are unbacked and appropriate semantics aren't known.
|
||||
If they are known (indices are static/backed/unbacked with info), a SliceView is created.
|
||||
"""
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
CallMethodKey,
|
||||
resolve_unbacked_bindings,
|
||||
)
|
||||
|
||||
assert isinstance(x, TensorBox)
|
||||
dim = _validate_dim(x, dim, 0)
|
||||
return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp))
|
||||
size = x.get_size()[dim]
|
||||
step = sympy.expand(step)
|
||||
assert isinstance(step, sympy.Expr) or step > 0, step
|
||||
|
||||
# maybe apply slice optimization
|
||||
try:
|
||||
if (
|
||||
start == 0
|
||||
and V.graph.sizevars.statically_known_leq(size, end)
|
||||
and step == 1
|
||||
):
|
||||
return x
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
# try to avoid dynamic slice
|
||||
def handle_negative_index(idx, size, default):
|
||||
if idx is None:
|
||||
return default
|
||||
idx = sympy.expand(idx)
|
||||
size = sympy.expand(size)
|
||||
if V.graph.sizevars.guard_or_false(idx >= 0):
|
||||
return idx
|
||||
elif V.graph.sizevars.guard_or_false(idx < 0):
|
||||
return size + idx
|
||||
return None
|
||||
|
||||
ambiguous_slice = clamp
|
||||
if ambiguous_slice:
|
||||
start_index = handle_negative_index(start, size, 0)
|
||||
end_index = handle_negative_index(end, size, size)
|
||||
if start_index is not None and end_index is not None:
|
||||
start, end = start_index, end_index
|
||||
ambiguous_slice = False
|
||||
|
||||
# ambiguous_slice=False means we know what semantics this slice call follows,
|
||||
# and don't need to generate an extern kernel to represent the output size.
|
||||
# This is assumed True for clamp=False
|
||||
# (meant to follow standard indexing semantics: 0 <= index < size)
|
||||
if not ambiguous_slice:
|
||||
return TensorBox(
|
||||
ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp)
|
||||
) # go to SliceView/ReinterpretView
|
||||
|
||||
# unbacked territory: create DynamicSlice ExternKernel
|
||||
# clamp is True, unbacked start / end
|
||||
assert clamp
|
||||
unbacked_bindings = resolve_unbacked_bindings(
|
||||
V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
|
||||
)
|
||||
assert unbacked_bindings is not None
|
||||
assert len(unbacked_bindings) <= 2, unbacked_bindings
|
||||
sym_size, sym_storage = None, None
|
||||
for sym, keypath in unbacked_bindings.items():
|
||||
if keypath == (CallMethodKey("size"), pytree.SequenceKey(dim)):
|
||||
sym_size = sym
|
||||
elif keypath == (CallMethodKey("storage_offset"),):
|
||||
sym_storage = sym
|
||||
|
||||
def compute_slice_index(index, size):
|
||||
fn = lambda x: V.graph.sizevars.guard_or_false(x) # noqa: E731
|
||||
|
||||
if fn(sympy.Ge(index, 0)) and fn(sympy.Le(index, size)):
|
||||
return index
|
||||
elif fn(sympy.Lt(index, 0)) and fn(sympy.Ge(index, -size)):
|
||||
return -index
|
||||
elif fn(sympy.Gt(index, size)):
|
||||
return size
|
||||
elif fn(sympy.Lt(index, -size)):
|
||||
return 0
|
||||
return None
|
||||
|
||||
start_index = compute_slice_index(start, size)
|
||||
end_index = compute_slice_index(end, size)
|
||||
if start_index is not None and end_index is not None:
|
||||
# we shouldn't have allocated size symbol, if output size was determinable from input indices
|
||||
assert sym_size is None
|
||||
new_size = sympy.Max(0, end_index - start_index)
|
||||
else:
|
||||
b_size = ir.DynamicSliceSize(
|
||||
sym_size,
|
||||
start,
|
||||
end,
|
||||
x.get_size()[dim],
|
||||
)
|
||||
b_size.name = V.graph.register_buffer(b_size)
|
||||
V.graph.register_operation(b_size)
|
||||
new_size = sym_size
|
||||
|
||||
if start_index is not None:
|
||||
# we shouldn't have allocated storage offset symbol if start index was determinable
|
||||
assert sym_storage is None
|
||||
new_storage_offset = x.get_layout().offset + start_index * x.get_stride()[dim]
|
||||
else:
|
||||
b_storage = ir.DynamicSelectStorageOffset(
|
||||
sym_storage,
|
||||
start,
|
||||
x.get_layout().offset,
|
||||
x.get_stride()[dim],
|
||||
x.get_size()[dim],
|
||||
clamp=True,
|
||||
)
|
||||
b_storage.name = V.graph.register_buffer(b_storage)
|
||||
V.graph.register_operation(b_storage)
|
||||
new_storage_offset = sym_storage
|
||||
|
||||
new_sizes = list(x.get_size())
|
||||
new_strides = list(x.get_stride())
|
||||
new_sizes[dim] = new_size
|
||||
new_strides[dim] *= step
|
||||
return as_strided(x, new_sizes, new_strides, new_storage_offset)
|
||||
|
||||
|
||||
@register_lowering(aten.as_strided, type_promotion_kind=None)
|
||||
|
|
@ -1800,6 +1921,7 @@ def select(x, dim, idx):
|
|||
x.get_layout().offset,
|
||||
new_stride[dim],
|
||||
x.get_size()[dim],
|
||||
clamp=False,
|
||||
)
|
||||
buffer.name = V.graph.register_buffer(buffer)
|
||||
V.graph.register_operation(buffer)
|
||||
|
|
@ -2991,6 +3113,8 @@ def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
|
|||
dim = _validate_dim(x, dim, 0)
|
||||
dim_size = x.get_size()[dim]
|
||||
|
||||
start = ir.SliceView.handle_negative_index(start, dim_size)
|
||||
end = ir.SliceView.handle_negative_index(end, dim_size)
|
||||
start, end = ir.SliceView.normalize_start_end(x, dim, start, end)
|
||||
|
||||
src_size = list(x.get_size())
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import math
|
|||
import operator
|
||||
import sys
|
||||
from functools import reduce
|
||||
from typing import Callable, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._custom_op
|
||||
|
|
@ -15,6 +15,7 @@ import torch._prims_common as utils
|
|||
from torch._dispatch.python import no_python_dispatcher
|
||||
from torch._ops import OpOverload
|
||||
from torch._prims_common import (
|
||||
canonicalize_dim,
|
||||
contiguous_for_memory_format_or_false,
|
||||
elementwise_dtypes,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
|
|
@ -746,6 +747,88 @@ def _padded_dense_to_jagged_forward(fake_mode, func, padded, offsets, total_L=No
|
|||
return padded.new_empty(output_shape)
|
||||
|
||||
|
||||
def _compute_slice_index(size, index):
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_and
|
||||
|
||||
if guard_or_false(sym_and(index >= 0, index <= size)):
|
||||
return index
|
||||
elif guard_or_false(sym_and(index < 0, index >= -size)):
|
||||
return index + size
|
||||
elif guard_or_false(index < -size):
|
||||
return 0
|
||||
elif guard_or_false(index > size):
|
||||
return size
|
||||
return None
|
||||
|
||||
|
||||
@register_op_impl(torch.ops.aten.slice.Tensor)
|
||||
def slice_forward(
|
||||
fake_mode,
|
||||
func,
|
||||
self,
|
||||
dim: int = 0,
|
||||
start: Optional[int] = None,
|
||||
end: Optional[int] = None,
|
||||
step: int = 1,
|
||||
):
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
statically_known_true,
|
||||
)
|
||||
|
||||
shape_env = fake_mode.shape_env
|
||||
|
||||
ndim = self.dim()
|
||||
if ndim == 0:
|
||||
raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
|
||||
dim = canonicalize_dim(self.dim(), dim)
|
||||
sizes = list(self.size())
|
||||
strides = list(self.stride())
|
||||
|
||||
if step <= 0:
|
||||
raise RuntimeError("slice step must be positive")
|
||||
|
||||
# start, end
|
||||
start_index = 0 if start is None else _compute_slice_index(sizes[dim], start)
|
||||
end_index = (
|
||||
sizes[dim]
|
||||
if statically_known_true(end == sys.maxsize) or end is None
|
||||
else _compute_slice_index(sizes[dim], end)
|
||||
)
|
||||
|
||||
# size
|
||||
new_size = None
|
||||
if start_index is not None and end_index is not None:
|
||||
if guard_or_false(end_index >= start_index):
|
||||
new_size = (end_index - start_index + step - 1) // step
|
||||
elif guard_or_false(start_index >= end_index):
|
||||
new_size = 0
|
||||
|
||||
# create unbacked if case unknown
|
||||
if new_size is None:
|
||||
new_size = shape_env.create_unbacked_symint()
|
||||
torch._check_is_size(new_size, max=sizes[dim])
|
||||
|
||||
# stride
|
||||
new_stride = strides[dim] * step
|
||||
|
||||
# storage offset
|
||||
if start_index is not None:
|
||||
storage_offset = self.storage_offset() + start_index * strides[dim]
|
||||
else:
|
||||
storage_offset = shape_env.create_unbacked_symint()
|
||||
torch._check(storage_offset >= 0)
|
||||
|
||||
sizes[dim] = new_size
|
||||
strides[dim] = new_stride
|
||||
if self.is_quantized:
|
||||
raise NotImplementedError(
|
||||
"Slice decomposition for quantized tensors aren't implemented"
|
||||
)
|
||||
else:
|
||||
return self.as_strided(sizes, strides, storage_offset)
|
||||
|
||||
|
||||
@register_op_impl(torch.ops.aten.masked_select.default)
|
||||
def masked_select(fake_mode, func, self, mask):
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -2616,7 +2616,9 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
if (
|
||||
func not in meta_table
|
||||
and not self.cpp_meta_supports_symint(func)
|
||||
and not (has_symbolic_sizes and func in self._view_fake_tensor_impl_ops)
|
||||
and not (
|
||||
has_symbolic_sizes and func in self._unbacked_special_fake_handling_ops
|
||||
)
|
||||
):
|
||||
from torch._decomp import decomposition_table
|
||||
|
||||
|
|
@ -2925,8 +2927,10 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
aten._sparse_coo_tensor_with_dims_and_tensors.default,
|
||||
)
|
||||
|
||||
_view_fake_tensor_impl_ops = ordered_set(
|
||||
aten.view.default, aten._unsafe_view.default
|
||||
_unbacked_special_fake_handling_ops = ordered_set(
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
aten.slice.Tensor,
|
||||
)
|
||||
|
||||
def cpp_meta_supports_symint(self, func: OpOverload) -> bool:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user