mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support more dtypes for input, indices in gather (#151822)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151822 Approved by: https://github.com/ngimel
This commit is contained in:
parent
4c8dee7986
commit
f0c9b3385d
|
|
@ -19,8 +19,8 @@ inline void scatter_gather_dtype_check(
|
|||
) {
|
||||
if (index.numel() != 0) {
|
||||
TORCH_CHECK(
|
||||
index.scalar_type() == at::ScalarType::Long,
|
||||
method_name, "(): Expected dtype int64 for index"
|
||||
index.scalar_type() == at::ScalarType::Long || index.scalar_type() == at::ScalarType::Int,
|
||||
method_name, "(): Expected dtype int32/int64 for index"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -175,9 +175,10 @@ TORCH_META_FUNC(gather)
|
|||
auto is_index_empty = index.numel() == 0;
|
||||
if (!is_index_empty) {
|
||||
TORCH_CHECK(
|
||||
index.scalar_type() == at::ScalarType::Long,
|
||||
index.scalar_type() == ScalarType::Long ||
|
||||
index.scalar_type() == ScalarType::Int,
|
||||
"gather",
|
||||
"(): Expected dtype int64 for index");
|
||||
"(): Expected dtype int32/int64 for index");
|
||||
}
|
||||
if (is_index_empty)
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -167,10 +167,11 @@ template <bool is_scatter_like = true>
|
|||
struct cpu_scatter_gather_base_kernel {
|
||||
template <typename func_t>
|
||||
void operator()(const Tensor& self, int64_t dim,
|
||||
const Tensor& index, const Scalar& value,
|
||||
const Tensor& _index, const Scalar& value,
|
||||
const std::string& method_name, func_t& kernel_func) {
|
||||
|
||||
Tensor buffer;
|
||||
Tensor index = _index.to(ScalarType::Long);
|
||||
bool need_acc = isReducedFloatingType(self.scalar_type());
|
||||
create_acc_buffer(buffer, self, need_acc);
|
||||
|
||||
|
|
@ -263,10 +264,11 @@ struct cpu_scatter_gather_base_kernel {
|
|||
|
||||
template <typename func_t>
|
||||
void operator()(const Tensor& self, int64_t dim,
|
||||
const Tensor& index, const Tensor& src,
|
||||
const Tensor& _index, const Tensor& src,
|
||||
const std::string& method_name, func_t& kernel_func) {
|
||||
|
||||
Tensor buffer;
|
||||
Tensor index = _index.to(ScalarType::Long);
|
||||
bool need_acc = isReducedFloatingType(self.scalar_type());
|
||||
create_acc_buffer(buffer, self, need_acc);
|
||||
|
||||
|
|
@ -358,10 +360,11 @@ struct cpu_scatter_gather_base_kernel {
|
|||
}
|
||||
|
||||
void operator()(const Tensor& self, int64_t dim,
|
||||
const Tensor& index, const Tensor& src,
|
||||
const Tensor& _index, const Tensor& src,
|
||||
const std::string& method_name, ReduceMean& kernel_func) {
|
||||
|
||||
Tensor buffer;
|
||||
Tensor index = _index.to(ScalarType::Long);
|
||||
bool need_acc = isReducedFloatingType(self.scalar_type());
|
||||
create_acc_buffer(buffer, self, need_acc);
|
||||
|
||||
|
|
@ -453,9 +456,10 @@ struct cpu_scatter_gather_base_kernel {
|
|||
}
|
||||
|
||||
void operator()(const Tensor& self, int64_t dim,
|
||||
const Tensor& index, const Tensor& src,
|
||||
const Tensor& _index, const Tensor& src,
|
||||
const std::string& method_name, ReduceMaximum& kernel_func) {
|
||||
Tensor buffer;
|
||||
Tensor index = _index.to(ScalarType::Long);
|
||||
bool need_acc = isReducedFloatingType(self.scalar_type());
|
||||
create_acc_buffer(buffer, self, need_acc);
|
||||
|
||||
|
|
@ -547,10 +551,11 @@ struct cpu_scatter_gather_base_kernel {
|
|||
}
|
||||
|
||||
void operator()(const Tensor& self, int64_t dim,
|
||||
const Tensor& index, const Tensor& src,
|
||||
const Tensor& _index, const Tensor& src,
|
||||
const std::string& method_name, ReduceMinimum& kernel_func) {
|
||||
|
||||
Tensor buffer;
|
||||
Tensor index = _index.to(ScalarType::Long);
|
||||
bool need_acc = isReducedFloatingType(self.scalar_type());
|
||||
create_acc_buffer(buffer, self, need_acc);
|
||||
|
||||
|
|
@ -810,7 +815,8 @@ void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index,
|
|||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index, const Tensor& self) {
|
||||
void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& _index, const Tensor& self) {
|
||||
Tensor index = _index.to(ScalarType::Long);
|
||||
const int64_t* index_data = index.const_data_ptr<int64_t>();
|
||||
scalar_t* result_data = result.data_ptr<scalar_t>();
|
||||
const scalar_t* self_data = self.const_data_ptr<scalar_t>();
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
|
|||
auto inp_stride_bytes = index_stride[0];
|
||||
auto out_stride_bytes = iter.strides(0)[1];
|
||||
if (iter.numel() == 0) return;
|
||||
at::native::vectorized_gather_kernel_launch<alignment>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
|
||||
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
|
||||
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@
|
|||
#include <ATen/ceil_div.h>
|
||||
|
||||
namespace at::native {
|
||||
template <int Alignment>
|
||||
__global__ void vectorized_gather_kernel(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size, int64_t ind_dim_size, int64_t inp_stride, int64_t out_stride, bool allow_neg_indices) {
|
||||
template <int Alignment, typename index_t>
|
||||
__global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx, int num_ind, int64_t slice_size, int64_t ind_dim_size, int64_t inp_stride, int64_t out_stride, bool allow_neg_indices) {
|
||||
int64_t ind = idx[blockIdx.x];
|
||||
if (allow_neg_indices) {
|
||||
ind = (ind < 0) ? ind + ind_dim_size : ind;
|
||||
|
|
@ -22,8 +22,8 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, int64_t * idx,
|
|||
|
||||
|
||||
|
||||
template <int64_t Alignment>
|
||||
void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind,
|
||||
template <int64_t Alignment, typename index_t>
|
||||
void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind,
|
||||
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices){
|
||||
|
||||
constexpr int64_t max_num_threads=256;
|
||||
|
|
@ -32,13 +32,15 @@ void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int
|
|||
static_cast<int64_t>(C10_WARP_SIZE));
|
||||
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
|
||||
auto block = std::min(max_num_threads, num_threads);
|
||||
vectorized_gather_kernel<Alignment><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
|
||||
vectorized_gather_kernel<Alignment, index_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
|
||||
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
// explicit template instantiation
|
||||
template void vectorized_gather_kernel_launch<16>(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size_in_bytes,
|
||||
template void vectorized_gather_kernel_launch<16, int64_t>(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size_in_bytes,
|
||||
int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices);
|
||||
template void vectorized_gather_kernel_launch<16, int32_t>(char * out, char * inp, int32_t * idx, int num_ind, int64_t slice_size_in_bytes,
|
||||
int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices);
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ inline bool fast_gather_kernel_eligible(const TensorIterator& iter, char * const
|
|||
get_alignment(static_cast<size_t>(iter.strides(0)[1])) == alignment;
|
||||
}
|
||||
|
||||
template <int64_t Alignment>
|
||||
void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind,
|
||||
template <int64_t Alignment, typename index_t>
|
||||
void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind,
|
||||
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes,
|
||||
bool allow_neg_indices=false);
|
||||
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) {
|
|||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template <bool is_scatter_like, typename scalar_t>
|
||||
template <bool is_scatter_like, typename scalar_t, typename index_t>
|
||||
struct _cuda_scatter_gather_internal_kernel {
|
||||
template <typename func_t>
|
||||
void operator() (
|
||||
|
|
@ -128,7 +128,7 @@ struct _cuda_scatter_gather_internal_kernel {
|
|||
) {
|
||||
if (!iter.can_use_32bit_indexing()) {
|
||||
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, scalar_t>()(
|
||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, scalar_t, index_t>()(
|
||||
sub_iter, index_size, index_stride, numel, f
|
||||
);
|
||||
}
|
||||
|
|
@ -151,7 +151,7 @@ struct _cuda_scatter_gather_internal_kernel {
|
|||
auto inp_stride_bytes = index_stride * element_size;
|
||||
auto out_stride_bytes = iter.strides(0)[1];
|
||||
if (iter.numel() == 0) return;
|
||||
at::native::vectorized_gather_kernel_launch<alignment>(self_ptr, src_ptr, (int64_t*)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes);
|
||||
at::native::vectorized_gather_kernel_launch<alignment, index_t>(self_ptr, src_ptr, (index_t*)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -159,7 +159,7 @@ struct _cuda_scatter_gather_internal_kernel {
|
|||
auto loop = [=]C10_DEVICE(int i) {
|
||||
auto offsets = offset_calc.get(i);
|
||||
|
||||
int64_t idx_dim = *(int64_t*)(index_ptr + offsets[2]);
|
||||
int64_t idx_dim = *(index_t*)(index_ptr + offsets[2]);
|
||||
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
|
||||
&& "scatter gather kernel index out of bounds");
|
||||
|
||||
|
|
@ -229,9 +229,11 @@ struct cuda_scatter_gather_base_kernel {
|
|||
using dtype = typename std::conditional<cast_to_opaque,
|
||||
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
|
||||
|
||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
|
||||
iter, index_size, index_stride, self.numel(), f
|
||||
);
|
||||
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
|
||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
|
||||
iter, index_size, index_stride, self.numel(), f
|
||||
);
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
@ -279,19 +281,40 @@ struct cuda_scatter_gather_base_kernel {
|
|||
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
|
||||
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
|
||||
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
|
||||
iter.dtype(),
|
||||
"cuda_scatter_gather_base_kernel_func", [&] {
|
||||
if (self.is_quantized()) {
|
||||
TORCH_CHECK(
|
||||
self.qscheme() == kPerTensorAffine,
|
||||
"Only per_tensor quantized quantized tensors are supported by gather.")
|
||||
AT_DISPATCH_QINT_TYPES(iter.dtype(), "gather_quant_cuda", [&] {
|
||||
using dtype = typename std::conditional<cast_to_opaque,
|
||||
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
|
||||
|
||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
|
||||
iter, index_size, index_stride, self.numel(), f
|
||||
);
|
||||
}
|
||||
);
|
||||
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
|
||||
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
|
||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
|
||||
iter, index_size, index_stride, self.numel(), f
|
||||
);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_V2(
|
||||
iter.dtype(),
|
||||
"gather_cuda",
|
||||
AT_WRAP([&] {
|
||||
using dtype = typename std::conditional<cast_to_opaque,
|
||||
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
|
||||
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
|
||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
|
||||
iter, index_size, index_stride, self.numel(), f
|
||||
);
|
||||
});
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
|
||||
AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
kBFloat16);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
|
|
@ -338,7 +361,6 @@ struct cuda_scatter_gather_base_kernel {
|
|||
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
|
||||
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
|
||||
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
iter.dtype(),
|
||||
|
|
@ -346,15 +368,17 @@ struct cuda_scatter_gather_base_kernel {
|
|||
using dtype = typename std::conditional<cast_to_opaque,
|
||||
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
|
||||
|
||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
|
||||
iter, index_size, index_stride, self.numel(), f
|
||||
);
|
||||
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
|
||||
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
|
||||
iter, index_size, index_stride, self.numel(), f
|
||||
);
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
}; // struct cuda_scatter_gather_base_kernel
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename index_t>
|
||||
struct _cuda_scatter_fill_internal_kernel {
|
||||
template <typename func_t>
|
||||
void operator()(
|
||||
|
|
@ -367,7 +391,7 @@ struct _cuda_scatter_fill_internal_kernel {
|
|||
) {
|
||||
if (!iter.can_use_32bit_indexing()) {
|
||||
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
||||
_cuda_scatter_fill_internal_kernel<scalar_t>()(
|
||||
_cuda_scatter_fill_internal_kernel<scalar_t, index_t>()(
|
||||
sub_iter, src_val, index_size, index_stride, numel, f
|
||||
);
|
||||
}
|
||||
|
|
@ -381,7 +405,7 @@ struct _cuda_scatter_fill_internal_kernel {
|
|||
auto loop = [=]C10_DEVICE(int i) {
|
||||
auto offsets = offset_calc.get(i);
|
||||
|
||||
int64_t idx_dim = *(int64_t*)(index_ptr + offsets[1]);
|
||||
int64_t idx_dim = *(index_t*)(index_ptr + offsets[1]);
|
||||
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
|
||||
&& "index out of bounds"
|
||||
);
|
||||
|
|
@ -437,9 +461,11 @@ struct cuda_scatter_fill_base_kernel {
|
|||
auto src_scalar_val = src.to<scalar_t>();
|
||||
auto src_val = *(dtype*)&src_scalar_val;
|
||||
|
||||
_cuda_scatter_fill_internal_kernel<dtype>()(
|
||||
iter, src_val, index_size, index_stride, self.numel(), f
|
||||
);
|
||||
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_fill_base_kernel_func", [&] () {
|
||||
_cuda_scatter_fill_internal_kernel<dtype, index_t>()(
|
||||
iter, src_val, index_size, index_stride, self.numel(), f
|
||||
);
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
@ -480,9 +506,11 @@ struct cuda_scatter_fill_base_kernel {
|
|||
auto src_scalar_val = src.to<scalar_t>();
|
||||
auto src_val = *(dtype*)&src_scalar_val;
|
||||
|
||||
_cuda_scatter_fill_internal_kernel<dtype>()(
|
||||
iter, src_val, index_size, index_stride, self.numel(), f
|
||||
);
|
||||
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_fill_base_kernel_reduce_multiply", [&] () {
|
||||
_cuda_scatter_fill_internal_kernel<dtype, index_t>()(
|
||||
iter, src_val, index_size, index_stride, self.numel(), f
|
||||
);
|
||||
});
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3335,7 +3335,6 @@ def gather(x, dim, index, sparse_grad=False):
|
|||
# Empty index case. Return an empty array with the same shape
|
||||
return new_empty(x, index.get_size())
|
||||
|
||||
assert index.get_dtype() == torch.int64
|
||||
size = x.get_size()
|
||||
offset = len(size) == 0
|
||||
dim = _validate_dim(x, dim, offset)
|
||||
|
|
|
|||
|
|
@ -5420,8 +5420,8 @@ def meta_gather(self, dim, index, sparse_grad=False):
|
|||
is_index_empty = guard_size_oblivious(index.numel() == 0)
|
||||
if not is_index_empty:
|
||||
torch._check(
|
||||
index.dtype == torch.long,
|
||||
lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
|
||||
index.dtype == torch.long or index.dtype == torch.int,
|
||||
lambda: f"gather(): Expected dtype int32/int64 for index, but got {index.dtype}",
|
||||
)
|
||||
gather_shape_check(self, wrapped_dim, index)
|
||||
return self.new_empty(index.shape)
|
||||
|
|
@ -5460,8 +5460,8 @@ def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
|
|||
|
||||
if guard_size_oblivious(index.numel() != 0):
|
||||
torch._check(
|
||||
index.dtype == torch.long,
|
||||
lambda: f"{method_name}(): Expected dtype int64 for index",
|
||||
index.dtype == torch.long or index.dtype == torch.int,
|
||||
lambda: f"{method_name}(): Expected dtype int32/int64 for index",
|
||||
)
|
||||
|
||||
if src_opt is not None:
|
||||
|
|
|
|||
|
|
@ -2618,6 +2618,10 @@ def sample_inputs_gather(op_info, device, dtype, requires_grad, **kwargs):
|
|||
make_arg((M, S)),
|
||||
0,
|
||||
gather_variable((S, S), 1, M, True, device=device))
|
||||
yield SampleInput(
|
||||
make_arg((M, S)),
|
||||
0,
|
||||
gather_variable((S, S), 1, M, True, device=device).to(torch.int32))
|
||||
yield SampleInput(
|
||||
make_arg((M, S)),
|
||||
1,
|
||||
|
|
@ -2663,11 +2667,6 @@ def error_inputs_gather(op_info, device, **kwargs):
|
|||
yield ErrorInput(SampleInput(bad_src, args=(1, idx,)),
|
||||
error_regex="Size does not match at dimension 0")
|
||||
|
||||
# Index must have long dtype
|
||||
bad_idx = idx.to(torch.int32)
|
||||
yield ErrorInput(SampleInput(src, args=(1, bad_idx)),
|
||||
error_regex="Expected dtype int64 for index")
|
||||
|
||||
# TODO: FIXME
|
||||
# out.dtype must match src.dtype
|
||||
# Creates new src & idx since SampleInputs can't share tensors
|
||||
|
|
@ -2740,13 +2739,6 @@ def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs):
|
|||
yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
|
||||
error_regex="Expected self.dtype to be equal to src.dtype")
|
||||
|
||||
# Index dtype must be long
|
||||
src = make_tensor((2, 5), device=device, dtype=torch.float32)
|
||||
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.int32)
|
||||
dst = torch.zeros((3, 5), device=device, dtype=torch.float32)
|
||||
yield ErrorInput(SampleInput(dst, args=(0, idx, src)),
|
||||
error_regex="Expected dtype int64 for index")
|
||||
|
||||
# Index and destination must have the same number of dimensions
|
||||
src = make_tensor((2, 5), device=device, dtype=torch.float32)
|
||||
idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long)
|
||||
|
|
@ -7139,6 +7131,7 @@ def sample_inputs_scatter(op_info, device, dtype, requires_grad, **kwargs):
|
|||
zero = torch.tensor(0, dtype=torch.long, device=device)
|
||||
test_cases = (
|
||||
(_tensor((M, S)), (0, _gather((S, S), 1, M), _tensor((S, S)))),
|
||||
(_tensor((M, S)), (0, _gather((S, S), 1, M).to(torch.int32), _tensor((S, S)))),
|
||||
(_tensor((M, S)), (1, _gather((S, S), 0, S), _tensor((S, S)))),
|
||||
(_tensor((M, S)), (-1, _gather((S, S), 0, S), _tensor((S, S)))),
|
||||
(_tensor((M, S)), (0, _gather((M, S // 2), 1, M), _tensor((M, S // 2)))),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user