diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6df7761d822..6136a6aa8c5 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1,5 +1,6 @@ #include #include +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -1710,11 +1711,14 @@ Tensor narrow_symint( "], but got ", start, ")") - if (start < 0) { - start = start + cur_size; - } + // Bounds check without converting start: + // - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start + + // length <= 0 + // - If start >= 0: need start + length <= cur_size + auto end = start + length; TORCH_SYM_CHECK( - start.sym_le(cur_size - length), + (start.sym_lt(0).sym_and((end).sym_le(0))) + .sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))), "start (", start, ") + length (", @@ -1722,7 +1726,31 @@ Tensor narrow_symint( ") exceeds dimension size (", cur_size, ")."); - return at::slice_symint(self, dim, start, start + length, 1); + + if (TORCH_GUARD_OR_FALSE(start.sym_ge(0).sym_or(end.sym_ne(0)))) { + return at::slice_symint(self, dim, start, end, 1); + } else if (TORCH_GUARD_OR_FALSE(start.sym_lt(0))) { + // Avoid the complex symbolic expressions path for non-unbacked. + return at::slice_symint(self, dim, start + cur_size, end + cur_size, 1); + } else { + // Cannot statically determine the condition due to unbacked. + // This is an interesting situation; when start is negative and + // start + length == 0, slice and narrow do different things. + // i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to + // pass curr_size instead of 0. Otherwise, they would do the same thing. + // This says at runtime: if start < 0 and end == 0, then pass curr_size + // instead of 0. + + auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt(); + auto result = + at::slice_symint(self, dim, start, end + use_different * cur_size, 1); + + // Ensure slice allocated unbacked size is specialized to length. + SymInt new_size = result.sym_size(dim); + TORCH_SYM_CHECK(new_size.sym_eq(length), "") + + return result; + } } // This overload exists purely for XLA, because they wanted to pass in diff --git a/c10/core/SymBool.cpp b/c10/core/SymBool.cpp index d804eb9d274..48c407b8b06 100644 --- a/c10/core/SymBool.cpp +++ b/c10/core/SymBool.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace c10 { @@ -111,4 +112,17 @@ bool SymBool::has_hint() const { return toSymNodeImpl()->has_hint(); } +SymInt SymBool::toSymInt() const { + // If concrete bool, return concrete SymInt + if (auto ma = maybe_as_bool()) { + return SymInt(*ma ? 1 : 0); + } + + // Symbolic case: use sym_ite to convert bool to int (0 or 1) + auto node = toSymNodeImpl(); + auto one_node = node->wrap_int(1); + auto zero_node = node->wrap_int(0); + return SymInt(node->sym_ite(one_node, zero_node)); +} + } // namespace c10 diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index d5d509e239b..a27a28a5bf8 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -12,6 +12,8 @@ namespace c10 { +class SymInt; + class C10_API SymBool { public: /*implicit*/ SymBool(bool b) : data_(b) {} @@ -80,6 +82,10 @@ class C10_API SymBool { return toSymNodeImplUnowned()->constant_bool(); } + // Convert SymBool to SymInt (0 or 1) + // This is the C++ equivalent of Python's cast_symbool_to_symint_guardless + SymInt toSymInt() const; + bool is_heap_allocated() const { return ptr_; } diff --git a/test/export/test_export.py b/test/export/test_export.py index 3908f03b11e..cdc18b1d4c5 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6093,26 +6093,19 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): retry_export( cf_implicitsize(), (torch.tensor(2), torch.randn(10)), - fixes=[ - # Could not guard on data-dependent expression u0 < 0 - "torch._check(i >= 0)", - ], + fixes=[], ) class cf_stacklist(torch.nn.Module): def forward(self, xs, y, fixes): i = y.item() eval(fixes) - # instead of xs[i] return torch.stack(xs, 0).narrow(0, i, 1).squeeze() retry_export( cf_stacklist(), ([torch.ones(5) * i for i in range(10)], torch.tensor(2)), - fixes=[ - # Could not guard on data-dependent expression u0 < 0 - "torch._check(i >= 0)", - ], + fixes=[], ) class cf_tensorsplit(torch.nn.Module): @@ -6166,7 +6159,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): class cf_stacklist(torch.nn.Module): def forward(self, xs, y): # y.item() is not a local, so we can't suggest a fix - return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() + if y.item() < 0: + return ( + torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze() + ) + else: + return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() with self.assertRaisesRegex( error_type, @@ -6196,7 +6194,18 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): def forward(self, xs, y): box = Box(y.item()) # box.content is not a local, so we can't suggest a fix - return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze() + if box.content < 0: + return ( + torch.stack(xs, 0) + .narrow(0, box.content + xs.size(), 1) + .squeeze() + ) + else: + return ( + torch.stack(xs, 0) + .narrow(0, box.content + xs.size(), 1) + .squeeze() + ) with self.assertRaisesRegex( error_type, diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index fb1d22805d5..b63e0427c26 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4401,6 +4401,57 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1] self.assertEqual(compiled(a, b), func(a, b)) + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_narrow_unbacked_start(self): + def func(x, start, length): + # unbacked start + u0 = start.item() + return torch.narrow(x, 0, u0, length) + + compiled_func = torch.compile(func, fullgraph=True, backend="inductor") + + x = torch.tensor([1, 2, 3, 4, 5, 6]) + + # Test cases: (start, length) + test_cases = [ + # Negative starts + (-2, 2), # Start from second-to-last element + (-1, 1), # Start from last element + (-3, 3), # Start from third-to-last element + (-6, 2), # Start from beginning (negative) + (-4, 1), # Start from fourth-to-last element + # Positive starts + (0, 2), # Start from beginning + (1, 3), # Start from second element + (2, 2), # Start from third element + (4, 2), # Start near end + # Edge cases + (0, 6), # Full tensor + (0, 1), # Single element from start + (5, 1), # Single element from end + ] + + for start_val, length in test_cases: + with self.subTest(start=start_val, length=length): + start = torch.tensor([start_val]) + + # Test with compiled function + result_compiled = compiled_func(x, start, length) + + # Test with eager function (expected behavior) + result_eager = func(x, start, length) + + # Compare results + self.assertEqual(result_compiled, result_eager) + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_narrow_unbacked_start_cpp_wrapper(self): + """Test narrow with unbacked start with cpp_wrapper""" + self.test_narrow_unbacked_start() + instantiate_parametrized_tests(TestUnbacked) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 829f3ac974d..af0e9ed2ebe 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2058,7 +2058,8 @@ class PythonWrapperCodegen(CodeGen): neg = self.codegen_sizevar( sympy.Max(0, sympy.Min(x + node.size, node.size)) ) - return f"{pos} if {x} >= 0 else {neg}" + x_cond = self.codegen_sizevar(x) + return f"{pos} if {x_cond} >= 0 else {neg}" def codegen_with_step(start_var, end_var, step): if step == 1: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index aeccdfbe000..693d25aea61 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -547,6 +547,7 @@ def rebind_unbacked( assert shape_env is not None for raw_u0, path in bindings.items(): u1 = pytree.key_get(result, path) + # Sometimes, things were previously unbacked bindings become constants. # There are two situations this can happen. # @@ -602,7 +603,23 @@ def rebind_unbacked( if u1.node.hint is not None: continue - raw_u1 = u1.node.expr + # unbacked symbols bindings might be replaced to other backed or + # unbacked replacements. + # + # Example: + # u = x.item() + # torch._check(u == 5) + # + # The safest approach is to retrieve raw_u1 from u1.node._expr + # and perform the rebinding on the original unbacked symbol, + # even if it’s no longer directly referenced. + # + # In other words, we should always rebind the original symbol + # before any replacements are applied. + # u0 -> u0 == s1 + raw_u1 = u1.node._expr + + # TODO Do we still need this logic below? # Simplify SymBool binding if ( isinstance(raw_u1, sympy.Piecewise) diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 526443577b3..bf9eaf71e42 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -306,6 +306,24 @@ class PythonPrinter(ExprPrinter): raise TypeError("ndigits must be an instance of sympy.Integer") return f"round({self._print(number)}, {ndigits})" + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary expressions + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: e1 if c1 else (e2 if c2 else (... else eN)) + result = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self._print(expr_i) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self._print(cond_i) + if result is None: + result = expr_str + else: + result = f"({expr_str} if {cond_str} else {result})" + return result if result else "0" + class CppPrinter(ExprPrinter): def _print_Integer(self, expr: sympy.Expr) -> str: @@ -327,6 +345,24 @@ class CppPrinter(ExprPrinter): ) return f"{c} ? {p} : {q}" + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary operators + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: c1 ? e1 : (c2 ? e2 : (... : eN)) + result = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5) + if result is None: + result = expr_str + else: + result = f"{cond_str} ? {expr_str} : {result}" + return f"({result})" if result else "0" + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: x, div, mod = expr.args x = self.doprint(x)