diff --git a/aten/src/ATen/native/ScatterGatherChecks.h b/aten/src/ATen/native/ScatterGatherChecks.h index 3a826a7a1b9..9fa850fdc2c 100644 --- a/aten/src/ATen/native/ScatterGatherChecks.h +++ b/aten/src/ATen/native/ScatterGatherChecks.h @@ -19,8 +19,8 @@ inline void scatter_gather_dtype_check( ) { if (index.numel() != 0) { TORCH_CHECK( - index.scalar_type() == at::ScalarType::Long, - method_name, "(): Expected dtype int64 for index" + index.scalar_type() == at::ScalarType::Long || index.scalar_type() == at::ScalarType::Int, + method_name, "(): Expected dtype int32/int64 for index" ); } diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 5b5572b2b31..bd7d8c6f037 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -175,9 +175,10 @@ TORCH_META_FUNC(gather) auto is_index_empty = index.numel() == 0; if (!is_index_empty) { TORCH_CHECK( - index.scalar_type() == at::ScalarType::Long, + index.scalar_type() == ScalarType::Long || + index.scalar_type() == ScalarType::Int, "gather", - "(): Expected dtype int64 for index"); + "(): Expected dtype int32/int64 for index"); } if (is_index_empty) return; diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp index 651fc2d91e8..b6d8d684ae6 100644 --- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp +++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp @@ -167,10 +167,11 @@ template struct cpu_scatter_gather_base_kernel { template void operator()(const Tensor& self, int64_t dim, - const Tensor& index, const Scalar& value, + const Tensor& _index, const Scalar& value, const std::string& method_name, func_t& kernel_func) { Tensor buffer; + Tensor index = _index.to(ScalarType::Long); bool need_acc = isReducedFloatingType(self.scalar_type()); create_acc_buffer(buffer, self, need_acc); @@ -263,10 +264,11 @@ struct cpu_scatter_gather_base_kernel { template void operator()(const Tensor& self, int64_t dim, - const Tensor& index, const Tensor& src, + const Tensor& _index, const Tensor& src, const std::string& method_name, func_t& kernel_func) { Tensor buffer; + Tensor index = _index.to(ScalarType::Long); bool need_acc = isReducedFloatingType(self.scalar_type()); create_acc_buffer(buffer, self, need_acc); @@ -358,10 +360,11 @@ struct cpu_scatter_gather_base_kernel { } void operator()(const Tensor& self, int64_t dim, - const Tensor& index, const Tensor& src, + const Tensor& _index, const Tensor& src, const std::string& method_name, ReduceMean& kernel_func) { Tensor buffer; + Tensor index = _index.to(ScalarType::Long); bool need_acc = isReducedFloatingType(self.scalar_type()); create_acc_buffer(buffer, self, need_acc); @@ -453,9 +456,10 @@ struct cpu_scatter_gather_base_kernel { } void operator()(const Tensor& self, int64_t dim, - const Tensor& index, const Tensor& src, + const Tensor& _index, const Tensor& src, const std::string& method_name, ReduceMaximum& kernel_func) { Tensor buffer; + Tensor index = _index.to(ScalarType::Long); bool need_acc = isReducedFloatingType(self.scalar_type()); create_acc_buffer(buffer, self, need_acc); @@ -547,10 +551,11 @@ struct cpu_scatter_gather_base_kernel { } void operator()(const Tensor& self, int64_t dim, - const Tensor& index, const Tensor& src, + const Tensor& _index, const Tensor& src, const std::string& method_name, ReduceMinimum& kernel_func) { Tensor buffer; + Tensor index = _index.to(ScalarType::Long); bool need_acc = isReducedFloatingType(self.scalar_type()); create_acc_buffer(buffer, self, need_acc); @@ -810,7 +815,8 @@ void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index, } template -void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index, const Tensor& self) { +void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& _index, const Tensor& self) { + Tensor index = _index.to(ScalarType::Long); const int64_t* index_data = index.const_data_ptr(); scalar_t* result_data = result.data_ptr(); const scalar_t* self_data = self.const_data_ptr(); diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index 7481398acf0..850032931de 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -84,7 +84,7 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co auto inp_stride_bytes = index_stride[0]; auto out_stride_bytes = iter.strides(0)[1]; if (iter.numel() == 0) return; - at::native::vectorized_gather_kernel_launch(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind, + at::native::vectorized_gather_kernel_launch(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true); return; } diff --git a/aten/src/ATen/native/cuda/IndexKernelUtils.cu b/aten/src/ATen/native/cuda/IndexKernelUtils.cu index abd256bcace..3e13f934e21 100644 --- a/aten/src/ATen/native/cuda/IndexKernelUtils.cu +++ b/aten/src/ATen/native/cuda/IndexKernelUtils.cu @@ -7,8 +7,8 @@ #include namespace at::native { -template -__global__ void vectorized_gather_kernel(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size, int64_t ind_dim_size, int64_t inp_stride, int64_t out_stride, bool allow_neg_indices) { +template +__global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx, int num_ind, int64_t slice_size, int64_t ind_dim_size, int64_t inp_stride, int64_t out_stride, bool allow_neg_indices) { int64_t ind = idx[blockIdx.x]; if (allow_neg_indices) { ind = (ind < 0) ? ind + ind_dim_size : ind; @@ -22,8 +22,8 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, int64_t * idx, -template -void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind, +template +void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind, int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices){ constexpr int64_t max_num_threads=256; @@ -32,13 +32,15 @@ void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int static_cast(C10_WARP_SIZE)); dim3 grid = {static_cast(num_ind), static_cast(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1}; auto block = std::min(max_num_threads, num_threads); - vectorized_gather_kernel<<>>(out, inp, idx, num_ind, slice_size_in_bytes, + vectorized_gather_kernel<<>>(out, inp, idx, num_ind, slice_size_in_bytes, ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices); C10_CUDA_KERNEL_LAUNCH_CHECK(); } // explicit template instantiation -template void vectorized_gather_kernel_launch<16>(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size_in_bytes, +template void vectorized_gather_kernel_launch<16, int64_t>(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size_in_bytes, +int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices); +template void vectorized_gather_kernel_launch<16, int32_t>(char * out, char * inp, int32_t * idx, int num_ind, int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices); } diff --git a/aten/src/ATen/native/cuda/IndexKernelUtils.h b/aten/src/ATen/native/cuda/IndexKernelUtils.h index 9b00d91155e..20a67dac3f6 100644 --- a/aten/src/ATen/native/cuda/IndexKernelUtils.h +++ b/aten/src/ATen/native/cuda/IndexKernelUtils.h @@ -26,8 +26,8 @@ inline bool fast_gather_kernel_eligible(const TensorIterator& iter, char * const get_alignment(static_cast(iter.strides(0)[1])) == alignment; } -template -void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind, +template +void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind, int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices=false); diff --git a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu index 6a3485fab5f..26064dd9837 100644 --- a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu +++ b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu @@ -116,7 +116,7 @@ static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) { C10_CUDA_KERNEL_LAUNCH_CHECK(); } -template +template struct _cuda_scatter_gather_internal_kernel { template void operator() ( @@ -128,7 +128,7 @@ struct _cuda_scatter_gather_internal_kernel { ) { if (!iter.can_use_32bit_indexing()) { for (auto& sub_iter : iter.with_32bit_indexing()) { - _cuda_scatter_gather_internal_kernel()( + _cuda_scatter_gather_internal_kernel()( sub_iter, index_size, index_stride, numel, f ); } @@ -151,7 +151,7 @@ struct _cuda_scatter_gather_internal_kernel { auto inp_stride_bytes = index_stride * element_size; auto out_stride_bytes = iter.strides(0)[1]; if (iter.numel() == 0) return; - at::native::vectorized_gather_kernel_launch(self_ptr, src_ptr, (int64_t*)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes); + at::native::vectorized_gather_kernel_launch(self_ptr, src_ptr, (index_t*)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes); return; } } @@ -159,7 +159,7 @@ struct _cuda_scatter_gather_internal_kernel { auto loop = [=]C10_DEVICE(int i) { auto offsets = offset_calc.get(i); - int64_t idx_dim = *(int64_t*)(index_ptr + offsets[2]); + int64_t idx_dim = *(index_t*)(index_ptr + offsets[2]); CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size && "scatter gather kernel index out of bounds"); @@ -229,9 +229,11 @@ struct cuda_scatter_gather_base_kernel { using dtype = typename std::conditional, scalar_t>::type; - _cuda_scatter_gather_internal_kernel()( - iter, index_size, index_stride, self.numel(), f - ); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () { + _cuda_scatter_gather_internal_kernel()( + iter, index_size, index_stride, self.numel(), f + ); + }); } ); } @@ -279,19 +281,40 @@ struct cuda_scatter_gather_base_kernel { auto index_size = is_scatter_like ? self_dim_size : src_dim_size; auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride; - - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( - at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, - iter.dtype(), - "cuda_scatter_gather_base_kernel_func", [&] { + if (self.is_quantized()) { + TORCH_CHECK( + self.qscheme() == kPerTensorAffine, + "Only per_tensor quantized quantized tensors are supported by gather.") + AT_DISPATCH_QINT_TYPES(iter.dtype(), "gather_quant_cuda", [&] { using dtype = typename std::conditional, scalar_t>::type; - - _cuda_scatter_gather_internal_kernel()( - iter, index_size, index_stride, self.numel(), f - ); - } - ); + OpaqueType, scalar_t>::type; + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () { + _cuda_scatter_gather_internal_kernel()( + iter, index_size, index_stride, self.numel(), f + ); + }); + }); + } else { + AT_DISPATCH_V2( + iter.dtype(), + "gather_cuda", + AT_WRAP([&] { + using dtype = typename std::conditional, scalar_t>::type; + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () { + _cuda_scatter_gather_internal_kernel()( + iter, index_size, index_stride, self.numel(), f + ); + }); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), + AT_EXPAND(AT_FLOAT8_TYPES), + kComplexHalf, + kHalf, + kBool, + kBFloat16); + } } template @@ -338,7 +361,6 @@ struct cuda_scatter_gather_base_kernel { auto index_size = is_scatter_like ? self_dim_size : src_dim_size; auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride; - AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), @@ -346,15 +368,17 @@ struct cuda_scatter_gather_base_kernel { using dtype = typename std::conditional, scalar_t>::type; - _cuda_scatter_gather_internal_kernel()( - iter, index_size, index_stride, self.numel(), f - ); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () { + _cuda_scatter_gather_internal_kernel()( + iter, index_size, index_stride, self.numel(), f + ); + }); } ); } }; // struct cuda_scatter_gather_base_kernel -template +template struct _cuda_scatter_fill_internal_kernel { template void operator()( @@ -367,7 +391,7 @@ struct _cuda_scatter_fill_internal_kernel { ) { if (!iter.can_use_32bit_indexing()) { for (auto& sub_iter : iter.with_32bit_indexing()) { - _cuda_scatter_fill_internal_kernel()( + _cuda_scatter_fill_internal_kernel()( sub_iter, src_val, index_size, index_stride, numel, f ); } @@ -381,7 +405,7 @@ struct _cuda_scatter_fill_internal_kernel { auto loop = [=]C10_DEVICE(int i) { auto offsets = offset_calc.get(i); - int64_t idx_dim = *(int64_t*)(index_ptr + offsets[1]); + int64_t idx_dim = *(index_t*)(index_ptr + offsets[1]); CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size && "index out of bounds" ); @@ -437,9 +461,11 @@ struct cuda_scatter_fill_base_kernel { auto src_scalar_val = src.to(); auto src_val = *(dtype*)&src_scalar_val; - _cuda_scatter_fill_internal_kernel()( - iter, src_val, index_size, index_stride, self.numel(), f - ); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_fill_base_kernel_func", [&] () { + _cuda_scatter_fill_internal_kernel()( + iter, src_val, index_size, index_stride, self.numel(), f + ); + }); } ); } @@ -480,9 +506,11 @@ struct cuda_scatter_fill_base_kernel { auto src_scalar_val = src.to(); auto src_val = *(dtype*)&src_scalar_val; - _cuda_scatter_fill_internal_kernel()( - iter, src_val, index_size, index_stride, self.numel(), f - ); + AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_fill_base_kernel_reduce_multiply", [&] () { + _cuda_scatter_fill_internal_kernel()( + iter, src_val, index_size, index_stride, self.numel(), f + ); + }); } ); } diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index a6ced95237b..96f582d5826 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3335,7 +3335,6 @@ def gather(x, dim, index, sparse_grad=False): # Empty index case. Return an empty array with the same shape return new_empty(x, index.get_size()) - assert index.get_dtype() == torch.int64 size = x.get_size() offset = len(size) == 0 dim = _validate_dim(x, dim, offset) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index c9a10268635..f91347087a1 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5420,8 +5420,8 @@ def meta_gather(self, dim, index, sparse_grad=False): is_index_empty = guard_size_oblivious(index.numel() == 0) if not is_index_empty: torch._check( - index.dtype == torch.long, - lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}", + index.dtype == torch.long or index.dtype == torch.int, + lambda: f"gather(): Expected dtype int32/int64 for index, but got {index.dtype}", ) gather_shape_check(self, wrapped_dim, index) return self.new_empty(index.shape) @@ -5460,8 +5460,8 @@ def scatter_gather_dtype_check(method_name, self, index, src_opt=None): if guard_size_oblivious(index.numel() != 0): torch._check( - index.dtype == torch.long, - lambda: f"{method_name}(): Expected dtype int64 for index", + index.dtype == torch.long or index.dtype == torch.int, + lambda: f"{method_name}(): Expected dtype int32/int64 for index", ) if src_opt is not None: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index ff8ba68bbd0..67c9d2e1291 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2618,6 +2618,10 @@ def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs): make_arg((M, S)), 0, gather_variable((S, S), 1, M, True, device=device)) + yield SampleInput( + make_arg((M, S)), + 0, + gather_variable((S, S), 1, M, True, device=device).to(torch.int32)) yield SampleInput( make_arg((M, S)), 1, @@ -2663,11 +2667,6 @@ def error_inputs_gather(op_info, device, **kwargs): yield ErrorInput(SampleInput(bad_src, args=(1, idx,)), error_regex="Size does not match at dimension 0") - # Index must have long dtype - bad_idx = idx.to(torch.int32) - yield ErrorInput(SampleInput(src, args=(1, bad_idx)), - error_regex="Expected dtype int64 for index") - # TODO: FIXME # out.dtype must match src.dtype # Creates new src & idx since SampleInputs can't share tensors @@ -2740,13 +2739,6 @@ def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs): yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_regex="Expected self.dtype to be equal to src.dtype") - # Index dtype must be long - src = make_tensor((2, 5), device=device, dtype=torch.float32) - idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.int32) - dst = torch.zeros((3, 5), device=device, dtype=torch.float32) - yield ErrorInput(SampleInput(dst, args=(0, idx, src)), - error_regex="Expected dtype int64 for index") - # Index and destination must have the same number of dimensions src = make_tensor((2, 5), device=device, dtype=torch.float32) idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long) @@ -7139,6 +7131,7 @@ def sample_inputs_scatter(op_info, device, dtype, requires_grad, **kwargs): zero = torch.tensor(0, dtype=torch.long, device=device) test_cases = ( (_tensor((M, S)), (0, _gather((S, S), 1, M), _tensor((S, S)))), + (_tensor((M, S)), (0, _gather((S, S), 1, M).to(torch.int32), _tensor((S, S)))), (_tensor((M, S)), (1, _gather((S, S), 0, S), _tensor((S, S)))), (_tensor((M, S)), (-1, _gather((S, S), 0, S), _tensor((S, S)))), (_tensor((M, S)), (0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))),