diff --git a/torch/csrc/autograd/utils/grad_layout_contract.h b/torch/csrc/autograd/utils/grad_layout_contract.h index 7d1e66a0fdc..4d1787d55c7 100644 --- a/torch/csrc/autograd/utils/grad_layout_contract.h +++ b/torch/csrc/autograd/utils/grad_layout_contract.h @@ -13,22 +13,9 @@ namespace utils { inline bool obeys_layout_contract(const at::Tensor& grad, const at::Tensor& variable) { TORCH_INTERNAL_ASSERT(!grad.is_sparse()); TORCH_INTERNAL_ASSERT(!variable.is_sparse()); - if (variable.is_non_overlapping_and_dense()) { - // Only look at stride for dimensions that are not of size 1 - const auto& grad_sizes = grad.sizes(); - const auto& grad_strides = grad.strides(); - const auto& variable_strides = variable.strides(); - for (const auto idx : c10::irange(grad_sizes.size())) { - if (grad_sizes[idx] != 1) { - if (grad_strides[idx] != variable_strides[idx]) { - return false; - } - } - } - return true; - } else { - return grad.is_contiguous(at::MemoryFormat::Contiguous); - } + return variable.is_non_overlapping_and_dense() ? + (grad.strides() == variable.strides()) : + grad.is_contiguous(at::MemoryFormat::Contiguous); } // Creates a clone of new_grad that obeys the contract with variable.