From 7999735d23aeca844d4a7b23de6ac2370767099a Mon Sep 17 00:00:00 2001 From: Narek Malkhasyan Date: Thu, 5 Jun 2025 14:51:25 +0000 Subject: [PATCH] [CUDA][MPS] Fix torch.arange bound validation for large float inputs (#154320) Fixes #153133 Fixes an inconsistency in torch.arange on CUDA and MPS backends when using float32 and large input values. Previously, invalid ranges (e.g., start > end with a positive step) could silently return empty tensors due to precision loss in validation logic. The fix introduces double precision validation for checking whether the step sign is consistent with the range direction. This ensures torch.arange behaves consistently with CPU for large float32 inputs, and raises an appropriate error when the range is invalid. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154320 Approved by: https://github.com/malfet --- aten/src/ATen/native/RangeFactories.cpp | 8 +--- aten/src/ATen/native/RangeUtils.h | 37 +++++++++++++------ aten/src/ATen/native/cuda/RangeFactories.cu | 22 ++++------- .../native/mps/operations/RangeFactories.mm | 19 ++-------- torch/csrc/lazy/core/shape_inference.cpp | 17 +-------- .../_internal/common_methods_invocations.py | 10 +++-- 6 files changed, 47 insertions(+), 66 deletions(-) diff --git a/aten/src/ATen/native/RangeFactories.cpp b/aten/src/ATen/native/RangeFactories.cpp index 5ecc0f15933..24b745b1a68 100644 --- a/aten/src/ATen/native/RangeFactories.cpp +++ b/aten/src/ATen/native/RangeFactories.cpp @@ -157,12 +157,8 @@ Tensor& range_out(const Scalar& start, const Scalar& end, const Scalar& step, Te auto xend = end.to(); auto xstep = step.to(); - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and lower bound inconsistent with step sign"); + arange_check_bounds(start, end, step); + int64_t size = static_cast(((xend - xstart) / xstep) + 1); if (result.numel() != size) { result.resize_({size}); diff --git a/aten/src/ATen/native/RangeUtils.h b/aten/src/ATen/native/RangeUtils.h index d3ad1c6ab7d..dcab86ca9a4 100644 --- a/aten/src/ATen/native/RangeUtils.h +++ b/aten/src/ATen/native/RangeUtils.h @@ -6,19 +6,30 @@ namespace at::native { +inline void arange_check_bounds( + const c10::Scalar& start, + const c10::Scalar& end, + const c10::Scalar& step) { + // use double precision for validation to avoid precision issues + double dstart = start.to(); + double dend = end.to(); + double dstep = step.to(); + + TORCH_CHECK(dstep > 0 || dstep < 0, "step must be nonzero"); + TORCH_CHECK( + std::isfinite(dstart) && std::isfinite(dend), + "unsupported range: ", + dstart, + " -> ", + dend); + TORCH_CHECK( + ((dstep > 0) && (dend >= dstart)) || ((dstep < 0) && (dend <= dstart)), + "upper bound and lower bound inconsistent with step sign"); +} + template int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar& step) { - using accscalar_t = at::acc_type; - auto xstart = start.to(); - auto xend = end.to(); - auto xstep = step.to(); - - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + arange_check_bounds(start, end, step); // we use double precision for (start - end) / step // to compute size_d for consistency across devices. @@ -29,6 +40,10 @@ int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar // the corner-case we do want to take into account is int64_t, which has higher precision than double double size_d; if constexpr (std::is_same_v) { + using accscalar_t = at::acc_type; + auto xstart = start.to(); + auto xend = end.to(); + auto xstep = step.to(); int64_t sgn = (xstep > 0) - (xstep < 0); size_d = std::ceil((xend - xstart + xstep - sgn) / xstep); } else { diff --git a/aten/src/ATen/native/cuda/RangeFactories.cu b/aten/src/ATen/native/cuda/RangeFactories.cu index e471ce9f9d7..9d7ead7e498 100644 --- a/aten/src/ATen/native/cuda/RangeFactories.cu +++ b/aten/src/ATen/native/cuda/RangeFactories.cu @@ -1,10 +1,11 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include #include -#include +#include +#include #include +#include #include +#include #include #include @@ -181,12 +182,8 @@ Tensor& range_cuda_out(const Scalar& start, const Scalar& end, const Scalar& ste auto xend = end.to(); auto xstep = step.to(); - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + arange_check_bounds(start, end, step); + int64_t size = static_cast(((xend - xstart) / xstep) + 1); if (result.numel() != size) { @@ -217,12 +214,7 @@ Tensor& arange_cuda_out(const Scalar& start, const Scalar& end, const Scalar& st auto xend = end.to(); auto xstep = step.to(); - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", xstart, " -> ", xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + arange_check_bounds(start, end, step); // we use double precision for (start - end) / step // to compute size_d for consistency across devices. diff --git a/aten/src/ATen/native/mps/operations/RangeFactories.mm b/aten/src/ATen/native/mps/operations/RangeFactories.mm index 613db5c5f48..4c1631f0f11 100644 --- a/aten/src/ATen/native/mps/operations/RangeFactories.mm +++ b/aten/src/ATen/native/mps/operations/RangeFactories.mm @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -65,14 +66,7 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste size_d = std::ceil(static_cast(end.to() - start.to()) / step.to()); } - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && std::isfinite(static_cast(xend)), - "unsupported range: ", - xstart, - " -> ", - xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + arange_check_bounds(start, end, step); TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), "invalid size, possible overflow?"); @@ -147,14 +141,7 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step size_d = static_cast(end.to() - start.to()) / step.to() + 1; } - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK(std::isfinite(static_cast(xstart)) && std::isfinite(static_cast(xend)), - "unsupported range: ", - xstart, - " -> ", - xend); - TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + arange_check_bounds(start, end, step); TORCH_CHECK(size_d >= 0 && size_d <= static_cast(std::numeric_limits::max()), "invalid size, possible overflow?"); diff --git a/torch/csrc/lazy/core/shape_inference.cpp b/torch/csrc/lazy/core/shape_inference.cpp index e2e9795ad5a..338dcb29f10 100644 --- a/torch/csrc/lazy/core/shape_inference.cpp +++ b/torch/csrc/lazy/core/shape_inference.cpp @@ -59,6 +59,7 @@ #include #include #include +#include #include #include #include @@ -106,9 +107,6 @@ TORCH_API std::vector compute_shape_arange_out( // Note: acc_type further defines an accumulataion type depending on the // scalar_t and whether its on cuda vs cpu. using accscalar_t = at::acc_type; - auto xstart = start.to(); - auto xend = end.to(); - auto xstep = step.to(); // we use double precision for (start - end) / step // to compute size_d for consistency across devices. @@ -129,18 +127,7 @@ TORCH_API std::vector compute_shape_arange_out( step.to()); } - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK( - std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", - xstart, - " -> ", - xend); - TORCH_CHECK( - ((xstep > 0) && (xend >= xstart)) || - ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + at::native::arange_check_bounds(start, end, step); TORCH_CHECK( size_d >= 0 && diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 554d259a1a8..d181611a0f8 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -776,9 +776,13 @@ def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs): yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False}) def error_inputs_arange(op, device, **kwargs): - yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzer') - yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign') - yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign') + yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzero') + yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') + yield ErrorInput(SampleInput(1549556900, args=(1549556828, 1989724)), error_type=RuntimeError, + error_regex='upper bound and lower bound inconsistent with step sign') yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range') yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range')