mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Support more dtypes for input, indices in gather (#151822)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151822 Approved by: https://github.com/ngimel
This commit is contained in:
parent
4c8dee7986
commit
f0c9b3385d
|
|
@ -19,8 +19,8 @@ inline void scatter_gather_dtype_check(
|
||||||
) {
|
) {
|
||||||
if (index.numel() != 0) {
|
if (index.numel() != 0) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
index.scalar_type() == at::ScalarType::Long,
|
index.scalar_type() == at::ScalarType::Long || index.scalar_type() == at::ScalarType::Int,
|
||||||
method_name, "(): Expected dtype int64 for index"
|
method_name, "(): Expected dtype int32/int64 for index"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -175,9 +175,10 @@ TORCH_META_FUNC(gather)
|
||||||
auto is_index_empty = index.numel() == 0;
|
auto is_index_empty = index.numel() == 0;
|
||||||
if (!is_index_empty) {
|
if (!is_index_empty) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
index.scalar_type() == at::ScalarType::Long,
|
index.scalar_type() == ScalarType::Long ||
|
||||||
|
index.scalar_type() == ScalarType::Int,
|
||||||
"gather",
|
"gather",
|
||||||
"(): Expected dtype int64 for index");
|
"(): Expected dtype int32/int64 for index");
|
||||||
}
|
}
|
||||||
if (is_index_empty)
|
if (is_index_empty)
|
||||||
return;
|
return;
|
||||||
|
|
|
||||||
|
|
@ -167,10 +167,11 @@ template <bool is_scatter_like = true>
|
||||||
struct cpu_scatter_gather_base_kernel {
|
struct cpu_scatter_gather_base_kernel {
|
||||||
template <typename func_t>
|
template <typename func_t>
|
||||||
void operator()(const Tensor& self, int64_t dim,
|
void operator()(const Tensor& self, int64_t dim,
|
||||||
const Tensor& index, const Scalar& value,
|
const Tensor& _index, const Scalar& value,
|
||||||
const std::string& method_name, func_t& kernel_func) {
|
const std::string& method_name, func_t& kernel_func) {
|
||||||
|
|
||||||
Tensor buffer;
|
Tensor buffer;
|
||||||
|
Tensor index = _index.to(ScalarType::Long);
|
||||||
bool need_acc = isReducedFloatingType(self.scalar_type());
|
bool need_acc = isReducedFloatingType(self.scalar_type());
|
||||||
create_acc_buffer(buffer, self, need_acc);
|
create_acc_buffer(buffer, self, need_acc);
|
||||||
|
|
||||||
|
|
@ -263,10 +264,11 @@ struct cpu_scatter_gather_base_kernel {
|
||||||
|
|
||||||
template <typename func_t>
|
template <typename func_t>
|
||||||
void operator()(const Tensor& self, int64_t dim,
|
void operator()(const Tensor& self, int64_t dim,
|
||||||
const Tensor& index, const Tensor& src,
|
const Tensor& _index, const Tensor& src,
|
||||||
const std::string& method_name, func_t& kernel_func) {
|
const std::string& method_name, func_t& kernel_func) {
|
||||||
|
|
||||||
Tensor buffer;
|
Tensor buffer;
|
||||||
|
Tensor index = _index.to(ScalarType::Long);
|
||||||
bool need_acc = isReducedFloatingType(self.scalar_type());
|
bool need_acc = isReducedFloatingType(self.scalar_type());
|
||||||
create_acc_buffer(buffer, self, need_acc);
|
create_acc_buffer(buffer, self, need_acc);
|
||||||
|
|
||||||
|
|
@ -358,10 +360,11 @@ struct cpu_scatter_gather_base_kernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
void operator()(const Tensor& self, int64_t dim,
|
void operator()(const Tensor& self, int64_t dim,
|
||||||
const Tensor& index, const Tensor& src,
|
const Tensor& _index, const Tensor& src,
|
||||||
const std::string& method_name, ReduceMean& kernel_func) {
|
const std::string& method_name, ReduceMean& kernel_func) {
|
||||||
|
|
||||||
Tensor buffer;
|
Tensor buffer;
|
||||||
|
Tensor index = _index.to(ScalarType::Long);
|
||||||
bool need_acc = isReducedFloatingType(self.scalar_type());
|
bool need_acc = isReducedFloatingType(self.scalar_type());
|
||||||
create_acc_buffer(buffer, self, need_acc);
|
create_acc_buffer(buffer, self, need_acc);
|
||||||
|
|
||||||
|
|
@ -453,9 +456,10 @@ struct cpu_scatter_gather_base_kernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
void operator()(const Tensor& self, int64_t dim,
|
void operator()(const Tensor& self, int64_t dim,
|
||||||
const Tensor& index, const Tensor& src,
|
const Tensor& _index, const Tensor& src,
|
||||||
const std::string& method_name, ReduceMaximum& kernel_func) {
|
const std::string& method_name, ReduceMaximum& kernel_func) {
|
||||||
Tensor buffer;
|
Tensor buffer;
|
||||||
|
Tensor index = _index.to(ScalarType::Long);
|
||||||
bool need_acc = isReducedFloatingType(self.scalar_type());
|
bool need_acc = isReducedFloatingType(self.scalar_type());
|
||||||
create_acc_buffer(buffer, self, need_acc);
|
create_acc_buffer(buffer, self, need_acc);
|
||||||
|
|
||||||
|
|
@ -547,10 +551,11 @@ struct cpu_scatter_gather_base_kernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
void operator()(const Tensor& self, int64_t dim,
|
void operator()(const Tensor& self, int64_t dim,
|
||||||
const Tensor& index, const Tensor& src,
|
const Tensor& _index, const Tensor& src,
|
||||||
const std::string& method_name, ReduceMinimum& kernel_func) {
|
const std::string& method_name, ReduceMinimum& kernel_func) {
|
||||||
|
|
||||||
Tensor buffer;
|
Tensor buffer;
|
||||||
|
Tensor index = _index.to(ScalarType::Long);
|
||||||
bool need_acc = isReducedFloatingType(self.scalar_type());
|
bool need_acc = isReducedFloatingType(self.scalar_type());
|
||||||
create_acc_buffer(buffer, self, need_acc);
|
create_acc_buffer(buffer, self, need_acc);
|
||||||
|
|
||||||
|
|
@ -810,7 +815,8 @@ void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index, const Tensor& self) {
|
void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& _index, const Tensor& self) {
|
||||||
|
Tensor index = _index.to(ScalarType::Long);
|
||||||
const int64_t* index_data = index.const_data_ptr<int64_t>();
|
const int64_t* index_data = index.const_data_ptr<int64_t>();
|
||||||
scalar_t* result_data = result.data_ptr<scalar_t>();
|
scalar_t* result_data = result.data_ptr<scalar_t>();
|
||||||
const scalar_t* self_data = self.const_data_ptr<scalar_t>();
|
const scalar_t* self_data = self.const_data_ptr<scalar_t>();
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
|
||||||
auto inp_stride_bytes = index_stride[0];
|
auto inp_stride_bytes = index_stride[0];
|
||||||
auto out_stride_bytes = iter.strides(0)[1];
|
auto out_stride_bytes = iter.strides(0)[1];
|
||||||
if (iter.numel() == 0) return;
|
if (iter.numel() == 0) return;
|
||||||
at::native::vectorized_gather_kernel_launch<alignment>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
|
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
|
||||||
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
|
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@
|
||||||
#include <ATen/ceil_div.h>
|
#include <ATen/ceil_div.h>
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
template <int Alignment>
|
template <int Alignment, typename index_t>
|
||||||
__global__ void vectorized_gather_kernel(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size, int64_t ind_dim_size, int64_t inp_stride, int64_t out_stride, bool allow_neg_indices) {
|
__global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx, int num_ind, int64_t slice_size, int64_t ind_dim_size, int64_t inp_stride, int64_t out_stride, bool allow_neg_indices) {
|
||||||
int64_t ind = idx[blockIdx.x];
|
int64_t ind = idx[blockIdx.x];
|
||||||
if (allow_neg_indices) {
|
if (allow_neg_indices) {
|
||||||
ind = (ind < 0) ? ind + ind_dim_size : ind;
|
ind = (ind < 0) ? ind + ind_dim_size : ind;
|
||||||
|
|
@ -22,8 +22,8 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, int64_t * idx,
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <int64_t Alignment>
|
template <int64_t Alignment, typename index_t>
|
||||||
void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind,
|
void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind,
|
||||||
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices){
|
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices){
|
||||||
|
|
||||||
constexpr int64_t max_num_threads=256;
|
constexpr int64_t max_num_threads=256;
|
||||||
|
|
@ -32,13 +32,15 @@ void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int
|
||||||
static_cast<int64_t>(C10_WARP_SIZE));
|
static_cast<int64_t>(C10_WARP_SIZE));
|
||||||
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
|
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
|
||||||
auto block = std::min(max_num_threads, num_threads);
|
auto block = std::min(max_num_threads, num_threads);
|
||||||
vectorized_gather_kernel<Alignment><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
|
vectorized_gather_kernel<Alignment, index_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
|
||||||
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);
|
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// explicit template instantiation
|
// explicit template instantiation
|
||||||
template void vectorized_gather_kernel_launch<16>(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size_in_bytes,
|
template void vectorized_gather_kernel_launch<16, int64_t>(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size_in_bytes,
|
||||||
|
int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices);
|
||||||
|
template void vectorized_gather_kernel_launch<16, int32_t>(char * out, char * inp, int32_t * idx, int num_ind, int64_t slice_size_in_bytes,
|
||||||
int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices);
|
int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,8 @@ inline bool fast_gather_kernel_eligible(const TensorIterator& iter, char * const
|
||||||
get_alignment(static_cast<size_t>(iter.strides(0)[1])) == alignment;
|
get_alignment(static_cast<size_t>(iter.strides(0)[1])) == alignment;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int64_t Alignment>
|
template <int64_t Alignment, typename index_t>
|
||||||
void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind,
|
void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind,
|
||||||
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes,
|
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes,
|
||||||
bool allow_neg_indices=false);
|
bool allow_neg_indices=false);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) {
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool is_scatter_like, typename scalar_t>
|
template <bool is_scatter_like, typename scalar_t, typename index_t>
|
||||||
struct _cuda_scatter_gather_internal_kernel {
|
struct _cuda_scatter_gather_internal_kernel {
|
||||||
template <typename func_t>
|
template <typename func_t>
|
||||||
void operator() (
|
void operator() (
|
||||||
|
|
@ -128,7 +128,7 @@ struct _cuda_scatter_gather_internal_kernel {
|
||||||
) {
|
) {
|
||||||
if (!iter.can_use_32bit_indexing()) {
|
if (!iter.can_use_32bit_indexing()) {
|
||||||
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
||||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, scalar_t>()(
|
_cuda_scatter_gather_internal_kernel<is_scatter_like, scalar_t, index_t>()(
|
||||||
sub_iter, index_size, index_stride, numel, f
|
sub_iter, index_size, index_stride, numel, f
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
@ -151,7 +151,7 @@ struct _cuda_scatter_gather_internal_kernel {
|
||||||
auto inp_stride_bytes = index_stride * element_size;
|
auto inp_stride_bytes = index_stride * element_size;
|
||||||
auto out_stride_bytes = iter.strides(0)[1];
|
auto out_stride_bytes = iter.strides(0)[1];
|
||||||
if (iter.numel() == 0) return;
|
if (iter.numel() == 0) return;
|
||||||
at::native::vectorized_gather_kernel_launch<alignment>(self_ptr, src_ptr, (int64_t*)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes);
|
at::native::vectorized_gather_kernel_launch<alignment, index_t>(self_ptr, src_ptr, (index_t*)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -159,7 +159,7 @@ struct _cuda_scatter_gather_internal_kernel {
|
||||||
auto loop = [=]C10_DEVICE(int i) {
|
auto loop = [=]C10_DEVICE(int i) {
|
||||||
auto offsets = offset_calc.get(i);
|
auto offsets = offset_calc.get(i);
|
||||||
|
|
||||||
int64_t idx_dim = *(int64_t*)(index_ptr + offsets[2]);
|
int64_t idx_dim = *(index_t*)(index_ptr + offsets[2]);
|
||||||
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
|
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
|
||||||
&& "scatter gather kernel index out of bounds");
|
&& "scatter gather kernel index out of bounds");
|
||||||
|
|
||||||
|
|
@ -229,9 +229,11 @@ struct cuda_scatter_gather_base_kernel {
|
||||||
using dtype = typename std::conditional<cast_to_opaque,
|
using dtype = typename std::conditional<cast_to_opaque,
|
||||||
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
|
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
|
||||||
|
|
||||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
|
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
|
||||||
iter, index_size, index_stride, self.numel(), f
|
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
|
||||||
);
|
iter, index_size, index_stride, self.numel(), f
|
||||||
|
);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
@ -279,19 +281,40 @@ struct cuda_scatter_gather_base_kernel {
|
||||||
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
|
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
|
||||||
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
|
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
|
||||||
|
|
||||||
|
if (self.is_quantized()) {
|
||||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
TORCH_CHECK(
|
||||||
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
|
self.qscheme() == kPerTensorAffine,
|
||||||
iter.dtype(),
|
"Only per_tensor quantized quantized tensors are supported by gather.")
|
||||||
"cuda_scatter_gather_base_kernel_func", [&] {
|
AT_DISPATCH_QINT_TYPES(iter.dtype(), "gather_quant_cuda", [&] {
|
||||||
using dtype = typename std::conditional<cast_to_opaque,
|
using dtype = typename std::conditional<cast_to_opaque,
|
||||||
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
|
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
|
||||||
|
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
|
||||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
|
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
|
||||||
iter, index_size, index_stride, self.numel(), f
|
iter, index_size, index_stride, self.numel(), f
|
||||||
);
|
);
|
||||||
}
|
});
|
||||||
);
|
});
|
||||||
|
} else {
|
||||||
|
AT_DISPATCH_V2(
|
||||||
|
iter.dtype(),
|
||||||
|
"gather_cuda",
|
||||||
|
AT_WRAP([&] {
|
||||||
|
using dtype = typename std::conditional<cast_to_opaque,
|
||||||
|
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
|
||||||
|
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
|
||||||
|
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
|
||||||
|
iter, index_size, index_stride, self.numel(), f
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||||
|
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
|
||||||
|
AT_EXPAND(AT_FLOAT8_TYPES),
|
||||||
|
kComplexHalf,
|
||||||
|
kHalf,
|
||||||
|
kBool,
|
||||||
|
kBFloat16);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename func_t>
|
template <typename func_t>
|
||||||
|
|
@ -338,7 +361,6 @@ struct cuda_scatter_gather_base_kernel {
|
||||||
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
|
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
|
||||||
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
|
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
|
||||||
|
|
||||||
|
|
||||||
AT_DISPATCH_ALL_TYPES_AND2(
|
AT_DISPATCH_ALL_TYPES_AND2(
|
||||||
at::ScalarType::Half, at::ScalarType::BFloat16,
|
at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||||
iter.dtype(),
|
iter.dtype(),
|
||||||
|
|
@ -346,15 +368,17 @@ struct cuda_scatter_gather_base_kernel {
|
||||||
using dtype = typename std::conditional<cast_to_opaque,
|
using dtype = typename std::conditional<cast_to_opaque,
|
||||||
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
|
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
|
||||||
|
|
||||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
|
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
|
||||||
iter, index_size, index_stride, self.numel(), f
|
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
|
||||||
);
|
iter, index_size, index_stride, self.numel(), f
|
||||||
|
);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}; // struct cuda_scatter_gather_base_kernel
|
}; // struct cuda_scatter_gather_base_kernel
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename index_t>
|
||||||
struct _cuda_scatter_fill_internal_kernel {
|
struct _cuda_scatter_fill_internal_kernel {
|
||||||
template <typename func_t>
|
template <typename func_t>
|
||||||
void operator()(
|
void operator()(
|
||||||
|
|
@ -367,7 +391,7 @@ struct _cuda_scatter_fill_internal_kernel {
|
||||||
) {
|
) {
|
||||||
if (!iter.can_use_32bit_indexing()) {
|
if (!iter.can_use_32bit_indexing()) {
|
||||||
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
||||||
_cuda_scatter_fill_internal_kernel<scalar_t>()(
|
_cuda_scatter_fill_internal_kernel<scalar_t, index_t>()(
|
||||||
sub_iter, src_val, index_size, index_stride, numel, f
|
sub_iter, src_val, index_size, index_stride, numel, f
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
@ -381,7 +405,7 @@ struct _cuda_scatter_fill_internal_kernel {
|
||||||
auto loop = [=]C10_DEVICE(int i) {
|
auto loop = [=]C10_DEVICE(int i) {
|
||||||
auto offsets = offset_calc.get(i);
|
auto offsets = offset_calc.get(i);
|
||||||
|
|
||||||
int64_t idx_dim = *(int64_t*)(index_ptr + offsets[1]);
|
int64_t idx_dim = *(index_t*)(index_ptr + offsets[1]);
|
||||||
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
|
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
|
||||||
&& "index out of bounds"
|
&& "index out of bounds"
|
||||||
);
|
);
|
||||||
|
|
@ -437,9 +461,11 @@ struct cuda_scatter_fill_base_kernel {
|
||||||
auto src_scalar_val = src.to<scalar_t>();
|
auto src_scalar_val = src.to<scalar_t>();
|
||||||
auto src_val = *(dtype*)&src_scalar_val;
|
auto src_val = *(dtype*)&src_scalar_val;
|
||||||
|
|
||||||
_cuda_scatter_fill_internal_kernel<dtype>()(
|
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_fill_base_kernel_func", [&] () {
|
||||||
iter, src_val, index_size, index_stride, self.numel(), f
|
_cuda_scatter_fill_internal_kernel<dtype, index_t>()(
|
||||||
);
|
iter, src_val, index_size, index_stride, self.numel(), f
|
||||||
|
);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
@ -480,9 +506,11 @@ struct cuda_scatter_fill_base_kernel {
|
||||||
auto src_scalar_val = src.to<scalar_t>();
|
auto src_scalar_val = src.to<scalar_t>();
|
||||||
auto src_val = *(dtype*)&src_scalar_val;
|
auto src_val = *(dtype*)&src_scalar_val;
|
||||||
|
|
||||||
_cuda_scatter_fill_internal_kernel<dtype>()(
|
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_fill_base_kernel_reduce_multiply", [&] () {
|
||||||
iter, src_val, index_size, index_stride, self.numel(), f
|
_cuda_scatter_fill_internal_kernel<dtype, index_t>()(
|
||||||
);
|
iter, src_val, index_size, index_stride, self.numel(), f
|
||||||
|
);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3335,7 +3335,6 @@ def gather(x, dim, index, sparse_grad=False):
|
||||||
# Empty index case. Return an empty array with the same shape
|
# Empty index case. Return an empty array with the same shape
|
||||||
return new_empty(x, index.get_size())
|
return new_empty(x, index.get_size())
|
||||||
|
|
||||||
assert index.get_dtype() == torch.int64
|
|
||||||
size = x.get_size()
|
size = x.get_size()
|
||||||
offset = len(size) == 0
|
offset = len(size) == 0
|
||||||
dim = _validate_dim(x, dim, offset)
|
dim = _validate_dim(x, dim, offset)
|
||||||
|
|
|
||||||
|
|
@ -5420,8 +5420,8 @@ def meta_gather(self, dim, index, sparse_grad=False):
|
||||||
is_index_empty = guard_size_oblivious(index.numel() == 0)
|
is_index_empty = guard_size_oblivious(index.numel() == 0)
|
||||||
if not is_index_empty:
|
if not is_index_empty:
|
||||||
torch._check(
|
torch._check(
|
||||||
index.dtype == torch.long,
|
index.dtype == torch.long or index.dtype == torch.int,
|
||||||
lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
|
lambda: f"gather(): Expected dtype int32/int64 for index, but got {index.dtype}",
|
||||||
)
|
)
|
||||||
gather_shape_check(self, wrapped_dim, index)
|
gather_shape_check(self, wrapped_dim, index)
|
||||||
return self.new_empty(index.shape)
|
return self.new_empty(index.shape)
|
||||||
|
|
@ -5460,8 +5460,8 @@ def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
|
||||||
|
|
||||||
if guard_size_oblivious(index.numel() != 0):
|
if guard_size_oblivious(index.numel() != 0):
|
||||||
torch._check(
|
torch._check(
|
||||||
index.dtype == torch.long,
|
index.dtype == torch.long or index.dtype == torch.int,
|
||||||
lambda: f"{method_name}(): Expected dtype int64 for index",
|
lambda: f"{method_name}(): Expected dtype int32/int64 for index",
|
||||||
)
|
)
|
||||||
|
|
||||||
if src_opt is not None:
|
if src_opt is not None:
|
||||||
|
|
|
||||||
|
|
@ -2618,6 +2618,10 @@ def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
make_arg((M, S)),
|
make_arg((M, S)),
|
||||||
0,
|
0,
|
||||||
gather_variable((S, S), 1, M, True, device=device))
|
gather_variable((S, S), 1, M, True, device=device))
|
||||||
|
yield SampleInput(
|
||||||
|
make_arg((M, S)),
|
||||||
|
0,
|
||||||
|
gather_variable((S, S), 1, M, True, device=device).to(torch.int32))
|
||||||
yield SampleInput(
|
yield SampleInput(
|
||||||
make_arg((M, S)),
|
make_arg((M, S)),
|
||||||
1,
|
1,
|
||||||
|
|
@ -2663,11 +2667,6 @@ def error_inputs_gather(op_info, device, **kwargs):
|
||||||
yield ErrorInput(SampleInput(bad_src, args=(1, idx,)),
|
yield ErrorInput(SampleInput(bad_src, args=(1, idx,)),
|
||||||
error_regex="Size does not match at dimension 0")
|
error_regex="Size does not match at dimension 0")
|
||||||
|
|
||||||
# Index must have long dtype
|
|
||||||
bad_idx = idx.to(torch.int32)
|
|
||||||
yield ErrorInput(SampleInput(src, args=(1, bad_idx)),
|
|
||||||
error_regex="Expected dtype int64 for index")
|
|
||||||
|
|
||||||
# TODO: FIXME
|
# TODO: FIXME
|
||||||
# out.dtype must match src.dtype
|
# out.dtype must match src.dtype
|
||||||
# Creates new src & idx since SampleInputs can't share tensors
|
# Creates new src & idx since SampleInputs can't share tensors
|
||||||
|
|
@ -2740,13 +2739,6 @@ def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
|
||||||
yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
|
yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
|
||||||
error_regex="Expected self.dtype to be equal to src.dtype")
|
error_regex="Expected self.dtype to be equal to src.dtype")
|
||||||
|
|
||||||
# Index dtype must be long
|
|
||||||
src = make_tensor((2, 5), device=device, dtype=torch.float32)
|
|
||||||
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.int32)
|
|
||||||
dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
|
|
||||||
yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
|
|
||||||
error_regex="Expected dtype int64 for index")
|
|
||||||
|
|
||||||
# Index and destination must have the same number of dimensions
|
# Index and destination must have the same number of dimensions
|
||||||
src = make_tensor((2, 5), device=device, dtype=torch.float32)
|
src = make_tensor((2, 5), device=device, dtype=torch.float32)
|
||||||
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
|
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
|
||||||
|
|
@ -7139,6 +7131,7 @@ def sample_inputs_scatter(op_info, device, dtype, requires_grad, **kwargs):
|
||||||
zero = torch.tensor(0, dtype=torch.long, device=device)
|
zero = torch.tensor(0, dtype=torch.long, device=device)
|
||||||
test_cases = (
|
test_cases = (
|
||||||
(_tensor((M, S)), (0, _gather((S, S), 1, M), _tensor((S, S)))),
|
(_tensor((M, S)), (0, _gather((S, S), 1, M), _tensor((S, S)))),
|
||||||
|
(_tensor((M, S)), (0, _gather((S, S), 1, M).to(torch.int32), _tensor((S, S)))),
|
||||||
(_tensor((M, S)), (1, _gather((S, S), 0, S), _tensor((S, S)))),
|
(_tensor((M, S)), (1, _gather((S, S), 0, S), _tensor((S, S)))),
|
||||||
(_tensor((M, S)), (-1, _gather((S, S), 0, S), _tensor((S, S)))),
|
(_tensor((M, S)), (-1, _gather((S, S), 0, S), _tensor((S, S)))),
|
||||||
(_tensor((M, S)), (0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))),
|
(_tensor((M, S)), (0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user