mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Avoid DDE in narrow with unbacked start (#166361)
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice. The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate, for that case we shall pass dim_size instead of start+length Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361 Approved by: https://github.com/aorenste
This commit is contained in:
parent
f0745ddb11
commit
1aef88c72d
|
|
@ -1,5 +1,6 @@
|
||||||
#include <ATen/core/ATen_fwd.h>
|
#include <ATen/core/ATen_fwd.h>
|
||||||
#include <c10/core/ScalarType.h>
|
#include <c10/core/ScalarType.h>
|
||||||
|
#include <c10/core/SymInt.h>
|
||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
#include <ATen/AccumulateType.h>
|
#include <ATen/AccumulateType.h>
|
||||||
#include <ATen/Dispatch.h>
|
#include <ATen/Dispatch.h>
|
||||||
|
|
@ -1710,11 +1711,14 @@ Tensor narrow_symint(
|
||||||
"], but got ",
|
"], but got ",
|
||||||
start,
|
start,
|
||||||
")")
|
")")
|
||||||
if (start < 0) {
|
// Bounds check without converting start:
|
||||||
start = start + cur_size;
|
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
|
||||||
}
|
// length <= 0
|
||||||
|
// - If start >= 0: need start + length <= cur_size
|
||||||
|
auto end = start + length;
|
||||||
TORCH_SYM_CHECK(
|
TORCH_SYM_CHECK(
|
||||||
start.sym_le(cur_size - length),
|
(start.sym_lt(0).sym_and((end).sym_le(0)))
|
||||||
|
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
|
||||||
"start (",
|
"start (",
|
||||||
start,
|
start,
|
||||||
") + length (",
|
") + length (",
|
||||||
|
|
@ -1722,7 +1726,31 @@ Tensor narrow_symint(
|
||||||
") exceeds dimension size (",
|
") exceeds dimension size (",
|
||||||
cur_size,
|
cur_size,
|
||||||
").");
|
").");
|
||||||
return at::slice_symint(self, dim, start, start + length, 1);
|
|
||||||
|
if (TORCH_GUARD_OR_FALSE(start.sym_ge(0).sym_or(end.sym_ne(0)))) {
|
||||||
|
return at::slice_symint(self, dim, start, end, 1);
|
||||||
|
} else if (TORCH_GUARD_OR_FALSE(start.sym_lt(0))) {
|
||||||
|
// Avoid the complex symbolic expressions path for non-unbacked.
|
||||||
|
return at::slice_symint(self, dim, start + cur_size, end + cur_size, 1);
|
||||||
|
} else {
|
||||||
|
// Cannot statically determine the condition due to unbacked.
|
||||||
|
// This is an interesting situation; when start is negative and
|
||||||
|
// start + length == 0, slice and narrow do different things.
|
||||||
|
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
|
||||||
|
// pass curr_size instead of 0. Otherwise, they would do the same thing.
|
||||||
|
// This says at runtime: if start < 0 and end == 0, then pass curr_size
|
||||||
|
// instead of 0.
|
||||||
|
|
||||||
|
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
|
||||||
|
auto result =
|
||||||
|
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
|
||||||
|
|
||||||
|
// Ensure slice allocated unbacked size is specialized to length.
|
||||||
|
SymInt new_size = result.sym_size(dim);
|
||||||
|
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This overload exists purely for XLA, because they wanted to pass in
|
// This overload exists purely for XLA, because they wanted to pass in
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
#include <c10/core/SymBool.h>
|
#include <c10/core/SymBool.h>
|
||||||
|
#include <c10/core/SymInt.h>
|
||||||
#include <c10/core/SymNodeImpl.h>
|
#include <c10/core/SymNodeImpl.h>
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
@ -111,4 +112,17 @@ bool SymBool::has_hint() const {
|
||||||
return toSymNodeImpl()->has_hint();
|
return toSymNodeImpl()->has_hint();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SymInt SymBool::toSymInt() const {
|
||||||
|
// If concrete bool, return concrete SymInt
|
||||||
|
if (auto ma = maybe_as_bool()) {
|
||||||
|
return SymInt(*ma ? 1 : 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
|
||||||
|
auto node = toSymNodeImpl();
|
||||||
|
auto one_node = node->wrap_int(1);
|
||||||
|
auto zero_node = node->wrap_int(0);
|
||||||
|
return SymInt(node->sym_ite(one_node, zero_node));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
|
class SymInt;
|
||||||
|
|
||||||
class C10_API SymBool {
|
class C10_API SymBool {
|
||||||
public:
|
public:
|
||||||
/*implicit*/ SymBool(bool b) : data_(b) {}
|
/*implicit*/ SymBool(bool b) : data_(b) {}
|
||||||
|
|
@ -80,6 +82,10 @@ class C10_API SymBool {
|
||||||
return toSymNodeImplUnowned()->constant_bool();
|
return toSymNodeImplUnowned()->constant_bool();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert SymBool to SymInt (0 or 1)
|
||||||
|
// This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
|
||||||
|
SymInt toSymInt() const;
|
||||||
|
|
||||||
bool is_heap_allocated() const {
|
bool is_heap_allocated() const {
|
||||||
return ptr_;
|
return ptr_;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6093,26 +6093,19 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
retry_export(
|
retry_export(
|
||||||
cf_implicitsize(),
|
cf_implicitsize(),
|
||||||
(torch.tensor(2), torch.randn(10)),
|
(torch.tensor(2), torch.randn(10)),
|
||||||
fixes=[
|
fixes=[],
|
||||||
# Could not guard on data-dependent expression u0 < 0
|
|
||||||
"torch._check(i >= 0)",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
class cf_stacklist(torch.nn.Module):
|
class cf_stacklist(torch.nn.Module):
|
||||||
def forward(self, xs, y, fixes):
|
def forward(self, xs, y, fixes):
|
||||||
i = y.item()
|
i = y.item()
|
||||||
eval(fixes)
|
eval(fixes)
|
||||||
# instead of xs[i]
|
|
||||||
return torch.stack(xs, 0).narrow(0, i, 1).squeeze()
|
return torch.stack(xs, 0).narrow(0, i, 1).squeeze()
|
||||||
|
|
||||||
retry_export(
|
retry_export(
|
||||||
cf_stacklist(),
|
cf_stacklist(),
|
||||||
([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
|
([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
|
||||||
fixes=[
|
fixes=[],
|
||||||
# Could not guard on data-dependent expression u0 < 0
|
|
||||||
"torch._check(i >= 0)",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
class cf_tensorsplit(torch.nn.Module):
|
class cf_tensorsplit(torch.nn.Module):
|
||||||
|
|
@ -6166,7 +6159,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
class cf_stacklist(torch.nn.Module):
|
class cf_stacklist(torch.nn.Module):
|
||||||
def forward(self, xs, y):
|
def forward(self, xs, y):
|
||||||
# y.item() is not a local, so we can't suggest a fix
|
# y.item() is not a local, so we can't suggest a fix
|
||||||
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
|
if y.item() < 0:
|
||||||
|
return (
|
||||||
|
torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze()
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
error_type,
|
error_type,
|
||||||
|
|
@ -6196,7 +6194,18 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
def forward(self, xs, y):
|
def forward(self, xs, y):
|
||||||
box = Box(y.item())
|
box = Box(y.item())
|
||||||
# box.content is not a local, so we can't suggest a fix
|
# box.content is not a local, so we can't suggest a fix
|
||||||
return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze()
|
if box.content < 0:
|
||||||
|
return (
|
||||||
|
torch.stack(xs, 0)
|
||||||
|
.narrow(0, box.content + xs.size(), 1)
|
||||||
|
.squeeze()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
torch.stack(xs, 0)
|
||||||
|
.narrow(0, box.content + xs.size(), 1)
|
||||||
|
.squeeze()
|
||||||
|
)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
error_type,
|
error_type,
|
||||||
|
|
|
||||||
|
|
@ -4401,6 +4401,57 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
||||||
|
|
||||||
self.assertEqual(compiled(a, b), func(a, b))
|
self.assertEqual(compiled(a, b), func(a, b))
|
||||||
|
|
||||||
|
@fresh_cache()
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
def test_narrow_unbacked_start(self):
|
||||||
|
def func(x, start, length):
|
||||||
|
# unbacked start
|
||||||
|
u0 = start.item()
|
||||||
|
return torch.narrow(x, 0, u0, length)
|
||||||
|
|
||||||
|
compiled_func = torch.compile(func, fullgraph=True, backend="inductor")
|
||||||
|
|
||||||
|
x = torch.tensor([1, 2, 3, 4, 5, 6])
|
||||||
|
|
||||||
|
# Test cases: (start, length)
|
||||||
|
test_cases = [
|
||||||
|
# Negative starts
|
||||||
|
(-2, 2), # Start from second-to-last element
|
||||||
|
(-1, 1), # Start from last element
|
||||||
|
(-3, 3), # Start from third-to-last element
|
||||||
|
(-6, 2), # Start from beginning (negative)
|
||||||
|
(-4, 1), # Start from fourth-to-last element
|
||||||
|
# Positive starts
|
||||||
|
(0, 2), # Start from beginning
|
||||||
|
(1, 3), # Start from second element
|
||||||
|
(2, 2), # Start from third element
|
||||||
|
(4, 2), # Start near end
|
||||||
|
# Edge cases
|
||||||
|
(0, 6), # Full tensor
|
||||||
|
(0, 1), # Single element from start
|
||||||
|
(5, 1), # Single element from end
|
||||||
|
]
|
||||||
|
|
||||||
|
for start_val, length in test_cases:
|
||||||
|
with self.subTest(start=start_val, length=length):
|
||||||
|
start = torch.tensor([start_val])
|
||||||
|
|
||||||
|
# Test with compiled function
|
||||||
|
result_compiled = compiled_func(x, start, length)
|
||||||
|
|
||||||
|
# Test with eager function (expected behavior)
|
||||||
|
result_eager = func(x, start, length)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
self.assertEqual(result_compiled, result_eager)
|
||||||
|
|
||||||
|
@fresh_cache()
|
||||||
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||||
|
@torch._inductor.config.patch("cpp_wrapper", True)
|
||||||
|
def test_narrow_unbacked_start_cpp_wrapper(self):
|
||||||
|
"""Test narrow with unbacked start with cpp_wrapper"""
|
||||||
|
self.test_narrow_unbacked_start()
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestUnbacked)
|
instantiate_parametrized_tests(TestUnbacked)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2058,7 +2058,8 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
neg = self.codegen_sizevar(
|
neg = self.codegen_sizevar(
|
||||||
sympy.Max(0, sympy.Min(x + node.size, node.size))
|
sympy.Max(0, sympy.Min(x + node.size, node.size))
|
||||||
)
|
)
|
||||||
return f"{pos} if {x} >= 0 else {neg}"
|
x_cond = self.codegen_sizevar(x)
|
||||||
|
return f"{pos} if {x_cond} >= 0 else {neg}"
|
||||||
|
|
||||||
def codegen_with_step(start_var, end_var, step):
|
def codegen_with_step(start_var, end_var, step):
|
||||||
if step == 1:
|
if step == 1:
|
||||||
|
|
|
||||||
|
|
@ -547,6 +547,7 @@ def rebind_unbacked(
|
||||||
assert shape_env is not None
|
assert shape_env is not None
|
||||||
for raw_u0, path in bindings.items():
|
for raw_u0, path in bindings.items():
|
||||||
u1 = pytree.key_get(result, path)
|
u1 = pytree.key_get(result, path)
|
||||||
|
|
||||||
# Sometimes, things were previously unbacked bindings become constants.
|
# Sometimes, things were previously unbacked bindings become constants.
|
||||||
# There are two situations this can happen.
|
# There are two situations this can happen.
|
||||||
#
|
#
|
||||||
|
|
@ -602,7 +603,23 @@ def rebind_unbacked(
|
||||||
if u1.node.hint is not None:
|
if u1.node.hint is not None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
raw_u1 = u1.node.expr
|
# unbacked symbols bindings might be replaced to other backed or
|
||||||
|
# unbacked replacements.
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# u = x.item()
|
||||||
|
# torch._check(u == 5)
|
||||||
|
#
|
||||||
|
# The safest approach is to retrieve raw_u1 from u1.node._expr
|
||||||
|
# and perform the rebinding on the original unbacked symbol,
|
||||||
|
# even if it’s no longer directly referenced.
|
||||||
|
#
|
||||||
|
# In other words, we should always rebind the original symbol
|
||||||
|
# before any replacements are applied.
|
||||||
|
# u0 -> u0 == s1
|
||||||
|
raw_u1 = u1.node._expr
|
||||||
|
|
||||||
|
# TODO Do we still need this logic below?
|
||||||
# Simplify SymBool binding
|
# Simplify SymBool binding
|
||||||
if (
|
if (
|
||||||
isinstance(raw_u1, sympy.Piecewise)
|
isinstance(raw_u1, sympy.Piecewise)
|
||||||
|
|
|
||||||
|
|
@ -306,6 +306,24 @@ class PythonPrinter(ExprPrinter):
|
||||||
raise TypeError("ndigits must be an instance of sympy.Integer")
|
raise TypeError("ndigits must be an instance of sympy.Integer")
|
||||||
return f"round({self._print(number)}, {ndigits})"
|
return f"round({self._print(number)}, {ndigits})"
|
||||||
|
|
||||||
|
def _print_Piecewise(self, expr: sympy.Expr) -> str:
|
||||||
|
# Convert Piecewise(expr_cond_pairs) to nested ternary expressions
|
||||||
|
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
|
||||||
|
# becomes: e1 if c1 else (e2 if c2 else (... else eN))
|
||||||
|
result = None
|
||||||
|
for expr_i, cond_i in reversed(expr.args):
|
||||||
|
expr_str = self._print(expr_i)
|
||||||
|
if cond_i == True: # noqa: E712
|
||||||
|
# This is the default case
|
||||||
|
result = expr_str
|
||||||
|
else:
|
||||||
|
cond_str = self._print(cond_i)
|
||||||
|
if result is None:
|
||||||
|
result = expr_str
|
||||||
|
else:
|
||||||
|
result = f"({expr_str} if {cond_str} else {result})"
|
||||||
|
return result if result else "0"
|
||||||
|
|
||||||
|
|
||||||
class CppPrinter(ExprPrinter):
|
class CppPrinter(ExprPrinter):
|
||||||
def _print_Integer(self, expr: sympy.Expr) -> str:
|
def _print_Integer(self, expr: sympy.Expr) -> str:
|
||||||
|
|
@ -327,6 +345,24 @@ class CppPrinter(ExprPrinter):
|
||||||
)
|
)
|
||||||
return f"{c} ? {p} : {q}"
|
return f"{c} ? {p} : {q}"
|
||||||
|
|
||||||
|
def _print_Piecewise(self, expr: sympy.Expr) -> str:
|
||||||
|
# Convert Piecewise(expr_cond_pairs) to nested ternary operators
|
||||||
|
# Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
|
||||||
|
# becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
|
||||||
|
result = None
|
||||||
|
for expr_i, cond_i in reversed(expr.args):
|
||||||
|
expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
|
||||||
|
if cond_i == True: # noqa: E712
|
||||||
|
# This is the default case
|
||||||
|
result = expr_str
|
||||||
|
else:
|
||||||
|
cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5)
|
||||||
|
if result is None:
|
||||||
|
result = expr_str
|
||||||
|
else:
|
||||||
|
result = f"{cond_str} ? {expr_str} : {result}"
|
||||||
|
return f"({result})" if result else "0"
|
||||||
|
|
||||||
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
||||||
x, div, mod = expr.args
|
x, div, mod = expr.args
|
||||||
x = self.doprint(x)
|
x = self.doprint(x)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user