mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
port sparse_mm.reduce to pytorch and optimize it on CPU (#83727)
### Motivation of this PR This patch is to migrate `spmm_reduce` from `torch-sparse` (a 3rd party dependency for PyG) to `torch`, which is a response to the initial proposal for fusion of **Gather, Apply Scatter** in Message Passing of GNN inference/training. https://github.com/pytorch/pytorch/issues/71300 **GAS** is the major step for Message Passing, the behavior of **GAS** can be classified into 2 kinds depending on the storage type of `EdgeIndex` which records the connections of nodes: * COO: the hotspot is `scatter_reduce` * CSR: the hotspot is `spmm_reduce` The reduce type can be choose from: "max", "mean", "max", "min". extend `torch.sparse.mm` with an `reduce` argument, maps to `torch.sparse_mm.reduce` internally. `sparse_mm_reduce` is registered under the TensorTypeId of `SparseCsrCPU`, and this operator requires an internal interface `_sparse_mm_reduce_impl` which has dual outputs: * `out` - the actual output * `arg_out` - records output indices in the non zero elements if the reduce type is "max" or "min", this is only useful for training. So for inference, it will not be calculated. ### Performance Benchmark on GCN for obgn-products on Xeon single socket, the workload is improved by `4.3x` with this patch. Performance benefit for training will be bigger, the original backward impl for `sum|mean` is sequential; the original backward impl for `max|min` is not fused. #### before: ``` ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ torch_sparse::spmm_sum 97.09% 56.086s 97.09% 56.088s 6.232s 9 aten::linear 0.00% 85.000us 1.38% 795.485ms 88.387ms 9 aten::matmul 0.00% 57.000us 1.38% 795.260ms 88.362ms 9 aten::mm 1.38% 795.201ms 1.38% 795.203ms 88.356ms 9 aten::relu 0.00% 50.000us 0.76% 440.434ms 73.406ms 6 aten::clamp_min 0.76% 440.384ms 0.76% 440.384ms 73.397ms 6 aten::add_ 0.57% 327.801ms 0.57% 327.801ms 36.422ms 9 aten::log_softmax 0.00% 23.000us 0.10% 55.503ms 18.501ms 3 ``` #### after ``` ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls ----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ aten::spmm_sum 87.35% 11.826s 87.36% 11.827s 1.314s 9 aten::linear 0.00% 92.000us 5.87% 794.451ms 88.272ms 9 aten::matmul 0.00% 62.000us 5.87% 794.208ms 88.245ms 9 aten::mm 5.87% 794.143ms 5.87% 794.146ms 88.238ms 9 aten::relu 0.00% 53.000us 3.35% 452.977ms 75.496ms 6 aten::clamp_min 3.35% 452.924ms 3.35% 452.924ms 75.487ms 6 aten::add_ 2.58% 348.663ms 2.58% 348.663ms 38.740ms 9 aten::argmax 0.42% 57.473ms 0.42% 57.475ms 14.369ms 4 aten::log_softmax 0.00% 22.000us 0.39% 52.605ms 17.535ms 3 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/83727 Approved by: https://github.com/jgong5, https://github.com/cpuhrsch, https://github.com/rusty1s, https://github.com/pearu
This commit is contained in:
parent
24ae50bcc7
commit
c620ece726
512
aten/src/ATen/native/cpu/SpmmReduceKernel.cpp
Normal file
512
aten/src/ATen/native/cpu/SpmmReduceKernel.cpp
Normal file
|
|
@ -0,0 +1,512 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/native/cpu/SpmmReduceKernel.h>
|
||||
#include <ATen/native/cpu/ReduceUtils.h>
|
||||
#include <ATen/native/cpu/utils.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_native.h>
|
||||
#endif
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, typename index_t, ReductionType reduce>
|
||||
void spmm_reduce_kernel_impl(
|
||||
const Tensor& out,
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& values,
|
||||
const Tensor& other_) {
|
||||
|
||||
int64_t nnz = other_.numel();
|
||||
if (nnz == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto other = other_.contiguous();
|
||||
|
||||
// access `crow_indices`, `col_indices` and `values` via TessorAccessor
|
||||
scalar_t* out_data = out.data_ptr<scalar_t>();
|
||||
auto csr_data = crow_indices.accessor<index_t, 1>();
|
||||
auto col_data = col_indices.accessor<index_t, 1>();
|
||||
auto val_data = values.accessor<scalar_t, 1>();
|
||||
scalar_t* other_data = other.data_ptr<scalar_t>();
|
||||
|
||||
int64_t M = crow_indices.numel() - 1;
|
||||
int64_t K = other.size(-1);
|
||||
|
||||
using Vec = vec::Vectorized<scalar_t>;
|
||||
utils::parallel_sparse_csr(csr_data, M, nnz, [&](int64_t begin, int64_t end) {
|
||||
int64_t row_start, row_end, c;
|
||||
for (const auto m : c10::irange(begin, end)) {
|
||||
row_start = csr_data[m];
|
||||
row_end = csr_data[m + 1];
|
||||
|
||||
scalar_t* out_ptr = out_data + m * K;
|
||||
|
||||
constexpr int64_t kVecSize = Vec::size();
|
||||
constexpr int64_t kVLEN = kVecSize * 4;
|
||||
constexpr int64_t CHUNK_SIZE = 16;
|
||||
|
||||
// step 1: reinit the output row for reduce type 'amax' and 'amin'
|
||||
int64_t count = row_end - row_start;
|
||||
if (count != 0) {
|
||||
init<scalar_t, reduce>(out_ptr, K, /*include_self*/false);
|
||||
}
|
||||
|
||||
// step 2: reduce, do blocking on rowwise to reduce write memory bandwidth
|
||||
for (int64_t e0 = row_start; e0 < row_end; e0 += CHUNK_SIZE) {
|
||||
int64_t e1 = std::min(e0 + CHUNK_SIZE, row_end);
|
||||
|
||||
int64_t k = 0;
|
||||
for (; k < K - (K % kVLEN); k += kVLEN) {
|
||||
Vec out_vec0 = Vec::loadu(out_ptr + k);
|
||||
Vec out_vec1 = Vec::loadu(out_ptr + k + kVecSize);
|
||||
Vec out_vec2 = Vec::loadu(out_ptr + k + kVecSize * 2);
|
||||
Vec out_vec3 = Vec::loadu(out_ptr + k + kVecSize * 3);
|
||||
for (const auto e : c10::irange(e0, e1)) {
|
||||
c = col_data[e];
|
||||
scalar_t val = val_data[e];
|
||||
scalar_t* other_ptr = other_data + c * K + k;
|
||||
|
||||
out_vec0 = update<Vec, reduce>(out_vec0, Vec::loadu(other_ptr) * Vec(val));
|
||||
out_vec1 = update<Vec, reduce>(out_vec1, Vec::loadu(other_ptr + kVecSize) * Vec(val));
|
||||
out_vec2 = update<Vec, reduce>(out_vec2, Vec::loadu(other_ptr + kVecSize * 2) * Vec(val));
|
||||
out_vec3 = update<Vec, reduce>(out_vec3, Vec::loadu(other_ptr + kVecSize * 3) * Vec(val));
|
||||
}
|
||||
out_vec0.store(out_ptr + k);
|
||||
out_vec1.store(out_ptr + k + kVecSize);
|
||||
out_vec2.store(out_ptr + k + kVecSize * 2);
|
||||
out_vec3.store(out_ptr + k + kVecSize * 3);
|
||||
}
|
||||
for (; k < K - (K % kVecSize); k += kVecSize) {
|
||||
Vec out_vec = Vec::loadu(out_ptr + k);
|
||||
for (const auto e : c10::irange(e0, e1)) {
|
||||
c = col_data[e];
|
||||
scalar_t val = val_data[e];
|
||||
scalar_t* other_ptr = other_data + c * K;
|
||||
out_vec = update<Vec, reduce>(out_vec, Vec::loadu(other_ptr + k) * Vec(val));
|
||||
}
|
||||
out_vec.store(out_ptr + k);
|
||||
}
|
||||
for (; k < K; k++) {
|
||||
scalar_t out_val = out_ptr[k];
|
||||
for (const auto e : c10::irange(e0, e1)) {
|
||||
c = col_data[e];
|
||||
scalar_t val = val_data[e];
|
||||
scalar_t* other_ptr = other_data + c * K;
|
||||
out_val = update<scalar_t, reduce>(out_val, other_ptr[k] * val);
|
||||
}
|
||||
out_ptr[k] = out_val;
|
||||
}
|
||||
}
|
||||
|
||||
// step 3: finalize
|
||||
write<scalar_t, reduce>(out_ptr, count, K);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// update both val and arg, used for `amin` and `amax`
|
||||
// it is a little troublesome to vectorize it since `scalar_t` and `index_t`
|
||||
// might have different vector length, for example, each vector holds 8 floats
|
||||
// and 4 int64_t.
|
||||
template <typename scalar_t, typename index_t, ReductionType reduce>
|
||||
inline void update_with_index(scalar_t *val, scalar_t new_val, index_t *arg, index_t new_arg) {
|
||||
if ((reduce == ReductionType::MIN && new_val < *val) ||
|
||||
(reduce == ReductionType::MAX && new_val > *val) ||
|
||||
at::_isnan<scalar_t>(new_val)) {
|
||||
*val = new_val;
|
||||
*arg = new_arg;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename index_t, ReductionType reduce>
|
||||
void spmm_reduce_arg_kernel_impl(
|
||||
const Tensor& out,
|
||||
const Tensor& arg_out,
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& values,
|
||||
const Tensor& other_) {
|
||||
|
||||
TORCH_CHECK(reduce == ReductionType::MAX || reduce == ReductionType::MIN);
|
||||
int64_t nnz = values.numel();
|
||||
if (nnz == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto other = other_.contiguous();
|
||||
|
||||
scalar_t* out_data = out.data_ptr<scalar_t>();
|
||||
index_t* arg_out_data = arg_out.data_ptr<index_t>();
|
||||
auto csr_data = crow_indices.accessor<index_t, 1>();
|
||||
auto col_data = col_indices.accessor<index_t, 1>();
|
||||
auto val_data = values.accessor<scalar_t, 1>();
|
||||
scalar_t* other_data = other.data_ptr<scalar_t>();
|
||||
|
||||
int64_t M = crow_indices.numel() - 1;
|
||||
int64_t K = other.size(-1);
|
||||
|
||||
at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
|
||||
int64_t row_start, row_end, c;
|
||||
for (const auto m : c10::irange(begin, end)) {
|
||||
row_start = csr_data[m];
|
||||
row_end = csr_data[m + 1];
|
||||
|
||||
scalar_t* out_ptr = out_data + m * K;
|
||||
index_t* arg_out_ptr = arg_out_data + m * K;
|
||||
|
||||
if (row_end != row_start) {
|
||||
init<scalar_t, reduce>(out_ptr, K, /*include_self*/false);
|
||||
for (const auto e : c10::irange(row_start, row_end)) {
|
||||
c = col_data[e];
|
||||
scalar_t val = val_data[e];
|
||||
|
||||
scalar_t* other_ptr = other_data + c * K;
|
||||
for (const auto k : c10::irange(K)) {
|
||||
update_with_index<scalar_t, index_t, reduce>(
|
||||
&out_ptr[k], val * other_ptr[k], &arg_out_ptr[k], index_t(e));
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename index_t, ReductionType reduce>
|
||||
void spmm_reduce_backward_input_kernel_impl(
|
||||
const Tensor& grad_self,
|
||||
const Tensor& grad_out_,
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& other_,
|
||||
const Tensor& row_indices) {
|
||||
|
||||
int64_t nnz = grad_self._nnz();
|
||||
if (nnz == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto grad_out = grad_out_.contiguous();
|
||||
auto other = other_.contiguous();
|
||||
|
||||
auto values = grad_self.values();
|
||||
auto grad_values_data = values.accessor<scalar_t, 1>();
|
||||
scalar_t* grad_out_data = grad_out.data_ptr<scalar_t>();
|
||||
auto crow_data = crow_indices.accessor<index_t, 1>();
|
||||
auto col_data = col_indices.accessor<index_t, 1>();
|
||||
scalar_t* other_data = other.data_ptr<scalar_t>();
|
||||
auto row_data = row_indices.accessor<index_t, 1>();
|
||||
|
||||
int64_t K = grad_out.size(1);
|
||||
|
||||
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
|
||||
at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
index_t row = row_data[i], col = col_data[i];
|
||||
|
||||
scalar_t val = vec::map2_reduce_all<scalar_t>(
|
||||
[](Vec x, Vec y) { return x * y; },
|
||||
[](Vec x, Vec y) { return x + y; },
|
||||
other_data + col * K,
|
||||
grad_out_data + row * K,
|
||||
K);
|
||||
|
||||
if (reduce == ReductionType::MEAN) {
|
||||
index_t row_start = crow_data[row], row_end = crow_data[row + 1];
|
||||
val /= (row_end - row_start);
|
||||
}
|
||||
|
||||
grad_values_data[i] = val;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// backward for reduce type 'amax' or 'amin'
|
||||
template <typename scalar_t, typename index_t>
|
||||
void spmm_reduce_backward_input_arg_kernel_impl(
|
||||
const Tensor& grad_self,
|
||||
const Tensor& grad_out_,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& other_,
|
||||
const Tensor& arg_out_) {
|
||||
|
||||
int64_t nnz = grad_self._nnz();
|
||||
if (nnz == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto grad_out = grad_out_.contiguous();
|
||||
auto other = other_.contiguous();
|
||||
auto arg_out = arg_out_.contiguous();
|
||||
|
||||
auto grad_values = grad_self.values();
|
||||
auto grad_values_data = grad_values.accessor<scalar_t, 1>();
|
||||
scalar_t* grad_out_data = grad_out.data_ptr<scalar_t>();
|
||||
auto col_data = col_indices.accessor<index_t, 1>();
|
||||
scalar_t* other_data = other.data_ptr<scalar_t>();
|
||||
index_t* arg_out_data = arg_out.data_ptr<index_t>();
|
||||
|
||||
int64_t M = grad_out.size(0);
|
||||
int64_t K = grad_out.size(1);
|
||||
auto grad = at::empty({M, K}, grad_out.options());
|
||||
scalar_t* grad_data = grad.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto m : c10::irange(begin, end)) {
|
||||
scalar_t* grad_out_ptr = grad_out_data + m * K;
|
||||
scalar_t* grad_ptr = grad_data + m * K;
|
||||
index_t* arg_out_ptr = arg_out_data + m * K;
|
||||
|
||||
for (const auto k : c10::irange(K)) {
|
||||
if (arg_out_ptr[k] == index_t(nnz)) {
|
||||
grad_ptr[k] = scalar_t(0);
|
||||
} else {
|
||||
// collect weight at max/min indices
|
||||
index_t col = col_data[arg_out_data[m * K + k]];
|
||||
grad_ptr[k] = other_data[col * K + k] * grad_out_ptr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// scatter_add, consider to parallel this with atomic
|
||||
for (const auto i : c10::irange(M * K)) {
|
||||
index_t ind = arg_out_data[i];
|
||||
if (ind != index_t(nnz)) {
|
||||
grad_values_data[ind] += grad_data[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename index_t>
|
||||
void spmm_reduce_normalize_values_kernel_impl(
|
||||
const Tensor& normalized_values,
|
||||
const Tensor& values,
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& row_indices) {
|
||||
|
||||
int64_t nnz = values.numel();
|
||||
if (nnz == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto normalized_values_data = normalized_values.accessor<scalar_t, 1>();
|
||||
auto values_data = values.accessor<scalar_t, 1>();
|
||||
auto crow_data = crow_indices.accessor<index_t, 1>();
|
||||
auto row_data = row_indices.accessor<index_t, 1>();
|
||||
|
||||
at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
index_t row = row_data[i];
|
||||
index_t row_start = crow_data[row], row_end = crow_data[row + 1];
|
||||
// Note that when the row index row is listed in row_indices,
|
||||
// then crow_indices[row+1] > crow_indices[row] holds
|
||||
normalized_values_data[i] = values_data[i] / (row_end - row_start);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename index_t>
|
||||
void spmm_reduce_backward_other_arg_kernel_impl(
|
||||
const Tensor& grad_other,
|
||||
const Tensor& grad_out_,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& values,
|
||||
const Tensor& arg_out_) {
|
||||
|
||||
int64_t nnz = values.numel();
|
||||
if (nnz == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto grad_out = grad_out_.contiguous();
|
||||
auto arg_out = arg_out_.contiguous();
|
||||
|
||||
scalar_t* grad_other_data = grad_other.data_ptr<scalar_t>();
|
||||
scalar_t* grad_out_data = grad_out.data_ptr<scalar_t>();
|
||||
auto col_data = col_indices.accessor<index_t, 1>();
|
||||
auto values_data = values.accessor<scalar_t, 1>();
|
||||
index_t* arg_out_data = arg_out.data_ptr<index_t>();
|
||||
|
||||
int64_t M = grad_out.size(0);
|
||||
int64_t K = grad_out.size(1);
|
||||
auto grad = at::empty({M, K}, grad_out.options());
|
||||
scalar_t* grad_data = grad.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto m : c10::irange(begin, end)) {
|
||||
scalar_t* grad_out_ptr = grad_out_data + m * K;
|
||||
scalar_t* grad_ptr = grad_data + m * K;
|
||||
index_t* arg_out_ptr = arg_out_data + m * K;
|
||||
|
||||
for (const auto k : c10::irange(K)) {
|
||||
if (arg_out_ptr[k] == index_t(nnz)) {
|
||||
grad_ptr[k] = scalar_t(0);
|
||||
} else {
|
||||
grad_ptr[k] = values_data[arg_out_ptr[k]] * grad_out_ptr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// scatter_add, consider to parallel this with atomic
|
||||
for (const auto m : c10::irange(M)) {
|
||||
for (const auto k : c10::irange(K)) {
|
||||
index_t ind = arg_out_data[m * K + k];
|
||||
if (ind != index_t(nnz)) {
|
||||
index_t col = col_data[ind];
|
||||
grad_other_data[col * K + k] += grad_data[m * K + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void spmm_reduce_kernel(
|
||||
const Tensor& out,
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& values,
|
||||
const Tensor& other,
|
||||
ReductionType reduce_op) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
|
||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
|
||||
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
|
||||
spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
|
||||
out, crow_indices, col_indices, values, other);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void spmm_reduce_arg_kernel(
|
||||
const Tensor& out,
|
||||
const Tensor& arg_out,
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& values,
|
||||
const Tensor& other,
|
||||
ReductionType reduce_op) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
|
||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
|
||||
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
|
||||
spmm_reduce_arg_kernel_impl<scalar_t, index_t, reduce>(
|
||||
out, arg_out, crow_indices, col_indices, values, other);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void spmm_reduce_backward_input_kernel(
|
||||
const Tensor& grad_self,
|
||||
const Tensor& grad_out,
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& other,
|
||||
const Tensor& row_indices,
|
||||
ReductionType reduce_op) {
|
||||
TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() {
|
||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_indices", [&]() {
|
||||
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
|
||||
spmm_reduce_backward_input_kernel_impl<scalar_t, index_t, reduce>(
|
||||
grad_self, grad_out, crow_indices, col_indices, other, row_indices);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void spmm_reduce_backward_input_arg_kernel(
|
||||
const Tensor& grad_self,
|
||||
const Tensor& grad_out,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& other,
|
||||
const Tensor& arg_out,
|
||||
ReductionType reduce_op) {
|
||||
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() {
|
||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_arg_indices", [&]() {
|
||||
spmm_reduce_backward_input_arg_kernel_impl<scalar_t, index_t>(
|
||||
grad_self, grad_out, col_indices, other, arg_out);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void spmm_reduce_normalize_values_kernel(
|
||||
const Tensor& normalized_values,
|
||||
const Tensor& values,
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& row_indices) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() {
|
||||
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "spmm_reduce_normalize_values_indices", [&]() {
|
||||
spmm_reduce_normalize_values_kernel_impl<scalar_t, index_t>(
|
||||
normalized_values, values, crow_indices, row_indices);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void spmm_reduce_backward_other_kernel(
|
||||
const Tensor& grad_other,
|
||||
const Tensor& grad_out,
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& values,
|
||||
const Tensor& row_indices,
|
||||
const Tensor& ccol_indices,
|
||||
const Tensor& csr2csc,
|
||||
ReductionType reduce_op) {
|
||||
TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
|
||||
// need to permute row_indices to CSC order
|
||||
auto row = row_indices.index_select(0, csr2csc);
|
||||
|
||||
Tensor val;
|
||||
if (reduce_op == ReductionType::MEAN) {
|
||||
// for reduce type "mean", need to normalize the values
|
||||
// with rowcount for each of the nonzero element.
|
||||
Tensor normalized_values = at::empty(values.sizes(), values.options());
|
||||
spmm_reduce_normalize_values_kernel(normalized_values, values, crow_indices, row_indices);
|
||||
val = normalized_values.index_select(0, csr2csc);
|
||||
} else {
|
||||
val = values.index_select(0, csr2csc);
|
||||
}
|
||||
|
||||
spmm_reduce_kernel(grad_other, ccol_indices, row, val, grad_out, ReductionType::SUM);
|
||||
}
|
||||
|
||||
void spmm_reduce_backward_other_arg_kernel(
|
||||
const Tensor& grad_other,
|
||||
const Tensor& grad_out,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& values,
|
||||
const Tensor& arg_out,
|
||||
ReductionType reduce_op) {
|
||||
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() {
|
||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_other_arg_indices", [&]() {
|
||||
spmm_reduce_backward_other_arg_kernel_impl<scalar_t, index_t>(
|
||||
grad_other, grad_out, col_indices, values, arg_out);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
REGISTER_DISPATCH(spmm_reduce_stub, &spmm_reduce_kernel);
|
||||
REGISTER_DISPATCH(spmm_reduce_arg_stub, &spmm_reduce_arg_kernel);
|
||||
REGISTER_DISPATCH(spmm_reduce_backward_input_stub, &spmm_reduce_backward_input_kernel);
|
||||
REGISTER_DISPATCH(spmm_reduce_backward_input_arg_stub, &spmm_reduce_backward_input_arg_kernel);
|
||||
REGISTER_DISPATCH(spmm_reduce_backward_other_stub, &spmm_reduce_backward_other_kernel);
|
||||
REGISTER_DISPATCH(spmm_reduce_backward_other_arg_stub, &spmm_reduce_backward_other_arg_kernel);
|
||||
|
||||
}} // at::native
|
||||
22
aten/src/ATen/native/cpu/SpmmReduceKernel.h
Normal file
22
aten/src/ATen/native/cpu/SpmmReduceKernel.h
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/ReductionType.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
||||
using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
||||
using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
||||
using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
||||
using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
|
||||
|
||||
DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub);
|
||||
DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub);
|
||||
DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub);
|
||||
DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub);
|
||||
DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub);
|
||||
DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub);
|
||||
|
||||
} // at::native
|
||||
|
|
@ -3829,6 +3829,9 @@
|
|||
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
|
||||
python_module: sparse
|
||||
|
||||
- func: _sparse_mm.reduce(Tensor sparse, Tensor dense, str reduce) -> Tensor
|
||||
python_module: sparse
|
||||
|
||||
- func: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor
|
||||
dispatch:
|
||||
SparseCPU: sparse_sparse_matmul_cpu
|
||||
|
|
@ -6440,6 +6443,16 @@
|
|||
SparseCsrCUDA: sparse_sampled_addmm_sparse_csr_cuda
|
||||
SparseCsrCPU: sparse_sampled_addmm_sparse_csr_cpu
|
||||
|
||||
- func: _sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor)
|
||||
python_module: sparse
|
||||
dispatch:
|
||||
SparseCsrCPU: _sparse_mm_reduce_impl_sparse_csr_cpu
|
||||
|
||||
- func: _sparse_mm_reduce_impl_backward(Tensor self, Tensor grad_out, Tensor weight, str reduce, Tensor arg_out, bool[2] output_mask) -> (Tensor, Tensor)
|
||||
python_module: sparse
|
||||
dispatch:
|
||||
SparseCsrCPU: _sparse_mm_reduce_impl_backward_sparse_csr_cpu
|
||||
|
||||
- func: addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
#include <ATen/Operators.h>
|
||||
#else
|
||||
#include <ATen/ops/_conj_physical_native.h>
|
||||
#include <ATen/ops/_convert_indices_from_coo_to_csr.h>
|
||||
#include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
|
||||
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
|
||||
#include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
|
||||
|
|
@ -51,6 +52,7 @@
|
|||
#include <ATen/ops/deg2rad.h>
|
||||
#include <ATen/ops/deg2rad_native.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_like.h>
|
||||
#include <ATen/ops/erf.h>
|
||||
#include <ATen/ops/erf_native.h>
|
||||
#include <ATen/ops/erfinv.h>
|
||||
|
|
@ -1292,5 +1294,134 @@ Tensor _sparse_csr_prod_cpu(const Tensor& input, IntArrayRef dims_to_reduce, boo
|
|||
return result;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> _sparse_mm_reduce_impl_sparse_csr_cpu(
|
||||
const Tensor& self,
|
||||
const Tensor& other,
|
||||
const c10::string_view reduce) {
|
||||
|
||||
auto layout = self.layout();
|
||||
TORCH_CHECK(layout == kSparseCsr,
|
||||
"sparse_mm_reduce: expect self to be SparseCsr, got ", layout);
|
||||
TORCH_CHECK(self.dense_dim() == 0,
|
||||
"sparse_mm_reduce: expected non-hybrid self tensor.");
|
||||
TORCH_CHECK(self.dim() == 2,
|
||||
"sparse_mm_reduce: expected self to be a 2-D tensor, got ", self.dim(), "-D tensor.");
|
||||
|
||||
sparse::impl::check_sparse_mm_reduce_impl_inputs</*train*/false>(
|
||||
self, Tensor(), other);
|
||||
|
||||
auto op = get_reduction_enum(reduce);
|
||||
TORCH_CHECK(op != ReductionType::PROD, "sparse_mm_reduce: reduce type of prod has not been enabled.")
|
||||
|
||||
auto crow = self.crow_indices();
|
||||
auto col = self.col_indices();
|
||||
auto val = self.values();
|
||||
|
||||
// init output to be all zeros, for `rows` that has no nonzero elements,
|
||||
// the corresponding rows in the output will be zero.
|
||||
auto out = at::zeros({self.size(0), other.size(1)}, other.options());
|
||||
auto arg_out = at::empty({0}, col.options());
|
||||
|
||||
int64_t nnz = self._nnz();
|
||||
if (nnz == 0) {
|
||||
return std::make_tuple(out, arg_out);
|
||||
}
|
||||
|
||||
// only need to calculate the out args
|
||||
// for reduce type "amax" and "amin" for training
|
||||
bool need_arg_out = at::GradMode::is_enabled()
|
||||
&& (self.requires_grad() || other.requires_grad())
|
||||
&& (op == ReductionType::MAX || op == ReductionType::MIN);
|
||||
|
||||
if (!need_arg_out) {
|
||||
spmm_reduce_stub(kCPU, out, crow, col, val, other, op);
|
||||
} else {
|
||||
// allocate memory and init with invalid index
|
||||
arg_out.resize_(out.sizes());
|
||||
arg_out.fill_(nnz);
|
||||
spmm_reduce_arg_stub(kCPU, out, arg_out, crow, col, val, other, op);
|
||||
}
|
||||
|
||||
return std::make_tuple(std::move(out), std::move(arg_out));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> _sparse_mm_reduce_impl_backward_sparse_csr_cpu(
|
||||
const Tensor& self,
|
||||
const Tensor& grad_out,
|
||||
const Tensor& other,
|
||||
const c10::string_view reduce,
|
||||
const Tensor& arg_out,
|
||||
std::array<bool, 2> output_mask) {
|
||||
|
||||
auto layout = self.layout();
|
||||
TORCH_CHECK(layout == kSparseCsr,
|
||||
"sparse_mm_reduce: expect self to be SparseCsr, got ", layout);
|
||||
|
||||
sparse::impl::check_sparse_mm_reduce_impl_inputs</*train*/true>(
|
||||
self, grad_out, other);
|
||||
|
||||
auto op = get_reduction_enum(reduce);
|
||||
|
||||
auto crow = self.crow_indices();
|
||||
auto col = self.col_indices();
|
||||
auto val = self.values();
|
||||
|
||||
// `row`: row indices of COO format
|
||||
// `ccol`: ccol indices of CSC format (with permute)
|
||||
// `permute`: permute pattern from CSR to CSC
|
||||
//
|
||||
// TODO: optimize the following section,
|
||||
// currently `argsort` is sequential.
|
||||
Tensor row, ccol, permute;
|
||||
{
|
||||
bool out_int32 = crow.scalar_type() == ScalarType::Int;
|
||||
Tensor coo_indices = at::_convert_indices_from_csr_to_coo(
|
||||
crow,
|
||||
col,
|
||||
out_int32,
|
||||
/*transpose*/false);
|
||||
row = coo_indices.select(0, 0);
|
||||
|
||||
// calculte the global index for CSC
|
||||
// and get the conversion permute pattern
|
||||
Tensor index = col.mul(self.size(0)).add_(row);
|
||||
permute = index.argsort();
|
||||
|
||||
ccol = at::_convert_indices_from_coo_to_csr(
|
||||
/*column indices*/col.index_select(0, permute),
|
||||
/*column count*/self.size(1),
|
||||
out_int32);
|
||||
}
|
||||
|
||||
Tensor grad_self, grad_other;
|
||||
if (output_mask[0]) {
|
||||
// grad_input has the same indices and nnz with input
|
||||
grad_self = at::empty_like(self);
|
||||
grad_self.values().zero_();
|
||||
if (op == ReductionType::MAX || op == ReductionType::MIN) {
|
||||
spmm_reduce_backward_input_arg_stub(kCPU, grad_self, grad_out, col, other, arg_out, op);
|
||||
} else {
|
||||
spmm_reduce_backward_input_stub(kCPU, grad_self, grad_out, crow, col, other, row, op);
|
||||
}
|
||||
}
|
||||
if (output_mask[1]) {
|
||||
grad_other = at::zeros(other.sizes(), other.options());
|
||||
if (op == ReductionType::MAX || op == ReductionType::MIN) {
|
||||
spmm_reduce_backward_other_arg_stub(kCPU, grad_other, grad_out, col, val, arg_out, op);
|
||||
} else {
|
||||
spmm_reduce_backward_other_stub(kCPU, grad_other, grad_out, crow, val, row, ccol, permute, op);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(std::move(grad_self), std::move(grad_other));
|
||||
}
|
||||
|
||||
DEFINE_DISPATCH(spmm_reduce_stub);
|
||||
DEFINE_DISPATCH(spmm_reduce_arg_stub);
|
||||
DEFINE_DISPATCH(spmm_reduce_backward_input_stub);
|
||||
DEFINE_DISPATCH(spmm_reduce_backward_input_arg_stub);
|
||||
DEFINE_DISPATCH(spmm_reduce_backward_other_stub);
|
||||
DEFINE_DISPATCH(spmm_reduce_backward_other_arg_stub);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/Scalar.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/native/ReductionType.h>
|
||||
#include <ATen/native/cpu/SpmmReduceKernel.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -59,6 +62,28 @@ inline void _check_dim(const Tensor& self, int64_t target_dim, c10::string_view
|
|||
" instead.");
|
||||
}
|
||||
|
||||
template <bool train>
|
||||
inline void check_sparse_mm_reduce_impl_inputs(
|
||||
const Tensor& self,
|
||||
const Tensor& grad_out,
|
||||
const Tensor& other) {
|
||||
TORCH_INTERNAL_ASSERT(self.is_sparse_csr());
|
||||
|
||||
const auto input_scalar_type = self.values().scalar_type();
|
||||
CheckedFrom c = train ? "sparse_mm_reduce_backward" : "sparse_mm_reduce";
|
||||
if (train) {
|
||||
checkLayout(c, grad_out, kStrided);
|
||||
checkScalarType(c, {grad_out, "grad_out", 1}, input_scalar_type);
|
||||
check_dim_size(grad_out, 2, 0, self.size(0));
|
||||
check_dim_size(grad_out, 2, 1, other.size(1));
|
||||
}
|
||||
|
||||
int pos = train ? 2 : 1;
|
||||
checkLayout(c, other, kStrided);
|
||||
checkScalarType(c, {other, "other", pos}, input_scalar_type);
|
||||
check_dim_size(other, 2, 0, self.size(1));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@
|
|||
#include <ATen/ops/_sparse_sum_backward_native.h>
|
||||
#include <ATen/ops/_sparse_sum_native.h>
|
||||
#include <ATen/ops/_sparse_sparse_matmul.h>
|
||||
#include <ATen/ops/_sparse_mm_reduce_impl.h>
|
||||
#include <ATen/ops/_sparse_mm_reduce_impl_native.h>
|
||||
#include <ATen/ops/add.h>
|
||||
#include <ATen/ops/add_native.h>
|
||||
#include <ATen/ops/addmm.h>
|
||||
|
|
@ -1393,6 +1395,12 @@ SparseTensor& _sparse_mm_out(const SparseTensor& sparse,
|
|||
return at::addmm_out(result, t, sparse, dense, 0, 1); // redispatch!
|
||||
}
|
||||
|
||||
Tensor _sparse_mm(const Tensor& mat1, const Tensor& mat2, const c10::string_view reduce) {
|
||||
// result: out, arg_out
|
||||
auto result = at::_sparse_mm_reduce_impl(mat1, mat2, reduce);
|
||||
return std::get<0>(result);
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// hspmm(SparseTensor mat1, Tensor mat2)
|
||||
// --------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -1140,6 +1140,7 @@ aten_native_source_codegen_list = [
|
|||
"aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp",
|
||||
"aten/src/ATen/native/cpu/spherical_bessel_j0.cpp",
|
||||
"aten/src/ATen/native/cpu/SampledAddmmKernel.cpp",
|
||||
"aten/src/ATen/native/cpu/SpmmReduceKernel.cpp",
|
||||
"aten/src/ATen/native/cpu/SparseFactories.cpp",
|
||||
"aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -428,6 +428,7 @@ dtensor_fails = {
|
|||
xfail("select_scatter"),
|
||||
xfail("sort"),
|
||||
xfail("sparse.sampled_addmm"),
|
||||
xfail("sparse.mm", "reduce"),
|
||||
xfail("special.airy_ai"),
|
||||
xfail("special.bessel_j0"),
|
||||
xfail("special.bessel_j1"),
|
||||
|
|
|
|||
|
|
@ -457,6 +457,8 @@ aten::_sparse_log_softmax
|
|||
aten::_sparse_log_softmax.out
|
||||
aten::_sparse_log_softmax_backward_data
|
||||
aten::_sparse_log_softmax_backward_data.out
|
||||
aten::_sparse_mm_reduce_impl
|
||||
aten::_sparse_mm_reduce_impl_backward
|
||||
aten::_sparse_softmax
|
||||
aten::_sparse_softmax.out
|
||||
aten::_sparse_softmax_backward_data
|
||||
|
|
|
|||
|
|
@ -2237,6 +2237,7 @@ aot_autograd_failures = {
|
|||
xfail('cov'),
|
||||
xfail('chalf'), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
|
||||
xfail('sparse.sampled_addmm'),
|
||||
xfail('sparse.mm', 'reduce'),
|
||||
skip('nn.functional.binary_cross_entropy_with_logits'), # seems to fail sometimes?
|
||||
skip('nn.functional.margin_ranking_loss'), # seems flaky
|
||||
skip('linalg.lu_solve'), # flaky
|
||||
|
|
|
|||
|
|
@ -370,6 +370,7 @@ class TestOperators(TestCase):
|
|||
@skipOps('TestOperators', 'test_grad', vjp_fail.union({
|
||||
xfail('chalf', '', device_type='cpu'), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
|
||||
xfail('sparse.sampled_addmm', ''), # RuntimeError: Sparse CSR tensors do not have strides
|
||||
xfail('sparse.mm', 'reduce'), # RuntimeError: Sparse CSR tensors do not have strides
|
||||
|
||||
# Non-contiguous Bugs
|
||||
#
|
||||
|
|
@ -567,6 +568,7 @@ class TestOperators(TestCase):
|
|||
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
||||
@skipOps('TestOperators', 'test_vjp', vjp_fail.union({
|
||||
xfail('sparse.sampled_addmm', ''),
|
||||
xfail('sparse.mm', 'reduce'),
|
||||
|
||||
# ---- Non-Contiguous Failures ----
|
||||
# This is expected to fail as the operator
|
||||
|
|
@ -645,6 +647,7 @@ class TestOperators(TestCase):
|
|||
xfail('nn.functional.ctc_loss'), # Not Implemented
|
||||
xfail('native_layer_norm', ''), # Expected a proper Tensor but got None for argument #1 'other'
|
||||
xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides
|
||||
xfail('sparse.mm', 'reduce'), # sparse tensors have no strides
|
||||
skip('nn.functional.scaled_dot_product_attention', device_type='cuda'),
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
# Mismatched elements: 1 / 15 (6.7%)
|
||||
|
|
@ -768,6 +771,7 @@ class TestOperators(TestCase):
|
|||
xfail("quantile", device_type='cpu'), # Batching rule not implemented for `at::equal`
|
||||
xfail("scatter_reduce", "prod"), # vmap (looks like you are calling item/data-dependent)
|
||||
xfail("sparse.sampled_addmm"), # RuntimeError: Sparse CSR tensors do not have strides
|
||||
xfail("sparse.mm", "reduce"), # RuntimeError: Sparse CSR tensors do not have strides
|
||||
xfail("svd_lowrank"), # calls random op
|
||||
xfail("take"), # vmap: inplace into a regular tensor
|
||||
xfail("to"), # rank 4 tensor for channels_last
|
||||
|
|
@ -894,6 +898,7 @@ class TestOperators(TestCase):
|
|||
xfail('nn.functional.max_unpool2d', 'grad'),
|
||||
|
||||
xfail('sparse.sampled_addmm', ''),
|
||||
xfail('sparse.mm', 'reduce'),
|
||||
xfail('as_strided_scatter', ''), # calls as_strided
|
||||
xfail('index_reduce', ''), # .item() call
|
||||
# ---------------------------------------------------------------------
|
||||
|
|
@ -1179,6 +1184,7 @@ class TestOperators(TestCase):
|
|||
xfail('_segment_reduce', 'offsets'),
|
||||
xfail('_segment_reduce', 'lengths'),
|
||||
xfail('sparse.sampled_addmm', ''),
|
||||
xfail('sparse.mm', 'reduce'),
|
||||
xfail("native_batch_norm"),
|
||||
xfail("_native_batch_norm_legit"),
|
||||
xfail("native_dropout_backward"),
|
||||
|
|
@ -1252,6 +1258,7 @@ class TestOperators(TestCase):
|
|||
xfail('nn.functional.dropout3d', ''),
|
||||
xfail('as_strided_scatter', ''),
|
||||
xfail('sparse.sampled_addmm', ''),
|
||||
xfail('sparse.mm', 'reduce'),
|
||||
xfail("native_batch_norm"),
|
||||
xfail("_native_batch_norm_legit"),
|
||||
xfail('as_strided', 'partial_views'),
|
||||
|
|
@ -1350,6 +1357,7 @@ class TestOperators(TestCase):
|
|||
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
|
||||
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
|
||||
xfail('_segment_reduce', 'offsets'), # NYI: forward-AD for _segment_reduce
|
||||
xfail('sparse.mm', 'reduce'), # Sparse tensors have no strides
|
||||
xfail('index_reduce', ''), # NYI: forward-AD for index_reduce
|
||||
xfail('_segment_reduce', 'lengths'), # NYI: forward-AD for _segment_reduce
|
||||
xfail('native_dropout_backward'), # NYI
|
||||
|
|
@ -1505,6 +1513,7 @@ class TestOperators(TestCase):
|
|||
xfail('_segment_reduce', 'lengths'), # Forward AD not implemented and no decomposition
|
||||
xfail('_segment_reduce', 'offsets'), # Forward AD not implemented and no decomposition
|
||||
xfail('sparse.sampled_addmm'), # RuntimeError: Sparse CSR tensors do not have strides
|
||||
xfail('sparse.mm', 'reduce'), # RuntimeError: Sparse CSR tensors do not have strides
|
||||
xfail('svd_lowrank'), # calls random op
|
||||
xfail('take'), # vmap: inplace into regular tensor
|
||||
xfail('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||
|
|
@ -1753,6 +1762,7 @@ class TestOperators(TestCase):
|
|||
skip('linalg.lu_factor_ex', dtypes=(torch.float32,), device_type='cuda'), # fails on all but windows
|
||||
skip('linalg.multi_dot', '', device_type='cpu'),
|
||||
skip('sparse.sampled_addmm', ''),
|
||||
skip('sparse.mm', 'reduce'),
|
||||
skip('native_layer_norm', '', device_type='cpu'),
|
||||
})
|
||||
@opsToleranceOverride('TestOperators', 'test_vmap_autograd_grad', (
|
||||
|
|
|
|||
|
|
@ -3475,6 +3475,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
xfail('pca_lowrank', ''), # random operation
|
||||
xfail('svd_lowrank', ''), # random operation
|
||||
xfail('sparse.sampled_addmm'), # sparse
|
||||
xfail('sparse.mm', 'reduce'), # sparse
|
||||
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable autograd.Function
|
||||
skip('_softmax_backward_data'),
|
||||
skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test
|
||||
|
|
@ -3701,6 +3702,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
xfail('clamp_min', ''),
|
||||
xfail('special.bessel_j0'),
|
||||
xfail('sparse.sampled_addmm'),
|
||||
xfail('sparse.mm', 'reduce'),
|
||||
xfail('special.bessel_y0'),
|
||||
xfail('special.chebyshev_polynomial_u'),
|
||||
xfail('special.modified_bessel_k1'),
|
||||
|
|
|
|||
|
|
@ -250,6 +250,7 @@ inductor_expected_failures_single_sample["cpu"] = {
|
|||
"scatter_reduce.prod": {f16, f32, f64},
|
||||
"_segment_reduce.lengths": {f16, f32, f64},
|
||||
"sparse.sampled_addmm": {f32, f64},
|
||||
"sparse.mm.reduce": {bf16, f32, f64},
|
||||
"stft": {f32, f64},
|
||||
"tensor_split": {b8, f16, f32, f64, i32, i64},
|
||||
"to_sparse": {f32, f64},
|
||||
|
|
|
|||
|
|
@ -1189,6 +1189,7 @@ make_fx_failures = {
|
|||
|
||||
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
|
||||
xfail('sparse.sampled_addmm'),
|
||||
xfail('sparse.mm', 'reduce'),
|
||||
|
||||
# proxy tensor doesn't support sparse correctly right now
|
||||
skip('to_sparse'),
|
||||
|
|
|
|||
|
|
@ -2394,6 +2394,100 @@ class TestSparseCSR(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, r"Expected mat2 to have strided layout"):
|
||||
torch.sparse.sampled_addmm(a_sparse, a, a_sparse)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
def test_sparse_mm_reduce_sum(self, device, dtype):
|
||||
def run_test(m, n, k, nnz, train):
|
||||
sparse = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=torch.int64)
|
||||
dense = sparse.to_dense()
|
||||
|
||||
mat = torch.randn(k, n, dtype=dtype)
|
||||
ref_mat = mat.clone()
|
||||
|
||||
if train:
|
||||
sparse.requires_grad_()
|
||||
mat.requires_grad_()
|
||||
dense.requires_grad_()
|
||||
ref_mat.requires_grad_()
|
||||
|
||||
ref_out = torch.mm(dense, ref_mat)
|
||||
out = torch.sparse.mm(sparse, mat, 'sum')
|
||||
|
||||
self.assertEqual(out, ref_out)
|
||||
|
||||
if train:
|
||||
ref_out.sum().backward()
|
||||
out.sum().backward()
|
||||
|
||||
grad_input = sparse.grad
|
||||
ref_grad_input = dense.grad
|
||||
grad_mat = mat.grad
|
||||
ref_grad_mat = ref_mat.grad
|
||||
|
||||
self.assertEqual(grad_input.to_dense(), ref_grad_input)
|
||||
self.assertEqual(grad_mat, ref_grad_mat)
|
||||
|
||||
run_test(4, 5, 4, 10, False)
|
||||
run_test(4, 4, 4, 16, True)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
def test_sparse_mm_reduce(self, device, dtype):
|
||||
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
|
||||
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
|
||||
mat = torch.randn(n, k, dtype=dtype)
|
||||
ref_mat = mat.clone()
|
||||
ref_values = csr.values().clone()
|
||||
|
||||
out_int32 = index_dtype == torch.int32
|
||||
coo_indices = torch._convert_indices_from_csr_to_coo(
|
||||
csr.crow_indices(),
|
||||
csr.col_indices(),
|
||||
out_int32=out_int32)
|
||||
row, col = coo_indices[0], coo_indices[1]
|
||||
|
||||
def ref(row, col, val, mat):
|
||||
out = torch.zeros([m, k], dtype=dtype)
|
||||
weight = mat.index_select(0, col)
|
||||
src = weight.mul(val.view(-1, 1))
|
||||
index = row.view(-1, 1).expand_as(weight)
|
||||
index = index.to(dtype=torch.int64)
|
||||
# scatter_reduce expect index to be int64
|
||||
out.scatter_reduce_(0, index, src, reduce=reduce_type, include_self=False)
|
||||
return out
|
||||
|
||||
if train:
|
||||
csr.requires_grad_()
|
||||
mat.requires_grad_()
|
||||
ref_values.requires_grad_()
|
||||
ref_mat.requires_grad_()
|
||||
|
||||
ref_out = ref(row, col, ref_values, ref_mat)
|
||||
out = torch.sparse.mm(csr, mat, reduce_type)
|
||||
self.assertEqual(out, ref_out)
|
||||
|
||||
if train and dtype is not torch.bfloat16:
|
||||
ref_out.sum().backward()
|
||||
out.sum().backward()
|
||||
|
||||
grad_values = csr.grad.values()
|
||||
grad_weight = mat.grad
|
||||
ref_grad_values = ref_values.grad
|
||||
ref_grad_weight = ref_mat.grad
|
||||
self.assertEqual(grad_values, ref_grad_values)
|
||||
self.assertEqual(grad_weight, ref_grad_weight)
|
||||
|
||||
for train in [False, True]:
|
||||
for index_dtype in [torch.int32, torch.int64]:
|
||||
for reduce_type in ["sum", "mean", "amax", "amin"]:
|
||||
# by setting nnz < M, create empty rows
|
||||
run_test(3, 4, 11, 1, reduce_type, index_dtype, train)
|
||||
run_test(3, 4, 11, 6, reduce_type, index_dtype, train)
|
||||
run_test(3, 4, 11, 12, reduce_type, index_dtype, train)
|
||||
# we are doing blocking with 4x vector length in the kernel,
|
||||
# so need to test when K > 4x vector length
|
||||
run_test(4, 7, 33, 13, reduce_type, index_dtype, train)
|
||||
|
||||
@skipMeta
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
def test_coo_csr_conversion(self, device, dtype):
|
||||
|
|
|
|||
|
|
@ -2415,6 +2415,10 @@
|
|||
mat1: maybe_multiply(grad.sparse_mask(self).mm(mat2.mH()), alpha.conj())
|
||||
mat2: maybe_multiply(mat1.mH().mm(grad.sparse_mask(self)), alpha.conj())
|
||||
|
||||
- name: _sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor)
|
||||
output_differentiability: [True, False]
|
||||
self, other: "grad.defined() ? _sparse_mm_reduce_impl_backward(self, grad, other, reduce, result1, grad_input_mask) : std::tuple<Tensor, Tensor>()"
|
||||
|
||||
- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor
|
||||
grad_output: smooth_l1_loss_backward(grad, self, target, reduction, beta)
|
||||
self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta)
|
||||
|
|
|
|||
|
|
@ -61,41 +61,65 @@ mm = _add_docstr(_sparse._sparse_mm, r"""
|
|||
.. note::
|
||||
This function doesn't support computing derivaties with respect to CSR matrices.
|
||||
|
||||
Args:
|
||||
mat1 (Tensor): the first sparse matrix to be multiplied
|
||||
mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense
|
||||
This function also additionally accepts an optional :attr:`reduce` argument that allows
|
||||
specification of an optional reduction operation, mathematically performs the following operation:
|
||||
|
||||
Shape:
|
||||
The format of the output tensor of this function follows:
|
||||
- sparse x sparse -> sparse
|
||||
- sparse x dense -> dense
|
||||
.. math::
|
||||
|
||||
Example::
|
||||
z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj}
|
||||
|
||||
>>> a = torch.randn(2, 3).to_sparse().requires_grad_(True)
|
||||
>>> a
|
||||
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
|
||||
[0, 1, 2, 0, 1, 2]]),
|
||||
values=tensor([ 1.5901, 0.0183, -0.6146, 1.8061, -0.0112, 0.6302]),
|
||||
size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True)
|
||||
where :math:`\bigoplus` defines the reduce operator. :attr:`reduce` is implemented only for
|
||||
CSR storage format on CPU device.
|
||||
|
||||
>>> b = torch.randn(3, 2, requires_grad=True)
|
||||
>>> b
|
||||
tensor([[-0.6479, 0.7874],
|
||||
[-1.2056, 0.5641],
|
||||
[-1.1716, -0.9923]], requires_grad=True)
|
||||
Args:
|
||||
mat1 (Tensor): the first sparse matrix to be multiplied
|
||||
mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense
|
||||
reduce (str, optional): the reduction operation to apply for non-unique indices
|
||||
(:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`). Default :obj:`"sum"`.
|
||||
|
||||
>>> y = torch.sparse.mm(a, b)
|
||||
>>> y
|
||||
tensor([[-0.3323, 1.8723],
|
||||
[-1.8951, 0.7904]], grad_fn=<SparseAddmmBackward>)
|
||||
>>> y.sum().backward()
|
||||
>>> a.grad
|
||||
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
|
||||
[0, 1, 2, 0, 1, 2]]),
|
||||
values=tensor([ 0.1394, -0.6415, -2.1639, 0.1394, -0.6415, -2.1639]),
|
||||
size=(2, 3), nnz=6, layout=torch.sparse_coo)
|
||||
""")
|
||||
Shape:
|
||||
The format of the output tensor of this function follows:
|
||||
- sparse x sparse -> sparse
|
||||
- sparse x dense -> dense
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_()
|
||||
>>> a
|
||||
tensor(indices=tensor([[0, 0, 1],
|
||||
[0, 2, 1]]),
|
||||
values=tensor([1., 2., 3.]),
|
||||
size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True)
|
||||
>>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True)
|
||||
>>> b
|
||||
tensor([[0., 1.],
|
||||
[2., 0.],
|
||||
[0., 0.]], requires_grad=True)
|
||||
>>> y = torch.sparse.mm(a, b)
|
||||
>>> y
|
||||
tensor([[0., 1.],
|
||||
[6., 0.]], grad_fn=<SparseAddmmBackward0>)
|
||||
>>> y.sum().backward()
|
||||
>>> a.grad
|
||||
tensor(indices=tensor([[0, 0, 1],
|
||||
[0, 2, 1]]),
|
||||
values=tensor([1., 0., 2.]),
|
||||
size=(2, 3), nnz=3, layout=torch.sparse_coo)
|
||||
>>> c = a.detach().to_sparse_csr()
|
||||
>>> c
|
||||
tensor(crow_indices=tensor([0, 2, 3]),
|
||||
col_indices=tensor([0, 2, 1]),
|
||||
values=tensor([1., 2., 3.]), size=(2, 3), nnz=3,
|
||||
layout=torch.sparse_csr)
|
||||
>>> y1 = torch.sparse.mm(c, b, 'sum')
|
||||
>>> y1
|
||||
tensor([[0., 1.],
|
||||
[6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
|
||||
>>> y2 = torch.sparse.mm(c, b, 'max')
|
||||
>>> y2
|
||||
tensor([[0., 1.],
|
||||
[6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
|
||||
""")
|
||||
|
||||
|
||||
sampled_addmm = _add_docstr(_sparse.sparse_sampled_addmm, r"""
|
||||
|
|
@ -149,7 +173,6 @@ Examples::
|
|||
size=(3, 3), nnz=3, layout=torch.sparse_csr)
|
||||
""")
|
||||
|
||||
|
||||
def sum(input: Tensor, dim: DimOrDims = None,
|
||||
dtype: Optional[DType] = None) -> Tensor:
|
||||
r"""
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from torch.testing._internal.common_dtype import (
|
|||
all_types, empty_types, complex_types_and, integral_types
|
||||
)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
||||
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
||||
skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride,
|
||||
skipCPUIfNoMklSparse,
|
||||
toleranceOverride, tol)
|
||||
|
|
@ -1050,6 +1050,21 @@ def sample_inputs_sparse_sampled_addmm(op_info, device, dtype, requires_grad, **
|
|||
beta=beta,
|
||||
)
|
||||
|
||||
def sample_inputs_sparse_mm_reduce(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
reductions = ["sum", "mean", "amax", "amin"]
|
||||
for m, k, reduce in product([5, 7], [3, 11], reductions):
|
||||
yield SampleInput(
|
||||
torch.eye(m, m)
|
||||
.to(device=device, dtype=dtype)
|
||||
.to_sparse_csr()
|
||||
.requires_grad_(requires_grad),
|
||||
make_arg((m, k)),
|
||||
reduce,
|
||||
)
|
||||
|
||||
|
||||
def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
|
||||
yield SampleInput(make_arg(S, M), make_arg(M))
|
||||
|
|
@ -10392,6 +10407,47 @@ op_db: List[OpInfo] = [
|
|||
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
|
||||
)),
|
||||
OpInfo('sparse.mm',
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
variant_test_name='reduce',
|
||||
supports_autograd=True,
|
||||
supports_out=False,
|
||||
supports_gradgrad=False,
|
||||
supports_forward_ad=False,
|
||||
sample_inputs_func=sample_inputs_sparse_mm_reduce,
|
||||
decorators=[onlyCPU],
|
||||
skips=(
|
||||
# NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
|
||||
# RuntimeError: Sparse CSR tensors do not have strides.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
|
||||
# RuntimeError: Sparse CSR tensors do not have strides
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
|
||||
# RuntimeError: Sparse CSR tensors do not have strides
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'),
|
||||
# RuntimeError: Sparse CSR tensors do not have strides
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'),
|
||||
# RuntimeError: Sparse CSR tensors do not have strides
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
|
||||
# RuntimeError: Sparse CSR tensors do not have strides
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
|
||||
# RuntimeError: Sparse CSR tensors do not have strides
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
|
||||
# RuntimeError: Sparse CSR tensors do not have strides
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
# RuntimeError: unsupported memory format option Preserve
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
|
||||
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
|
||||
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
|
||||
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
|
||||
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'),
|
||||
)),
|
||||
UnaryUfuncInfo('i0',
|
||||
ref=np_unary_ufunc_integer_promotion_wrapper(
|
||||
scipy.special.i0) if TEST_SCIPY else None,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user