mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[ROCm] Improve backwards indexing when stride is not one (#147630)
Improve backwards indexing when stride is not one. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147630 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
daff65d671
commit
a1cb67b69e
|
|
@ -55,6 +55,205 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
|
|||
#endif
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#define SKIP_SORTED_INDICES 32
|
||||
template <typename scalar_t, int SZ>
|
||||
__global__ void indexing_backward_kernel(
|
||||
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
|
||||
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
|
||||
extern __shared__ unsigned char smem[];
|
||||
auto smem_dups_cache = reinterpret_cast<int64_t*>(smem);
|
||||
|
||||
int smem_offset = threadIdx.y * C10_WARP_SIZE;
|
||||
|
||||
int laneIdx = threadIdx.x % C10_WARP_SIZE;
|
||||
int64_t grad_row = 0;
|
||||
|
||||
for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) {
|
||||
// Init duplicates every time we compute a new set of entries:
|
||||
smem_dups_cache[smem_offset + laneIdx] = 0;
|
||||
WARP_SYNC();
|
||||
|
||||
int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE;
|
||||
int64_t idx = base_idx + laneIdx;
|
||||
|
||||
if (idx < numel) {
|
||||
int64_t crnt_sorted_idx = sorted_indices[idx];
|
||||
|
||||
if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) {
|
||||
// Determine the number of duplicates in advance:
|
||||
int64_t num_duplicates = 1;
|
||||
|
||||
// Lookahead in case there is a large number of duplicates. Once that is done, handle the tail.
|
||||
while ((idx + num_duplicates + SKIP_SORTED_INDICES - 1) < numel) {
|
||||
if (sorted_indices[idx + num_duplicates + SKIP_SORTED_INDICES - 1] != crnt_sorted_idx) break;
|
||||
num_duplicates += SKIP_SORTED_INDICES;
|
||||
}
|
||||
while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
|
||||
num_duplicates++;
|
||||
}
|
||||
|
||||
smem_dups_cache[smem_offset + laneIdx] = num_duplicates;
|
||||
}
|
||||
}
|
||||
|
||||
WARP_SYNC();
|
||||
|
||||
// All lanes in the warp are still active here. Use them all to reduce duplicates when
|
||||
// large number of duplicates are present:
|
||||
for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) {
|
||||
// All lanes read the shared memory entry for number of duplicates
|
||||
int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp];
|
||||
|
||||
// Check if the original sub-warp had duplicates to eliminate, if not skip.
|
||||
if (new_num_duplicates == 0)
|
||||
continue;
|
||||
|
||||
// There are duplicates that need eliminating:
|
||||
int64_t new_idx = base_idx + subwarp;
|
||||
int64_t new_crnt_sorted_idx = sorted_indices[new_idx];
|
||||
const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before;
|
||||
|
||||
if (!accumulate) {
|
||||
const int64_t grad_row = ((int64_t)indices[new_idx + new_num_duplicates - 1]) * stride + z * numel * stride;
|
||||
int64_t feature_dim = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
while (feature_dim < stride) {
|
||||
grad_weight[new_weight_row + feature_dim] = grad_output[grad_row + feature_dim];
|
||||
feature_dim += gridDim.y * blockDim.x;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int dup = 0; dup < new_num_duplicates; dup++) {
|
||||
const int64_t grad_row = ((int64_t) indices[new_idx + dup]) * stride + z * numel * stride;
|
||||
|
||||
// All lanes do the same thing up to here.
|
||||
int64_t feature_dim = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Each lane has a different feature_dim.
|
||||
while (feature_dim < stride) {
|
||||
grad_weight[new_weight_row + feature_dim] += grad_output[grad_row + feature_dim];
|
||||
feature_dim += gridDim.y * blockDim.x;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void indexing_backward_kernel_stride_1(
|
||||
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
|
||||
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
|
||||
int laneIdx = threadIdx.x % C10_WARP_SIZE;
|
||||
|
||||
const opmath_t scale = (opmath_t)1.0;
|
||||
int64_t grad_row = 0;
|
||||
|
||||
extern __shared__ unsigned char smem[];
|
||||
auto smem_dups_cache = reinterpret_cast<int64_t*>(smem);
|
||||
|
||||
// Each warp gets a different section of the share memory allocation:
|
||||
int smem_offset = threadIdx.y * C10_WARP_SIZE;
|
||||
|
||||
// Number of values processed by each thread (grain size)
|
||||
for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) {
|
||||
// Init duplicates every time we compute a new set of entries:
|
||||
smem_dups_cache[smem_offset + laneIdx] = 0;
|
||||
|
||||
int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE;
|
||||
int64_t idx = base_idx + laneIdx;
|
||||
|
||||
// Each lane calculates the number of duplicates:
|
||||
if (idx < numel) {
|
||||
int64_t crnt_sorted_idx = sorted_indices[idx];
|
||||
|
||||
if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) {
|
||||
// Determine the number of duplicates in advance:
|
||||
int64_t num_duplicates = 1;
|
||||
|
||||
// Lookahead in case there is a large number of duplicates. Once that is done, handle the tail.
|
||||
while ((idx + num_duplicates + SKIP_SORTED_INDICES - 1) < numel) {
|
||||
if (sorted_indices[idx + num_duplicates + SKIP_SORTED_INDICES - 1] != crnt_sorted_idx) break;
|
||||
num_duplicates += SKIP_SORTED_INDICES;
|
||||
}
|
||||
while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
|
||||
num_duplicates++;
|
||||
}
|
||||
|
||||
if (!accumulate) {
|
||||
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
|
||||
grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
|
||||
grad_weight[weight_row] =
|
||||
static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row]) * scale);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Each lane sequentially handles the duplicate elimination:
|
||||
if (num_duplicates < C10_WARP_SIZE) {
|
||||
opmath_t gradient = (opmath_t)0.0;
|
||||
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
|
||||
for (int64_t i = 0; i < num_duplicates; ++i) {
|
||||
grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
|
||||
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
|
||||
}
|
||||
|
||||
grad_weight[weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row]) + gradient);
|
||||
} else {
|
||||
// Add duplicate to the cache:
|
||||
smem_dups_cache[smem_offset + laneIdx] = num_duplicates;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WARP_SYNC();
|
||||
|
||||
// All lanes in the warp are still active here. Use them all to reduce duplicates when
|
||||
// large number of duplicates are present:
|
||||
for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) {
|
||||
// All lanes read the shared memory entry for number of duplicates
|
||||
int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp];
|
||||
|
||||
// Check if the original sub-warp had duplicates to eliminate, if not skip.
|
||||
if (new_num_duplicates == 0)
|
||||
continue;
|
||||
|
||||
// There are duplicates that need eliminating:
|
||||
int64_t new_idx = base_idx + subwarp;
|
||||
int64_t new_crnt_sorted_idx = sorted_indices[new_idx];
|
||||
const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before;
|
||||
|
||||
// Result of the reduction will be in this variable:
|
||||
opmath_t gradient = (opmath_t)0.0;
|
||||
|
||||
int64_t num_warp_passes = new_num_duplicates / C10_WARP_SIZE;
|
||||
// Parallel reduction across the array of duplicates using all the lanes in the warp:
|
||||
for (int64_t i = 0; i < num_warp_passes; ++i) {
|
||||
grad_row = ((int64_t) indices[new_idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride;
|
||||
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
|
||||
}
|
||||
|
||||
// Reduce across the lanes of the warp:
|
||||
WARP_SYNC();
|
||||
for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
gradient += WARP_SHFL_DOWN(gradient, offset);
|
||||
}
|
||||
|
||||
if (laneIdx == 0) {
|
||||
for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < new_num_duplicates; ++i) {
|
||||
grad_row = ((int64_t) indices[new_idx + i]) * stride + z * numel * stride;
|
||||
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
|
||||
}
|
||||
|
||||
grad_weight[new_weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[new_weight_row]) + gradient);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <typename scalar_t, int SZ>
|
||||
__global__ void indexing_backward_kernel(
|
||||
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
|
||||
|
|
@ -133,169 +332,6 @@ __global__ void indexing_backward_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
template <typename scalar_t, bool accumulate>
|
||||
__global__ void indexing_backward_kernel_rocm(
|
||||
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
|
||||
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim) {
|
||||
|
||||
// This implementation is adopted from indexing_backward_kernel above.
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
|
||||
int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
if (idx < numel && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){
|
||||
do {
|
||||
// if not accumulate, we only keep the last duplicate index so skip those before it
|
||||
if constexpr (!accumulate) {
|
||||
if ((idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) {
|
||||
idx++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before;
|
||||
const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride;
|
||||
|
||||
opmath_t gradient;
|
||||
opmath_t weight;
|
||||
|
||||
int64_t feature_dim = threadIdx.x + blockIdx.y * blockDim.x;
|
||||
while (feature_dim < stride) {
|
||||
gradient = static_cast<opmath_t>(grad_output[grad_row + feature_dim]);
|
||||
if constexpr (accumulate) {
|
||||
weight = static_cast<opmath_t>(grad_weight[weight_row + feature_dim]);
|
||||
}
|
||||
|
||||
if constexpr (accumulate) {
|
||||
weight += gradient;
|
||||
} else {
|
||||
weight = gradient;
|
||||
}
|
||||
|
||||
grad_weight[weight_row + feature_dim] = static_cast<scalar_t>(weight);
|
||||
feature_dim += gridDim.y * blockDim.x;
|
||||
}
|
||||
|
||||
idx++;
|
||||
} while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#define SKIP 32
|
||||
template <typename scalar_t>
|
||||
__global__ void indexing_backward_kernel_stride_1(
|
||||
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
|
||||
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
|
||||
int laneIdx = threadIdx.x % C10_WARP_SIZE;
|
||||
|
||||
const opmath_t scale = (opmath_t)1.0;
|
||||
int64_t grad_row = 0;
|
||||
|
||||
extern __shared__ unsigned char smem[];
|
||||
auto smem_dups_cache = reinterpret_cast<int64_t*>(smem);
|
||||
|
||||
// Each warp gets a different section of the share memory allocation:
|
||||
int smem_offset = threadIdx.y * C10_WARP_SIZE;
|
||||
|
||||
// Number of values processed by each thread (grain size)
|
||||
for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z) {
|
||||
// Init duplicates every time we compute a new set of entries:
|
||||
smem_dups_cache[smem_offset + laneIdx] = 0;
|
||||
|
||||
int64_t base_idx = blockIdx.x * blockDim.y * C10_WARP_SIZE + threadIdx.y * C10_WARP_SIZE;
|
||||
int64_t idx = base_idx + laneIdx;
|
||||
|
||||
// Each lane calculates the number of duplicates:
|
||||
if (idx < numel) {
|
||||
int64_t crnt_sorted_idx = sorted_indices[idx];
|
||||
|
||||
if (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]) {
|
||||
// Determine the number of duplicates in advance:
|
||||
int64_t num_duplicates = 1;
|
||||
|
||||
// Lookahead in case there is a large number of duplicates. Once that is done, handle the tail.
|
||||
while ((idx + num_duplicates + SKIP - 1) < numel) {
|
||||
if (sorted_indices[idx + num_duplicates + SKIP - 1] != crnt_sorted_idx) break;
|
||||
num_duplicates += SKIP;
|
||||
}
|
||||
while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
|
||||
num_duplicates++;
|
||||
}
|
||||
|
||||
if (!accumulate) {
|
||||
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
|
||||
grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
|
||||
grad_weight[weight_row] =
|
||||
static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row]) * scale);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Each lane sequentially handles the duplicate elimination:
|
||||
if (num_duplicates < C10_WARP_SIZE) {
|
||||
opmath_t gradient = (opmath_t)0.0;
|
||||
const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
|
||||
for (int64_t i = 0; i < num_duplicates; ++i) {
|
||||
grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
|
||||
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
|
||||
}
|
||||
|
||||
grad_weight[weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row]) + gradient);
|
||||
} else {
|
||||
// Add duplicate to the cache:
|
||||
smem_dups_cache[smem_offset + laneIdx] = num_duplicates;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WARP_SYNC();
|
||||
|
||||
// All lanes in the warp are still active here. Use them all to reduce duplicates when
|
||||
// large number of duplicates are present:
|
||||
for (int subwarp = 0; subwarp < C10_WARP_SIZE; subwarp++) {
|
||||
// All lanes read the shared memory entry for number of duplicates
|
||||
int64_t new_num_duplicates = smem_dups_cache[smem_offset + subwarp];
|
||||
|
||||
// Check if the original sub-warp had duplicates to eliminate, if not skip.
|
||||
if (new_num_duplicates == 0)
|
||||
continue;
|
||||
|
||||
// There are duplicates that need eliminating:
|
||||
int64_t new_idx = base_idx + subwarp;
|
||||
int64_t new_crnt_sorted_idx = sorted_indices[new_idx];
|
||||
const int64_t new_weight_row = new_crnt_sorted_idx * stride + z * stride_before;
|
||||
|
||||
// Result of the reduction will be in this variable:
|
||||
opmath_t gradient = (opmath_t)0.0;
|
||||
|
||||
int64_t num_warp_passes = new_num_duplicates / C10_WARP_SIZE;
|
||||
// Parallel reduction across the array of duplicates using all the lanes in the warp:
|
||||
for (int64_t i = 0; i < num_warp_passes; ++i) {
|
||||
grad_row = ((int64_t) indices[new_idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride;
|
||||
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
|
||||
}
|
||||
|
||||
// Reduce across the lanes of the warp:
|
||||
WARP_SYNC();
|
||||
for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
gradient += WARP_SHFL_DOWN(gradient, offset);
|
||||
}
|
||||
|
||||
if (laneIdx == 0) {
|
||||
for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < new_num_duplicates; ++i) {
|
||||
grad_row = ((int64_t) indices[new_idx + i]) * stride + z * numel * stride;
|
||||
gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
|
||||
}
|
||||
|
||||
grad_weight[new_weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[new_weight_row]) + gradient);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <typename scalar_t>
|
||||
__global__ void indexing_backward_kernel_stride_1(
|
||||
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
|
||||
|
|
@ -664,11 +700,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
|||
linearIndex.numel()*sliceSize*nElemBefore == expandedValue.numel(),
|
||||
"number of flattened indices did not match number of elements in the value tensor: ",
|
||||
linearIndex.numel()*sliceSize*nElemBefore, " vs ", expandedValue.numel());
|
||||
#ifdef USE_ROCM
|
||||
const int UNROLL = 1;
|
||||
#else
|
||||
|
||||
const int UNROLL = 4;
|
||||
#endif
|
||||
const int indices_per_block = 4;
|
||||
const int warp_size = at::cuda::warp_size();
|
||||
dim3 grid(ceil_div(num_indices, (int64_t) indices_per_block),
|
||||
|
|
@ -676,18 +709,17 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
|||
std::min(std::max<int>(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2]));
|
||||
dim3 block(warp_size, indices_per_block);
|
||||
|
||||
|
||||
if (sliceSize == 1) {
|
||||
#ifdef USE_ROCM
|
||||
// Adapt grid size to smaller virtual warp size:
|
||||
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
|
||||
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
|
||||
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
|
||||
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
|
||||
#define KERNEL_GRID new_grid
|
||||
#define KERNEL_SMEM smem_dups_size
|
||||
#else
|
||||
#define KERNEL_GRID grid
|
||||
#define KERNEL_SMEM 0
|
||||
#endif
|
||||
|
||||
if (sliceSize == 1) {
|
||||
// This implementation is faster with high amounts of duplicates but could overflow
|
||||
// if FP16 / BF16 is used
|
||||
AT_DISPATCH_V2(
|
||||
|
|
@ -719,8 +751,6 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
|||
kHalf,
|
||||
kBool,
|
||||
kBFloat16);
|
||||
#undef KERNEL_GRID
|
||||
#undef KERNEL_SMEM
|
||||
} else {
|
||||
if (sliceSize <= warp_size) {
|
||||
AT_DISPATCH_V2(
|
||||
|
|
@ -751,72 +781,12 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
|||
kHalf,
|
||||
kBool,
|
||||
kBFloat16);
|
||||
#ifdef USE_ROCM
|
||||
} else if (UNROLL == 1) {
|
||||
if (accumulate) {
|
||||
AT_DISPATCH_V2(
|
||||
expandedValue.scalar_type(),
|
||||
"indexing_backward",
|
||||
AT_WRAP([&] {
|
||||
indexing_backward_kernel_rocm<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
sorted_indices.const_data_ptr<int64_t>(),
|
||||
orig_indices.const_data_ptr<int64_t>(),
|
||||
expandedValue.const_data_ptr<scalar_t>(),
|
||||
src_.mutable_data_ptr<scalar_t>(),
|
||||
num_indices,
|
||||
sliceSize,
|
||||
strideBefore,
|
||||
nElemBefore);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
// AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
|
||||
// should not be supported here, then reenable AT_FLOAT8_DTYPES
|
||||
kFloat8_e4m3fn,
|
||||
kFloat8_e5m2,
|
||||
kFloat8_e4m3fnuz,
|
||||
kFloat8_e5m2fnuz,
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
kBFloat16);
|
||||
} else {
|
||||
AT_DISPATCH_V2(
|
||||
expandedValue.scalar_type(),
|
||||
"indexing_backward",
|
||||
AT_WRAP([&] {
|
||||
indexing_backward_kernel_rocm<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||
sorted_indices.const_data_ptr<int64_t>(),
|
||||
orig_indices.const_data_ptr<int64_t>(),
|
||||
expandedValue.const_data_ptr<scalar_t>(),
|
||||
src_.mutable_data_ptr<scalar_t>(),
|
||||
num_indices,
|
||||
sliceSize,
|
||||
strideBefore,
|
||||
nElemBefore);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
// AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
|
||||
// should not be supported here, then reenable AT_FLOAT8_DTYPES
|
||||
kFloat8_e4m3fn,
|
||||
kFloat8_e5m2,
|
||||
kFloat8_e4m3fnuz,
|
||||
kFloat8_e5m2fnuz,
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
kBFloat16);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
AT_DISPATCH_V2(
|
||||
expandedValue.scalar_type(),
|
||||
"indexing_backward",
|
||||
AT_WRAP([&] {
|
||||
indexing_backward_kernel<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
|
||||
indexing_backward_kernel<scalar_t, UNROLL><<<KERNEL_GRID, block, KERNEL_SMEM, stream>>>(
|
||||
sorted_indices.const_data_ptr<int64_t>(),
|
||||
orig_indices.const_data_ptr<int64_t>(),
|
||||
expandedValue.const_data_ptr<scalar_t>(),
|
||||
|
|
@ -843,6 +813,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
|
|||
}
|
||||
}
|
||||
|
||||
#undef KERNEL_GRID
|
||||
#undef KERNEL_SMEM
|
||||
|
||||
if (permuted) {
|
||||
self.copy_(src_.permute(inversePerm));
|
||||
} else if (!self_contiguous) {
|
||||
|
|
|
|||
|
|
@ -992,6 +992,7 @@ class TestIndexing(TestCase):
|
|||
num_indices = 401988
|
||||
max_index_range = 2000
|
||||
target_index_range = [16, 256, 2000]
|
||||
# BFloat16
|
||||
for generated_index_range in target_index_range:
|
||||
# create CPU tensors
|
||||
a_tensor_size = (max_index_range, 256)
|
||||
|
|
@ -1010,6 +1011,27 @@ class TestIndexing(TestCase):
|
|||
a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True)
|
||||
self.assertEqual(a_dev.cpu(), a)
|
||||
|
||||
# Float32
|
||||
for generated_index_range in target_index_range:
|
||||
# create CPU tensors
|
||||
a_tensor_size = (max_index_range, 256)
|
||||
a = torch.randn(a_tensor_size, dtype=torch.float32)
|
||||
b = generate_indices(
|
||||
num_indices=num_indices, index_range=generated_index_range
|
||||
)
|
||||
c_tensor_size = (num_indices, 256)
|
||||
c = torch.randn(c_tensor_size, dtype=torch.float32)
|
||||
# create GPU copies
|
||||
a_dev = a.to(device)
|
||||
b_dev = b.to(device)
|
||||
c_dev = c.to(device)
|
||||
# run
|
||||
torch.use_deterministic_algorithms(True)
|
||||
a.index_put_(indices=[b], values=c, accumulate=True)
|
||||
torch.use_deterministic_algorithms(False)
|
||||
a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True)
|
||||
self.assertEqual(a_dev.cpu(), a)
|
||||
|
||||
@onlyCUDA
|
||||
def test_index_put_accumulate_non_contiguous(self, device):
|
||||
t = torch.zeros((5, 2, 2))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user