mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
ENH Adds no batch dim support for AvgPool1d (#61860)
Summary: Towards https://github.com/pytorch/pytorch/issues/60585 Pull Request resolved: https://github.com/pytorch/pytorch/pull/61860 Reviewed By: albanD Differential Revision: D29826382 Pulled By: jbschlosser fbshipit-source-id: 47e12073d866f0604310fc1ff270cde9907e516d
This commit is contained in:
parent
5a00152a3d
commit
0309c5780d
|
|
@ -91,20 +91,20 @@ Tensor avg_pool1d(
|
|||
if (stride.empty()) {
|
||||
stride = kernel_size;
|
||||
}
|
||||
checkDim("avg_pool1d", TensorArg(self, "self", 1), 3);
|
||||
checkDimRange("avg_pool1d", TensorArg(self, "self", 1), 2, 4 /* exclusive */);
|
||||
check1d("avg_pool1d", "kernel_size", kernel_size);
|
||||
check1d("avg_pool1d", "stride", stride);
|
||||
check1d("avg_pool1d", "padding", padding);
|
||||
|
||||
auto output = at::avg_pool2d(
|
||||
self.unsqueeze(2),
|
||||
self.unsqueeze(-2),
|
||||
{1, kernel_size[0]},
|
||||
{1, stride[0]},
|
||||
{0, padding[0]},
|
||||
ceil_mode,
|
||||
count_include_pad);
|
||||
|
||||
return output.squeeze(2);
|
||||
return output.squeeze(-2);
|
||||
}
|
||||
|
||||
Tensor max_pool2d(
|
||||
|
|
|
|||
|
|
@ -502,8 +502,8 @@ class AvgPool1d(_AvgPoolNd):
|
|||
count_include_pad: when True, will include the zero-padding in the averaging calculation
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, L_{in})`
|
||||
- Output: :math:`(N, C, L_{out})`, where
|
||||
- Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in}`.
|
||||
- Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
|
||||
|
||||
.. math::
|
||||
L_{out} = \left\lfloor \frac{L_{in} +
|
||||
|
|
|
|||
|
|
@ -2109,6 +2109,14 @@ new_module_tests = [
|
|||
input_size=(2, 3, 6),
|
||||
desc='stride_pad',
|
||||
),
|
||||
dict(
|
||||
module_name='AvgPool1d',
|
||||
constructor_args=(2,),
|
||||
cpp_constructor_args='torch::nn::AvgPool1dOptions(2)',
|
||||
input_size=(3, 6),
|
||||
reference_fn=single_batch_reference_fn,
|
||||
desc='no_batch_dim',
|
||||
),
|
||||
dict(
|
||||
module_name='AvgPool2d',
|
||||
constructor_args=((2, 2),),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user