Fix median bug on discontigous tensors (#46917)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46917

fixes https://github.com/pytorch/pytorch/issues/46814

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D24633412

Pulled By: heitorschueroff

fbshipit-source-id: 54732671b298bdc2b04b13ab3a373892ee0933c3
This commit is contained in:
Heitor Schueroff 2020-10-29 17:10:31 -07:00 committed by Facebook GitHub Bot
parent 9bc8f071a3
commit ddeacf1565
3 changed files with 19 additions and 1 deletions

View File

@ -319,7 +319,7 @@ std::tuple<Tensor&, Tensor&> median_with_indices_impl(
NoNamesGuard guard;
dim = at::maybe_wrap_dim(dim, self.dim());
Tensor in = self.dim() > 0 ? self : self.unsqueeze(0);
Tensor in = self.dim() > 0 ? self.contiguous() : self.unsqueeze(0);
int64_t size = in.size(dim);
TORCH_CHECK(

View File

@ -143,6 +143,7 @@ static uint64_t nextHighestPowerOf2(uint64_t n) {
}
// WARNING: This function assumes input tensors are contiguous
template <typename scalar_t, typename index_t, typename Launcher>
void run_launcher(
Tensor& values,

View File

@ -10536,6 +10536,23 @@ class TestTorchDeviceType(TestCase):
check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]])
check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]])
# Discontiguous and strided tensors
a = torch.arange(12, device=device)
self.assertEqual(a[::2].median(), torch.tensor(4, device=device))
self.assertEqual(a[::2].nanmedian(), torch.tensor(4, device=device))
a.resize_(3, 4)
self.assertEqual(a.T.median(), torch.tensor(5, device=device))
self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
self.assertEqual(a[::2, ::2].median(-1)[0], torch.tensor([0, 8], device=device))
self.assertEqual(a[::2, ::2].nanmedian(-1)[0], torch.tensor([0, 8], device=device))
a.resize_(2, 3, 2)
self.assertEqual(a.T.median(), torch.tensor(5, device=device))
self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.double)