Revert D23468286: [pytorch][PR] Optimize code path for adaptive_avg_pool2d when output size is (1, 1)

Test Plan: revert-hammer

Differential Revision:
D23468286 (f8f35fddd4)

Original commit changeset: cc181f705fea

fbshipit-source-id: 3a1db0eef849e0c2f3c0c64040d2a8b799644fa3
This commit is contained in:
Natalia Gimelshein 2020-09-04 11:25:08 -07:00 committed by Facebook GitHub Bot
parent 6474057c76
commit 0c2bc4fe20
2 changed files with 6 additions and 37 deletions

View File

@ -329,17 +329,14 @@ namespace {
return at::mkldnn_adaptive_avg_pool2d(input, output_size);
}
if (!input.is_quantized() && output_size[0] == 1 && output_size[1] == 1) {
// TODO: fastpath for Channels_last should be explored later;
if (input.suggest_memory_format() == at::MemoryFormat::Contiguous && !input.is_quantized() && output_size[0] == 1 && output_size[1] == 1) {
// in this case, adaptive pooling is just computing mean over hw
// dimensions, which can be done more efficiently
Tensor out = input.mean({-1, -2}, /* keepdim = */ true);
if (input.suggest_memory_format() == at::MemoryFormat::ChannelsLast) {
// assert ndim == 4, since ndim = 3 doesn't give channels_last
const int n = input.size(0);
const int c = input.size(1);
out.as_strided_({n, c, 1, 1}, {c, 1, c, c});
}
return out;
int64_t mean_size = input.size(-1) * input.size(-2);
Tensor out = input.contiguous().view({-1, mean_size}).mean(-1);
return input.dim() == 3 ? out.view({input.size(0), 1, 1})
: out.view({input.size(0), input.size(1), 1, 1});
} else {
return _adaptive_avg_pool2d(input, output_size);
}

View File

@ -9845,34 +9845,6 @@ class TestNNDeviceType(NNTestCase):
test('threshold', 3, 2)
test('threshold', 3, 2, inplace=True)
@onlyOnCPUAndCUDA # TODO: fix on XLA
def test_adaptive_avg_pool2d_output_size_one(self, device):
def helper(size, memory_format):
x = torch.randint(1, 10, size, dtype=torch.float, device=device, requires_grad=True)
if memory_format == 'non_contiguous':
x = x[::2, ::2, ::2, ::2]
else:
x = x.to(memory_format=memory_format)
net = torch.nn.AdaptiveAvgPool2d((1, 1))
out = net(x)
ref_out = x.contiguous().mean((-1, -2)).view((x.size(0), x.size(1), 1, 1))
out.sum().backward() # make sure it doesn't crash
self.assertEqual(out, ref_out)
if memory_format == torch.channels_last:
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
c = out.size(1)
self.assertEqual(out.stride(), [c, 1, c, c])
else:
self.assertTrue(out.is_contiguous())
c = out.size(1)
self.assertEqual(out.stride(), [c, 1, 1, 1])
for mf in (torch.contiguous_format, torch.channels_last, 'non_contiguous'):
helper((2, 3, 6, 6), mf)
@onlyCUDA
@dtypesIfCUDA(torch.half, torch.float, torch.double)
def test_avg_pool2d_nhwc(self, device, dtype):