[CUDA][MPS] Fix torch.arange bound validation for large float inputs (#154320)

Fixes #153133

Fixes an inconsistency in torch.arange on CUDA and MPS backends when using float32 and large input values. Previously, invalid ranges (e.g., start > end with a positive step) could silently return empty tensors due to precision loss in validation logic.

The fix introduces double precision validation for checking whether the step sign is consistent with the range direction.

This ensures torch.arange behaves consistently with CPU for large float32 inputs, and raises an appropriate error when the range is invalid.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154320
Approved by: https://github.com/malfet
This commit is contained in:
Narek Malkhasyan 2025-06-05 14:51:25 +00:00 committed by PyTorch MergeBot
parent ed661a5f11
commit 7999735d23
6 changed files with 47 additions and 66 deletions

View File

@ -157,12 +157,8 @@ Tensor& range_out(const Scalar& start, const Scalar& end, const Scalar& step, Te
auto xend = end.to<accscalar_t>(); auto xend = end.to<accscalar_t>();
auto xstep = step.to<accscalar_t>(); auto xstep = step.to<accscalar_t>();
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); arange_check_bounds(start, end, step);
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
std::isfinite(static_cast<double>(xend)),
"unsupported range: ", xstart, " -> ", xend);
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and lower bound inconsistent with step sign");
int64_t size = static_cast<int64_t>(((xend - xstart) / xstep) + 1); int64_t size = static_cast<int64_t>(((xend - xstart) / xstep) + 1);
if (result.numel() != size) { if (result.numel() != size) {
result.resize_({size}); result.resize_({size});

View File

@ -6,19 +6,30 @@
namespace at::native { namespace at::native {
inline void arange_check_bounds(
const c10::Scalar& start,
const c10::Scalar& end,
const c10::Scalar& step) {
// use double precision for validation to avoid precision issues
double dstart = start.to<double>();
double dend = end.to<double>();
double dstep = step.to<double>();
TORCH_CHECK(dstep > 0 || dstep < 0, "step must be nonzero");
TORCH_CHECK(
std::isfinite(dstart) && std::isfinite(dend),
"unsupported range: ",
dstart,
" -> ",
dend);
TORCH_CHECK(
((dstep > 0) && (dend >= dstart)) || ((dstep < 0) && (dend <= dstart)),
"upper bound and lower bound inconsistent with step sign");
}
template <typename scalar_t> template <typename scalar_t>
int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar& step) { int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar& step) {
using accscalar_t = at::acc_type<scalar_t, false>; arange_check_bounds(start, end, step);
auto xstart = start.to<accscalar_t>();
auto xend = end.to<accscalar_t>();
auto xstep = step.to<accscalar_t>();
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
std::isfinite(static_cast<double>(xend)),
"unsupported range: ", xstart, " -> ", xend);
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");
// we use double precision for (start - end) / step // we use double precision for (start - end) / step
// to compute size_d for consistency across devices. // to compute size_d for consistency across devices.
@ -29,6 +40,10 @@ int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar
// the corner-case we do want to take into account is int64_t, which has higher precision than double // the corner-case we do want to take into account is int64_t, which has higher precision than double
double size_d; double size_d;
if constexpr (std::is_same_v<scalar_t, int64_t>) { if constexpr (std::is_same_v<scalar_t, int64_t>) {
using accscalar_t = at::acc_type<scalar_t, false>;
auto xstart = start.to<accscalar_t>();
auto xend = end.to<accscalar_t>();
auto xstep = step.to<accscalar_t>();
int64_t sgn = (xstep > 0) - (xstep < 0); int64_t sgn = (xstep > 0) - (xstep < 0);
size_d = std::ceil((xend - xstart + xstep - sgn) / xstep); size_d = std::ceil((xend - xstart + xstep - sgn) / xstep);
} else { } else {

View File

@ -1,10 +1,11 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/Dispatch.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/detail/FunctionTraits.h> #include <ATen/detail/FunctionTraits.h>
#include <ATen/native/RangeUtils.h>
#include <cmath> #include <cmath>
#include <limits> #include <limits>
@ -181,12 +182,8 @@ Tensor& range_cuda_out(const Scalar& start, const Scalar& end, const Scalar& ste
auto xend = end.to<accscalar_t>(); auto xend = end.to<accscalar_t>();
auto xstep = step.to<accscalar_t>(); auto xstep = step.to<accscalar_t>();
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); arange_check_bounds(start, end, step);
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
std::isfinite(static_cast<double>(xend)),
"unsupported range: ", xstart, " -> ", xend);
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");
int64_t size = static_cast<int64_t>(((xend - xstart) / xstep) + 1); int64_t size = static_cast<int64_t>(((xend - xstart) / xstep) + 1);
if (result.numel() != size) { if (result.numel() != size) {
@ -217,12 +214,7 @@ Tensor& arange_cuda_out(const Scalar& start, const Scalar& end, const Scalar& st
auto xend = end.to<accscalar_t>(); auto xend = end.to<accscalar_t>();
auto xstep = step.to<accscalar_t>(); auto xstep = step.to<accscalar_t>();
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); arange_check_bounds(start, end, step);
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
std::isfinite(static_cast<double>(xend)),
"unsupported range: ", xstart, " -> ", xend);
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");
// we use double precision for (start - end) / step // we use double precision for (start - end) / step
// to compute size_d for consistency across devices. // to compute size_d for consistency across devices.

View File

@ -3,6 +3,7 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h> #include <ATen/Dispatch.h>
#include <ATen/detail/FunctionTraits.h> #include <ATen/detail/FunctionTraits.h>
#include <ATen/native/RangeUtils.h>
#include <ATen/native/mps/OperationUtils.h> #include <ATen/native/mps/OperationUtils.h>
#include <ATen/ops/arange_native.h> #include <ATen/ops/arange_native.h>
#include <ATen/ops/linspace_native.h> #include <ATen/ops/linspace_native.h>
@ -65,14 +66,7 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>()) / step.to<double>()); size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>()) / step.to<double>());
} }
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); arange_check_bounds(start, end, step);
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) && std::isfinite(static_cast<double>(xend)),
"unsupported range: ",
xstart,
" -> ",
xend);
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");
TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()), TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
"invalid size, possible overflow?"); "invalid size, possible overflow?");
@ -147,14 +141,7 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step
size_d = static_cast<double>(end.to<double>() - start.to<double>()) / step.to<double>() + 1; size_d = static_cast<double>(end.to<double>() - start.to<double>()) / step.to<double>() + 1;
} }
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); arange_check_bounds(start, end, step);
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) && std::isfinite(static_cast<double>(xend)),
"unsupported range: ",
xstart,
" -> ",
xend);
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");
TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()), TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
"invalid size, possible overflow?"); "invalid size, possible overflow?");

View File

@ -59,6 +59,7 @@
#include <ATen/NativeFunctions.h> #include <ATen/NativeFunctions.h>
#include <ATen/WrapDimUtils.h> #include <ATen/WrapDimUtils.h>
#include <ATen/native/ConvUtils.h> #include <ATen/native/ConvUtils.h>
#include <ATen/native/RangeUtils.h>
#include <ATen/native/ReduceOpsUtils.h> #include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/TensorConversions.h> #include <ATen/native/TensorConversions.h>
#include <c10/core/ScalarType.h> #include <c10/core/ScalarType.h>
@ -106,9 +107,6 @@ TORCH_API std::vector<Shape> compute_shape_arange_out(
// Note: acc_type further defines an accumulataion type depending on the // Note: acc_type further defines an accumulataion type depending on the
// scalar_t and whether its on cuda vs cpu. // scalar_t and whether its on cuda vs cpu.
using accscalar_t = at::acc_type<scalar_t, false>; using accscalar_t = at::acc_type<scalar_t, false>;
auto xstart = start.to<accscalar_t>();
auto xend = end.to<accscalar_t>();
auto xstep = step.to<accscalar_t>();
// we use double precision for (start - end) / step // we use double precision for (start - end) / step
// to compute size_d for consistency across devices. // to compute size_d for consistency across devices.
@ -129,18 +127,7 @@ TORCH_API std::vector<Shape> compute_shape_arange_out(
step.to<double>()); step.to<double>());
} }
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); at::native::arange_check_bounds(start, end, step);
TORCH_CHECK(
std::isfinite(static_cast<double>(xstart)) &&
std::isfinite(static_cast<double>(xend)),
"unsupported range: ",
xstart,
" -> ",
xend);
TORCH_CHECK(
((xstep > 0) && (xend >= xstart)) ||
((xstep < 0) && (xend <= xstart)),
"upper bound and larger bound inconsistent with step sign");
TORCH_CHECK( TORCH_CHECK(
size_d >= 0 && size_d >= 0 &&

View File

@ -776,9 +776,13 @@ def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs):
yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False}) yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False})
def error_inputs_arange(op, device, **kwargs): def error_inputs_arange(op, device, **kwargs):
yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzer') yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzero')
yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign') yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError,
yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign') error_regex='upper bound and lower bound inconsistent with step sign')
yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError,
error_regex='upper bound and lower bound inconsistent with step sign')
yield ErrorInput(SampleInput(1549556900, args=(1549556828, 1989724)), error_type=RuntimeError,
error_regex='upper bound and lower bound inconsistent with step sign')
yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range') yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range')
yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range') yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range')