mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support multi-dimensional lengths in segment_reduce to support pytorch_scatter.segment_* functionalities (CUDA)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77061 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
c0a7c1d02e
commit
40f7ef1f3d
|
|
@ -13,6 +13,8 @@
|
||||||
#else
|
#else
|
||||||
#include <ATen/ops/empty.h>
|
#include <ATen/ops/empty.h>
|
||||||
#include <ATen/ops/zeros.h>
|
#include <ATen/ops/zeros.h>
|
||||||
|
#include <ATen/ops/cat.h>
|
||||||
|
#include <ATen/ops/cumsum.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
@ -68,7 +70,7 @@ Tensor _get_complete_sum(const Tensor& lengths) {
|
||||||
offsets[0].zero_();
|
offsets[0].zero_();
|
||||||
|
|
||||||
AT_DISPATCH_INDEX_TYPES(
|
AT_DISPATCH_INDEX_TYPES(
|
||||||
lengths.type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
|
lengths.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
|
||||||
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
|
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
|
||||||
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
|
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
|
||||||
at::cuda::cub::inclusive_sum(
|
at::cuda::cub::inclusive_sum(
|
||||||
|
|
@ -108,22 +110,34 @@ __global__ void segment_reduce_forward_kernel(
|
||||||
const index_t* lengths_data,
|
const index_t* lengths_data,
|
||||||
const index_t* lengths_cumsum_data,
|
const index_t* lengths_cumsum_data,
|
||||||
const int64_t segment_count,
|
const int64_t segment_count,
|
||||||
const int64_t stride_count,
|
const int64_t lengths_stride_axis,
|
||||||
bool is_initial_set,
|
bool is_initial_set,
|
||||||
scalar_t initial_value) {
|
scalar_t initial_value,
|
||||||
|
const int64_t outer_offset,
|
||||||
|
const int64_t inner_offset,
|
||||||
|
const int64_t data_stride_axis,
|
||||||
|
const int64_t data_size_axis,
|
||||||
|
const int64_t output_stride_axis,
|
||||||
|
const int64_t output_size_axis,
|
||||||
|
const int64_t lengths_cumsum_stride_axis) {
|
||||||
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int64_t row_id = idx / stride_count;
|
if (idx >= (outer_offset * segment_count * inner_offset)) {
|
||||||
int64_t lane_id = idx % stride_count;
|
|
||||||
if (idx >= (segment_count * stride_count)) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
int64_t offset_start = lengths_cumsum_data[row_id];
|
int64_t row_id = idx / inner_offset;
|
||||||
int64_t offset_end = lengths_cumsum_data[row_id + 1];
|
int64_t lane_id = idx % inner_offset; // lane_id is the inner_idx
|
||||||
|
int64_t outer_idx = row_id / segment_count;
|
||||||
|
int64_t dim_idx = row_id % segment_count;
|
||||||
|
|
||||||
|
int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1) + dim_idx;
|
||||||
|
index_t offset_start = lengths_cumsum_data[offset_idx];
|
||||||
|
index_t offset_end = lengths_cumsum_data[offset_idx + 1];
|
||||||
|
|
||||||
// ===== step2: apply reduction
|
// ===== step2: apply reduction
|
||||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
for (index_t j = offset_start; j < offset_end; ++j) {
|
||||||
int64_t starting_index = (j * stride_count) + lane_id;
|
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||||
const auto data = values_data[starting_index];
|
+ j * data_stride_axis + lane_id;
|
||||||
|
const auto data = values_data[data_index];
|
||||||
// TODO: There is no need to branch with every element
|
// TODO: There is no need to branch with every element
|
||||||
if (reduction == SegmentReductionType::MAX) {
|
if (reduction == SegmentReductionType::MAX) {
|
||||||
initial_value =
|
initial_value =
|
||||||
|
|
@ -142,19 +156,22 @@ __global__ void segment_reduce_forward_kernel(
|
||||||
}
|
}
|
||||||
|
|
||||||
// ===== step3: finalize reduction
|
// ===== step3: finalize reduction
|
||||||
CUDA_KERNEL_ASSERT(lengths_data[row_id] >= 0);
|
int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
|
||||||
if (lengths_data[row_id] == 0 && !is_initial_set &&
|
CUDA_KERNEL_ASSERT(lengths_data[lengths_idx] >= 0);
|
||||||
|
if (lengths_data[lengths_idx] == 0 && !is_initial_set &&
|
||||||
reduction == SegmentReductionType::MEAN) {
|
reduction == SegmentReductionType::MEAN) {
|
||||||
initial_value = static_cast<scalar_t>(NAN);
|
initial_value = static_cast<scalar_t>(NAN);
|
||||||
} else if (
|
} else if (
|
||||||
reduction == SegmentReductionType::MEAN && lengths_data[row_id] > 0 &&
|
reduction == SegmentReductionType::MEAN && lengths_data[lengths_idx] > 0 &&
|
||||||
!at::_isnan(initial_value)) {
|
!at::_isnan(initial_value)) {
|
||||||
initial_value = initial_value / lengths_data[row_id];
|
initial_value = initial_value / lengths_data[lengths_idx];
|
||||||
}
|
}
|
||||||
int64_t output_index = (row_id * stride_count) + lane_id;
|
int64_t output_index = outer_idx * output_stride_axis * output_size_axis
|
||||||
|
+ dim_idx * output_stride_axis + lane_id;
|
||||||
output_data[output_index] = initial_value;
|
output_data[output_index] = initial_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename scalar_t, typename index_t>
|
template <typename scalar_t, typename index_t>
|
||||||
__global__ void segment_reduce_backward_kernel(
|
__global__ void segment_reduce_backward_kernel(
|
||||||
SegmentReductionType reduction,
|
SegmentReductionType reduction,
|
||||||
|
|
@ -165,32 +182,46 @@ __global__ void segment_reduce_backward_kernel(
|
||||||
const index_t* lengths_data,
|
const index_t* lengths_data,
|
||||||
const index_t* lengths_cumsum_data,
|
const index_t* lengths_cumsum_data,
|
||||||
const int64_t segment_count,
|
const int64_t segment_count,
|
||||||
const int64_t stride_count,
|
const int64_t lengths_stride_axis,
|
||||||
scalar_t initial_prod_value) {
|
scalar_t initial_prod_value,
|
||||||
|
const int64_t outer_offset,
|
||||||
|
const int64_t inner_offset,
|
||||||
|
const int64_t data_stride_axis,
|
||||||
|
const int64_t data_size_axis,
|
||||||
|
const int64_t output_stride_axis,
|
||||||
|
const int64_t output_size_axis,
|
||||||
|
const int64_t lengths_cumsum_stride_axis) {
|
||||||
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int64_t row_id = idx / stride_count;
|
if (idx >= (outer_offset * segment_count * inner_offset)) {
|
||||||
int64_t lane_id = idx % stride_count;
|
|
||||||
|
|
||||||
if (idx >= (segment_count * stride_count)) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (lengths_data[row_id] == 0) {
|
int64_t row_id = idx / inner_offset;
|
||||||
|
int64_t lane_id = idx % inner_offset; // lane_id is the inner_idx
|
||||||
|
int64_t outer_idx = row_id / segment_count;
|
||||||
|
int64_t dim_idx = row_id % segment_count;
|
||||||
|
|
||||||
|
int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
|
||||||
|
auto segment_length = lengths_data[lengths_idx];
|
||||||
|
if (segment_length == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t offset_start = lengths_cumsum_data[row_id];
|
int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1) + dim_idx;
|
||||||
int64_t offset_end = lengths_cumsum_data[row_id + 1];
|
index_t offset_start = lengths_cumsum_data[offset_idx];
|
||||||
|
index_t offset_end = lengths_cumsum_data[offset_idx + 1];
|
||||||
|
|
||||||
int64_t output_index = (row_id * stride_count) + lane_id;
|
int64_t output_index = outer_idx * output_stride_axis * output_size_axis
|
||||||
|
+ dim_idx * output_stride_axis + lane_id;
|
||||||
|
|
||||||
if (reduction == SegmentReductionType::MAX ||
|
if (reduction == SegmentReductionType::MAX ||
|
||||||
reduction == SegmentReductionType::MIN) {
|
reduction == SegmentReductionType::MIN) {
|
||||||
int64_t counter = 0;
|
int64_t counter = 0;
|
||||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||||
int64_t starting_index = (j * stride_count) + lane_id;
|
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||||
if (at::_isnan(values_data[starting_index]) ||
|
+ j * data_stride_axis + lane_id;
|
||||||
values_data[starting_index] == output_data[output_index]) {
|
if (at::_isnan(values_data[data_index]) ||
|
||||||
grad_input_data[starting_index] = grad_data[output_index];
|
values_data[data_index] == output_data[output_index]) {
|
||||||
|
grad_input_data[data_index] = grad_data[output_index];
|
||||||
counter++;
|
counter++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -200,42 +231,47 @@ __global__ void segment_reduce_backward_kernel(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||||
int64_t starting_index = (j * stride_count) + lane_id;
|
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||||
if (grad_input_data[starting_index] > 0) {
|
+ j * data_stride_axis + lane_id;
|
||||||
grad_input_data[starting_index] =
|
if (grad_input_data[data_index] > 0) {
|
||||||
grad_input_data[starting_index] / counter;
|
grad_input_data[data_index] =
|
||||||
|
grad_input_data[data_index] / counter;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (reduction == SegmentReductionType::MEAN) {
|
} else if (reduction == SegmentReductionType::MEAN) {
|
||||||
auto grad_val = grad_data[output_index] / lengths_data[row_id];
|
auto grad_val = grad_data[output_index] / segment_length;
|
||||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||||
int64_t starting_index = (j * stride_count) + lane_id;
|
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||||
grad_input_data[starting_index] = grad_val;
|
+ j * data_stride_axis + lane_id;
|
||||||
|
grad_input_data[data_index] = grad_val;
|
||||||
}
|
}
|
||||||
} else if (reduction == SegmentReductionType::SUM) {
|
} else if (reduction == SegmentReductionType::SUM) {
|
||||||
const auto& grad_val = grad_data[output_index];
|
const auto& grad_val = grad_data[output_index];
|
||||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||||
int64_t starting_index = (j * stride_count) + lane_id;
|
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||||
grad_input_data[starting_index] = grad_val;
|
+ j * data_stride_axis + lane_id;
|
||||||
|
grad_input_data[data_index] = grad_val;
|
||||||
}
|
}
|
||||||
} else if (reduction == SegmentReductionType::PROD) {
|
} else if (reduction == SegmentReductionType::PROD) {
|
||||||
const auto& grad_val = grad_data[output_index] * output_data[output_index];
|
const auto& grad_val = grad_data[output_index] * output_data[output_index];
|
||||||
for (int64_t j = offset_start; j < offset_end; ++j) {
|
for (int64_t j = offset_start; j < offset_end; ++j) {
|
||||||
int64_t starting_index = (j * stride_count) + lane_id;
|
int64_t data_index = outer_idx * data_stride_axis * data_size_axis
|
||||||
if (at::_isnan(values_data[starting_index]) ||
|
+ j * data_stride_axis + lane_id;
|
||||||
values_data[starting_index] == 0) {
|
if (at::_isnan(values_data[data_index]) ||
|
||||||
|
values_data[data_index] == 0) {
|
||||||
// explicitly compute exclusive prod
|
// explicitly compute exclusive prod
|
||||||
scalar_t exclusive_prod = initial_prod_value;
|
scalar_t exclusive_prod = initial_prod_value;
|
||||||
int64_t idx;
|
int64_t prod_idx;
|
||||||
for (int64_t k = offset_start; k < offset_end; ++k) {
|
for (int64_t k = offset_start; k < offset_end; ++k) {
|
||||||
if (k != j) {
|
if (k != j) {
|
||||||
idx = (k * stride_count) + lane_id;
|
prod_idx = outer_idx * data_stride_axis * data_size_axis
|
||||||
exclusive_prod *= values_data[idx];
|
+ k * data_stride_axis + lane_id;
|
||||||
|
exclusive_prod *= values_data[prod_idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
grad_input_data[starting_index] = grad_data[output_index] * exclusive_prod;
|
grad_input_data[data_index] = grad_data[output_index] * exclusive_prod;
|
||||||
} else {
|
} else {
|
||||||
grad_input_data[starting_index] = grad_val / values_data[starting_index];
|
grad_input_data[data_index] = grad_val / values_data[data_index];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -251,28 +287,43 @@ Tensor _segment_reduce_cuda_backward_kernel(
|
||||||
const Tensor& lengths_contig,
|
const Tensor& lengths_contig,
|
||||||
int64_t axis,
|
int64_t axis,
|
||||||
const c10::optional<Scalar>& initial) {
|
const c10::optional<Scalar>& initial) {
|
||||||
int64_t segment_count = lengths_contig.numel();
|
axis = lengths_contig.dim() - 1;
|
||||||
auto output_shape = data_contig.sizes().vec();
|
int64_t segment_count = lengths_contig.size(axis);
|
||||||
output_shape[axis] = segment_count;
|
int64_t lengths_stride_axis = lengths_contig.stride(axis);
|
||||||
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
|
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
|
||||||
|
|
||||||
int64_t stride_count = data_contig.numel() / data_contig.size(axis);
|
auto zeros_shape = lengths_contig.sizes().vec();
|
||||||
|
zeros_shape[axis] = 1;
|
||||||
|
auto offsets = at::cat({at::zeros(zeros_shape, lengths_contig.options()), lengths_contig}, axis);
|
||||||
|
offsets.cumsum_(axis);
|
||||||
|
|
||||||
auto offsets = _get_complete_sum(lengths_contig);
|
// outer_offset is the size of the outer dimensions of output (before axis)
|
||||||
|
// inner_offset is the size of the inner dimensions of output (after axis)
|
||||||
|
int64_t outer_offset = 1, inner_offset = 1;
|
||||||
|
for (int64_t d = 0; d < axis; d++) {
|
||||||
|
outer_offset *= output_contig.size(d);
|
||||||
|
}
|
||||||
|
for (int64_t d = axis + 1; d < output_contig.dim(); d++) {
|
||||||
|
inner_offset *= output_contig.size(d);
|
||||||
|
}
|
||||||
|
|
||||||
constexpr int threads_per_block = 256;
|
constexpr int threads_per_block = 256;
|
||||||
int64_t num_blocks =
|
int64_t num_blocks = (outer_offset * inner_offset * segment_count + threads_per_block - 1) / threads_per_block;
|
||||||
((segment_count * stride_count) + threads_per_block - 1) /
|
|
||||||
threads_per_block;
|
|
||||||
|
|
||||||
num_blocks = std::max(num_blocks, (int64_t)1);
|
num_blocks = std::max(num_blocks, (int64_t)1);
|
||||||
|
|
||||||
|
auto data_stride_axis = data_contig.stride(axis);
|
||||||
|
auto data_size_axis = data_contig.size(axis);
|
||||||
|
auto output_stride_axis = output_contig.stride(axis);
|
||||||
|
auto output_size_axis = output_contig.size(axis);
|
||||||
|
auto offsets_stride_axis = offsets.stride(axis);
|
||||||
|
|
||||||
AT_DISPATCH_INDEX_TYPES(
|
AT_DISPATCH_INDEX_TYPES(
|
||||||
lengths_contig.type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
|
lengths_contig.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
|
||||||
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
|
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
|
||||||
auto* offsets_data = offsets.data_ptr<index_t>();
|
auto* offsets_data = offsets.data_ptr<index_t>();
|
||||||
|
|
||||||
// TODO: Swtich to TensorIterator for better maintainablility and
|
// TODO: Switch to TensorIterator for better maintainablility and
|
||||||
// readability
|
// readability
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||||
kBFloat16,
|
kBFloat16,
|
||||||
|
|
@ -305,8 +356,16 @@ Tensor _segment_reduce_cuda_backward_kernel(
|
||||||
lengths_data,
|
lengths_data,
|
||||||
offsets_data,
|
offsets_data,
|
||||||
segment_count,
|
segment_count,
|
||||||
stride_count,
|
lengths_stride_axis,
|
||||||
initial_prod_value);
|
initial_prod_value,
|
||||||
|
outer_offset,
|
||||||
|
inner_offset,
|
||||||
|
data_stride_axis,
|
||||||
|
data_size_axis,
|
||||||
|
output_stride_axis,
|
||||||
|
output_size_axis,
|
||||||
|
offsets_stride_axis
|
||||||
|
);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
@ -319,24 +378,46 @@ Tensor _segment_reduce_cuda_kernel(
|
||||||
const Tensor& lengths,
|
const Tensor& lengths,
|
||||||
int64_t axis,
|
int64_t axis,
|
||||||
const c10::optional<Scalar>& initial) {
|
const c10::optional<Scalar>& initial) {
|
||||||
int64_t segment_count = lengths.numel();
|
// data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
|
||||||
|
TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
|
||||||
|
TORCH_CHECK(lengths.is_contiguous(), "Expected lengths to be contiguous.");
|
||||||
|
axis = lengths.dim() - 1;
|
||||||
|
int64_t segment_count = lengths.size(axis);
|
||||||
|
int64_t lengths_stride_axis = lengths.stride(axis);
|
||||||
auto output_shape = data.sizes().vec();
|
auto output_shape = data.sizes().vec();
|
||||||
output_shape[axis] = segment_count;
|
output_shape[axis] = segment_count;
|
||||||
auto output = at::empty(output_shape, data.options());
|
auto output = at::empty(output_shape, data.options());
|
||||||
|
|
||||||
int64_t stride_count = data.numel() / data.size(axis);
|
// _get_complete_sum only supports 1D?
|
||||||
|
auto zeros_shape = lengths.sizes().vec();
|
||||||
|
zeros_shape[axis] = 1;
|
||||||
|
auto offsets = at::cat({at::zeros(zeros_shape, lengths.options()), lengths}, axis);
|
||||||
|
offsets.cumsum_(axis);
|
||||||
|
|
||||||
auto offsets = _get_complete_sum(lengths);
|
// outer_offset is the size of the outer dimensions of output (before axis)
|
||||||
|
// inner_offset is the size of the inner dimensions of output (after axis)
|
||||||
|
int64_t outer_offset = 1, inner_offset = 1;
|
||||||
|
for (int64_t d = 0; d < axis; d++) {
|
||||||
|
outer_offset *= output.size(d);
|
||||||
|
}
|
||||||
|
for (int64_t d = axis + 1; d < output.dim(); d++) {
|
||||||
|
inner_offset *= output.size(d);
|
||||||
|
}
|
||||||
|
|
||||||
constexpr int threads_per_block = 256;
|
constexpr int threads_per_block = 256;
|
||||||
int64_t num_blocks =
|
// segment_count * stride_count is just output.numel() ?
|
||||||
((segment_count * stride_count) + threads_per_block - 1) /
|
int64_t num_blocks = (output.numel() + threads_per_block - 1) / threads_per_block;
|
||||||
threads_per_block;
|
|
||||||
|
|
||||||
num_blocks = std::max(num_blocks, (int64_t)1);
|
num_blocks = std::max(num_blocks, (int64_t)1);
|
||||||
|
|
||||||
|
auto data_stride_axis = data.stride(axis);
|
||||||
|
auto data_size_axis = data.size(axis);
|
||||||
|
auto output_stride_axis = output.stride(axis);
|
||||||
|
auto output_size_axis = output.size(axis);
|
||||||
|
auto offsets_stride_axis = offsets.stride(axis);
|
||||||
|
|
||||||
AT_DISPATCH_INDEX_TYPES(
|
AT_DISPATCH_INDEX_TYPES(
|
||||||
lengths.type(), "_segment_reduce_cuda_kernel1", ([&] {
|
lengths.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
|
||||||
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
|
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
|
||||||
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
|
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||||
|
|
@ -376,9 +457,17 @@ Tensor _segment_reduce_cuda_kernel(
|
||||||
lengths_data_ptr,
|
lengths_data_ptr,
|
||||||
offsets_data_ptr,
|
offsets_data_ptr,
|
||||||
segment_count,
|
segment_count,
|
||||||
stride_count,
|
lengths_stride_axis,
|
||||||
initial.has_value(),
|
initial.has_value(),
|
||||||
initial_value);
|
initial_value,
|
||||||
|
outer_offset,
|
||||||
|
inner_offset,
|
||||||
|
data_stride_axis,
|
||||||
|
data_size_axis,
|
||||||
|
output_stride_axis,
|
||||||
|
output_size_axis,
|
||||||
|
offsets_stride_axis
|
||||||
|
);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
} else {
|
} else {
|
||||||
if (reduction == SegmentReductionType::MAX) {
|
if (reduction == SegmentReductionType::MAX) {
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,12 @@ import torch
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
dtypes,
|
dtypes,
|
||||||
onlyCPU
|
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
TestCase,
|
TestCase,
|
||||||
run_tests,
|
run_tests,
|
||||||
gradcheck,
|
gradcheck,
|
||||||
parametrize
|
parametrize,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -300,7 +299,6 @@ class TestSegmentReductions(TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
|
@parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
|
||||||
@onlyCPU # will be removed in next PR where CUDA implementation of segment_reduce is adjusted
|
|
||||||
def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
|
def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
|
||||||
val_dtype, length_dtype = dtypes
|
val_dtype, length_dtype = dtypes
|
||||||
# zero-length segments are filled with reduction inits contrary to pytorch_scatter.
|
# zero-length segments are filled with reduction inits contrary to pytorch_scatter.
|
||||||
|
|
@ -384,7 +382,6 @@ class TestSegmentReductions(TestCase):
|
||||||
axis=dim,
|
axis=dim,
|
||||||
unsafe=True,
|
unsafe=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(actual_result, expected)
|
self.assertEqual(actual_result, expected)
|
||||||
|
|
||||||
if val_dtype == torch.float64:
|
if val_dtype == torch.float64:
|
||||||
|
|
@ -469,20 +466,19 @@ class TestSegmentReductions(TestCase):
|
||||||
check_backward,
|
check_backward,
|
||||||
)
|
)
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
@dtypes(torch.int, torch.int64)
|
@dtypes(torch.int, torch.int64)
|
||||||
def test_unsafe_flag(self, device, dtype):
|
def test_unsafe_flag(self, device, dtype):
|
||||||
length_type = dtype
|
length_type = dtype
|
||||||
lengths = torch.tensor([0, 2, 3, 0], dtype=length_type)
|
lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
|
||||||
data = torch.arange(6).float()
|
data = torch.arange(6, dtype=torch.float, device=device)
|
||||||
|
|
||||||
# test for error on 1-D lenghts
|
# test for error on 1-D lenghts
|
||||||
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
|
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
|
||||||
torch.segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
|
torch.segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
|
||||||
|
|
||||||
# test for error on multi-D lengths
|
# test for error on multi-D lengths
|
||||||
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type)
|
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
|
||||||
nd_data = torch.arange(12).reshape(2, 6).float()
|
nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
|
||||||
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
|
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
|
||||||
torch.segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)
|
torch.segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8327,7 +8327,7 @@ def sample_inputs_scatter_reduce(op_info, device, dtype, requires_grad, **kwargs
|
||||||
args=(1, idx, src, reduce),
|
args=(1, idx, src, reduce),
|
||||||
kwargs={'include_self': True})
|
kwargs={'include_self': True})
|
||||||
|
|
||||||
def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, **kwargs):
|
def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode='lengths', **kwargs):
|
||||||
def _tensor(shape, dtype=dtype, low=None, high=None):
|
def _tensor(shape, dtype=dtype, low=None, high=None):
|
||||||
return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
|
return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
|
@ -8340,6 +8340,11 @@ def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, **kwargs
|
||||||
((S, S), 0, [0, 1, 2, 2], False),
|
((S, S), 0, [0, 1, 2, 2], False),
|
||||||
# test when lengths do not sum to dim size
|
# test when lengths do not sum to dim size
|
||||||
((M, S, S), 0, [1, 2, 0, 6, 0], True),
|
((M, S, S), 0, [1, 2, 0, 6, 0], True),
|
||||||
|
# test for higher dimensions
|
||||||
|
((S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False),
|
||||||
|
((S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
|
||||||
|
((S, S, S), 1, [[0, 1, 2, 2] for _ in range(S)], False),
|
||||||
|
((S, S, S), 1, [[2, 0, 3, 0], [0, 1, 2, 2], [3, 0, 2, 0], [1, 1, 1, 2], [0, 1, 2, 2]], False),
|
||||||
)
|
)
|
||||||
|
|
||||||
reductions = ["max", "mean", "min", "sum", "prod"]
|
reductions = ["max", "mean", "min", "sum", "prod"]
|
||||||
|
|
@ -19373,6 +19378,7 @@ op_db: List[OpInfo] = [
|
||||||
),
|
),
|
||||||
OpInfo(
|
OpInfo(
|
||||||
'segment_reduce',
|
'segment_reduce',
|
||||||
|
variant_test_name='lengths',
|
||||||
dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
dtypes=floating_types_and(torch.float16, torch.bfloat16),
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
# RuntimeError: derivative for aten::_segment_reduce_backward is not implemented
|
# RuntimeError: derivative for aten::_segment_reduce_backward is not implemented
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user