From 7cfd054075b0713642fab8dc4313fc1e5d992599 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Wed, 2 Jul 2025 23:12:29 +0000 Subject: [PATCH] [attempt 2] Compute contiguity symbolically to avoid dde, and introduce c++ sym_is_contiguous (#157472) Summary: When we compute contiguity for a tensor with dynamic shapes we first: 1) Try to compute it without guarding. 2) If all shapes hinted, compute it with potentially adding guards. 3) if any input is not hinted, compute it symbolically. sym_is_contiguous return a SymBool that is then either evaluated or guard_or_false can be called on it to avoid data dependent errors. ex: bool is_contiguous = input.sym_is_contiguous().guard_or_false(__FILE__, __LINE__); is_contiguous_or_false is a helper function that does that. In this PR I only handle default contiguity, will follow up with changes for other formats like channel_last . We use this patter in this PR for several locations to avoid DDEs. Test Plan: contbuild & OSS CI, Rollback Plan: Reviewed By: malfet Differential Revision: D77639021 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157472 Approved by: https://github.com/aorenste --- .github/ci_commit_pins/xla.txt | 2 +- aten/src/ATen/FunctionalTensorWrapper.cpp | 4 +- aten/src/ATen/FunctionalTensorWrapper.h | 3 +- aten/src/ATen/FunctionalizeFallbackKernel.cpp | 11 +- aten/src/ATen/LegacyBatchedTensorImpl.cpp | 2 +- aten/src/ATen/LegacyBatchedTensorImpl.h | 3 +- aten/src/ATen/MemoryOverlap.cpp | 4 +- aten/src/ATen/NestedTensorImpl.cpp | 2 +- aten/src/ATen/NestedTensorImpl.h | 2 +- aten/src/ATen/SparseCsrTensorImpl.cpp | 3 +- aten/src/ATen/SparseCsrTensorImpl.h | 2 +- aten/src/ATen/core/TensorBase.h | 21 +++- aten/src/ATen/functorch/BatchedTensorImpl.cpp | 2 +- aten/src/ATen/functorch/BatchedTensorImpl.h | 2 +- aten/src/ATen/native/Linear.cpp | 2 +- aten/src/ATen/native/TensorProperties.cpp | 2 +- aten/src/ATen/native/TensorShape.cpp | 9 +- aten/src/ATen/native/metal/MetalTensorImpl.h | 2 +- .../ATen/native/transformers/attention.cpp | 2 +- .../native/vulkan/VulkanOpaqueTensorImpl.h | 3 +- c10/core/Contiguity.h | 77 +++++++++--- c10/core/SymbolicShapeMeta.cpp | 66 ++++++++-- c10/core/SymbolicShapeMeta.h | 11 ++ c10/core/TensorImpl.cpp | 10 +- c10/core/TensorImpl.h | 118 ++++++++++++------ c10/core/UndefinedTensorImpl.cpp | 3 +- c10/core/UndefinedTensorImpl.h | 2 +- test/export/test_export.py | 42 +++++++ test/test_dynamic_shapes.py | 73 ++++++++++- test/test_proxy_tensor.py | 4 +- .../templates/python_variable_methods.cpp | 2 +- torch/csrc/lazy/core/tensor_impl.cpp | 5 +- torch/csrc/lazy/core/tensor_impl.h | 5 +- torch/utils/_sympy/printers.py | 3 + 34 files changed, 390 insertions(+), 114 deletions(-) diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 4ee628b53e2..dfbc78d8884 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -55a75404c9b75cd5fd62ab5d4deafc8c506b3af2 +926700d7832caa552ba2e1fc8302f6a2f4d2f6d8 diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index 0c5df695fa9..c3ca626bc9e 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -499,8 +499,8 @@ int64_t FunctionalTensorWrapper::dim_custom() const { int64_t FunctionalTensorWrapper::numel_custom() const { return value_.unsafeGetTensorImpl()->numel(); } -bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const { - return value_.unsafeGetTensorImpl()->is_contiguous(memory_format); +c10::SymBool FunctionalTensorWrapper::sym_is_contiguous_custom(at::MemoryFormat memory_format) const { + return value_.unsafeGetTensorImpl()->sym_is_contiguous(memory_format); } c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const { return value_.unsafeGetTensorImpl()->sym_sizes(); diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index 3dcc82bdd04..bec2d463196 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -236,7 +236,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { at::IntArrayRef strides_custom() const override; int64_t dim_custom() const override; int64_t numel_custom() const override; - bool is_contiguous_custom(at::MemoryFormat memory_format) const override; + c10::SymBool sym_is_contiguous_custom( + at::MemoryFormat memory_format) const override; c10::SymIntArrayRef sym_sizes_custom() const override; c10::SymInt sym_size_custom(int64_t d) const override; c10::SymIntArrayRef sym_strides_custom() const override; diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.cpp b/aten/src/ATen/FunctionalizeFallbackKernel.cpp index bc2170b7ba0..97094c9f125 100644 --- a/aten/src/ATen/FunctionalizeFallbackKernel.cpp +++ b/aten/src/ATen/FunctionalizeFallbackKernel.cpp @@ -320,11 +320,9 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size); if (!stride.has_value()) { - // With unbacked symints, computeStride could fail even on contiguous - // tensors. In this case, we can use the strides of an empty tensor of - // inferred_size. - TORCH_CHECK( - self.is_contiguous(), + + TORCH_SYM_CHECK( + self.sym_is_contiguous(), "View is not valid from size:", self.sym_sizes(), " stride: ", @@ -333,6 +331,9 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt inferred_size, " in case of unbacked symbols consider adding torch.check to guide computing strides."); + // With unbacked symints, computeStride could fail even on contiguous + // tensors. In this case, we can use the strides of an empty tensor of + // inferred_size. stride = at::detail::empty_symint_meta( inferred_size, std::nullopt, diff --git a/aten/src/ATen/LegacyBatchedTensorImpl.cpp b/aten/src/ATen/LegacyBatchedTensorImpl.cpp index 12c562f5d8e..cceefe985a7 100644 --- a/aten/src/ATen/LegacyBatchedTensorImpl.cpp +++ b/aten/src/ATen/LegacyBatchedTensorImpl.cpp @@ -84,7 +84,7 @@ IntArrayRef BatchedTensorImpl::strides_custom() const { // TODO: implement proper contiguity on batched tensor, then put // sizes_strides_policy back to Default -bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { +c10::SymBool BatchedTensorImpl::sym_is_contiguous_custom(at::MemoryFormat memory_format) const { TORCH_CHECK(memory_format == MemoryFormat::Contiguous, "NYI: querying is_contiguous inside of vmap for memory_format ", "other than torch.contiguous_format"); diff --git a/aten/src/ATen/LegacyBatchedTensorImpl.h b/aten/src/ATen/LegacyBatchedTensorImpl.h index fa6c472e1fa..798e3535af3 100644 --- a/aten/src/ATen/LegacyBatchedTensorImpl.h +++ b/aten/src/ATen/LegacyBatchedTensorImpl.h @@ -82,7 +82,8 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { IntArrayRef strides_custom() const override; // Override a bunch of methods inherited from TensorImpl to return error // messages. - bool is_contiguous_custom(at::MemoryFormat memory_format) const override; + c10::SymBool sym_is_contiguous_custom( + at::MemoryFormat memory_format) const override; void set_size(int64_t dim, int64_t new_size) override; void set_stride(int64_t dim, int64_t new_stride) override; void set_storage_offset(int64_t storage_offset) override; diff --git a/aten/src/ATen/MemoryOverlap.cpp b/aten/src/ATen/MemoryOverlap.cpp index 5e165d3064a..1bc8c30158a 100644 --- a/aten/src/ATen/MemoryOverlap.cpp +++ b/aten/src/ATen/MemoryOverlap.cpp @@ -24,7 +24,7 @@ MemOverlap has_internal_overlap(TensorImpl* t) { } } - if (t->is_non_overlapping_and_dense()) { + if (t->is_non_overlapping_and_dense_or_false()) { return MemOverlap::No; } @@ -63,7 +63,7 @@ MemOverlapStatus get_overlap_status(const TensorImpl* a, const TensorImpl* b) { if (a->numel() == 0 || b->numel() == 0) { return MemOverlapStatus::No; } - if (!a->is_non_overlapping_and_dense() || !b->is_non_overlapping_and_dense()) { + if (!a->is_non_overlapping_and_dense_or_false() || !b->is_non_overlapping_and_dense_or_false()) { return MemOverlapStatus::TooHard; } // Test for storage equality, rather than pointer equality. diff --git a/aten/src/ATen/NestedTensorImpl.cpp b/aten/src/ATen/NestedTensorImpl.cpp index b35a7ef9401..647b2f1685d 100644 --- a/aten/src/ATen/NestedTensorImpl.cpp +++ b/aten/src/ATen/NestedTensorImpl.cpp @@ -273,7 +273,7 @@ c10::SymInt NestedTensorImpl::sym_numel_custom() const { return NestedTensorImpl::numel_custom(); } -bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const { +c10::SymBool NestedTensorImpl::sym_is_contiguous_custom(MemoryFormat) const { return nested_tensor_impl_is_contiguous(this); } IntArrayRef NestedTensorImpl::sizes_custom() const { diff --git a/aten/src/ATen/NestedTensorImpl.h b/aten/src/ATen/NestedTensorImpl.h index 697969edbbd..f40684ce0ba 100644 --- a/aten/src/ATen/NestedTensorImpl.h +++ b/aten/src/ATen/NestedTensorImpl.h @@ -115,7 +115,7 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl { // with real implementations int64_t numel_custom() const override; c10::SymInt sym_numel_custom() const override; - bool is_contiguous_custom(MemoryFormat) const override; + c10::SymBool sym_is_contiguous_custom(MemoryFormat) const override; int64_t size_custom(int64_t d) const override { return this->size(d); } diff --git a/aten/src/ATen/SparseCsrTensorImpl.cpp b/aten/src/ATen/SparseCsrTensorImpl.cpp index 0ec3c97a2da..f73d75ab53a 100644 --- a/aten/src/ATen/SparseCsrTensorImpl.cpp +++ b/aten/src/ATen/SparseCsrTensorImpl.cpp @@ -252,8 +252,7 @@ void SparseCsrTensorImpl::set_stride(int64_t dim, int64_t new_stride) { void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) { TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset."); } -bool SparseCsrTensorImpl::is_contiguous_custom(MemoryFormat) const { +c10::SymBool SparseCsrTensorImpl::sym_is_contiguous_custom(MemoryFormat) const { TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous"); } - } // namespace at diff --git a/aten/src/ATen/SparseCsrTensorImpl.h b/aten/src/ATen/SparseCsrTensorImpl.h index 94ac1e1c393..14688163a37 100644 --- a/aten/src/ATen/SparseCsrTensorImpl.h +++ b/aten/src/ATen/SparseCsrTensorImpl.h @@ -86,7 +86,7 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl { protected: IntArrayRef strides_custom() const override; SymIntArrayRef sym_strides_custom() const override; - bool is_contiguous_custom(MemoryFormat) const override; + SymBool sym_is_contiguous_custom(MemoryFormat) const override; public: void set_size(int64_t dim, int64_t new_size) override; diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 3dc2ec66f81..8463379149e 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -124,7 +124,7 @@ class TORCH_API TensorBase { } TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { - if (is_contiguous(memory_format)) { + if (is_contiguous_or_false(memory_format)) { return *this; } else { return __dispatch_contiguous(memory_format); @@ -265,6 +265,25 @@ class TORCH_API TensorBase { return impl_->is_contiguous(memory_format); } + // Like is_contiguous, but more dynamic shape-friendly. May return a symbolic representation of + // contiguity instead of SymTrue SymFalse, when results are data-dependent. + c10::SymBool sym_is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { + if (impl_->has_symbolic_sizes_strides()) { + return impl_->sym_is_contiguous(memory_format); + } + return impl_->is_contiguous(memory_format); + } + + // Like is_contiguous, but more dynamic shape-friendly. Can returns + // false instead of throwing data-dependent errors for tensors with unbacked + // sizes or strides. + bool is_contiguous_or_false(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { + if (impl_->has_symbolic_sizes_strides()) { + return impl_->sym_is_contiguous(memory_format).guard_or_false(__FILE__, __LINE__); + } + return impl_->is_contiguous(memory_format); + } + bool is_non_overlapping_and_dense() const { return impl_->is_non_overlapping_and_dense(); } diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.cpp b/aten/src/ATen/functorch/BatchedTensorImpl.cpp index c3c85144565..ee222b4e61a 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.cpp +++ b/aten/src/ATen/functorch/BatchedTensorImpl.cpp @@ -126,7 +126,7 @@ SymIntArrayRef BatchedTensorImpl::sym_strides_custom() const { // TODO: implement proper contiguity on batched tensor, then put // sizes_strides_policy back to Default -bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { +c10::SymBool BatchedTensorImpl::sym_is_contiguous_custom(at::MemoryFormat memory_format) const { TORCH_CHECK(memory_format == MemoryFormat::Contiguous, "NYI: querying is_contiguous inside of vmap for memory_format ", "other than torch.contiguous_format"); diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.h b/aten/src/ATen/functorch/BatchedTensorImpl.h index ce3d2900841..3eccc94d3ea 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.h +++ b/aten/src/ATen/functorch/BatchedTensorImpl.h @@ -69,7 +69,7 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { IntArrayRef strides_custom() const override; SymIntArrayRef sym_strides_custom() const override; // Override a bunch of methods inherited from TensorImpl to return error messages. - bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override; + c10::SymBool sym_is_contiguous_custom(at::MemoryFormat memory_format) const override; void set_size(int64_t dim, int64_t new_size) override; void set_stride(int64_t dim, int64_t new_stride) override; c10::intrusive_ptr shallow_copy_and_detach( diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index 121d38b55d5..5d3a84ea39f 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -93,7 +93,7 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optionaldefined() && !input.is_xla()) { // Also hit the fused path for contiguous 3D input, if not using xla // backend. Reshaping/flattening has some performance implications on xla. - bool is_contiguous = definitely_contiguous(input.sym_sizes(), input.sym_strides(), input.sym_numel()); + bool is_contiguous = input.is_contiguous_or_false(); if (is_contiguous && input_dim == 3) { return _flatten_nd_linear(input, weight, *bias); } else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) { diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index 5a4d55e0e3c..77acfe47363 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -113,7 +113,7 @@ Tensor& detach_(Tensor& self) { } Tensor contiguous(const Tensor& self, MemoryFormat memory_format) { - if (self.is_contiguous(memory_format)) { + if (self.is_contiguous_or_false(memory_format)) { return self; } TORCH_CHECK( diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 921c7d9c744..958c80f2c7f 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1998,19 +1998,18 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { TORCH_CHECK(false, "reshape is not implemented for sparse tensors"); } - auto sym_sizes = self.sym_sizes(); - auto sym_strides = self.sym_strides(); - auto sym_numel = self.sym_numel(); - if (definitely_contiguous(sym_sizes, sym_strides, sym_numel) && - !self.is_mkldnn()) { + if (self.is_contiguous_or_false() && !self.is_mkldnn()) { return self.view_symint(proposed_shape); } + auto sym_numel = self.sym_numel(); c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel); if (self.is_mkldnn()) { return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape)); } + auto sym_sizes = self.sym_sizes(); + auto sym_strides = self.sym_strides(); // `computeStride` returns the proper strides to use if this // `reshape` can be just a view. diff --git a/aten/src/ATen/native/metal/MetalTensorImpl.h b/aten/src/ATen/native/metal/MetalTensorImpl.h index 2fb87b2f4f8..44152dd3c6d 100644 --- a/aten/src/ATen/native/metal/MetalTensorImpl.h +++ b/aten/src/ATen/native/metal/MetalTensorImpl.h @@ -35,7 +35,7 @@ struct TORCH_API MetalTensorImpl : public OpaqueTensorImpl { return c10::fromIntArrayRefKnownNonNegative(strides_); } - bool is_contiguous_custom(c10::MemoryFormat memory_format) const override { + c10::SymBool sym_is_contiguous_custom(c10::MemoryFormat memory_format) const override { return true; } diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 332d4a2ebfe..8647a199ad8 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -776,7 +776,7 @@ Tensor scaled_dot_product_attention( #ifdef USE_MPS const auto any_nested = query_.is_nested() || key.is_nested() || value.is_nested(); const bool any_inputs_require_grad = query_.requires_grad() || key.requires_grad() || value.requires_grad(); - const auto all_contiguous = query_.is_contiguous() && key.is_contiguous() && value.is_contiguous(); + const auto all_contiguous = query_.is_contiguous_or_false() && key.is_contiguous_or_false() && value.is_contiguous_or_false(); if (query_device_type == DeviceType::MPS && dropout_p == 0.0 && !(GradMode::is_enabled() && any_inputs_require_grad) && (all_contiguous || mps::is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_15_0_PLUS)) diff --git a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h index 04823c592cc..532caa62687 100644 --- a/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h +++ b/aten/src/ATen/native/vulkan/VulkanOpaqueTensorImpl.h @@ -33,7 +33,8 @@ struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl { return c10::fromIntArrayRefKnownNonNegative(strides_); } - bool is_contiguous_custom(c10::MemoryFormat memory_format) const override { + c10::SymBool sym_is_contiguous_custom( + c10::MemoryFormat memory_format) const override { (void)memory_format; return true; } diff --git a/c10/core/Contiguity.h b/c10/core/Contiguity.h index 276d2ce07b5..279a795583b 100644 --- a/c10/core/Contiguity.h +++ b/c10/core/Contiguity.h @@ -12,7 +12,7 @@ namespace c10 { template bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { - if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) { + if (numel == 0) { return true; } @@ -20,11 +20,11 @@ bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { // NB: make sure we do signed arithmetic for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { const auto& size_d = sizes[d]; - if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) { + if (size_d == 1) { continue; } - if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) { + if (strides[d] != expected_stride) { return false; } expected_stride *= size_d; @@ -32,29 +32,66 @@ bool _compute_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { return true; } -// This function will return True if the tensor is contiguous, and False if the -// its not or if we can't determine if it is contiguous due to unbacked symbols -// (it could be either in that case based on the actual runtime data). -template -bool definitely_contiguous(ArrayRef sizes, ArrayRef strides, T numel) { - if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) { +// Return a SymBool with underlying symbolic expression that represents +// contiguity. Guaranteed not to add guards. +inline static c10::SymBool _compute_contiguous_sym( + ArrayRef sizes, + ArrayRef strides, + const c10::SymInt& numel) { + // If this return true, the tensor is contiguous indeed. Otherwise it could be + // either. + auto is_contiguous_or_false = [&]() { + if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) { + return true; + } + + // When calculating the expected stride, we can choose to multiply + // with max(1, size[d]) or size[d]. Regardless, this is ok for this + // function. Why? + // (1) If size[d] == 0, then the tensor is contiguous and if + // we return true or false it won't break this function. + // (2) If size[d] is not 0, then max(1,size[d]) and size[d] are equal. + // Therefore, if we choose to use max(1, size[d]) or size[d] to + // calculate the expected stride, the result is the same. + // + // We symbolically check both paths to maximize the cases where this + // function returns true. This is because make_contiguous_strides_for adds + // the max symbolically, and in some other situations the max might not be + // there. And we want to ensure we return true in both cases. + c10::SymInt expected_stride = 1; + c10::SymInt expected_stride_max = 1; + // NB: make sure we do signed arithmetic + for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { + if (TORCH_GUARD_OR_FALSE(sym_eq(sizes[d], 1))) { + continue; + } + + if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride)) && + TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride_max))) { + return false; + } + expected_stride_max *= sizes[d].max(1); + expected_stride *= sizes[d]; + } return true; + }; + + if (is_contiguous_or_false()) { + return c10::SymBool(true); } - T expected_stride = 1; - // NB: make sure we do signed arithmetic + // Build a single expression that represents contiguity and return it. + c10::SymBool is_empty = sym_eq(numel, 0); + c10::SymBool is_contiguous_cond = true; + + c10::SymInt expected_stride = 1; for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { const auto& size_d = sizes[d]; - if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) { - continue; - } - - if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) { - return false; - } - expected_stride *= size_d; + is_contiguous_cond = is_contiguous_cond.sym_and( + size_d.sym_eq(1).sym_or(sym_eq(strides[d], expected_stride))); + expected_stride = expected_stride * size_d; } - return true; + return is_contiguous_cond.sym_or(is_empty); } template diff --git a/c10/core/SymbolicShapeMeta.cpp b/c10/core/SymbolicShapeMeta.cpp index 3becf927cd5..6fa2ab0ed4f 100644 --- a/c10/core/SymbolicShapeMeta.cpp +++ b/c10/core/SymbolicShapeMeta.cpp @@ -79,18 +79,51 @@ SymBool SymbolicShapeMeta::compute_contiguous() const { } c10::SymIntArrayRef sizes(sizes_); c10::SymIntArrayRef strides(strides_); - return _compute_contiguous(sizes, strides, numel()); + + auto result = _compute_contiguous_sym(sizes, strides, numel()); + + // If the result is already determined without guarding, just return it. + auto maybe_as_bool = result.maybe_as_bool(); + if (maybe_as_bool.has_value()) { + return maybe_as_bool.value(); + } + + auto all_hinted = true; + for (const auto& s : sizes) { + if (!s.has_hint()) { + all_hinted = false; + break; + } + } + + if (all_hinted) { + for (const auto& s : strides) { + if (!s.has_hint()) { + all_hinted = false; + break; + } + } + } + + if (all_hinted) { + // We avoid going through the slow path if everything is hinted, + // because evaluating a large SymPy expression can be expensive. + // TODO exclude backed_size_oblivious from this path. + return _compute_contiguous(sizes_, strides_, numel()); + } + + return result; } // The rest of them -#define DEFINE_EAGER_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \ - SymBool SymbolicShapeMeta::name() const { \ - if (!strides_valid_) { \ - return false; \ - } \ - c10::SymIntArrayRef sizes(sizes_); \ - c10::SymIntArrayRef strides(strides_); \ - return fallback(sizes, strides); \ +#define DEFINE_EAGER_SYMBOOL_COMPUTE(name, fallback) \ + SymBool SymbolicShapeMeta::name() const { \ + if (!strides_valid_) { \ + return false; \ + } \ + c10::SymIntArrayRef sizes(sizes_); \ + c10::SymIntArrayRef strides(strides_); \ + return fallback(sizes, strides); \ } #define DEFINE_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \ @@ -110,11 +143,13 @@ SymBool SymbolicShapeMeta::compute_contiguous() const { } // clang-format off -DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, is_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d) -DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, is_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d) -DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d, is_channels_last_strides_2d) -DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d, is_channels_last_strides_3d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d) +DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d) + DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense) + // clang-format on #undef DEFINE_SYMBOOL_COMPUTE @@ -192,6 +227,7 @@ void SymbolicShapeMeta::set_numel(SymInt val) const { numel_ = std::move(val); available_.fetch_or(numel_avail); } + void SymbolicShapeMeta::set_is_contiguous(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_contiguous()) { @@ -200,6 +236,7 @@ void SymbolicShapeMeta::set_is_contiguous(SymBool val) const { is_contiguous_ = std::move(val); available_.fetch_or(is_contiguous_avail); } + void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last_contiguous()) { @@ -208,6 +245,7 @@ void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const { is_channels_last_contiguous_ = std::move(val); available_.fetch_or(is_channels_last_contiguous_avail); } + void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last_3d_contiguous()) { @@ -216,6 +254,7 @@ void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const { is_channels_last_3d_contiguous_ = std::move(val); available_.fetch_or(is_channels_last_3d_contiguous_avail); } + void SymbolicShapeMeta::set_is_channels_last(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last()) { @@ -224,6 +263,7 @@ void SymbolicShapeMeta::set_is_channels_last(SymBool val) const { is_channels_last_ = std::move(val); available_.fetch_or(is_channels_last_avail); } + void SymbolicShapeMeta::set_is_channels_last_3d(SymBool val) const { std::scoped_lock lock(mutables_); if (has_is_channels_last_3d()) { diff --git a/c10/core/SymbolicShapeMeta.h b/c10/core/SymbolicShapeMeta.h index ce0769a8074..0820038968a 100644 --- a/c10/core/SymbolicShapeMeta.h +++ b/c10/core/SymbolicShapeMeta.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include @@ -82,6 +83,15 @@ class C10_API SymbolicShapeMeta { return numel_; } + const SymBool& is_contiguous(at::MemoryFormat memory_format) const { + if (memory_format == at::MemoryFormat::ChannelsLast) { + return this->is_channels_last_contiguous(); + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return this->is_channels_last_3d_contiguous(); + } + return this->is_contiguous(); + } + const SymBool& is_contiguous() const { if (C10_UNLIKELY(!has_is_contiguous())) { init_is_contiguous(); @@ -194,6 +204,7 @@ class C10_API SymbolicShapeMeta { // Lazily initialized variables, with the corresponding available_ flag // indicating whether the value has been initialized mutable std::atomic available_{0}; + enum avail { numel_avail = 1 << 0, is_contiguous_avail = 1 << 1, diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index f1c9eafb179..f3ec2f2d46e 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -310,12 +310,14 @@ void TensorImpl::throw_data_ptr_access_error() const { false, "Cannot access data pointer of Tensor that doesn't have storage"); } -bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { +c10::SymBool TensorImpl::sym_is_contiguous_custom( + at::MemoryFormat memory_format) const { if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { return pyobj_slot_.load_pyobj_interpreter()->is_contiguous( this, memory_format); } - return is_contiguous_default(memory_format); + + return sym_is_contiguous_default(memory_format); } bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const { @@ -326,12 +328,12 @@ bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const { return is_strides_like_default(memory_format); } -bool TensorImpl::is_non_overlapping_and_dense_custom() const { +c10::SymBool TensorImpl::sym_is_non_overlapping_and_dense_custom() const { if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { return pyobj_slot_.load_pyobj_interpreter()->is_non_overlapping_and_dense( this); } - return is_non_overlapping_and_dense_default(); + return sym_is_non_overlapping_and_dense_default(); } IntArrayRef TensorImpl::sizes_custom() const { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index d903bfa4e68..381bc65b27f 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -812,6 +812,43 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } } + c10::SymBool sym_is_contiguous( + at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return sym_is_contiguous_custom(memory_format); + } + return sym_is_contiguous_default(memory_format); + } + + template + T is_contiguous_default_impl(at::MemoryFormat memory_format) const { + if (!has_symbolic_sizes_strides_) { + if (memory_format == at::MemoryFormat::ChannelsLast) { + return is_channels_last_contiguous_; + } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { + return is_channels_last_3d_contiguous_; + } + return is_contiguous_; + } + + // Handle dynamic shapes. + const auto& symbolic = symbolic_shape_meta().is_contiguous(memory_format); + + if constexpr (std::is_same_v) { + return symbolic.guard_bool(__FILE__, __LINE__); + } else { + return symbolic; + } + } + + bool is_contiguous_default(at::MemoryFormat memory_format) const { + return is_contiguous_default_impl(memory_format); + } + + c10::SymBool sym_is_contiguous_default(at::MemoryFormat memory_format) const { + return is_contiguous_default_impl(memory_format); + } + /** * Whether or not a tensor is laid out in contiguous memory. * @@ -827,30 +864,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_contiguous_default(memory_format); } - // These are factored into separate functions in case subclasses - // want to use them - bool is_contiguous_default(at::MemoryFormat memory_format) const { - if (has_symbolic_sizes_strides_) { - if (memory_format == at::MemoryFormat::ChannelsLast) { - return symbolic_shape_meta().is_channels_last_contiguous().guard_bool( - __FILE__, __LINE__); - } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { - return symbolic_shape_meta() - .is_channels_last_3d_contiguous() - .guard_bool(__FILE__, __LINE__); - } - return symbolic_shape_meta().is_contiguous().guard_bool( - __FILE__, __LINE__); - } - - if (memory_format == at::MemoryFormat::ChannelsLast) { - return is_channels_last_contiguous_; - } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { - return is_channels_last_3d_contiguous_; - } - return is_contiguous_; - } - bool is_strides_like_default(at::MemoryFormat memory_format) const { if (has_symbolic_sizes_strides_) { if (memory_format == at::MemoryFormat::ChannelsLast) { @@ -873,9 +886,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } } + SymBool sym_is_non_overlapping_and_dense_default() const { + if (has_symbolic_sizes_strides_) { + return symbolic_shape_meta().is_non_overlapping_and_dense(); + } else { + return is_non_overlapping_and_dense_; + } + } + bool is_non_overlapping_and_dense_default() const { if (has_symbolic_sizes_strides_) { - return symbolic_shape_meta().is_non_overlapping_and_dense().guard_bool( + return sym_is_non_overlapping_and_dense_default().guard_bool( __FILE__, __LINE__); } else { return is_non_overlapping_and_dense_; @@ -968,9 +989,24 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * for a tensor to have rank, but not well defined sizes. */ // sizes_strides_policy_ >= CustomStrides - virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const; + virtual bool is_strides_like_custom(at::MemoryFormat memory_format) const; - virtual bool is_non_overlapping_and_dense_custom() const; + + virtual c10::SymBool sym_is_non_overlapping_and_dense_custom() const; + + bool is_non_overlapping_and_dense_custom() const { + return sym_is_non_overlapping_and_dense_custom().guard_bool( + __FILE__, __LINE__); + } + + virtual c10::SymBool sym_is_contiguous_custom( + at::MemoryFormat memory_format) const; + + bool is_contiguous_custom(at::MemoryFormat memory_format) const { + return sym_is_contiguous_custom(memory_format) + .guard_bool(__FILE__, __LINE__); + } + // sizes_strides_policy_ >= CustomSizes // Currently this method only exists to be overwritten by subclasses such as // NestedTensorImpl. @@ -1004,9 +1040,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { virtual c10::SymInt sym_storage_offset_custom() const; public: - /** - * True if this tensor has storage. See storage() for details. - */ +/** + * True if this tensor has storage. See storage() for details. + */ #ifdef DEBUG // Allow subclasses to check that their storage_ is never getting set in debug // builds. @@ -1016,11 +1052,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { #endif bool has_storage() const - // NOTE: we devirtualize this because it arguably shouldn't be an - // error just to ask subclasses if they have storage. - // This used to throw for most subclasses, but OpaqueTensorImpl - // wanted it to successfully return false, so we went ahead and made - // it a non-error. +// NOTE: we devirtualize this because it arguably shouldn't be an +// error just to ask subclasses if they have storage. +// This used to throw for most subclasses, but OpaqueTensorImpl +// wanted it to successfully return false, so we went ahead and made +// it a non-error. #ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY { return storage_; @@ -2447,6 +2483,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_strides_like(at::MemoryFormat::ChannelsLast3d); } + bool is_non_overlapping_and_dense_or_false() const { + return sym_is_non_overlapping_and_dense().guard_or_false( + __FILE__, __LINE__); + } + bool is_non_overlapping_and_dense() const { if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { return is_non_overlapping_and_dense_custom(); @@ -2454,6 +2495,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return is_non_overlapping_and_dense_default(); } + SymBool sym_is_non_overlapping_and_dense() const { + if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { + return sym_is_non_overlapping_and_dense_custom(); + } + return sym_is_non_overlapping_and_dense_default(); + } + // if this returns true, then it is guaranteed that this tensor has symbolic // sizes/strides bool has_symbolic_sizes_strides() const { diff --git a/c10/core/UndefinedTensorImpl.cpp b/c10/core/UndefinedTensorImpl.cpp index d781ddf9e97..b42d3a92545 100644 --- a/c10/core/UndefinedTensorImpl.cpp +++ b/c10/core/UndefinedTensorImpl.cpp @@ -12,7 +12,8 @@ UndefinedTensorImpl::UndefinedTensorImpl() set_custom_sizes_strides(SizesStridesPolicy::CustomStrides); } -bool UndefinedTensorImpl::is_contiguous_custom(MemoryFormat format) const { +c10::SymBool UndefinedTensorImpl::sym_is_contiguous_custom( + MemoryFormat format) const { return is_contiguous_default(format); } IntArrayRef UndefinedTensorImpl::strides_custom() const { diff --git a/c10/core/UndefinedTensorImpl.h b/c10/core/UndefinedTensorImpl.h index 33ac4e7f868..6b7573a6938 100644 --- a/c10/core/UndefinedTensorImpl.h +++ b/c10/core/UndefinedTensorImpl.h @@ -32,7 +32,7 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl { void set_storage_offset(int64_t offset) override; protected: - bool is_contiguous_custom(MemoryFormat format) const override; + c10::SymBool sym_is_contiguous_custom(MemoryFormat format) const override; IntArrayRef strides_custom() const override; SymIntArrayRef sym_strides_custom() const override; diff --git a/test/export/test_export.py b/test/export/test_export.py index 2988f4e7c5e..33c432d3104 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -15467,6 +15467,48 @@ class TestExportCustomClass(TorchTestCase): MyModel(), inps, dynamic_shapes=spec, strict=True ).run_decompositions({}) + def test_unbacked_contiguous(self): + class MyModel(torch.nn.Module): + def forward(self, x, mask): + masked_select = x.masked_select(mask) + view = masked_select.view(-1, 1548) + contig = view.contiguous() + return contig + 1 + + example_inputs = ( + torch.randn((768, 1548), dtype=torch.bfloat16), + torch.randint(low=0, high=1, size=(768, 1), dtype=torch.bool), + ) + spec = { + "x": [Dim.STATIC, Dim.STATIC], + "mask": [Dim.STATIC, Dim.STATIC], + } + + traced = export(MyModel(), example_inputs, strict=True) + self.assertExpectedInline( + traced.graph_module.code, + """\ +def forward(self, x, mask): + masked_select = torch.ops.aten.masked_select.default(x, mask); x = mask = None + sym_size_int_1 = torch.ops.aten.sym_size.int(masked_select, 0) + sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default = None + ge = sym_size_int_1 >= 0 + _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None + le = sym_size_int_1 <= 1188864 + _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 1188864 on node 'le'"); le = _assert_scalar_default_1 = None + mod = sym_size_int_1 % 1548 + eq_2 = mod == 0; mod = None + _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(Mod(u0, 1548), 0) on node 'eq_2'"); eq_2 = _assert_scalar_default_2 = None + floordiv = sym_size_int_1 // 1548 + mul_2 = 1548 * floordiv; floordiv = None + eq_3 = sym_size_int_1 == mul_2; sym_size_int_1 = mul_2 = None + _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(eq_3, "Runtime assertion failed for expression Eq(u0, 1548*((u0//1548))) on node 'eq_3'"); eq_3 = _assert_scalar_default_3 = None + view = torch.ops.aten.view.default(masked_select, [-1, 1548]); masked_select = None + add = torch.ops.aten.add.Tensor(view, 1); view = None + return (add,)""", + ignore_empty_lines=True, + ) + if __name__ == "__main__": run_tests() diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d3a6f4a1a27..f9fc61af81d 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3336,8 +3336,8 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", _assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None - mul_19: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None - return (mul_19,)""", # noqa: B950 + mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None + return (mul_21,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) @@ -3460,6 +3460,75 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", func(torch.ones(5, 6, 9, 8)) self.assertEqual(cnt.frame_count, 3) + @skipIfTorchDynamo("not allowed to trace mark_unbacked") + @fresh_cache() + def test_unbacked_contiguous(self): + cnt = CompileCounterWithBackend("inductor") + + def func(x): + contig = x.contiguous() + return (contig + 1) * 100 + + compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func) + + x = torch.randn(10, 10) + # make x not contiguous. + x = x.t_() + torch._dynamo.decorators.mark_unbacked(x, 0) + torch._dynamo.decorators.mark_unbacked(x, 1) + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + compiled_func(x) + self.assertEqual(compiled_func(x), func(x)) + y = torch.rand(20, 20).t() + self.assertEqual(compiled_func(y), func(y)) + self.assertEqual(cnt.frame_count, 1) + output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() + self.assertExpectedInline( + output, + """\ + ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None + _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None + clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None + add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None + mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None + return (mul_6,)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + log_stream, ctx = logs_to_string( + "torch._inductor.compile_fx", "post_grad_graphs" + ) + with ctx(): + # recompilation will happen due to stride specialization. + y = torch.rand(20, 20) + torch._dynamo.decorators.mark_unbacked(y, 0) + torch._dynamo.decorators.mark_unbacked(y, 1) + self.assertEqual(compiled_func(y), func(y)) + self.assertEqual(cnt.frame_count, 2) + + output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() + + # No clone this time since input is contiguous. + self.assertExpectedInline( + output, + """\ + ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None + _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None + add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None + mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None + return (mul_5,)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 4704c9992d5..6d36b36996c 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1370,8 +1370,8 @@ def forward(self, crop_camera_1, mask_1): view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); expand_1 = sym_size_int_1 = sym_size_int_2 = None bmm = torch.ops.aten.bmm.default(view, view_1); view = view_1 = None view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]); bmm = None - mul_6 = sym_size_int * 3 - view_3 = torch.ops.aten.view.default(view_2, [mul_6, 3]); view_2 = mul_6 = None + mul_9 = sym_size_int * 3 + view_3 = torch.ops.aten.view.default(view_2, [mul_9, 3]); view_2 = mul_9 = None mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None _unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index cd7bc028198..bfc5b80835c 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -264,7 +264,7 @@ static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObjec auto& self_ = THPVariable_Unpack(self); auto memory_format = r.memoryformat(0); // avoids touching the GIL or current device if self is already contiguous - if (self_.is_contiguous(memory_format)) { + if (self_.is_contiguous_or_false(memory_format)) { // NOTE: this logic is duplicated from VariableType.cpp. Since we need to // record this call to contiguous() in the trace regardless of whether // we actually call contiguous here, we need to record this information diff --git a/torch/csrc/lazy/core/tensor_impl.cpp b/torch/csrc/lazy/core/tensor_impl.cpp index 04730d55295..ce49338936e 100644 --- a/torch/csrc/lazy/core/tensor_impl.cpp +++ b/torch/csrc/lazy/core/tensor_impl.cpp @@ -195,13 +195,14 @@ bool LTCTensorImpl::is_strides_like_custom( return false; } -bool LTCTensorImpl::is_non_overlapping_and_dense_custom() const { +c10::SymBool LTCTensorImpl::sym_is_non_overlapping_and_dense_custom() const { // This should be true, but false as a temporary fix for a PyTorch core issue, // according to https://github.com/pytorch/xla/pull/2682. return false; } -bool LTCTensorImpl::is_contiguous_custom(c10::MemoryFormat _unused) const { +c10::SymBool LTCTensorImpl::sym_is_contiguous_custom( + c10::MemoryFormat _unused) const { // TODO(ezyang): I don't think this branch is actually necessary // TODO(ezyang): I don't think this logic is right, shouldn't we pass on // the memory format? diff --git a/torch/csrc/lazy/core/tensor_impl.h b/torch/csrc/lazy/core/tensor_impl.h index d5e937fc3dc..02f68c01c6f 100644 --- a/torch/csrc/lazy/core/tensor_impl.h +++ b/torch/csrc/lazy/core/tensor_impl.h @@ -41,10 +41,11 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl { int64_t numel_custom() const override; int64_t storage_offset_custom() const override; int64_t dim_custom() const override; - bool is_contiguous_custom(at::MemoryFormat memory_format) const override; bool is_strides_like_custom(at::MemoryFormat memory_format) const override; - bool is_non_overlapping_and_dense_custom() const override; + c10::SymBool sym_is_non_overlapping_and_dense_custom() const override; + c10::SymBool sym_is_contiguous_custom( + at::MemoryFormat memory_format) const override; c10::SymIntArrayRef sym_sizes_custom() const override; c10::SymIntArrayRef sym_strides_custom() const override; c10::SymInt sym_numel_custom() const override; diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index e9c6b8b0e93..acfcc596bd4 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -20,6 +20,9 @@ class ExprPrinter(StrPrinter): def _print_Mul(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, "*", precedence(expr)) + def _print_Not(self, expr: sympy.Expr) -> str: + return f"not ({self._print(expr.args[0])})" + def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str: return self.stringify(expr.args, " + ", precedence(expr))