mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Align max_pool1d Error Checking between CPU and CUDA/CPU requires_grad (#90211)
Fixes https://github.com/pytorch/pytorch/issues/85712 Standardizes error checking for max_pool1d between CPU and CPU requires_grad/CUDA. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90211 Approved by: https://github.com/mruberry
This commit is contained in:
parent
3859aace20
commit
df58020bb6
|
|
@ -24,14 +24,13 @@ DEFINE_DISPATCH(max_pool1d_stub);
|
|||
|
||||
namespace {
|
||||
|
||||
Tensor max_pool1d_impl(
|
||||
static void check_max_pool1d(
|
||||
const Tensor& self,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode) {
|
||||
NoNamesGuard guard;
|
||||
|
||||
TORCH_CHECK(
|
||||
self.dim() == 2 || self.dim() == 3,
|
||||
|
|
@ -58,6 +57,45 @@ Tensor max_pool1d_impl(
|
|||
stride = kernel_size;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
kernel_size[0] > 0,
|
||||
"max_pool1d() kernel_size must be greater than zero, but got ",
|
||||
kernel_size[0]);
|
||||
TORCH_CHECK(
|
||||
stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
|
||||
TORCH_CHECK(
|
||||
padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
|
||||
TORCH_CHECK(
|
||||
padding[0] <= kernel_size[0] / 2,
|
||||
"max_pool1d() padding should be at most half of kernel size, but got padding=",
|
||||
padding[0],
|
||||
" and kernel_size=",
|
||||
kernel_size[0]);
|
||||
TORCH_CHECK(
|
||||
dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
|
||||
|
||||
const int64_t OW = pooling_output_shape(self.size(-1), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
|
||||
TORCH_CHECK(OW >= 0, "max_pool1d() Invalid computed output size: ", OW);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
Tensor max_pool1d_impl(
|
||||
const Tensor& self,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode) {
|
||||
NoNamesGuard guard;
|
||||
|
||||
// If stride=None then set it to kernel_size
|
||||
if (stride.empty()) {
|
||||
stride = kernel_size;
|
||||
}
|
||||
|
||||
const int64_t NB = self.dim() == 3 ? self.size(-3) : 1;
|
||||
const int64_t NC = self.size(-2);
|
||||
const int64_t IW = self.size(-1);
|
||||
|
|
@ -66,25 +104,7 @@ Tensor max_pool1d_impl(
|
|||
const int64_t PJ = padding[0];
|
||||
const int64_t DJ = dilation[0];
|
||||
|
||||
TORCH_CHECK(
|
||||
KW > 0,
|
||||
"max_pool1d() kernel_size must be greater than zero, but got ",
|
||||
KW);
|
||||
TORCH_CHECK(
|
||||
SJ > 0, "max_pool1d() stride must be greater than zero, but got ", SJ);
|
||||
TORCH_CHECK(
|
||||
PJ >= 0, "max_pool1d() padding must be non-negative, but got ", PJ);
|
||||
TORCH_CHECK(
|
||||
PJ <= KW / 2,
|
||||
"max_pool1d() padding should be at most half of kernel size, but got padding=",
|
||||
PJ,
|
||||
" and kernel_size=",
|
||||
KW);
|
||||
TORCH_CHECK(
|
||||
DJ > 0, "max_pool1d() dilation must be greater than zero, but got ", DJ);
|
||||
|
||||
const int64_t OW = pooling_output_shape(IW, KW, PJ, SJ, DJ, ceil_mode);
|
||||
TORCH_CHECK(OW >= 0, "max_pool1d() Invalid computed output size: ", OW);
|
||||
Tensor output = at::empty({NB, NC, OW}, self.options());
|
||||
|
||||
PoolingParams1D params{NB, NC, IW, OW, KW, SJ, PJ, DJ};
|
||||
|
|
@ -121,6 +141,8 @@ Tensor max_pool1d(
|
|||
return at::quantized_max_pool1d(
|
||||
self, kernel_size, stride, padding, dilation, ceil_mode);
|
||||
}
|
||||
|
||||
check_max_pool1d(self, kernel_size, stride, padding, dilation, ceil_mode);
|
||||
if ((self.requires_grad() && at::GradMode::is_enabled()) ||
|
||||
self._fw_grad(/*level */ 0).defined() ||
|
||||
!self.device().is_cpu() ||
|
||||
|
|
|
|||
|
|
@ -3032,40 +3032,29 @@ def error_inputs_max_pool1d(op_info, device, **kwargs):
|
|||
error_regex=error_msg)
|
||||
|
||||
# error inputs for empty input with stride=0
|
||||
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
|
||||
error_msg = 'stride must be greater than zero, but got 0' if torch.device(
|
||||
device).type == 'cpu' and not requires_grad else 'stride should not be zero'
|
||||
error_msg = 'stride must be greater than zero, but got 0'
|
||||
yield ErrorInput(SampleInput(make_arg((3, 3, 3)), kwargs={'kernel_size': 1, 'stride': 0}),
|
||||
error_regex=error_msg)
|
||||
|
||||
# error inputs for empty input with dilation=0
|
||||
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
|
||||
error_msg = 'dilation must be greater than zero, but got 0' if torch.device(
|
||||
device).type == 'cpu' and not requires_grad else 'dilation should be greater than zero, but got dilation'
|
||||
error_msg = 'dilation must be greater than zero, but got 0'
|
||||
yield ErrorInput(SampleInput(make_arg((3, 3, 3)),
|
||||
kwargs={'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 0}),
|
||||
error_regex=error_msg)
|
||||
|
||||
# error inputs for invalid output size
|
||||
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
|
||||
error_msg = 'Invalid computed output size: -2' if torch.device(device).type == 'cpu' and not requires_grad \
|
||||
else \
|
||||
r'Given input size: \(2x1x2\). Calculated output size: \(2x1x-2\). Output size is too small'
|
||||
error_msg = 'Invalid computed output size: -2'
|
||||
yield ErrorInput(SampleInput(make_arg((2, 2, 2)),
|
||||
kwargs={'kernel_size': 5, 'stride': 1, 'padding': 0, 'dilation': 1}),
|
||||
error_regex=error_msg)
|
||||
|
||||
# error inputs when kernel_size=0
|
||||
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
|
||||
error_msg = 'kernel_size must be greater than zero' if torch.device(
|
||||
device).type == 'cpu' and not requires_grad else r'stride should not be zero'
|
||||
error_msg = 'kernel_size must be greater than zero'
|
||||
yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 0}),
|
||||
error_regex=error_msg)
|
||||
|
||||
# error inputs for strides > 0
|
||||
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
|
||||
error_msg = 'stride must be greater than zero' if torch.device(
|
||||
device).type == 'cpu' and not requires_grad else r'stride should not be zero'
|
||||
error_msg = 'stride must be greater than zero'
|
||||
yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 0}),
|
||||
error_regex=error_msg)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user