mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add shape check for avg_pool2d (#161952)
Fix https://github.com/pytorch/pytorch/issues/153312. **Example:** ```python import torch print(torch.__version__) tensor = torch.tensor([[ -7.8130e-88, -2.2092e-138, -1.8673e+03, -7.6272e-253, 3.9203e+110, 1.8380e-51, 2.8762e+268, 2.9094e+286, 5.1816e-228, -4.4916e+191, -7.4057e+80, -9.1955e-18, 5.6536e+225, 8.8364e-175, 1.5053e-226], [-3.0521e+239, -2.8307e+306, 1.3297e-03, -9.9969e-132, 2.8920e-286, 2.3964e+58, -6.8138e-281, 2.0321e-305, -3.5127e+74, -4.7560e-92, -8.9403e-99, -1.9739e-187, -2.5124e-173, 2.0458e+295, 4.4992e+52], [ 6.8752e+21, 1.9332e+189, -8.6940e-189, -6.6743e-15, 1.4691e+41, 1.0338e+63, -2.0779e-28, -7.6642e+104, 1.3390e+284, -8.0859e+194, 8.4600e+107, 4.9115e-44, 1.1665e+285, 5.1275e+203, 9.7580e+303]], dtype=torch.float64) try: res = torch.nn.functional.lp_pool1d( tensor, norm_type=-1.38119e+150, kernel_size=7879455037536781369, ceil_mode=True, ) print("CPU result:", res) except RuntimeError as e: print(f"CPU error: {e}") tensor_gpu = tensor.to("cuda:0") try: res = torch.nn.functional.lp_pool1d( tensor_gpu, norm_type=-1.38119e+150, kernel_size=7879455037536781369, ceil_mode=True, ) print("GPU result:", res) except RuntimeError as e: print(f"GPU error: {e}") ``` **Output:** - before ``` 2.9.0a0+git8703deb CPU result: tensor([[0.], [0.], [0.]], dtype=torch.float64) GPU error: integer out of range ``` - after ``` 2.9.0a0+git2e893df CPU error: integer out of range GPU error: integer out of range ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161952 Approved by: https://github.com/mingfeima, https://github.com/malfet
This commit is contained in:
parent
fd5da81fdd
commit
8d599045cf
|
|
@ -25,18 +25,19 @@ TORCH_PRECOMPUTE_META_FUNC(avg_pool2d)
|
|||
// #20866, #22032: Guarantee this for the official C++ API?
|
||||
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
|
||||
"avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
|
||||
const int64_t kH = kernel_size[0];
|
||||
const int64_t kW = kernel_size.size() == 1 ? kH : kernel_size[1];
|
||||
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
|
||||
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
|
||||
|
||||
TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
|
||||
"avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
|
||||
const int64_t dH = stride.empty() ? kH : stride[0];
|
||||
const int64_t dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1];
|
||||
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
|
||||
const int dW = stride.empty() ? kW :
|
||||
stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
|
||||
|
||||
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
|
||||
"avg_pool2d: padding must either be a single int, or a tuple of two ints");
|
||||
const int64_t padH = padding[0];
|
||||
const int64_t padW = padding.size() == 1 ? padH : padding[1];
|
||||
const int padH = safe_downcast<int, int64_t>(padding[0]);
|
||||
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
|
||||
|
||||
TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
|
||||
"divisor must be not zero");
|
||||
|
|
|
|||
|
|
@ -898,6 +898,16 @@ torch.cuda.synchronize()
|
|||
inp = torch.ones(1, 0, 50, 44, 31, device=device)
|
||||
mod(inp)
|
||||
|
||||
@onlyCPU
|
||||
def test_LPPool1d_kernel_size_overflow_large(self, device):
|
||||
avgpool = torch.nn.LPPool1d(
|
||||
-1.38119e150, 7879455037536781369, ceil_mode=True
|
||||
).to(device)
|
||||
inp = torch.randn(3, 15, device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "integer out of range"):
|
||||
avgpool(inp)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_AvgPool2d_empty(self, device):
|
||||
avgpool = torch.nn.AvgPool2d(3, stride=2).to(device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user