mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix median bug on discontigous tensors (#46917)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46917 fixes https://github.com/pytorch/pytorch/issues/46814 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D24633412 Pulled By: heitorschueroff fbshipit-source-id: 54732671b298bdc2b04b13ab3a373892ee0933c3
This commit is contained in:
parent
9bc8f071a3
commit
ddeacf1565
|
|
@ -319,7 +319,7 @@ std::tuple<Tensor&, Tensor&> median_with_indices_impl(
|
||||||
NoNamesGuard guard;
|
NoNamesGuard guard;
|
||||||
|
|
||||||
dim = at::maybe_wrap_dim(dim, self.dim());
|
dim = at::maybe_wrap_dim(dim, self.dim());
|
||||||
Tensor in = self.dim() > 0 ? self : self.unsqueeze(0);
|
Tensor in = self.dim() > 0 ? self.contiguous() : self.unsqueeze(0);
|
||||||
|
|
||||||
int64_t size = in.size(dim);
|
int64_t size = in.size(dim);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
|
|
|
||||||
|
|
@ -143,6 +143,7 @@ static uint64_t nextHighestPowerOf2(uint64_t n) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// WARNING: This function assumes input tensors are contiguous
|
||||||
template <typename scalar_t, typename index_t, typename Launcher>
|
template <typename scalar_t, typename index_t, typename Launcher>
|
||||||
void run_launcher(
|
void run_launcher(
|
||||||
Tensor& values,
|
Tensor& values,
|
||||||
|
|
|
||||||
|
|
@ -10536,6 +10536,23 @@ class TestTorchDeviceType(TestCase):
|
||||||
check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]])
|
check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]])
|
||||||
check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]])
|
check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]])
|
||||||
|
|
||||||
|
# Discontiguous and strided tensors
|
||||||
|
a = torch.arange(12, device=device)
|
||||||
|
self.assertEqual(a[::2].median(), torch.tensor(4, device=device))
|
||||||
|
self.assertEqual(a[::2].nanmedian(), torch.tensor(4, device=device))
|
||||||
|
|
||||||
|
a.resize_(3, 4)
|
||||||
|
self.assertEqual(a.T.median(), torch.tensor(5, device=device))
|
||||||
|
self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
|
||||||
|
self.assertEqual(a[::2, ::2].median(-1)[0], torch.tensor([0, 8], device=device))
|
||||||
|
self.assertEqual(a[::2, ::2].nanmedian(-1)[0], torch.tensor([0, 8], device=device))
|
||||||
|
|
||||||
|
a.resize_(2, 3, 2)
|
||||||
|
self.assertEqual(a.T.median(), torch.tensor(5, device=device))
|
||||||
|
self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
|
||||||
|
self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
|
||||||
|
self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
|
||||||
|
|
||||||
|
|
||||||
@onlyOnCPUAndCUDA
|
@onlyOnCPUAndCUDA
|
||||||
@dtypes(torch.float, torch.double)
|
@dtypes(torch.float, torch.double)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user