mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[CUDA][MPS] Fix torch.arange bound validation for large float inputs (#154320)
Fixes #153133 Fixes an inconsistency in torch.arange on CUDA and MPS backends when using float32 and large input values. Previously, invalid ranges (e.g., start > end with a positive step) could silently return empty tensors due to precision loss in validation logic. The fix introduces double precision validation for checking whether the step sign is consistent with the range direction. This ensures torch.arange behaves consistently with CPU for large float32 inputs, and raises an appropriate error when the range is invalid. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154320 Approved by: https://github.com/malfet
This commit is contained in:
parent
ed661a5f11
commit
7999735d23
|
|
@ -157,12 +157,8 @@ Tensor& range_out(const Scalar& start, const Scalar& end, const Scalar& step, Te
|
|||
auto xend = end.to<accscalar_t>();
|
||||
auto xstep = step.to<accscalar_t>();
|
||||
|
||||
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
|
||||
std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ", xstart, " -> ", xend);
|
||||
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
|
||||
"upper bound and lower bound inconsistent with step sign");
|
||||
arange_check_bounds(start, end, step);
|
||||
|
||||
int64_t size = static_cast<int64_t>(((xend - xstart) / xstep) + 1);
|
||||
if (result.numel() != size) {
|
||||
result.resize_({size});
|
||||
|
|
|
|||
|
|
@ -6,19 +6,30 @@
|
|||
|
||||
namespace at::native {
|
||||
|
||||
inline void arange_check_bounds(
|
||||
const c10::Scalar& start,
|
||||
const c10::Scalar& end,
|
||||
const c10::Scalar& step) {
|
||||
// use double precision for validation to avoid precision issues
|
||||
double dstart = start.to<double>();
|
||||
double dend = end.to<double>();
|
||||
double dstep = step.to<double>();
|
||||
|
||||
TORCH_CHECK(dstep > 0 || dstep < 0, "step must be nonzero");
|
||||
TORCH_CHECK(
|
||||
std::isfinite(dstart) && std::isfinite(dend),
|
||||
"unsupported range: ",
|
||||
dstart,
|
||||
" -> ",
|
||||
dend);
|
||||
TORCH_CHECK(
|
||||
((dstep > 0) && (dend >= dstart)) || ((dstep < 0) && (dend <= dstart)),
|
||||
"upper bound and lower bound inconsistent with step sign");
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar& step) {
|
||||
using accscalar_t = at::acc_type<scalar_t, false>;
|
||||
auto xstart = start.to<accscalar_t>();
|
||||
auto xend = end.to<accscalar_t>();
|
||||
auto xstep = step.to<accscalar_t>();
|
||||
|
||||
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
|
||||
std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ", xstart, " -> ", xend);
|
||||
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
|
||||
"upper bound and larger bound inconsistent with step sign");
|
||||
arange_check_bounds(start, end, step);
|
||||
|
||||
// we use double precision for (start - end) / step
|
||||
// to compute size_d for consistency across devices.
|
||||
|
|
@ -29,6 +40,10 @@ int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar
|
|||
// the corner-case we do want to take into account is int64_t, which has higher precision than double
|
||||
double size_d;
|
||||
if constexpr (std::is_same_v<scalar_t, int64_t>) {
|
||||
using accscalar_t = at::acc_type<scalar_t, false>;
|
||||
auto xstart = start.to<accscalar_t>();
|
||||
auto xend = end.to<accscalar_t>();
|
||||
auto xstep = step.to<accscalar_t>();
|
||||
int64_t sgn = (xstep > 0) - (xstep < 0);
|
||||
size_d = std::ceil((xend - xstart + xstep - sgn) / xstep);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/detail/FunctionTraits.h>
|
||||
#include <ATen/native/RangeUtils.h>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
|
||||
|
|
@ -181,12 +182,8 @@ Tensor& range_cuda_out(const Scalar& start, const Scalar& end, const Scalar& ste
|
|||
auto xend = end.to<accscalar_t>();
|
||||
auto xstep = step.to<accscalar_t>();
|
||||
|
||||
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
|
||||
std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ", xstart, " -> ", xend);
|
||||
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
|
||||
"upper bound and larger bound inconsistent with step sign");
|
||||
arange_check_bounds(start, end, step);
|
||||
|
||||
int64_t size = static_cast<int64_t>(((xend - xstart) / xstep) + 1);
|
||||
|
||||
if (result.numel() != size) {
|
||||
|
|
@ -217,12 +214,7 @@ Tensor& arange_cuda_out(const Scalar& start, const Scalar& end, const Scalar& st
|
|||
auto xend = end.to<accscalar_t>();
|
||||
auto xstep = step.to<accscalar_t>();
|
||||
|
||||
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) &&
|
||||
std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ", xstart, " -> ", xend);
|
||||
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
|
||||
"upper bound and larger bound inconsistent with step sign");
|
||||
arange_check_bounds(start, end, step);
|
||||
|
||||
// we use double precision for (start - end) / step
|
||||
// to compute size_d for consistency across devices.
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/detail/FunctionTraits.h>
|
||||
#include <ATen/native/RangeUtils.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/ops/arange_native.h>
|
||||
#include <ATen/ops/linspace_native.h>
|
||||
|
|
@ -65,14 +66,7 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
|
|||
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>()) / step.to<double>());
|
||||
}
|
||||
|
||||
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) && std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ",
|
||||
xstart,
|
||||
" -> ",
|
||||
xend);
|
||||
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
|
||||
"upper bound and larger bound inconsistent with step sign");
|
||||
arange_check_bounds(start, end, step);
|
||||
|
||||
TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
|
||||
"invalid size, possible overflow?");
|
||||
|
|
@ -147,14 +141,7 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step
|
|||
size_d = static_cast<double>(end.to<double>() - start.to<double>()) / step.to<double>() + 1;
|
||||
}
|
||||
|
||||
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
|
||||
TORCH_CHECK(std::isfinite(static_cast<double>(xstart)) && std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ",
|
||||
xstart,
|
||||
" -> ",
|
||||
xend);
|
||||
TORCH_CHECK(((xstep > 0) && (xend >= xstart)) || ((xstep < 0) && (xend <= xstart)),
|
||||
"upper bound and larger bound inconsistent with step sign");
|
||||
arange_check_bounds(start, end, step);
|
||||
|
||||
TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
|
||||
"invalid size, possible overflow?");
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@
|
|||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/RangeUtils.h>
|
||||
#include <ATen/native/ReduceOpsUtils.h>
|
||||
#include <ATen/native/TensorConversions.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
|
@ -106,9 +107,6 @@ TORCH_API std::vector<Shape> compute_shape_arange_out(
|
|||
// Note: acc_type further defines an accumulataion type depending on the
|
||||
// scalar_t and whether its on cuda vs cpu.
|
||||
using accscalar_t = at::acc_type<scalar_t, false>;
|
||||
auto xstart = start.to<accscalar_t>();
|
||||
auto xend = end.to<accscalar_t>();
|
||||
auto xstep = step.to<accscalar_t>();
|
||||
|
||||
// we use double precision for (start - end) / step
|
||||
// to compute size_d for consistency across devices.
|
||||
|
|
@ -129,18 +127,7 @@ TORCH_API std::vector<Shape> compute_shape_arange_out(
|
|||
step.to<double>());
|
||||
}
|
||||
|
||||
TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
|
||||
TORCH_CHECK(
|
||||
std::isfinite(static_cast<double>(xstart)) &&
|
||||
std::isfinite(static_cast<double>(xend)),
|
||||
"unsupported range: ",
|
||||
xstart,
|
||||
" -> ",
|
||||
xend);
|
||||
TORCH_CHECK(
|
||||
((xstep > 0) && (xend >= xstart)) ||
|
||||
((xstep < 0) && (xend <= xstart)),
|
||||
"upper bound and larger bound inconsistent with step sign");
|
||||
at::native::arange_check_bounds(start, end, step);
|
||||
|
||||
TORCH_CHECK(
|
||||
size_d >= 0 &&
|
||||
|
|
|
|||
|
|
@ -776,9 +776,13 @@ def sample_inputs_add_sub(op, device, dtype, requires_grad, **kwargs):
|
|||
yield SampleInput(lhs, args=(rhs,), kwargs={'alpha': False})
|
||||
|
||||
def error_inputs_arange(op, device, **kwargs):
|
||||
yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzer')
|
||||
yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign')
|
||||
yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError, error_regex='bound inconsistent with step sign')
|
||||
yield ErrorInput(SampleInput(0, args=(3, 0)), error_type=RuntimeError, error_regex='step must be nonzero')
|
||||
yield ErrorInput(SampleInput(0, args=(-3, 2)), error_type=RuntimeError,
|
||||
error_regex='upper bound and lower bound inconsistent with step sign')
|
||||
yield ErrorInput(SampleInput(0, args=(3, -2)), error_type=RuntimeError,
|
||||
error_regex='upper bound and lower bound inconsistent with step sign')
|
||||
yield ErrorInput(SampleInput(1549556900, args=(1549556828, 1989724)), error_type=RuntimeError,
|
||||
error_regex='upper bound and lower bound inconsistent with step sign')
|
||||
yield ErrorInput(SampleInput(0, args=(float('inf'), 2)), error_type=RuntimeError, error_regex='unsupported range')
|
||||
yield ErrorInput(SampleInput(float('-inf'), args=(1, 2)), error_type=RuntimeError, error_regex='unsupported range')
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user