Avoid DDE in narrow with unbacked start (#166361)

Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate,
for that case we shall pass dim_size instead of start+length

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361
Approved by: https://github.com/aorenste
This commit is contained in:
Laith Sakka 2025-10-29 15:27:14 -07:00 committed by PyTorch MergeBot
parent f0745ddb11
commit 1aef88c72d
8 changed files with 180 additions and 18 deletions

View File

@ -1,5 +1,6 @@
#include <ATen/core/ATen_fwd.h> #include <ATen/core/ATen_fwd.h>
#include <c10/core/ScalarType.h> #include <c10/core/ScalarType.h>
#include <c10/core/SymInt.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h> #include <ATen/Dispatch.h>
@ -1710,11 +1711,14 @@ Tensor narrow_symint(
"], but got ", "], but got ",
start, start,
")") ")")
if (start < 0) { // Bounds check without converting start:
start = start + cur_size; // - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
} // length <= 0
// - If start >= 0: need start + length <= cur_size
auto end = start + length;
TORCH_SYM_CHECK( TORCH_SYM_CHECK(
start.sym_le(cur_size - length), (start.sym_lt(0).sym_and((end).sym_le(0)))
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
"start (", "start (",
start, start,
") + length (", ") + length (",
@ -1722,7 +1726,31 @@ Tensor narrow_symint(
") exceeds dimension size (", ") exceeds dimension size (",
cur_size, cur_size,
")."); ").");
return at::slice_symint(self, dim, start, start + length, 1);
if (TORCH_GUARD_OR_FALSE(start.sym_ge(0).sym_or(end.sym_ne(0)))) {
return at::slice_symint(self, dim, start, end, 1);
} else if (TORCH_GUARD_OR_FALSE(start.sym_lt(0))) {
// Avoid the complex symbolic expressions path for non-unbacked.
return at::slice_symint(self, dim, start + cur_size, end + cur_size, 1);
} else {
// Cannot statically determine the condition due to unbacked.
// This is an interesting situation; when start is negative and
// start + length == 0, slice and narrow do different things.
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
// pass curr_size instead of 0. Otherwise, they would do the same thing.
// This says at runtime: if start < 0 and end == 0, then pass curr_size
// instead of 0.
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
auto result =
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
// Ensure slice allocated unbacked size is specialized to length.
SymInt new_size = result.sym_size(dim);
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
return result;
}
} }
// This overload exists purely for XLA, because they wanted to pass in // This overload exists purely for XLA, because they wanted to pass in

View File

@ -1,4 +1,5 @@
#include <c10/core/SymBool.h> #include <c10/core/SymBool.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h> #include <c10/core/SymNodeImpl.h>
namespace c10 { namespace c10 {
@ -111,4 +112,17 @@ bool SymBool::has_hint() const {
return toSymNodeImpl()->has_hint(); return toSymNodeImpl()->has_hint();
} }
SymInt SymBool::toSymInt() const {
// If concrete bool, return concrete SymInt
if (auto ma = maybe_as_bool()) {
return SymInt(*ma ? 1 : 0);
}
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
auto node = toSymNodeImpl();
auto one_node = node->wrap_int(1);
auto zero_node = node->wrap_int(0);
return SymInt(node->sym_ite(one_node, zero_node));
}
} // namespace c10 } // namespace c10

View File

@ -12,6 +12,8 @@
namespace c10 { namespace c10 {
class SymInt;
class C10_API SymBool { class C10_API SymBool {
public: public:
/*implicit*/ SymBool(bool b) : data_(b) {} /*implicit*/ SymBool(bool b) : data_(b) {}
@ -80,6 +82,10 @@ class C10_API SymBool {
return toSymNodeImplUnowned()->constant_bool(); return toSymNodeImplUnowned()->constant_bool();
} }
// Convert SymBool to SymInt (0 or 1)
// This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
SymInt toSymInt() const;
bool is_heap_allocated() const { bool is_heap_allocated() const {
return ptr_; return ptr_;
} }

View File

@ -6093,26 +6093,19 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
retry_export( retry_export(
cf_implicitsize(), cf_implicitsize(),
(torch.tensor(2), torch.randn(10)), (torch.tensor(2), torch.randn(10)),
fixes=[ fixes=[],
# Could not guard on data-dependent expression u0 < 0
"torch._check(i >= 0)",
],
) )
class cf_stacklist(torch.nn.Module): class cf_stacklist(torch.nn.Module):
def forward(self, xs, y, fixes): def forward(self, xs, y, fixes):
i = y.item() i = y.item()
eval(fixes) eval(fixes)
# instead of xs[i]
return torch.stack(xs, 0).narrow(0, i, 1).squeeze() return torch.stack(xs, 0).narrow(0, i, 1).squeeze()
retry_export( retry_export(
cf_stacklist(), cf_stacklist(),
([torch.ones(5) * i for i in range(10)], torch.tensor(2)), ([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
fixes=[ fixes=[],
# Could not guard on data-dependent expression u0 < 0
"torch._check(i >= 0)",
],
) )
class cf_tensorsplit(torch.nn.Module): class cf_tensorsplit(torch.nn.Module):
@ -6166,6 +6159,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
class cf_stacklist(torch.nn.Module): class cf_stacklist(torch.nn.Module):
def forward(self, xs, y): def forward(self, xs, y):
# y.item() is not a local, so we can't suggest a fix # y.item() is not a local, so we can't suggest a fix
if y.item() < 0:
return (
torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze()
)
else:
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
with self.assertRaisesRegex( with self.assertRaisesRegex(
@ -6196,7 +6194,18 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
def forward(self, xs, y): def forward(self, xs, y):
box = Box(y.item()) box = Box(y.item())
# box.content is not a local, so we can't suggest a fix # box.content is not a local, so we can't suggest a fix
return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze() if box.content < 0:
return (
torch.stack(xs, 0)
.narrow(0, box.content + xs.size(), 1)
.squeeze()
)
else:
return (
torch.stack(xs, 0)
.narrow(0, box.content + xs.size(), 1)
.squeeze()
)
with self.assertRaisesRegex( with self.assertRaisesRegex(
error_type, error_type,

View File

@ -4401,6 +4401,57 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
self.assertEqual(compiled(a, b), func(a, b)) self.assertEqual(compiled(a, b), func(a, b))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_narrow_unbacked_start(self):
def func(x, start, length):
# unbacked start
u0 = start.item()
return torch.narrow(x, 0, u0, length)
compiled_func = torch.compile(func, fullgraph=True, backend="inductor")
x = torch.tensor([1, 2, 3, 4, 5, 6])
# Test cases: (start, length)
test_cases = [
# Negative starts
(-2, 2), # Start from second-to-last element
(-1, 1), # Start from last element
(-3, 3), # Start from third-to-last element
(-6, 2), # Start from beginning (negative)
(-4, 1), # Start from fourth-to-last element
# Positive starts
(0, 2), # Start from beginning
(1, 3), # Start from second element
(2, 2), # Start from third element
(4, 2), # Start near end
# Edge cases
(0, 6), # Full tensor
(0, 1), # Single element from start
(5, 1), # Single element from end
]
for start_val, length in test_cases:
with self.subTest(start=start_val, length=length):
start = torch.tensor([start_val])
# Test with compiled function
result_compiled = compiled_func(x, start, length)
# Test with eager function (expected behavior)
result_eager = func(x, start, length)
# Compare results
self.assertEqual(result_compiled, result_eager)
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch("cpp_wrapper", True)
def test_narrow_unbacked_start_cpp_wrapper(self):
"""Test narrow with unbacked start with cpp_wrapper"""
self.test_narrow_unbacked_start()
instantiate_parametrized_tests(TestUnbacked) instantiate_parametrized_tests(TestUnbacked)

View File

@ -2058,7 +2058,8 @@ class PythonWrapperCodegen(CodeGen):
neg = self.codegen_sizevar( neg = self.codegen_sizevar(
sympy.Max(0, sympy.Min(x + node.size, node.size)) sympy.Max(0, sympy.Min(x + node.size, node.size))
) )
return f"{pos} if {x} >= 0 else {neg}" x_cond = self.codegen_sizevar(x)
return f"{pos} if {x_cond} >= 0 else {neg}"
def codegen_with_step(start_var, end_var, step): def codegen_with_step(start_var, end_var, step):
if step == 1: if step == 1:

View File

@ -547,6 +547,7 @@ def rebind_unbacked(
assert shape_env is not None assert shape_env is not None
for raw_u0, path in bindings.items(): for raw_u0, path in bindings.items():
u1 = pytree.key_get(result, path) u1 = pytree.key_get(result, path)
# Sometimes, things were previously unbacked bindings become constants. # Sometimes, things were previously unbacked bindings become constants.
# There are two situations this can happen. # There are two situations this can happen.
# #
@ -602,7 +603,23 @@ def rebind_unbacked(
if u1.node.hint is not None: if u1.node.hint is not None:
continue continue
raw_u1 = u1.node.expr # unbacked symbols bindings might be replaced to other backed or
# unbacked replacements.
#
# Example:
# u = x.item()
# torch._check(u == 5)
#
# The safest approach is to retrieve raw_u1 from u1.node._expr
# and perform the rebinding on the original unbacked symbol,
# even if its no longer directly referenced.
#
# In other words, we should always rebind the original symbol
# before any replacements are applied.
# u0 -> u0 == s1
raw_u1 = u1.node._expr
# TODO Do we still need this logic below?
# Simplify SymBool binding # Simplify SymBool binding
if ( if (
isinstance(raw_u1, sympy.Piecewise) isinstance(raw_u1, sympy.Piecewise)

View File

@ -306,6 +306,24 @@ class PythonPrinter(ExprPrinter):
raise TypeError("ndigits must be an instance of sympy.Integer") raise TypeError("ndigits must be an instance of sympy.Integer")
return f"round({self._print(number)}, {ndigits})" return f"round({self._print(number)}, {ndigits})"
def _print_Piecewise(self, expr: sympy.Expr) -> str:
# Convert Piecewise(expr_cond_pairs) to nested ternary expressions
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
# becomes: e1 if c1 else (e2 if c2 else (... else eN))
result = None
for expr_i, cond_i in reversed(expr.args):
expr_str = self._print(expr_i)
if cond_i == True: # noqa: E712
# This is the default case
result = expr_str
else:
cond_str = self._print(cond_i)
if result is None:
result = expr_str
else:
result = f"({expr_str} if {cond_str} else {result})"
return result if result else "0"
class CppPrinter(ExprPrinter): class CppPrinter(ExprPrinter):
def _print_Integer(self, expr: sympy.Expr) -> str: def _print_Integer(self, expr: sympy.Expr) -> str:
@ -327,6 +345,24 @@ class CppPrinter(ExprPrinter):
) )
return f"{c} ? {p} : {q}" return f"{c} ? {p} : {q}"
def _print_Piecewise(self, expr: sympy.Expr) -> str:
# Convert Piecewise(expr_cond_pairs) to nested ternary operators
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
# becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
result = None
for expr_i, cond_i in reversed(expr.args):
expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
if cond_i == True: # noqa: E712
# This is the default case
result = expr_str
else:
cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5)
if result is None:
result = expr_str
else:
result = f"{cond_str} ? {expr_str} : {result}"
return f"({result})" if result else "0"
def _print_ModularIndexing(self, expr: sympy.Expr) -> str: def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
x, div, mod = expr.args x, div, mod = expr.args
x = self.doprint(x) x = self.doprint(x)