mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Align max_pool1d Error Checking between CPU and CUDA/CPU requires_grad (#90211)
Fixes https://github.com/pytorch/pytorch/issues/85712 Standardizes error checking for max_pool1d between CPU and CPU requires_grad/CUDA. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90211 Approved by: https://github.com/mruberry
This commit is contained in:
parent
3859aace20
commit
df58020bb6
|
|
@ -24,14 +24,13 @@ DEFINE_DISPATCH(max_pool1d_stub);
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Tensor max_pool1d_impl(
|
static void check_max_pool1d(
|
||||||
const Tensor& self,
|
const Tensor& self,
|
||||||
IntArrayRef kernel_size,
|
IntArrayRef kernel_size,
|
||||||
IntArrayRef stride,
|
IntArrayRef stride,
|
||||||
IntArrayRef padding,
|
IntArrayRef padding,
|
||||||
IntArrayRef dilation,
|
IntArrayRef dilation,
|
||||||
bool ceil_mode) {
|
bool ceil_mode) {
|
||||||
NoNamesGuard guard;
|
|
||||||
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
self.dim() == 2 || self.dim() == 3,
|
self.dim() == 2 || self.dim() == 3,
|
||||||
|
|
@ -58,6 +57,45 @@ Tensor max_pool1d_impl(
|
||||||
stride = kernel_size;
|
stride = kernel_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
kernel_size[0] > 0,
|
||||||
|
"max_pool1d() kernel_size must be greater than zero, but got ",
|
||||||
|
kernel_size[0]);
|
||||||
|
TORCH_CHECK(
|
||||||
|
stride[0] > 0, "max_pool1d() stride must be greater than zero, but got ", stride[0]);
|
||||||
|
TORCH_CHECK(
|
||||||
|
padding[0] >= 0, "max_pool1d() padding must be non-negative, but got ", padding[0]);
|
||||||
|
TORCH_CHECK(
|
||||||
|
padding[0] <= kernel_size[0] / 2,
|
||||||
|
"max_pool1d() padding should be at most half of kernel size, but got padding=",
|
||||||
|
padding[0],
|
||||||
|
" and kernel_size=",
|
||||||
|
kernel_size[0]);
|
||||||
|
TORCH_CHECK(
|
||||||
|
dilation[0] > 0, "max_pool1d() dilation must be greater than zero, but got ", dilation[0]);
|
||||||
|
|
||||||
|
const int64_t OW = pooling_output_shape(self.size(-1), kernel_size[0], padding[0], stride[0], dilation[0], ceil_mode);
|
||||||
|
TORCH_CHECK(OW >= 0, "max_pool1d() Invalid computed output size: ", OW);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Tensor max_pool1d_impl(
|
||||||
|
const Tensor& self,
|
||||||
|
IntArrayRef kernel_size,
|
||||||
|
IntArrayRef stride,
|
||||||
|
IntArrayRef padding,
|
||||||
|
IntArrayRef dilation,
|
||||||
|
bool ceil_mode) {
|
||||||
|
NoNamesGuard guard;
|
||||||
|
|
||||||
|
// If stride=None then set it to kernel_size
|
||||||
|
if (stride.empty()) {
|
||||||
|
stride = kernel_size;
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t NB = self.dim() == 3 ? self.size(-3) : 1;
|
const int64_t NB = self.dim() == 3 ? self.size(-3) : 1;
|
||||||
const int64_t NC = self.size(-2);
|
const int64_t NC = self.size(-2);
|
||||||
const int64_t IW = self.size(-1);
|
const int64_t IW = self.size(-1);
|
||||||
|
|
@ -66,25 +104,7 @@ Tensor max_pool1d_impl(
|
||||||
const int64_t PJ = padding[0];
|
const int64_t PJ = padding[0];
|
||||||
const int64_t DJ = dilation[0];
|
const int64_t DJ = dilation[0];
|
||||||
|
|
||||||
TORCH_CHECK(
|
|
||||||
KW > 0,
|
|
||||||
"max_pool1d() kernel_size must be greater than zero, but got ",
|
|
||||||
KW);
|
|
||||||
TORCH_CHECK(
|
|
||||||
SJ > 0, "max_pool1d() stride must be greater than zero, but got ", SJ);
|
|
||||||
TORCH_CHECK(
|
|
||||||
PJ >= 0, "max_pool1d() padding must be non-negative, but got ", PJ);
|
|
||||||
TORCH_CHECK(
|
|
||||||
PJ <= KW / 2,
|
|
||||||
"max_pool1d() padding should be at most half of kernel size, but got padding=",
|
|
||||||
PJ,
|
|
||||||
" and kernel_size=",
|
|
||||||
KW);
|
|
||||||
TORCH_CHECK(
|
|
||||||
DJ > 0, "max_pool1d() dilation must be greater than zero, but got ", DJ);
|
|
||||||
|
|
||||||
const int64_t OW = pooling_output_shape(IW, KW, PJ, SJ, DJ, ceil_mode);
|
const int64_t OW = pooling_output_shape(IW, KW, PJ, SJ, DJ, ceil_mode);
|
||||||
TORCH_CHECK(OW >= 0, "max_pool1d() Invalid computed output size: ", OW);
|
|
||||||
Tensor output = at::empty({NB, NC, OW}, self.options());
|
Tensor output = at::empty({NB, NC, OW}, self.options());
|
||||||
|
|
||||||
PoolingParams1D params{NB, NC, IW, OW, KW, SJ, PJ, DJ};
|
PoolingParams1D params{NB, NC, IW, OW, KW, SJ, PJ, DJ};
|
||||||
|
|
@ -121,6 +141,8 @@ Tensor max_pool1d(
|
||||||
return at::quantized_max_pool1d(
|
return at::quantized_max_pool1d(
|
||||||
self, kernel_size, stride, padding, dilation, ceil_mode);
|
self, kernel_size, stride, padding, dilation, ceil_mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
check_max_pool1d(self, kernel_size, stride, padding, dilation, ceil_mode);
|
||||||
if ((self.requires_grad() && at::GradMode::is_enabled()) ||
|
if ((self.requires_grad() && at::GradMode::is_enabled()) ||
|
||||||
self._fw_grad(/*level */ 0).defined() ||
|
self._fw_grad(/*level */ 0).defined() ||
|
||||||
!self.device().is_cpu() ||
|
!self.device().is_cpu() ||
|
||||||
|
|
|
||||||
|
|
@ -3032,40 +3032,29 @@ def error_inputs_max_pool1d(op_info, device, **kwargs):
|
||||||
error_regex=error_msg)
|
error_regex=error_msg)
|
||||||
|
|
||||||
# error inputs for empty input with stride=0
|
# error inputs for empty input with stride=0
|
||||||
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
|
error_msg = 'stride must be greater than zero, but got 0'
|
||||||
error_msg = 'stride must be greater than zero, but got 0' if torch.device(
|
|
||||||
device).type == 'cpu' and not requires_grad else 'stride should not be zero'
|
|
||||||
yield ErrorInput(SampleInput(make_arg((3, 3, 3)), kwargs={'kernel_size': 1, 'stride': 0}),
|
yield ErrorInput(SampleInput(make_arg((3, 3, 3)), kwargs={'kernel_size': 1, 'stride': 0}),
|
||||||
error_regex=error_msg)
|
error_regex=error_msg)
|
||||||
|
|
||||||
# error inputs for empty input with dilation=0
|
# error inputs for empty input with dilation=0
|
||||||
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
|
error_msg = 'dilation must be greater than zero, but got 0'
|
||||||
error_msg = 'dilation must be greater than zero, but got 0' if torch.device(
|
|
||||||
device).type == 'cpu' and not requires_grad else 'dilation should be greater than zero, but got dilation'
|
|
||||||
yield ErrorInput(SampleInput(make_arg((3, 3, 3)),
|
yield ErrorInput(SampleInput(make_arg((3, 3, 3)),
|
||||||
kwargs={'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 0}),
|
kwargs={'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 0}),
|
||||||
error_regex=error_msg)
|
error_regex=error_msg)
|
||||||
|
|
||||||
# error inputs for invalid output size
|
# error inputs for invalid output size
|
||||||
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
|
error_msg = 'Invalid computed output size: -2'
|
||||||
error_msg = 'Invalid computed output size: -2' if torch.device(device).type == 'cpu' and not requires_grad \
|
|
||||||
else \
|
|
||||||
r'Given input size: \(2x1x2\). Calculated output size: \(2x1x-2\). Output size is too small'
|
|
||||||
yield ErrorInput(SampleInput(make_arg((2, 2, 2)),
|
yield ErrorInput(SampleInput(make_arg((2, 2, 2)),
|
||||||
kwargs={'kernel_size': 5, 'stride': 1, 'padding': 0, 'dilation': 1}),
|
kwargs={'kernel_size': 5, 'stride': 1, 'padding': 0, 'dilation': 1}),
|
||||||
error_regex=error_msg)
|
error_regex=error_msg)
|
||||||
|
|
||||||
# error inputs when kernel_size=0
|
# error inputs when kernel_size=0
|
||||||
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
|
error_msg = 'kernel_size must be greater than zero'
|
||||||
error_msg = 'kernel_size must be greater than zero' if torch.device(
|
|
||||||
device).type == 'cpu' and not requires_grad else r'stride should not be zero'
|
|
||||||
yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 0}),
|
yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 0}),
|
||||||
error_regex=error_msg)
|
error_regex=error_msg)
|
||||||
|
|
||||||
# error inputs for strides > 0
|
# error inputs for strides > 0
|
||||||
# NOTE: CPU vs (CPU with requires_grad and CUDA) error messages are different.
|
error_msg = 'stride must be greater than zero'
|
||||||
error_msg = 'stride must be greater than zero' if torch.device(
|
|
||||||
device).type == 'cpu' and not requires_grad else r'stride should not be zero'
|
|
||||||
yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 0}),
|
yield ErrorInput(SampleInput(x, kwargs={'kernel_size': 2, 'stride': 0}),
|
||||||
error_regex=error_msg)
|
error_regex=error_msg)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user