From 39df901b2aa0b975e59683b0593f18b86251440c Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 27 May 2025 13:24:57 -0700 Subject: [PATCH] introduce definitely_contiguous and use it for reshape and tensor meta data computation. (#153432) when a tensor has unbacked symbols it can be general enough to represent both contiguous and non contiguous tensors. in that case we cant really evaluate is_contiguous. In many places in the code base, we check for is_contiguous to take a fast path. but the general path usually works for both contiguous and not contiguous in that case we probably want to use definitely _contiguous API. This is appleid for reshape in this PR and also to tensor meta data computation, the meta data now will have an attribute that says that its contiguous when its always contiguous. We would store that only if definitely _contiguous is true now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153432 Approved by: https://github.com/bobrenjc93 --- aten/src/ATen/native/TensorShape.cpp | 12 ++- .../pr_time_benchmarks/expected_results.csv | 24 ++--- c10/core/Contiguity.h | 47 +++++++--- test/export/test_export.py | 2 +- test/test_dynamic_shapes.py | 33 +++++++ test/test_proxy_tensor.py | 4 +- torch/_inductor/codegen/simd.py | 7 +- torch/_prims_common/__init__.py | 90 +++++++++++++++---- torch/_refs/__init__.py | 3 +- torch/fx/passes/shape_prop.py | 9 +- 10 files changed, 178 insertions(+), 53 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 27ab9c1e834..f04e6cac631 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -1993,11 +1994,15 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { TORCH_CHECK(false, "reshape is not implemented for sparse tensors"); } - if (self.is_contiguous() && !self.is_mkldnn()) { + auto sym_sizes = self.sym_sizes(); + auto sym_strides = self.sym_strides(); + auto sym_numel = self.sym_numel(); + if (definitely_contiguous(sym_sizes, sym_strides, sym_numel) && + !self.is_mkldnn()) { return self.view_symint(proposed_shape); } - c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel()); + c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel); if (self.is_mkldnn()) { return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape)); @@ -2005,8 +2010,7 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { // `computeStride` returns the proper strides to use if this // `reshape` can be just a view. - auto stride = - at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape); + auto stride = at::detail::computeStride(sym_sizes, sym_strides, shape); // NB: Even though we have viewable geometry and the target strides here, // we do not just call `as_strided` on `self` because the backward diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 645a4b6f469..2088dcf6d50 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -2,7 +2,7 @@ add_loop_eager,compile_time_instruction_count,2953000000,0.015 -add_loop_eager_dynamic,compile_time_instruction_count,5808000000,0.025 +add_loop_eager_dynamic,compile_time_instruction_count,5738000000,0.025 @@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44010000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44490000000,0.025 @@ -22,11 +22,11 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,939900000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18140000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18270000000,0.015 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16220000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16310000000,0.015 @@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000 -update_hint_regression,compile_time_instruction_count,1681000000,0.02 +update_hint_regression,compile_time_instruction_count,1700000000,0.02 -float_args,compile_time_instruction_count,449800000,0.015 +float_args,compile_time_instruction_count,452500000,0.015 @@ -54,24 +54,24 @@ symint_sum_loop,compile_time_instruction_count,4262000000,0.015 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2091000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2112000000,0.015 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5981000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6022000000,0.015 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,8585000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,8672000000,0.015 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1900000000,0.015 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1917000000,0.015 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,0.015 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3859000000,0.015 -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10350000000,0.015 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10420000000,0.015 diff --git a/c10/core/Contiguity.h b/c10/core/Contiguity.h index 36f41b6251c..276d2ce07b5 100644 --- a/c10/core/Contiguity.h +++ b/c10/core/Contiguity.h @@ -12,24 +12,49 @@ namespace c10 { template bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { - bool is_contiguous = true; if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) { - return is_contiguous; + return true; } - T z = 1; + + T expected_stride = 1; // NB: make sure we do signed arithmetic for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { const auto& size_d = sizes[d]; - if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(strides[d], z))) { - z *= size_d; - } else { - is_contiguous = false; - break; - } + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) { + continue; } + + if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) { + return false; + } + expected_stride *= size_d; } - return is_contiguous; + return true; +} + +// This function will return True if the tensor is contiguous, and False if the +// its not or if we can't determine if it is contiguous due to unbacked symbols +// (it could be either in that case based on the actual runtime data). +template +bool definitely_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { + if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) { + return true; + } + + T expected_stride = 1; + // NB: make sure we do signed arithmetic + for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { + const auto& size_d = sizes[d]; + if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) { + continue; + } + + if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) { + return false; + } + expected_stride *= size_d; + } + return true; } template diff --git a/test/export/test_export.py b/test/export/test_export.py index 7db8fc5b349..02d9052a951 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2647,7 +2647,7 @@ graph(): with self.assertRaisesRegex( ValueError, r"Received user-specified .* \[None, 5\], conflicting with the inferred .*" - r"\[6, int_oo\],.* for inputs\['xs'\]\['data'\]\[0\]\[0\]\.shape\[0\]", + r"\[8, int_oo\],.* for inputs\['xs'\]\['data'\]\[0\]\[0\]\.shape\[0\]", ): export(Foo(), ({"data": [[x, y]]},), dynamic_shapes=shapes) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 9b92d556d57..a7ca37e2286 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3281,6 +3281,39 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", self.assertEqual(result_compiled, result_eager) self.assertEqual(cnt.frame_count, 1) + # Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride. + log_stream, ctx = logs_to_string( + "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + ) + with ctx(): + # This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3]. + # but not anymore since we use definitely_contiguous . + # We need a way to mark strides unbacked to avoid the recompilation here. + x = torch.randn(10, 10) + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(x, 1) + + aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() + self.assertExpectedInline( + aot_graphs, + """""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + result_compiled = compiled_func(x, torch.tensor([2, 50])) + result_eager = func(x, torch.tensor([2, 50])) + + self.assertEqual(result_compiled, result_eager) + self.assertEqual(cnt.frame_count, 2) + + x = torch.randn(4, 4) + + result_eager = func(x, torch.tensor([2, 8])) + result_compiled = compiled_func(x, torch.tensor([2, 8])) + self.assertEqual(result_compiled, result_eager) + self.assertEqual(cnt.frame_count, 2) + @unittest.skip("this test fails due to inductor/autograd issue #153041") @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_non_contigious_reshape_failing(self): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index c624c3e03f7..4704c9992d5 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1370,8 +1370,8 @@ def forward(self, crop_camera_1, mask_1): view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None - mul_4 = sym_size_int * 3 - view_3 = torch.ops.aten.view.default(view_2, [mul_4, 3]); view_2 = mul_4 = None + mul_6 = sym_size_int * 3 + view_3 = torch.ops.aten.view.default(view_2, [mul_6, 3]); view_2 = mul_6 = None mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None _unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 04c1a010fae..8f4dbda0fda 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -693,9 +693,10 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): ) ) else: - return_getters.append( - operator.itemgetter(add_range(current_group, size)) - ) + if current_group < len(remaining): + return_getters.append( + operator.itemgetter(add_range(current_group, size)) + ) return_getters_groups.append(return_getters) assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), ( diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 7b0c11488a3..d853a834f86 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -259,47 +259,64 @@ def check_all_strides( # This function is equivalent to compute_contiguous() from TensorImpl.cpp -def is_contiguous(a: TensorLikeType) -> bool: +def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool: """ Tests whether a tensor is contiguous or not. Tensors are contiguous when they have no elements, one element, or when they have "nested" strides. """ - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + guard_size_oblivious, + ) - if guard_size_oblivious(a.numel() < 2): + maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious + maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious + + if maybe_guard_or_false(a.numel() < 2): return True expected_stride = 1 for x, y in reversed(tuple(zip(a.shape, a.stride()))): # Skips checking strides when a dimension has length 1 - if guard_size_oblivious(x == 1): + if maybe_guard_or_false(x == 1): continue - if guard_size_oblivious(y != expected_stride): + if maybe_guard_or_true(y != expected_stride): return False - expected_stride = expected_stride * x + + # if x is 0 then a is contiguous anyway. So in the check above for non-contiguity condition we can + # can assume x is not 0 in expected_stride equation. This is also consistent with make_contiguous_strides_for. + expected_stride = expected_stride * sym_max(x, 1) return True # This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp -def is_channels_last_contiguous_2d(a: Tensor) -> bool: +def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool: # NHWC or not channels last 2D contiguous if a.ndim != 4: return False - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + guard_size_oblivious, + ) + + maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious + maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious expected_stride = 1 for idx in (1, 3, 2, 0): length = a.shape[idx] - if guard_size_oblivious(length == 1): + if maybe_guard_or_false(length == 1): continue stride = a.stride()[idx] - if guard_size_oblivious(stride != expected_stride): + if maybe_guard_or_true(stride != expected_stride): return False expected_stride *= length @@ -307,21 +324,28 @@ def is_channels_last_contiguous_2d(a: Tensor) -> bool: return True -def is_channels_last_contiguous_3d(a: Tensor) -> bool: +def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool: # NDHWC or not channels last 3D contiguous if a.ndim != 5: return False - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + guard_size_oblivious, + ) + + maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious + maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious expected_stride = 1 for idx in (1, 4, 3, 2, 0): length = a.shape[idx] - if guard_size_oblivious(length == 1): + if maybe_guard_or_false(length == 1): continue stride = a.stride()[idx] - if guard_size_oblivious(stride != expected_stride): + if maybe_guard_or_true(stride != expected_stride): return False expected_stride *= length @@ -345,16 +369,16 @@ def validate_memory_format(memory_format: torch.memory_format): def is_contiguous_for_memory_format( # type: ignore[return] - a: Tensor, *, memory_format: torch.memory_format + a: Tensor, *, memory_format: torch.memory_format, false_if_dde=False ) -> bool: validate_memory_format(memory_format) if memory_format == torch.contiguous_format: - return is_contiguous(a) + return is_contiguous(a, false_if_dde) if memory_format == torch.channels_last: - return is_channels_last_contiguous_2d(a) + return is_channels_last_contiguous_2d(a, false_if_dde) if memory_format == torch.channels_last_3d: - return is_channels_last_contiguous_3d(a) + return is_channels_last_contiguous_3d(a, false_if_dde) torch._check( False, @@ -362,6 +386,29 @@ def is_contiguous_for_memory_format( # type: ignore[return] ) +def definitely_contiguous(a: TensorLikeType) -> bool: + return is_contiguous(a, false_if_dde=True) + + +# similar to is_channels_last_contiguous_2d but return false on data dependency. +def is_known_channels_last_contiguous_2d(a: Tensor) -> bool: + return is_channels_last_contiguous_2d(a, false_if_dde=True) + + +# similar to is_channels_last_contiguous_3d but return false on data dependency. +def is_known_channels_last_contiguous_3d(a: Tensor) -> bool: + return is_channels_last_contiguous_3d(a, false_if_dde=True) + + +# similar to is_contiguous_for_memory_format but return false on data dependency. +def definitely_contiguous_for_memory_format( # type: ignore[return] + a: Tensor, *, memory_format: torch.memory_format +) -> bool: + return is_contiguous_for_memory_format( + a, memory_format=memory_format, false_if_dde=True + ) + + # NOTE: that tensors with no elements and channels last is ??? def is_channels_last_contiguous(a: Tensor) -> bool: """ @@ -379,6 +426,13 @@ def is_channels_last_contiguous(a: Tensor) -> bool: return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a) +# similar to is_channels_last_contiguous but return false on data dependency. +def is_known_channels_last_contiguous(a: Tensor) -> bool: + return is_known_channels_last_contiguous_2d( + a + ) or is_known_channels_last_contiguous_3d(a) + + def is_non_overlapping_and_dense(a: Tensor) -> bool: """ True when a tensor is non-overlapping and dense. diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index e3783286807..e128a3b5f81 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -19,6 +19,7 @@ import torch.utils._pytree as pytree from torch import sym_float, sym_int from torch._prims_common import ( BoolLike, + definitely_contiguous, DeviceLikeType, Dim, DimsSequenceType, @@ -3824,7 +3825,7 @@ def _view_simple(a: TensorLikeType, shape, data_dependent_error) -> TensorLikeTy if new_strides is not None: return a.as_strided(shape, new_strides) - if a.is_contiguous(): + if definitely_contiguous(a): return a.as_strided(shape, utils.make_contiguous_strides_for(shape)) raise data_dependent_error diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py index 1a88b73bba1..05fb3b5dbaf 100644 --- a/torch/fx/passes/shape_prop.py +++ b/torch/fx/passes/shape_prop.py @@ -7,6 +7,7 @@ import torch import torch.fx from torch._dispatch.python import enable_python_dispatcher from torch._guards import detect_fake_mode +from torch._prims_common import definitely_contiguous_for_memory_format from torch._subclasses.meta_utils import is_sparse_any from torch.fx._compatibility import compatibility from torch.fx.node import map_aggregate, Node @@ -32,6 +33,10 @@ class TensorMetadata(NamedTuple): qparams: dict[str, Any] +# When include_contiguity is True, we will set contiguity when its always true for the tensor. +# Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3). +# In such situation contiguity is not set. We could also make it a tri-state i.e: (definitely_contiguous, +# contiguous, and unknown). def _extract_tensor_metadata( result: torch.Tensor, include_contiguity=True ) -> TensorMetadata: @@ -52,7 +57,9 @@ def _extract_tensor_metadata( torch.channels_last_3d, } for query_format in memory_formats: - if result.is_contiguous(memory_format=query_format): + if definitely_contiguous_for_memory_format( + result, memory_format=query_format + ): memory_format = query_format break