mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Avoid DDE in narrow with unbacked start (#166361)
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice. The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate, for that case we shall pass dim_size instead of start+length Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361 Approved by: https://github.com/aorenste
This commit is contained in:
parent
f0745ddb11
commit
1aef88c72d
|
|
@ -1,5 +1,6 @@
|
|||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
|
|
@ -1710,11 +1711,14 @@ Tensor narrow_symint(
|
|||
"], but got ",
|
||||
start,
|
||||
")")
|
||||
if (start < 0) {
|
||||
start = start + cur_size;
|
||||
}
|
||||
// Bounds check without converting start:
|
||||
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
|
||||
// length <= 0
|
||||
// - If start >= 0: need start + length <= cur_size
|
||||
auto end = start + length;
|
||||
TORCH_SYM_CHECK(
|
||||
start.sym_le(cur_size - length),
|
||||
(start.sym_lt(0).sym_and((end).sym_le(0)))
|
||||
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
|
||||
"start (",
|
||||
start,
|
||||
") + length (",
|
||||
|
|
@ -1722,7 +1726,31 @@ Tensor narrow_symint(
|
|||
") exceeds dimension size (",
|
||||
cur_size,
|
||||
").");
|
||||
return at::slice_symint(self, dim, start, start + length, 1);
|
||||
|
||||
if (TORCH_GUARD_OR_FALSE(start.sym_ge(0).sym_or(end.sym_ne(0)))) {
|
||||
return at::slice_symint(self, dim, start, end, 1);
|
||||
} else if (TORCH_GUARD_OR_FALSE(start.sym_lt(0))) {
|
||||
// Avoid the complex symbolic expressions path for non-unbacked.
|
||||
return at::slice_symint(self, dim, start + cur_size, end + cur_size, 1);
|
||||
} else {
|
||||
// Cannot statically determine the condition due to unbacked.
|
||||
// This is an interesting situation; when start is negative and
|
||||
// start + length == 0, slice and narrow do different things.
|
||||
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
|
||||
// pass curr_size instead of 0. Otherwise, they would do the same thing.
|
||||
// This says at runtime: if start < 0 and end == 0, then pass curr_size
|
||||
// instead of 0.
|
||||
|
||||
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
|
||||
auto result =
|
||||
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
|
||||
|
||||
// Ensure slice allocated unbacked size is specialized to length.
|
||||
SymInt new_size = result.sym_size(dim);
|
||||
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// This overload exists purely for XLA, because they wanted to pass in
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <c10/core/SymBool.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
|
||||
namespace c10 {
|
||||
|
|
@ -111,4 +112,17 @@ bool SymBool::has_hint() const {
|
|||
return toSymNodeImpl()->has_hint();
|
||||
}
|
||||
|
||||
SymInt SymBool::toSymInt() const {
|
||||
// If concrete bool, return concrete SymInt
|
||||
if (auto ma = maybe_as_bool()) {
|
||||
return SymInt(*ma ? 1 : 0);
|
||||
}
|
||||
|
||||
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
|
||||
auto node = toSymNodeImpl();
|
||||
auto one_node = node->wrap_int(1);
|
||||
auto zero_node = node->wrap_int(0);
|
||||
return SymInt(node->sym_ite(one_node, zero_node));
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@
|
|||
|
||||
namespace c10 {
|
||||
|
||||
class SymInt;
|
||||
|
||||
class C10_API SymBool {
|
||||
public:
|
||||
/*implicit*/ SymBool(bool b) : data_(b) {}
|
||||
|
|
@ -80,6 +82,10 @@ class C10_API SymBool {
|
|||
return toSymNodeImplUnowned()->constant_bool();
|
||||
}
|
||||
|
||||
// Convert SymBool to SymInt (0 or 1)
|
||||
// This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
|
||||
SymInt toSymInt() const;
|
||||
|
||||
bool is_heap_allocated() const {
|
||||
return ptr_;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6093,26 +6093,19 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
retry_export(
|
||||
cf_implicitsize(),
|
||||
(torch.tensor(2), torch.randn(10)),
|
||||
fixes=[
|
||||
# Could not guard on data-dependent expression u0 < 0
|
||||
"torch._check(i >= 0)",
|
||||
],
|
||||
fixes=[],
|
||||
)
|
||||
|
||||
class cf_stacklist(torch.nn.Module):
|
||||
def forward(self, xs, y, fixes):
|
||||
i = y.item()
|
||||
eval(fixes)
|
||||
# instead of xs[i]
|
||||
return torch.stack(xs, 0).narrow(0, i, 1).squeeze()
|
||||
|
||||
retry_export(
|
||||
cf_stacklist(),
|
||||
([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
|
||||
fixes=[
|
||||
# Could not guard on data-dependent expression u0 < 0
|
||||
"torch._check(i >= 0)",
|
||||
],
|
||||
fixes=[],
|
||||
)
|
||||
|
||||
class cf_tensorsplit(torch.nn.Module):
|
||||
|
|
@ -6166,6 +6159,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
class cf_stacklist(torch.nn.Module):
|
||||
def forward(self, xs, y):
|
||||
# y.item() is not a local, so we can't suggest a fix
|
||||
if y.item() < 0:
|
||||
return (
|
||||
torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze()
|
||||
)
|
||||
else:
|
||||
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
|
|
@ -6196,7 +6194,18 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
def forward(self, xs, y):
|
||||
box = Box(y.item())
|
||||
# box.content is not a local, so we can't suggest a fix
|
||||
return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze()
|
||||
if box.content < 0:
|
||||
return (
|
||||
torch.stack(xs, 0)
|
||||
.narrow(0, box.content + xs.size(), 1)
|
||||
.squeeze()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
torch.stack(xs, 0)
|
||||
.narrow(0, box.content + xs.size(), 1)
|
||||
.squeeze()
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
error_type,
|
||||
|
|
|
|||
|
|
@ -4401,6 +4401,57 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
|||
|
||||
self.assertEqual(compiled(a, b), func(a, b))
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_narrow_unbacked_start(self):
|
||||
def func(x, start, length):
|
||||
# unbacked start
|
||||
u0 = start.item()
|
||||
return torch.narrow(x, 0, u0, length)
|
||||
|
||||
compiled_func = torch.compile(func, fullgraph=True, backend="inductor")
|
||||
|
||||
x = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||
|
||||
# Test cases: (start, length)
|
||||
test_cases = [
|
||||
# Negative starts
|
||||
(-2, 2), # Start from second-to-last element
|
||||
(-1, 1), # Start from last element
|
||||
(-3, 3), # Start from third-to-last element
|
||||
(-6, 2), # Start from beginning (negative)
|
||||
(-4, 1), # Start from fourth-to-last element
|
||||
# Positive starts
|
||||
(0, 2), # Start from beginning
|
||||
(1, 3), # Start from second element
|
||||
(2, 2), # Start from third element
|
||||
(4, 2), # Start near end
|
||||
# Edge cases
|
||||
(0, 6), # Full tensor
|
||||
(0, 1), # Single element from start
|
||||
(5, 1), # Single element from end
|
||||
]
|
||||
|
||||
for start_val, length in test_cases:
|
||||
with self.subTest(start=start_val, length=length):
|
||||
start = torch.tensor([start_val])
|
||||
|
||||
# Test with compiled function
|
||||
result_compiled = compiled_func(x, start, length)
|
||||
|
||||
# Test with eager function (expected behavior)
|
||||
result_eager = func(x, start, length)
|
||||
|
||||
# Compare results
|
||||
self.assertEqual(result_compiled, result_eager)
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||
def test_narrow_unbacked_start_cpp_wrapper(self):
|
||||
"""Test narrow with unbacked start with cpp_wrapper"""
|
||||
self.test_narrow_unbacked_start()
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
|
|
|||
|
|
@ -2058,7 +2058,8 @@ class PythonWrapperCodegen(CodeGen):
|
|||
neg = self.codegen_sizevar(
|
||||
sympy.Max(0, sympy.Min(x + node.size, node.size))
|
||||
)
|
||||
return f"{pos} if {x} >= 0 else {neg}"
|
||||
x_cond = self.codegen_sizevar(x)
|
||||
return f"{pos} if {x_cond} >= 0 else {neg}"
|
||||
|
||||
def codegen_with_step(start_var, end_var, step):
|
||||
if step == 1:
|
||||
|
|
|
|||
|
|
@ -547,6 +547,7 @@ def rebind_unbacked(
|
|||
assert shape_env is not None
|
||||
for raw_u0, path in bindings.items():
|
||||
u1 = pytree.key_get(result, path)
|
||||
|
||||
# Sometimes, things were previously unbacked bindings become constants.
|
||||
# There are two situations this can happen.
|
||||
#
|
||||
|
|
@ -602,7 +603,23 @@ def rebind_unbacked(
|
|||
if u1.node.hint is not None:
|
||||
continue
|
||||
|
||||
raw_u1 = u1.node.expr
|
||||
# unbacked symbols bindings might be replaced to other backed or
|
||||
# unbacked replacements.
|
||||
#
|
||||
# Example:
|
||||
# u = x.item()
|
||||
# torch._check(u == 5)
|
||||
#
|
||||
# The safest approach is to retrieve raw_u1 from u1.node._expr
|
||||
# and perform the rebinding on the original unbacked symbol,
|
||||
# even if it’s no longer directly referenced.
|
||||
#
|
||||
# In other words, we should always rebind the original symbol
|
||||
# before any replacements are applied.
|
||||
# u0 -> u0 == s1
|
||||
raw_u1 = u1.node._expr
|
||||
|
||||
# TODO Do we still need this logic below?
|
||||
# Simplify SymBool binding
|
||||
if (
|
||||
isinstance(raw_u1, sympy.Piecewise)
|
||||
|
|
|
|||
|
|
@ -306,6 +306,24 @@ class PythonPrinter(ExprPrinter):
|
|||
raise TypeError("ndigits must be an instance of sympy.Integer")
|
||||
return f"round({self._print(number)}, {ndigits})"
|
||||
|
||||
def _print_Piecewise(self, expr: sympy.Expr) -> str:
|
||||
# Convert Piecewise(expr_cond_pairs) to nested ternary expressions
|
||||
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
|
||||
# becomes: e1 if c1 else (e2 if c2 else (... else eN))
|
||||
result = None
|
||||
for expr_i, cond_i in reversed(expr.args):
|
||||
expr_str = self._print(expr_i)
|
||||
if cond_i == True: # noqa: E712
|
||||
# This is the default case
|
||||
result = expr_str
|
||||
else:
|
||||
cond_str = self._print(cond_i)
|
||||
if result is None:
|
||||
result = expr_str
|
||||
else:
|
||||
result = f"({expr_str} if {cond_str} else {result})"
|
||||
return result if result else "0"
|
||||
|
||||
|
||||
class CppPrinter(ExprPrinter):
|
||||
def _print_Integer(self, expr: sympy.Expr) -> str:
|
||||
|
|
@ -327,6 +345,24 @@ class CppPrinter(ExprPrinter):
|
|||
)
|
||||
return f"{c} ? {p} : {q}"
|
||||
|
||||
def _print_Piecewise(self, expr: sympy.Expr) -> str:
|
||||
# Convert Piecewise(expr_cond_pairs) to nested ternary operators
|
||||
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
|
||||
# becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
|
||||
result = None
|
||||
for expr_i, cond_i in reversed(expr.args):
|
||||
expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
|
||||
if cond_i == True: # noqa: E712
|
||||
# This is the default case
|
||||
result = expr_str
|
||||
else:
|
||||
cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5)
|
||||
if result is None:
|
||||
result = expr_str
|
||||
else:
|
||||
result = f"{cond_str} ? {expr_str} : {result}"
|
||||
return f"({result})" if result else "0"
|
||||
|
||||
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
||||
x, div, mod = expr.args
|
||||
x = self.doprint(x)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user