mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix median bug on discontigous tensors (#46917)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46917 fixes https://github.com/pytorch/pytorch/issues/46814 Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D24633412 Pulled By: heitorschueroff fbshipit-source-id: 54732671b298bdc2b04b13ab3a373892ee0933c3
This commit is contained in:
parent
9bc8f071a3
commit
ddeacf1565
|
|
@ -319,7 +319,7 @@ std::tuple<Tensor&, Tensor&> median_with_indices_impl(
|
|||
NoNamesGuard guard;
|
||||
|
||||
dim = at::maybe_wrap_dim(dim, self.dim());
|
||||
Tensor in = self.dim() > 0 ? self : self.unsqueeze(0);
|
||||
Tensor in = self.dim() > 0 ? self.contiguous() : self.unsqueeze(0);
|
||||
|
||||
int64_t size = in.size(dim);
|
||||
TORCH_CHECK(
|
||||
|
|
|
|||
|
|
@ -143,6 +143,7 @@ static uint64_t nextHighestPowerOf2(uint64_t n) {
|
|||
}
|
||||
|
||||
|
||||
// WARNING: This function assumes input tensors are contiguous
|
||||
template <typename scalar_t, typename index_t, typename Launcher>
|
||||
void run_launcher(
|
||||
Tensor& values,
|
||||
|
|
|
|||
|
|
@ -10536,6 +10536,23 @@ class TestTorchDeviceType(TestCase):
|
|||
check(torch.median, [[nan, nan], [1, 2]], [1], [[nan, 1]])
|
||||
check(torch.nanmedian, [[nan, nan], [1, 2]], [1], [[nan, 1.]])
|
||||
|
||||
# Discontiguous and strided tensors
|
||||
a = torch.arange(12, device=device)
|
||||
self.assertEqual(a[::2].median(), torch.tensor(4, device=device))
|
||||
self.assertEqual(a[::2].nanmedian(), torch.tensor(4, device=device))
|
||||
|
||||
a.resize_(3, 4)
|
||||
self.assertEqual(a.T.median(), torch.tensor(5, device=device))
|
||||
self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
|
||||
self.assertEqual(a[::2, ::2].median(-1)[0], torch.tensor([0, 8], device=device))
|
||||
self.assertEqual(a[::2, ::2].nanmedian(-1)[0], torch.tensor([0, 8], device=device))
|
||||
|
||||
a.resize_(2, 3, 2)
|
||||
self.assertEqual(a.T.median(), torch.tensor(5, device=device))
|
||||
self.assertEqual(a.T.nanmedian(), torch.tensor(5, device=device))
|
||||
self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
|
||||
self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device))
|
||||
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.float, torch.double)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user