From df9a4824e62d358f28cd824aa0dffc6385bf97f8 Mon Sep 17 00:00:00 2001 From: Samuel Park Date: Fri, 19 Sep 2025 20:57:00 +0000 Subject: [PATCH] Bugfix for doing negative padding (#161639) Fixes #161014 This bug fix introduces a fix that is consistent with the exception handling. Outlined in issue #161014, there is an edge case where the negative padding does not make the tensor size negative but still triggers the exception that the size is negative. The fix is simply adding `new_dim >=0` to include the zero dim and letting the operator return an empty tensor. In the PR I have added the edge case where the test will now check the negative padding where the dimension gets reduced to zero. But the sample is only for the `constant` type of padding. I would like some feedback if it is necessary to put the same sample on the `reduce` type as well. This is my first PR to contribute to PyTorch and any help/feedback will be welcome! Thank you! @malfet @manuelcandales @janeyx99 @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/161639 Approved by: https://github.com/manuelcandales --- aten/src/ATen/native/PadNd.cpp | 2 +- test/distributed/tensor/test_dtensor_ops.py | 1 - torch/_refs/__init__.py | 2 +- torch/csrc/lazy/core/shape_inference.cpp | 2 +- torch/testing/_internal/common_methods_invocations.py | 1 + 5 files changed, 4 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/PadNd.cpp b/aten/src/ATen/native/PadNd.cpp index 8099648d37b..3c00a16108c 100644 --- a/aten/src/ATen/native/PadNd.cpp +++ b/aten/src/ATen/native/PadNd.cpp @@ -73,7 +73,7 @@ Tensor constant_pad_nd(const Tensor& self, IntArrayRef pad, const Scalar& value) for (const auto i : c10::irange((size_t)l_pad)) { auto pad_idx = pad.size() - ((i + 1) * 2); auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]; - TORCH_CHECK(new_dim > 0, "The input size ", input_sizes[l_diff + i], ", plus negative padding ", + TORCH_CHECK(new_dim >= 0, "The input size ", input_sizes[l_diff + i], ", plus negative padding ", pad[pad_idx], " and ", pad[pad_idx + 1], " resulted in a negative output size, " "which is invalid. Check dimension ", l_diff + i, " of your input."); new_shape.emplace_back(new_dim); diff --git a/test/distributed/tensor/test_dtensor_ops.py b/test/distributed/tensor/test_dtensor_ops.py index 423dda9d43d..cfdab9d63f2 100644 --- a/test/distributed/tensor/test_dtensor_ops.py +++ b/test/distributed/tensor/test_dtensor_ops.py @@ -193,7 +193,6 @@ dtensor_fails = { xfail("linalg.lu_factor_ex"), xfail("linalg.lu_solve"), xfail("linalg.matrix_power"), - xfail("linalg.multi_dot"), xfail("linalg.pinv"), xfail("linalg.pinv", "hermitian"), xfail("linalg.slogdet"), diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 59ab6302624..099a5388557 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -2995,7 +2995,7 @@ def constant_pad_nd( pad_idx = len(pad) - ((i + 1) * 2) new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1] torch._check( - new_dim > 0, + new_dim >= 0, lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding " f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, " f"which is invalid. Check dimension {l_diff + i} of your input.", diff --git a/torch/csrc/lazy/core/shape_inference.cpp b/torch/csrc/lazy/core/shape_inference.cpp index 5e9c7dd2956..e7ab494d18e 100644 --- a/torch/csrc/lazy/core/shape_inference.cpp +++ b/torch/csrc/lazy/core/shape_inference.cpp @@ -225,7 +225,7 @@ std::vector compute_shape_constant_pad_nd( auto pad_idx = pad.size() - ((i + 1) * 2); auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]; TORCH_CHECK( - new_dim > 0, + new_dim >= 0, "The input size ", input_sizes[l_diff + i], ", plus negative padding ", diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 080c95bc7d2..f81104cbf4d 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -5905,6 +5905,7 @@ def sample_inputs_nn_pad(op_info, device, dtype, requires_grad, mode, **kwargs): ((1, 3), (1, 2)), ((1, 3), (0, 1)), ((1, 3), (0, 2, 0, 1)), + ((5, 3), (-1, -2, 1, 1)), ((0, 3, 3), (1, 2)), ((0, 3, 3), (0, 1)), ((0, 3, 3), (0, 2, 0, 1)),