mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Conv1d] Check overflow before we compute padding size. (#162363)
Fixes https://github.com/pytorch/pytorch/issues/161877 also fixes https://github.com/pytorch/pytorch/issues/161875 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162363 Approved by: https://github.com/jbschlosser
This commit is contained in:
parent
753d9bd806
commit
94eaeb9cb8
|
|
@ -689,6 +689,10 @@ static void check_shape_forward(const at::Tensor& input,
|
|||
", but got bias of size ", at::symint::sizes<T>(bias), " instead");
|
||||
|
||||
for (const auto i : c10::irange(2, k)) {
|
||||
// T could be int64_t or SymInt, Specialized numeric_limts<SymInt> in c10/core/SymInt.h
|
||||
TORCH_CHECK(padding[i-2] <= (std::numeric_limits<T>::max() - padding[i-2]),
|
||||
"Given padding=", padding[i-2], " at dimension ", i-2, " , expected padding to be at most ",
|
||||
(std::numeric_limits<T>::max() / 2));
|
||||
input_shape.push_back(at::symint::size<T>(input, i) + 2 * padding[i-2]);
|
||||
// log new kernel size considering dilation
|
||||
kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1);
|
||||
|
|
@ -715,6 +719,11 @@ static void check_shape_forward(const at::Tensor& input,
|
|||
"Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size");
|
||||
}
|
||||
} else { // transposed
|
||||
for (const auto i : c10::irange(2, k)) {
|
||||
TORCH_CHECK(padding[i-2] <= (std::numeric_limits<T>::max() - padding[i-2]),
|
||||
"Given padding=", padding[i-2], " at dimension ", i-2, " , expected padding to be at most ",
|
||||
(std::numeric_limits<T>::max() / 2));
|
||||
}
|
||||
TORCH_CHECK(at::symint::size<T>(input, 1) == weight_sizes[0],
|
||||
"Given transposed=", transposed, ", weight of size ", weight_sizes,
|
||||
", expected input", at::symint::sizes<T>(input), " to have ", weight_sizes[0],
|
||||
|
|
|
|||
|
|
@ -556,3 +556,26 @@ inline SymBool sym_ge(const SymInt& a, const SymInt& b) {
|
|||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::SymInt> {
|
||||
public:
|
||||
static constexpr bool is_specialized = true;
|
||||
|
||||
static constexpr int64_t max() noexcept {
|
||||
return std::numeric_limits<int64_t>::max();
|
||||
}
|
||||
|
||||
static constexpr int64_t min() noexcept {
|
||||
return std::numeric_limits<int64_t>::min();
|
||||
}
|
||||
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_integer = true;
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
|
|
|||
|
|
@ -93,6 +93,47 @@ class TestConvolutionNN(NNTestCase):
|
|||
input = torch.randn((1, 1, 1, 1), dtype=torch.float)
|
||||
self.assertEqual(m(input).size(), (1, 1, 1, 1))
|
||||
|
||||
def test_huge_padding(self):
|
||||
class Conv1dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv1d(
|
||||
in_channels=16,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=9223372036854775803,
|
||||
)
|
||||
self.add_module(name="conv1", module=self.conv1)
|
||||
|
||||
input_data = torch.randn(1, 16, 100)
|
||||
model = Conv1dModule()
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Given padding=9223372036854775803 at dimension 0 , expected padding to be at most",
|
||||
):
|
||||
model.conv1(input_data)
|
||||
|
||||
class ConvTransposed1dModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv_transposed1d = nn.ConvTranspose1d(
|
||||
in_channels=16,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=9223372036854775803,
|
||||
)
|
||||
self.add_module(name="conv_transposed1d", module=self.conv_transposed1d)
|
||||
|
||||
input_data = torch.randn(1, 16, 100)
|
||||
model = ConvTransposed1dModule()
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Given padding=9223372036854775803 at dimension 0 , expected padding to be at most",
|
||||
):
|
||||
model.conv_transposed1d(input_data)
|
||||
|
||||
def test_invalid_conv1d(self):
|
||||
for dtype in [
|
||||
torch.half,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user