From a2ab06514b7a47059808cee1fb597c4cc74665f6 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 8 Nov 2021 17:55:03 -0800 Subject: [PATCH] Fixes CUDA vs CPU consistency for index_put_ when accumulating (part 2) (#67189) Summary: Description: - Follow up PR to https://github.com/pytorch/pytorch/issues/66790 to fix the tests on functorch, https://github.com/pytorch/functorch/issues/195 In functorch, a null tensor is added to the list of indices for the batch dimension in C++, but I can not find an equivalent of that in python without using `torch.jit.script`. If any other better solutions could be suggested, I'd be happy to replace the current way of testing. cc ngimel zou3519 Pull Request resolved: https://github.com/pytorch/pytorch/pull/67189 Reviewed By: suo Differential Revision: D31966686 Pulled By: ngimel fbshipit-source-id: a14b9e5d77d9f43cd728d474e2976d84a87a6ff4 --- aten/src/ATen/native/cuda/Indexing.cu | 18 +++++++++-- test/test_indexing.py | 44 +++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index e24aa33b477..4872865d401 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -209,6 +209,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List (size_t)self.dim()) { TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); } + if (!self.is_contiguous()) { + self = self.contiguous(); + } Tensor linearIndex, src, expandedValue = value; int64_t nElemBefore, strideBefore, sliceSize; std::vector inversePerm; @@ -216,7 +219,15 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List 1) { + expanded_size.insert(expanded_size.begin(), nElemBefore); + } expandedValue = expandedValue.expand(expanded_size); } expandedValue = expandedValue.contiguous(); @@ -277,8 +288,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List