mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support python slicing with tensor inputs. (#165074)
when the slice is tensor, we decompose it to .item() call and pass the unbacked symbol to the slice to avoid DDE. the diff also fix an existing bug in codegen_dynamic_slice_size in the cpp wrapper. a +1 should be -1 making it match python codegen. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165074 Approved by: https://github.com/Lucaskabela
This commit is contained in:
parent
bea89d6060
commit
adedf26e21
|
|
@ -528,30 +528,6 @@ Attempted to call function marked as skipped
|
|||
f(x)
|
||||
self.assertEqual(len(ws), 2)
|
||||
|
||||
def test_slice_with_tensor(self):
|
||||
def fn(x, y):
|
||||
return x[:y]
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
|
||||
torch.randn(10),
|
||||
torch.tensor([3]),
|
||||
),
|
||||
"""\
|
||||
Dynamic slicing with Tensor arguments
|
||||
Explanation: Creating slices with Tensor arguments is not supported. e.g. `l[:x]`, where `x` is a 1-element tensor.
|
||||
Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.
|
||||
|
||||
Developer debug context: SliceVariable start: ConstantVariable(NoneType: None), stop: LazyVariableTracker(realized: TensorVariable()), step: ConstantVariable(NoneType: None)
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0038.html
|
||||
|
||||
from user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
return x[:y]""",
|
||||
)
|
||||
|
||||
def test_observed_exception(self):
|
||||
def fn():
|
||||
raise RuntimeError("test")
|
||||
|
|
|
|||
|
|
@ -2077,22 +2077,6 @@ def forward(self, l_x_):
|
|||
self.assertEqual(count, 1)
|
||||
self.assertEqual(gm_torch_mode(inp).shape, f(inp).shape)
|
||||
|
||||
def test_dynamic_slicing_invalid(self):
|
||||
def g(x, y):
|
||||
return x[y : x.shape[0]]
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported,
|
||||
"Dynamic slicing with Tensor arguments",
|
||||
):
|
||||
torch._dynamo.export(
|
||||
g,
|
||||
aten_graph=True,
|
||||
)(
|
||||
torch.randn(4, 5),
|
||||
torch.tensor(2),
|
||||
)
|
||||
|
||||
@config.patch(capture_scalar_outputs=True)
|
||||
def test_dynamic_slicing_simple(self):
|
||||
def f(x):
|
||||
|
|
|
|||
|
|
@ -3803,6 +3803,74 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
|||
def test_unbacked_slice_with_step_cpp_wrapper(self):
|
||||
self.test_unbacked_slice_with_step()
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_slice_with_tensor_indices(self):
|
||||
for d in [True, False]:
|
||||
# Test slicing with tensor start/stop/step on RHS (reading)
|
||||
|
||||
# Test 1: Basic slice with tensor start and stop
|
||||
def f1(x, start_t, stop_t):
|
||||
return x[start_t:stop_t]
|
||||
|
||||
x = torch.randn(20)
|
||||
start_t = torch.tensor(5)
|
||||
stop_t = torch.tensor(15)
|
||||
fn1 = torch.compile(f1, fullgraph=True, dynamic=d, backend="inductor")
|
||||
self.assertTrue(
|
||||
torch.allclose(fn1(x, start_t, stop_t), f1(x, start_t, stop_t))
|
||||
)
|
||||
|
||||
# Test 2: Slice with tensor step
|
||||
def f2(x, start_t, stop_t, step_t):
|
||||
return x[start_t:stop_t:step_t]
|
||||
|
||||
step_t = torch.tensor(2)
|
||||
fn2 = torch.compile(f2, fullgraph=True, dynamic=d, backend="inductor")
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
fn2(x, start_t, stop_t, step_t), f2(x, start_t, stop_t, step_t)
|
||||
)
|
||||
)
|
||||
|
||||
# Test 3: Slice with only tensor start
|
||||
def f3(x, start_t):
|
||||
return x[start_t:]
|
||||
|
||||
fn3 = torch.compile(f3, fullgraph=True, dynamic=d, backend="inductor")
|
||||
self.assertTrue(torch.allclose(fn3(x, start_t), f3(x, start_t)))
|
||||
|
||||
# Test 4: Slice with only tensor stop
|
||||
def f4(x, stop_t):
|
||||
return x[:stop_t]
|
||||
|
||||
fn4 = torch.compile(f4, fullgraph=True, dynamic=d, backend="inductor")
|
||||
self.assertTrue(torch.allclose(fn4(x, stop_t), f4(x, stop_t)))
|
||||
|
||||
# Test 5: Negative indices with tensors
|
||||
def f5(x, start_t):
|
||||
return x[start_t:-1]
|
||||
|
||||
start_t_neg = torch.tensor(-10)
|
||||
fn5 = torch.compile(f5, fullgraph=True, dynamic=d, backend="inductor")
|
||||
self.assertTrue(torch.allclose(fn5(x, start_t_neg), f5(x, start_t_neg)))
|
||||
|
||||
# Test 6: Multidimensional slice with tensor indices
|
||||
def f6(x, start_t, stop_t):
|
||||
return x[:, start_t:stop_t]
|
||||
|
||||
x_2d = torch.randn(10, 20)
|
||||
fn6 = torch.compile(f6, fullgraph=True, dynamic=d, backend="inductor")
|
||||
self.assertTrue(
|
||||
torch.allclose(fn6(x_2d, start_t, stop_t), f6(x_2d, start_t, stop_t))
|
||||
)
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||
def test_slice_with_tensor_indices_cpp_wrapper(self):
|
||||
self.test_slice_with_tensor_indices()
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_tensor_split(self):
|
||||
|
|
|
|||
|
|
@ -3243,7 +3243,7 @@ class InstructionTranslatorBase(
|
|||
|
||||
def BUILD_SLICE(self, inst: Instruction) -> None:
|
||||
items = self.popn(inst.argval)
|
||||
self.push(SliceVariable(items))
|
||||
self.push(SliceVariable(items, tx=self))
|
||||
|
||||
def BUILD_LIST(self, inst: Instruction) -> None:
|
||||
items = self.popn(inst.argval)
|
||||
|
|
|
|||
|
|
@ -1812,7 +1812,7 @@ class VariableBuilder:
|
|||
]
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
if isinstance(value, slice):
|
||||
return SliceVariable(items, source=self.source)
|
||||
return SliceVariable(items, self.tx, source=self.source)
|
||||
else:
|
||||
return RangeVariable(items, source=self.source)
|
||||
|
||||
|
|
|
|||
|
|
@ -1746,7 +1746,7 @@ class BuiltinVariable(VariableTracker):
|
|||
)
|
||||
|
||||
def call_slice(self, tx: "InstructionTranslator", *args):
|
||||
return variables.SliceVariable(args)
|
||||
return variables.SliceVariable(args, tx)
|
||||
|
||||
def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
from .builder import wrap_fx_proxy
|
||||
|
|
|
|||
|
|
@ -1503,7 +1503,7 @@ class NamedTupleVariable(TupleVariable):
|
|||
|
||||
|
||||
class SliceVariable(VariableTracker):
|
||||
def __init__(self, items, **kwargs) -> None:
|
||||
def __init__(self, items, tx=None, **kwargs) -> None:
|
||||
items_to_map = items
|
||||
start, stop, step = [variables.ConstantVariable.create(None)] * 3
|
||||
|
||||
|
|
@ -1516,18 +1516,24 @@ class SliceVariable(VariableTracker):
|
|||
else:
|
||||
raise AssertionError
|
||||
|
||||
if isinstance(start, variables.TensorVariable) or isinstance(
|
||||
stop, variables.TensorVariable
|
||||
):
|
||||
unimplemented_v2(
|
||||
gb_type="Dynamic slicing with Tensor arguments",
|
||||
context=f"SliceVariable start: {start}, stop: {stop}, step: {step}",
|
||||
explanation="Creating slices with Tensor arguments is not supported. "
|
||||
"e.g. `l[:x]`, where `x` is a 1-element tensor.",
|
||||
hints=[
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
# Convert TensorVariable to SymIntVariable by calling .item()
|
||||
# This decomposes a[:t] to u=t.item(); a[:u] at the dynamo level
|
||||
if isinstance(start, variables.TensorVariable):
|
||||
assert tx is not None, (
|
||||
"tx is required when slice indices are TensorVariables"
|
||||
)
|
||||
start = start.call_method(tx, "item", [], {})
|
||||
if isinstance(stop, variables.TensorVariable):
|
||||
assert tx is not None, (
|
||||
"tx is required when slice indices are TensorVariables"
|
||||
)
|
||||
stop = stop.call_method(tx, "item", [], {})
|
||||
if isinstance(step, variables.TensorVariable):
|
||||
assert tx is not None, (
|
||||
"tx is required when slice indices are TensorVariables"
|
||||
)
|
||||
step = step.call_method(tx, "item", [], {})
|
||||
|
||||
self.items = (start, stop, step)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
|
|
|||
|
|
@ -1551,7 +1551,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
step_str = f"{sym}_en_cl - {sym}_st_cl"
|
||||
else:
|
||||
step_str = (
|
||||
f"({sym}_en_cl - {sym}_st_cl + {step_cpp_str} + 1) / {step_cpp_str}"
|
||||
f"({sym}_en_cl - {sym}_st_cl + {step_cpp_str} - 1) / {step_cpp_str}"
|
||||
)
|
||||
self.writeline(f"int64_t {sym}_with_step = {step_str};")
|
||||
self.writeline(f"int64_t {sym} = {sym}_with_step < 0 ? 0 : {sym}_with_step;")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user