From ddeacf1565ca08015c20cb6d17ef0357937fff5c Mon Sep 17 00:00:00 2001 From: Heitor Schueroff Date: Thu, 29 Oct 2020 17:10:31 -0700 Subject: [PATCH] Fix median bug on discontigous tensors (#46917) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46917 fixes https://github.com/pytorch/pytorch/issues/46814 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D24633412 Pulled By: heitorschueroff fbshipit-source-id: 54732671b298bdc2b04b13ab3a373892ee0933c3 --- aten/src/ATen/native/cuda/Sorting.cu | 2 +- aten/src/ATen/native/cuda/SortingCommon.cuh | 1 + test/test_torch.py | 17 +++++++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu index 0a5760580c0..c6688b28691 100644 --- a/aten/src/ATen/native/cuda/Sorting.cu +++ b/aten/src/ATen/native/cuda/Sorting.cu @@ -319,7 +319,7 @@ std::tuple median_with_indices_impl( NoNamesGuard guard; dim = at::maybe_wrap_dim(dim, self.dim()); - Tensor in = self.dim() > 0 ? self : self.unsqueeze(0); + Tensor in = self.dim() > 0 ? self.contiguous() : self.unsqueeze(0); int64_t size = in.size(dim); TORCH_CHECK( diff --git a/aten/src/ATen/native/cuda/SortingCommon.cuh b/aten/src/ATen/native/cuda/SortingCommon.cuh index 54513955e91..0e5cb7371d5 100644 --- a/aten/src/ATen/native/cuda/SortingCommon.cuh +++ b/aten/src/ATen/native/cuda/SortingCommon.cuh @@ -143,6 +143,7 @@ static uint64_t nextHighestPowerOf2(uint64_t n) { } +// WARNING: This function assumes input tensors are contiguous template void run_launcher( Tensor& values, diff --git a/test/test_torch.py b/test/test_torch.py index 48810a2f92e..23ebf0d964b 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10536,6 +10536,23 @@ class TestTorchDeviceType(TestCase): check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]]) check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]]) + # Discontiguous and strided tensors + a = torch.arange(12, device=device) + self.assertEqual(a[::2].median(), torch.tensor(4, device=device)) + self.assertEqual(a[::2].nanmedian(), torch.tensor(4, device=device)) + + a.resize_(3, 4) + self.assertEqual(a.T.median(), torch.tensor(5, device=device)) + self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device)) + self.assertEqual(a[::2, ::2].median(-1)[0], torch.tensor([0, 8], device=device)) + self.assertEqual(a[::2, ::2].nanmedian(-1)[0], torch.tensor([0, 8], device=device)) + + a.resize_(2, 3, 2) + self.assertEqual(a.T.median(), torch.tensor(5, device=device)) + self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device)) + self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device)) + self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device)) + @onlyOnCPUAndCUDA @dtypes(torch.float, torch.double)