Support python slicing with tensor inputs. (#165074)

when the slice is tensor, we decompose it to .item() call and pass the unbacked symbol to the slice to avoid DDE.
the diff also fix an existing bug in codegen_dynamic_slice_size in the cpp wrapper.  a +1 should be -1 making it match
python codegen.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165074
Approved by: https://github.com/Lucaskabela
This commit is contained in:
Laith Sakka 2025-10-27 17:18:15 -07:00 committed by PyTorch MergeBot
parent bea89d6060
commit adedf26e21
8 changed files with 90 additions and 56 deletions

View File

@ -528,30 +528,6 @@ Attempted to call function marked as skipped
f(x)
self.assertEqual(len(ws), 2)
def test_slice_with_tensor(self):
def fn(x, y):
return x[:y]
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
torch.randn(10),
torch.tensor([3]),
),
"""\
Dynamic slicing with Tensor arguments
Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: LazyVariableTracker(realized: TensorVariable()), step: ConstantVariable(NoneType: None)
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html
from user code:
File "test_error_messages.py", line N, in fn
return x[:y]""",
)
def test_observed_exception(self):
def fn():
raise RuntimeError("test")

View File

@ -2077,22 +2077,6 @@ def forward(self, l_x_):
self.assertEqual(count, 1)
self.assertEqual(gm_torch_mode(inp).shape, f(inp).shape)
def test_dynamic_slicing_invalid(self):
def g(x, y):
return x[y : x.shape[0]]
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"Dynamic slicing with Tensor arguments",
):
torch._dynamo.export(
g,
aten_graph=True,
)(
torch.randn(4, 5),
torch.tensor(2),
)
@config.patch(capture_scalar_outputs=True)
def test_dynamic_slicing_simple(self):
def f(x):

View File

@ -3803,6 +3803,74 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
def test_unbacked_slice_with_step_cpp_wrapper(self):
self.test_unbacked_slice_with_step()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_slice_with_tensor_indices(self):
for d in [True, False]:
# Test slicing with tensor start/stop/step on RHS (reading)
# Test 1: Basic slice with tensor start and stop
def f1(x, start_t, stop_t):
return x[start_t:stop_t]
x = torch.randn(20)
start_t = torch.tensor(5)
stop_t = torch.tensor(15)
fn1 = torch.compile(f1, fullgraph=True, dynamic=d, backend="inductor")
self.assertTrue(
torch.allclose(fn1(x, start_t, stop_t), f1(x, start_t, stop_t))
)
# Test 2: Slice with tensor step
def f2(x, start_t, stop_t, step_t):
return x[start_t:stop_t:step_t]
step_t = torch.tensor(2)
fn2 = torch.compile(f2, fullgraph=True, dynamic=d, backend="inductor")
self.assertTrue(
torch.allclose(
fn2(x, start_t, stop_t, step_t), f2(x, start_t, stop_t, step_t)
)
)
# Test 3: Slice with only tensor start
def f3(x, start_t):
return x[start_t:]
fn3 = torch.compile(f3, fullgraph=True, dynamic=d, backend="inductor")
self.assertTrue(torch.allclose(fn3(x, start_t), f3(x, start_t)))
# Test 4: Slice with only tensor stop
def f4(x, stop_t):
return x[:stop_t]
fn4 = torch.compile(f4, fullgraph=True, dynamic=d, backend="inductor")
self.assertTrue(torch.allclose(fn4(x, stop_t), f4(x, stop_t)))
# Test 5: Negative indices with tensors
def f5(x, start_t):
return x[start_t:-1]
start_t_neg = torch.tensor(-10)
fn5 = torch.compile(f5, fullgraph=True, dynamic=d, backend="inductor")
self.assertTrue(torch.allclose(fn5(x, start_t_neg), f5(x, start_t_neg)))
# Test 6: Multidimensional slice with tensor indices
def f6(x, start_t, stop_t):
return x[:, start_t:stop_t]
x_2d = torch.randn(10, 20)
fn6 = torch.compile(f6, fullgraph=True, dynamic=d, backend="inductor")
self.assertTrue(
torch.allclose(fn6(x_2d, start_t, stop_t), f6(x_2d, start_t, stop_t))
)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_slice_with_tensor_indices_cpp_wrapper(self):
self.test_slice_with_tensor_indices()
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_tensor_split(self):

View File

@ -3243,7 +3243,7 @@ class InstructionTranslatorBase(
def BUILD_SLICE(self, inst: Instruction) -> None:
items = self.popn(inst.argval)
self.push(SliceVariable(items))
self.push(SliceVariable(items, tx=self))
def BUILD_LIST(self, inst: Instruction) -> None:
items = self.popn(inst.argval)

View File

@ -1812,7 +1812,7 @@ class VariableBuilder:
]
self.install_guards(GuardBuilder.TYPE_MATCH)
if isinstance(value, slice):
return SliceVariable(items, source=self.source)
return SliceVariable(items, self.tx, source=self.source)
else:
return RangeVariable(items, source=self.source)

View File

@ -1746,7 +1746,7 @@ class BuiltinVariable(VariableTracker):
)
def call_slice(self, tx: "InstructionTranslator", *args):
return variables.SliceVariable(args)
return variables.SliceVariable(args, tx)
def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs):
from .builder import wrap_fx_proxy

View File

@ -1503,7 +1503,7 @@ class NamedTupleVariable(TupleVariable):
class SliceVariable(VariableTracker):
def __init__(self, items, **kwargs) -> None:
def __init__(self, items, tx=None, **kwargs) -> None:
items_to_map = items
start, stop, step = [variables.ConstantVariable.create(None)] * 3
@ -1516,18 +1516,24 @@ class SliceVariable(VariableTracker):
else:
raise AssertionError
if isinstance(start, variables.TensorVariable) or isinstance(
stop, variables.TensorVariable
):
unimplemented_v2(
gb_type="Dynamic slicing with Tensor arguments",
context=f"SliceVariable start: {start}, stop: {stop}, step: {step}",
explanation="Creating slices with Tensor arguments is not supported. "
"e.g. `l[:x]`, where `x` is a 1-element tensor.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
# Convert TensorVariable to SymIntVariable by calling .item()
# This decomposes a[:t] to u=t.item(); a[:u] at the dynamo level
if isinstance(start, variables.TensorVariable):
assert tx is not None, (
"tx is required when slice indices are TensorVariables"
)
start = start.call_method(tx, "item", [], {})
if isinstance(stop, variables.TensorVariable):
assert tx is not None, (
"tx is required when slice indices are TensorVariables"
)
stop = stop.call_method(tx, "item", [], {})
if isinstance(step, variables.TensorVariable):
assert tx is not None, (
"tx is required when slice indices are TensorVariables"
)
step = step.call_method(tx, "item", [], {})
self.items = (start, stop, step)
super().__init__(**kwargs)

View File

@ -1551,7 +1551,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
step_str = f"{sym}_en_cl - {sym}_st_cl"
else:
step_str = (
f"({sym}_en_cl - {sym}_st_cl + {step_cpp_str} + 1) / {step_cpp_str}"
f"({sym}_en_cl - {sym}_st_cl + {step_cpp_str} - 1) / {step_cpp_str}"
)
self.writeline(f"int64_t {sym}_with_step = {step_str};")
self.writeline(f"int64_t {sym} = {sym}_with_step < 0 ? 0 : {sym}_with_step;")