[ROCm] Improve backwards indexing when stride is not one (#147630)

Improve backwards indexing when stride is not one.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147630
Approved by: https://github.com/jeffdaily
This commit is contained in:
Doru Bercea 2025-03-11 19:02:44 +00:00 committed by PyTorch MergeBot
parent daff65d671
commit a1cb67b69e
2 changed files with 230 additions and 235 deletions

View File

@ -55,6 +55,205 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
#endif #endif
} }
#ifdef USE_ROCM
#define SKIP_SORTED_INDICES 32
template <typename scalar_t, int SZ>
__global__ void indexing_backward_kernel(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
using opmath_t = at::opmath_type<scalar_t>;
extern __shared__ unsigned char smem[];
auto smem_dups_cache = reinterpret_cast<int64_t*>(smem);
int smem_offset = threadIdx.y * C10_WARP_SIZE;
int laneIdx = threadIdx.x % C10_WARP_SIZE;
int64_t grad_row = 0;
for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) {
// Init duplicates every time we compute a new set of entries:
smem_dups_cache[smem_offset + laneIdx] = 0;
WARP_SYNC();
int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE;
int64_t idx = base_idx + laneIdx;
if (idx < numel) {
int64_t crnt_sorted_idx = sorted_indices[idx];
if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) {
// Determine the number of duplicates in advance:
int64_t num_duplicates = 1;
// Lookahead in case there is a large number of duplicates. Once that is done, handle the tail.
while ((idx + num_duplicates + SKIP_SORTED_INDICES - 1) < numel) {
if (sorted_indices[idx + num_duplicates + SKIP_SORTED_INDICES - 1] != crnt_sorted_idx) break;
num_duplicates += SKIP_SORTED_INDICES;
}
while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
num_duplicates++;
}
smem_dups_cache[smem_offset + laneIdx] = num_duplicates;
}
}
WARP_SYNC();
// All lanes in the warp are still active here. Use them all to reduce duplicates when
// large number of duplicates are present:
for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) {
// All lanes read the shared memory entry for number of duplicates
int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp];
// Check if the original sub-warp had duplicates to eliminate, if not skip.
if (new_num_duplicates == 0)
continue;
// There are duplicates that need eliminating:
int64_t new_idx = base_idx + subwarp;
int64_t new_crnt_sorted_idx = sorted_indices[new_idx];
const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before;
if (!accumulate) {
const int64_t grad_row = ((int64_t)indices[new_idx + new_num_duplicates - 1]) * stride + z * numel * stride;
int64_t feature_dim = blockIdx.y * blockDim.x + threadIdx.x;
while (feature_dim < stride) {
grad_weight[new_weight_row + feature_dim] = grad_output[grad_row + feature_dim];
feature_dim += gridDim.y * blockDim.x;
}
continue;
}
for (int dup = 0; dup < new_num_duplicates; dup++) {
const int64_t grad_row = ((int64_t) indices[new_idx + dup]) * stride + z * numel * stride;
// All lanes do the same thing up to here.
int64_t feature_dim = blockIdx.y * blockDim.x + threadIdx.x;
// Each lane has a different feature_dim.
while (feature_dim < stride) {
grad_weight[new_weight_row + feature_dim] += grad_output[grad_row + feature_dim];
feature_dim += gridDim.y * blockDim.x;
}
}
}
}
}
template <typename scalar_t>
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
using opmath_t = at::opmath_type<scalar_t>;
int laneIdx = threadIdx.x % C10_WARP_SIZE;
const opmath_t scale = (opmath_t)1.0;
int64_t grad_row = 0;
extern __shared__ unsigned char smem[];
auto smem_dups_cache = reinterpret_cast<int64_t*>(smem);
// Each warp gets a different section of the share memory allocation:
int smem_offset = threadIdx.y * C10_WARP_SIZE;
// Number of values processed by each thread (grain size)
for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) {
// Init duplicates every time we compute a new set of entries:
smem_dups_cache[smem_offset + laneIdx] = 0;
int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE;
int64_t idx = base_idx + laneIdx;
// Each lane calculates the number of duplicates:
if (idx < numel) {
int64_t crnt_sorted_idx = sorted_indices[idx];
if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) {
// Determine the number of duplicates in advance:
int64_t num_duplicates = 1;
// Lookahead in case there is a large number of duplicates. Once that is done, handle the tail.
while ((idx + num_duplicates + SKIP_SORTED_INDICES - 1) < numel) {
if (sorted_indices[idx + num_duplicates + SKIP_SORTED_INDICES - 1] != crnt_sorted_idx) break;
num_duplicates += SKIP_SORTED_INDICES;
}
while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
num_duplicates++;
}
if (!accumulate) {
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
grad_weight[weight_row] =
static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row]) * scale);
continue;
}
// Each lane sequentially handles the duplicate elimination:
if (num_duplicates < C10_WARP_SIZE) {
opmath_t gradient = (opmath_t)0.0;
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
for (int64_t i = 0; i < num_duplicates; ++i) {
grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
}
grad_weight[weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row]) + gradient);
} else {
// Add duplicate to the cache:
smem_dups_cache[smem_offset + laneIdx] = num_duplicates;
}
}
}
WARP_SYNC();
// All lanes in the warp are still active here. Use them all to reduce duplicates when
// large number of duplicates are present:
for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) {
// All lanes read the shared memory entry for number of duplicates
int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp];
// Check if the original sub-warp had duplicates to eliminate, if not skip.
if (new_num_duplicates == 0)
continue;
// There are duplicates that need eliminating:
int64_t new_idx = base_idx + subwarp;
int64_t new_crnt_sorted_idx = sorted_indices[new_idx];
const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before;
// Result of the reduction will be in this variable:
opmath_t gradient = (opmath_t)0.0;
int64_t num_warp_passes = new_num_duplicates / C10_WARP_SIZE;
// Parallel reduction across the array of duplicates using all the lanes in the warp:
for (int64_t i = 0; i < num_warp_passes; ++i) {
grad_row = ((int64_t) indices[new_idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride;
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
}
// Reduce across the lanes of the warp:
WARP_SYNC();
for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) {
gradient += WARP_SHFL_DOWN(gradient, offset);
}
if (laneIdx == 0) {
for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < new_num_duplicates; ++i) {
grad_row = ((int64_t) indices[new_idx + i]) * stride + z * numel * stride;
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
}
grad_weight[new_weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[new_weight_row]) + gradient);
}
}
}
}
#else
template <typename scalar_t, int SZ> template <typename scalar_t, int SZ>
__global__ void indexing_backward_kernel( __global__ void indexing_backward_kernel(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@ -133,169 +332,6 @@ __global__ void indexing_backward_kernel(
} }
} }
#ifdef USE_ROCM
template <typename scalar_t, bool accumulate>
__global__ void indexing_backward_kernel_rocm(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim) {
// This implementation is adopted from indexing_backward_kernel above.
using opmath_t = at::opmath_type<scalar_t>;
for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
if (idx < numel && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){
do {
// if not accumulate, we only keep the last duplicate index so skip those before it
if constexpr (!accumulate) {
if ((idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) {
idx++;
continue;
}
}
const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before;
const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride;
opmath_t gradient;
opmath_t weight;
int64_t feature_dim = threadIdx.x + blockIdx.y * blockDim.x;
while (feature_dim < stride) {
gradient = static_cast<opmath_t>(grad_output[grad_row + feature_dim]);
if constexpr (accumulate) {
weight = static_cast<opmath_t>(grad_weight[weight_row + feature_dim]);
}
if constexpr (accumulate) {
weight += gradient;
} else {
weight = gradient;
}
grad_weight[weight_row + feature_dim] = static_cast<scalar_t>(weight);
feature_dim += gridDim.y * blockDim.x;
}
idx++;
} while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]);
}
}
}
#endif
#ifdef USE_ROCM
#define SKIP 32
template <typename scalar_t>
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
using opmath_t = at::opmath_type<scalar_t>;
int laneIdx = threadIdx.x % C10_WARP_SIZE;
const opmath_t scale = (opmath_t)1.0;
int64_t grad_row = 0;
extern __shared__ unsigned char smem[];
auto smem_dups_cache = reinterpret_cast<int64_t*>(smem);
// Each warp gets a different section of the share memory allocation:
int smem_offset = threadIdx.y * C10_WARP_SIZE;
// Number of values processed by each thread (grain size)
for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) {
// Init duplicates every time we compute a new set of entries:
smem_dups_cache[smem_offset + laneIdx] = 0;
int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE;
int64_t idx = base_idx + laneIdx;
// Each lane calculates the number of duplicates:
if (idx < numel) {
int64_t crnt_sorted_idx = sorted_indices[idx];
if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) {
// Determine the number of duplicates in advance:
int64_t num_duplicates = 1;
// Lookahead in case there is a large number of duplicates. Once that is done, handle the tail.
while ((idx + num_duplicates + SKIP - 1) < numel) {
if (sorted_indices[idx + num_duplicates + SKIP - 1] != crnt_sorted_idx) break;
num_duplicates += SKIP;
}
while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
num_duplicates++;
}
if (!accumulate) {
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
grad_weight[weight_row] =
static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row]) * scale);
continue;
}
// Each lane sequentially handles the duplicate elimination:
if (num_duplicates < C10_WARP_SIZE) {
opmath_t gradient = (opmath_t)0.0;
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
for (int64_t i = 0; i < num_duplicates; ++i) {
grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
}
grad_weight[weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row]) + gradient);
} else {
// Add duplicate to the cache:
smem_dups_cache[smem_offset + laneIdx] = num_duplicates;
}
}
}
WARP_SYNC();
// All lanes in the warp are still active here. Use them all to reduce duplicates when
// large number of duplicates are present:
for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) {
// All lanes read the shared memory entry for number of duplicates
int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp];
// Check if the original sub-warp had duplicates to eliminate, if not skip.
if (new_num_duplicates == 0)
continue;
// There are duplicates that need eliminating:
int64_t new_idx = base_idx + subwarp;
int64_t new_crnt_sorted_idx = sorted_indices[new_idx];
const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before;
// Result of the reduction will be in this variable:
opmath_t gradient = (opmath_t)0.0;
int64_t num_warp_passes = new_num_duplicates / C10_WARP_SIZE;
// Parallel reduction across the array of duplicates using all the lanes in the warp:
for (int64_t i = 0; i < num_warp_passes; ++i) {
grad_row = ((int64_t) indices[new_idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride;
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
}
// Reduce across the lanes of the warp:
WARP_SYNC();
for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) {
gradient += WARP_SHFL_DOWN(gradient, offset);
}
if (laneIdx == 0) {
for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < new_num_duplicates; ++i) {
grad_row = ((int64_t) indices[new_idx + i]) * stride + z * numel * stride;
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
}
grad_weight[new_weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[new_weight_row]) + gradient);
}
}
}
}
#else
template <typename scalar_t> template <typename scalar_t>
__global__ void indexing_backward_kernel_stride_1( __global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@ -664,11 +700,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
linearIndex.numel()*sliceSize*nElemBefore == expandedValue.numel(), linearIndex.numel()*sliceSize*nElemBefore == expandedValue.numel(),
"number of flattened indices did not match number of elements in the value tensor: ", "number of flattened indices did not match number of elements in the value tensor: ",
linearIndex.numel()*sliceSize*nElemBefore, " vs ", expandedValue.numel()); linearIndex.numel()*sliceSize*nElemBefore, " vs ", expandedValue.numel());
#ifdef USE_ROCM
const int UNROLL = 1;
#else
const int UNROLL = 4; const int UNROLL = 4;
#endif
const int indices_per_block = 4; const int indices_per_block = 4;
const int warp_size = at::cuda::warp_size(); const int warp_size = at::cuda::warp_size();
dim3 grid(ceil_div(num_indices, (int64_t) indices_per_block), dim3 grid(ceil_div(num_indices, (int64_t) indices_per_block),
@ -676,10 +709,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
std::min(std::max<int>(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2])); std::min(std::max<int>(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2]));
dim3 block(warp_size, indices_per_block); dim3 block(warp_size, indices_per_block);
if (sliceSize == 1) {
#ifdef USE_ROCM #ifdef USE_ROCM
// Adapt grid size to smaller virtual warp size:
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z); dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t); size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
#define KERNEL_GRID new_grid #define KERNEL_GRID new_grid
@ -688,6 +718,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
#define KERNEL_GRID grid #define KERNEL_GRID grid
#define KERNEL_SMEM 0 #define KERNEL_SMEM 0
#endif #endif
if (sliceSize == 1) {
// This implementation is faster with high amounts of duplicates but could overflow // This implementation is faster with high amounts of duplicates but could overflow
// if FP16 / BF16 is used // if FP16 / BF16 is used
AT_DISPATCH_V2( AT_DISPATCH_V2(
@ -719,8 +751,6 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
kHalf, kHalf,
kBool, kBool,
kBFloat16); kBFloat16);
#undef KERNEL_GRID
#undef KERNEL_SMEM
} else { } else {
if (sliceSize <= warp_size) { if (sliceSize <= warp_size) {
AT_DISPATCH_V2( AT_DISPATCH_V2(
@ -751,72 +781,12 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
kHalf, kHalf,
kBool, kBool,
kBFloat16); kBFloat16);
#ifdef USE_ROCM
} else if (UNROLL == 1) {
if (accumulate) {
AT_DISPATCH_V2(
expandedValue.scalar_type(),
"indexing_backward",
AT_WRAP([&] {
indexing_backward_kernel_rocm<scalar_t, true><<<grid, block, 0, stream>>>(
sorted_indices.const_data_ptr<int64_t>(),
orig_indices.const_data_ptr<int64_t>(),
expandedValue.const_data_ptr<scalar_t>(),
src_.mutable_data_ptr<scalar_t>(),
num_indices,
sliceSize,
strideBefore,
nElemBefore);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
// AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf,
kHalf,
kBool,
kBFloat16);
} else { } else {
AT_DISPATCH_V2( AT_DISPATCH_V2(
expandedValue.scalar_type(), expandedValue.scalar_type(),
"indexing_backward", "indexing_backward",
AT_WRAP([&] { AT_WRAP([&] {
indexing_backward_kernel_rocm<scalar_t, false><<<grid, block, 0, stream>>>( indexing_backward_kernel<scalar_t, UNROLL><<<KERNEL_GRID, block, KERNEL_SMEM, stream>>>(
sorted_indices.const_data_ptr<int64_t>(),
orig_indices.const_data_ptr<int64_t>(),
expandedValue.const_data_ptr<scalar_t>(),
src_.mutable_data_ptr<scalar_t>(),
num_indices,
sliceSize,
strideBefore,
nElemBefore);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
// AT_EXPAND(AT_FLOAT8_TYPES),
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
// should not be supported here, then reenable AT_FLOAT8_DTYPES
kFloat8_e4m3fn,
kFloat8_e5m2,
kFloat8_e4m3fnuz,
kFloat8_e5m2fnuz,
kComplexHalf,
kHalf,
kBool,
kBFloat16);
}
#endif
} else {
AT_DISPATCH_V2(
expandedValue.scalar_type(),
"indexing_backward",
AT_WRAP([&] {
indexing_backward_kernel<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
sorted_indices.const_data_ptr<int64_t>(), sorted_indices.const_data_ptr<int64_t>(),
orig_indices.const_data_ptr<int64_t>(), orig_indices.const_data_ptr<int64_t>(),
expandedValue.const_data_ptr<scalar_t>(), expandedValue.const_data_ptr<scalar_t>(),
@ -843,6 +813,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
} }
} }
#undef KERNEL_GRID
#undef KERNEL_SMEM
if (permuted) { if (permuted) {
self.copy_(src_.permute(inversePerm)); self.copy_(src_.permute(inversePerm));
} else if (!self_contiguous) { } else if (!self_contiguous) {

View File

@ -992,6 +992,7 @@ class TestIndexing(TestCase):
num_indices = 401988 num_indices = 401988
max_index_range = 2000 max_index_range = 2000
target_index_range = [16, 256, 2000] target_index_range = [16, 256, 2000]
# BFloat16
for generated_index_range in target_index_range: for generated_index_range in target_index_range:
# create CPU tensors # create CPU tensors
a_tensor_size = (max_index_range, 256) a_tensor_size = (max_index_range, 256)
@ -1010,6 +1011,27 @@ class TestIndexing(TestCase):
a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True) a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True)
self.assertEqual(a_dev.cpu(), a) self.assertEqual(a_dev.cpu(), a)
# Float32
for generated_index_range in target_index_range:
# create CPU tensors
a_tensor_size = (max_index_range, 256)
a = torch.randn(a_tensor_size, dtype=torch.float32)
b = generate_indices(
num_indices=num_indices, index_range=generated_index_range
)
c_tensor_size = (num_indices, 256)
c = torch.randn(c_tensor_size, dtype=torch.float32)
# create GPU copies
a_dev = a.to(device)
b_dev = b.to(device)
c_dev = c.to(device)
# run
torch.use_deterministic_algorithms(True)
a.index_put_(indices=[b], values=c, accumulate=True)
torch.use_deterministic_algorithms(False)
a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True)
self.assertEqual(a_dev.cpu(), a)
@onlyCUDA @onlyCUDA
def test_index_put_accumulate_non_contiguous(self, device): def test_index_put_accumulate_non_contiguous(self, device):
t = torch.zeros((5, 2, 2)) t = torch.zeros((5, 2, 2))