Support more dtypes for input, indices in gather (#151822)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151822
Approved by: https://github.com/ngimel
This commit is contained in:
Isuru Fernando 2025-05-01 01:48:25 +00:00 committed by PyTorch MergeBot
parent 4c8dee7986
commit f0c9b3385d
10 changed files with 97 additions and 68 deletions

View File

@ -19,8 +19,8 @@ inline void scatter_gather_dtype_check(
) {
if (index.numel() != 0) {
TORCH_CHECK(
index.scalar_type() == at::ScalarType::Long,
method_name, "(): Expected dtype int64 for index"
index.scalar_type() == at::ScalarType::Long || index.scalar_type() == at::ScalarType::Int,
method_name, "(): Expected dtype int32/int64 for index"
);
}

View File

@ -175,9 +175,10 @@ TORCH_META_FUNC(gather)
auto is_index_empty = index.numel() == 0;
if (!is_index_empty) {
TORCH_CHECK(
index.scalar_type() == at::ScalarType::Long,
index.scalar_type() == ScalarType::Long ||
index.scalar_type() == ScalarType::Int,
"gather",
"(): Expected dtype int64 for index");
"(): Expected dtype int32/int64 for index");
}
if (is_index_empty)
return;

View File

@ -167,10 +167,11 @@ template <bool is_scatter_like = true>
struct cpu_scatter_gather_base_kernel {
template <typename func_t>
void operator()(const Tensor& self, int64_t dim,
const Tensor& index, const Scalar& value,
const Tensor& _index, const Scalar& value,
const std::string& method_name, func_t& kernel_func) {
Tensor buffer;
Tensor index = _index.to(ScalarType::Long);
bool need_acc = isReducedFloatingType(self.scalar_type());
create_acc_buffer(buffer, self, need_acc);
@ -263,10 +264,11 @@ struct cpu_scatter_gather_base_kernel {
template <typename func_t>
void operator()(const Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src,
const Tensor& _index, const Tensor& src,
const std::string& method_name, func_t& kernel_func) {
Tensor buffer;
Tensor index = _index.to(ScalarType::Long);
bool need_acc = isReducedFloatingType(self.scalar_type());
create_acc_buffer(buffer, self, need_acc);
@ -358,10 +360,11 @@ struct cpu_scatter_gather_base_kernel {
}
void operator()(const Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src,
const Tensor& _index, const Tensor& src,
const std::string& method_name, ReduceMean& kernel_func) {
Tensor buffer;
Tensor index = _index.to(ScalarType::Long);
bool need_acc = isReducedFloatingType(self.scalar_type());
create_acc_buffer(buffer, self, need_acc);
@ -453,9 +456,10 @@ struct cpu_scatter_gather_base_kernel {
}
void operator()(const Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src,
const Tensor& _index, const Tensor& src,
const std::string& method_name, ReduceMaximum& kernel_func) {
Tensor buffer;
Tensor index = _index.to(ScalarType::Long);
bool need_acc = isReducedFloatingType(self.scalar_type());
create_acc_buffer(buffer, self, need_acc);
@ -547,10 +551,11 @@ struct cpu_scatter_gather_base_kernel {
}
void operator()(const Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src,
const Tensor& _index, const Tensor& src,
const std::string& method_name, ReduceMinimum& kernel_func) {
Tensor buffer;
Tensor index = _index.to(ScalarType::Long);
bool need_acc = isReducedFloatingType(self.scalar_type());
create_acc_buffer(buffer, self, need_acc);
@ -810,7 +815,8 @@ void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index,
}
template <typename scalar_t>
void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index, const Tensor& self) {
void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& _index, const Tensor& self) {
Tensor index = _index.to(ScalarType::Long);
const int64_t* index_data = index.const_data_ptr<int64_t>();
scalar_t* result_data = result.data_ptr<scalar_t>();
const scalar_t* self_data = self.const_data_ptr<scalar_t>();

View File

@ -84,7 +84,7 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
auto inp_stride_bytes = index_stride[0];
auto out_stride_bytes = iter.strides(0)[1];
if (iter.numel() == 0) return;
at::native::vectorized_gather_kernel_launch<alignment>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
return;
}

View File

@ -7,8 +7,8 @@
#include <ATen/ceil_div.h>
namespace at::native {
template <int Alignment>
__global__ void vectorized_gather_kernel(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size, int64_t ind_dim_size, int64_t inp_stride, int64_t out_stride, bool allow_neg_indices) {
template <int Alignment, typename index_t>
__global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx, int num_ind, int64_t slice_size, int64_t ind_dim_size, int64_t inp_stride, int64_t out_stride, bool allow_neg_indices) {
int64_t ind = idx[blockIdx.x];
if (allow_neg_indices) {
ind = (ind < 0) ? ind + ind_dim_size : ind;
@ -22,8 +22,8 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, int64_t * idx,
template <int64_t Alignment>
void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind,
template <int64_t Alignment, typename index_t>
void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind,
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices){
constexpr int64_t max_num_threads=256;
@ -32,13 +32,15 @@ void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int
static_cast<int64_t>(C10_WARP_SIZE));
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
auto block = std::min(max_num_threads, num_threads);
vectorized_gather_kernel<Alignment><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
vectorized_gather_kernel<Alignment, index_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// explicit template instantiation
template void vectorized_gather_kernel_launch<16>(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size_in_bytes,
template void vectorized_gather_kernel_launch<16, int64_t>(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size_in_bytes,
int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices);
template void vectorized_gather_kernel_launch<16, int32_t>(char * out, char * inp, int32_t * idx, int num_ind, int64_t slice_size_in_bytes,
int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices);
}

View File

@ -26,8 +26,8 @@ inline bool fast_gather_kernel_eligible(const TensorIterator& iter, char * const
get_alignment(static_cast<size_t>(iter.strides(0)[1])) == alignment;
}
template <int64_t Alignment>
void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind,
template <int64_t Alignment, typename index_t>
void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind,
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes,
bool allow_neg_indices=false);

View File

@ -116,7 +116,7 @@ static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) {
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <bool is_scatter_like, typename scalar_t>
template <bool is_scatter_like, typename scalar_t, typename index_t>
struct _cuda_scatter_gather_internal_kernel {
template <typename func_t>
void operator() (
@ -128,7 +128,7 @@ struct _cuda_scatter_gather_internal_kernel {
) {
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
_cuda_scatter_gather_internal_kernel<is_scatter_like, scalar_t>()(
_cuda_scatter_gather_internal_kernel<is_scatter_like, scalar_t, index_t>()(
sub_iter, index_size, index_stride, numel, f
);
}
@ -151,7 +151,7 @@ struct _cuda_scatter_gather_internal_kernel {
auto inp_stride_bytes = index_stride * element_size;
auto out_stride_bytes = iter.strides(0)[1];
if (iter.numel() == 0) return;
at::native::vectorized_gather_kernel_launch<alignment>(self_ptr, src_ptr, (int64_t*)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes);
at::native::vectorized_gather_kernel_launch<alignment, index_t>(self_ptr, src_ptr, (index_t*)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes);
return;
}
}
@ -159,7 +159,7 @@ struct _cuda_scatter_gather_internal_kernel {
auto loop = [=]C10_DEVICE(int i) {
auto offsets = offset_calc.get(i);
int64_t idx_dim = *(int64_t*)(index_ptr + offsets[2]);
int64_t idx_dim = *(index_t*)(index_ptr + offsets[2]);
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
&& "scatter gather kernel index out of bounds");
@ -229,9 +229,11 @@ struct cuda_scatter_gather_base_kernel {
using dtype = typename std::conditional<cast_to_opaque,
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
iter, index_size, index_stride, self.numel(), f
);
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
iter, index_size, index_stride, self.numel(), f
);
});
}
);
}
@ -279,19 +281,40 @@ struct cuda_scatter_gather_base_kernel {
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(),
"cuda_scatter_gather_base_kernel_func", [&] {
if (self.is_quantized()) {
TORCH_CHECK(
self.qscheme() == kPerTensorAffine,
"Only per_tensor quantized quantized tensors are supported by gather.")
AT_DISPATCH_QINT_TYPES(iter.dtype(), "gather_quant_cuda", [&] {
using dtype = typename std::conditional<cast_to_opaque,
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
iter, index_size, index_stride, self.numel(), f
);
}
);
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
iter, index_size, index_stride, self.numel(), f
);
});
});
} else {
AT_DISPATCH_V2(
iter.dtype(),
"gather_cuda",
AT_WRAP([&] {
using dtype = typename std::conditional<cast_to_opaque,
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
iter, index_size, index_stride, self.numel(), f
);
});
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
AT_EXPAND(AT_FLOAT8_TYPES),
kComplexHalf,
kHalf,
kBool,
kBFloat16);
}
}
template <typename func_t>
@ -338,7 +361,6 @@ struct cuda_scatter_gather_base_kernel {
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
iter.dtype(),
@ -346,15 +368,17 @@ struct cuda_scatter_gather_base_kernel {
using dtype = typename std::conditional<cast_to_opaque,
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
iter, index_size, index_stride, self.numel(), f
);
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
iter, index_size, index_stride, self.numel(), f
);
});
}
);
}
}; // struct cuda_scatter_gather_base_kernel
template <typename scalar_t>
template <typename scalar_t, typename index_t>
struct _cuda_scatter_fill_internal_kernel {
template <typename func_t>
void operator()(
@ -367,7 +391,7 @@ struct _cuda_scatter_fill_internal_kernel {
) {
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
_cuda_scatter_fill_internal_kernel<scalar_t>()(
_cuda_scatter_fill_internal_kernel<scalar_t, index_t>()(
sub_iter, src_val, index_size, index_stride, numel, f
);
}
@ -381,7 +405,7 @@ struct _cuda_scatter_fill_internal_kernel {
auto loop = [=]C10_DEVICE(int i) {
auto offsets = offset_calc.get(i);
int64_t idx_dim = *(int64_t*)(index_ptr + offsets[1]);
int64_t idx_dim = *(index_t*)(index_ptr + offsets[1]);
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
&& "index out of bounds"
);
@ -437,9 +461,11 @@ struct cuda_scatter_fill_base_kernel {
auto src_scalar_val = src.to<scalar_t>();
auto src_val = *(dtype*)&src_scalar_val;
_cuda_scatter_fill_internal_kernel<dtype>()(
iter, src_val, index_size, index_stride, self.numel(), f
);
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_fill_base_kernel_func", [&] () {
_cuda_scatter_fill_internal_kernel<dtype, index_t>()(
iter, src_val, index_size, index_stride, self.numel(), f
);
});
}
);
}
@ -480,9 +506,11 @@ struct cuda_scatter_fill_base_kernel {
auto src_scalar_val = src.to<scalar_t>();
auto src_val = *(dtype*)&src_scalar_val;
_cuda_scatter_fill_internal_kernel<dtype>()(
iter, src_val, index_size, index_stride, self.numel(), f
);
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_fill_base_kernel_reduce_multiply", [&] () {
_cuda_scatter_fill_internal_kernel<dtype, index_t>()(
iter, src_val, index_size, index_stride, self.numel(), f
);
});
}
);
}

View File

@ -3335,7 +3335,6 @@ def gather(x, dim, index, sparse_grad=False):
# Empty index case. Return an empty array with the same shape
return new_empty(x, index.get_size())
assert index.get_dtype() == torch.int64
size = x.get_size()
offset = len(size) == 0
dim = _validate_dim(x, dim, offset)

View File

@ -5420,8 +5420,8 @@ def meta_gather(self, dim, index, sparse_grad=False):
is_index_empty = guard_size_oblivious(index.numel() == 0)
if not is_index_empty:
torch._check(
index.dtype == torch.long,
lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
index.dtype == torch.long or index.dtype == torch.int,
lambda: f"gather(): Expected dtype int32/int64 for index, but got {index.dtype}",
)
gather_shape_check(self, wrapped_dim, index)
return self.new_empty(index.shape)
@ -5460,8 +5460,8 @@ def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
if guard_size_oblivious(index.numel() != 0):
torch._check(
index.dtype == torch.long,
lambda: f"{method_name}(): Expected dtype int64 for index",
index.dtype == torch.long or index.dtype == torch.int,
lambda: f"{method_name}(): Expected dtype int32/int64 for index",
)
if src_opt is not None:

View File

@ -2618,6 +2618,10 @@ def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs):
make_arg((M, S)),
0,
gather_variable((S, S), 1, M, True, device=device))
yield SampleInput(
make_arg((M, S)),
0,
gather_variable((S, S), 1, M, True, device=device).to(torch.int32))
yield SampleInput(
make_arg((M, S)),
1,
@ -2663,11 +2667,6 @@ def error_inputs_gather(op_info, device, **kwargs):
yield ErrorInput(SampleInput(bad_src, args=(1, idx,)),
error_regex="Size does not match at dimension 0")
# Index must have long dtype
bad_idx = idx.to(torch.int32)
yield ErrorInput(SampleInput(src, args=(1, bad_idx)),
error_regex="Expected dtype int64 for index")
# TODO: FIXME
# out.dtype must match src.dtype
# Creates new src & idx since SampleInputs can't share tensors
@ -2740,13 +2739,6 @@ def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
error_regex="Expected self.dtype to be equal to src.dtype")
# Index dtype must be long
src = make_tensor((2, 5), device=device, dtype=torch.float32)
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.int32)
dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
error_regex="Expected dtype int64 for index")
# Index and destination must have the same number of dimensions
src = make_tensor((2, 5), device=device, dtype=torch.float32)
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
@ -7139,6 +7131,7 @@ def sample_inputs_scatter(op_info, device, dtype, requires_grad, **kwargs):
zero = torch.tensor(0, dtype=torch.long, device=device)
test_cases = (
(_tensor((M, S)), (0, _gather((S, S), 1, M), _tensor((S, S)))),
(_tensor((M, S)), (0, _gather((S, S), 1, M).to(torch.int32), _tensor((S, S)))),
(_tensor((M, S)), (1, _gather((S, S), 0, S), _tensor((S, S)))),
(_tensor((M, S)), (-1, _gather((S, S), 0, S), _tensor((S, S)))),
(_tensor((M, S)), (0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))),