[Conv1d] Check overflow before we compute padding size. (#162363)

Fixes https://github.com/pytorch/pytorch/issues/161877
also fixes https://github.com/pytorch/pytorch/issues/161875

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162363
Approved by: https://github.com/jbschlosser
This commit is contained in:
thenumberouscode 2025-10-29 03:27:17 +00:00 committed by PyTorch MergeBot
parent 753d9bd806
commit 94eaeb9cb8
3 changed files with 73 additions and 0 deletions

View File

@ -689,6 +689,10 @@ static void check_shape_forward(const at::Tensor& input,
", but got bias of size ", at::symint::sizes<T>(bias), " instead");
for (const auto i : c10::irange(2, k)) {
// T could be int64_t or SymInt, Specialized numeric_limts<SymInt> in c10/core/SymInt.h
TORCH_CHECK(padding[i-2] <= (std::numeric_limits<T>::max() - padding[i-2]),
"Given padding=", padding[i-2], " at dimension ", i-2, " , expected padding to be at most ",
(std::numeric_limits<T>::max() / 2));
input_shape.push_back(at::symint::size<T>(input, i) + 2 * padding[i-2]);
// log new kernel size considering dilation
kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1);
@ -715,6 +719,11 @@ static void check_shape_forward(const at::Tensor& input,
"Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size");
}
} else { // transposed
for (const auto i : c10::irange(2, k)) {
TORCH_CHECK(padding[i-2] <= (std::numeric_limits<T>::max() - padding[i-2]),
"Given padding=", padding[i-2], " at dimension ", i-2, " , expected padding to be at most ",
(std::numeric_limits<T>::max() / 2));
}
TORCH_CHECK(at::symint::size<T>(input, 1) == weight_sizes[0],
"Given transposed=", transposed, ", weight of size ", weight_sizes,
", expected input", at::symint::sizes<T>(input), " to have ", weight_sizes[0],

View File

@ -556,3 +556,26 @@ inline SymBool sym_ge(const SymInt& a, const SymInt& b) {
}
} // namespace c10
#include <limits>
namespace std {
template <>
class numeric_limits<c10::SymInt> {
public:
static constexpr bool is_specialized = true;
static constexpr int64_t max() noexcept {
return std::numeric_limits<int64_t>::max();
}
static constexpr int64_t min() noexcept {
return std::numeric_limits<int64_t>::min();
}
static constexpr bool is_signed = true;
static constexpr bool is_integer = true;
};
} // namespace std

View File

@ -93,6 +93,47 @@ class TestConvolutionNN(NNTestCase):
input = torch.randn((1, 1, 1, 1), dtype=torch.float)
self.assertEqual(m(input).size(), (1, 1, 1, 1))
def test_huge_padding(self):
class Conv1dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv1d(
in_channels=16,
out_channels=32,
kernel_size=3,
stride=1,
padding=9223372036854775803,
)
self.add_module(name="conv1", module=self.conv1)
input_data = torch.randn(1, 16, 100)
model = Conv1dModule()
with self.assertRaisesRegex(
RuntimeError,
r"Given padding=9223372036854775803 at dimension 0 , expected padding to be at most",
):
model.conv1(input_data)
class ConvTransposed1dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv_transposed1d = nn.ConvTranspose1d(
in_channels=16,
out_channels=32,
kernel_size=3,
stride=2,
padding=9223372036854775803,
)
self.add_module(name="conv_transposed1d", module=self.conv_transposed1d)
input_data = torch.randn(1, 16, 100)
model = ConvTransposed1dModule()
with self.assertRaisesRegex(
RuntimeError,
r"Given padding=9223372036854775803 at dimension 0 , expected padding to be at most",
):
model.conv_transposed1d(input_data)
def test_invalid_conv1d(self):
for dtype in [
torch.half,