port sparse_mm.reduce to pytorch and optimize it on CPU (#83727)

### Motivation of this PR

This patch is to migrate `spmm_reduce` from `torch-sparse` (a 3rd party dependency for PyG) to `torch`, which is a response to the initial proposal for fusion of **Gather, Apply Scatter** in Message Passing of GNN inference/training. https://github.com/pytorch/pytorch/issues/71300

**GAS** is the major step for Message Passing, the behavior of **GAS** can be classified into 2 kinds depending on the storage type of `EdgeIndex` which records the connections of nodes:

* COO: the hotspot is `scatter_reduce`
* CSR: the hotspot is `spmm_reduce`

The reduce type can be choose from: "max", "mean", "max",  "min".

extend `torch.sparse.mm` with an `reduce` argument, maps to `torch.sparse_mm.reduce` internally.
`sparse_mm_reduce` is registered under the TensorTypeId of `SparseCsrCPU`, and this operator requires an internal interface `_sparse_mm_reduce_impl` which has dual outputs:
* `out` - the actual output
* `arg_out` - records output indices in the non zero elements if the reduce type is "max" or "min", this is only useful for training. So for inference, it will not be calculated.

### Performance

Benchmark on GCN for obgn-products on Xeon single socket, the workload is improved by `4.3x` with this patch.

Performance benefit for training will be bigger, the original backward impl for `sum|mean` is sequential; the original backward impl for `max|min` is not fused.

#### before:
```
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
       torch_sparse::spmm_sum        97.09%       56.086s        97.09%       56.088s        6.232s             9
                 aten::linear         0.00%      85.000us         1.38%     795.485ms      88.387ms             9
                 aten::matmul         0.00%      57.000us         1.38%     795.260ms      88.362ms             9
                     aten::mm         1.38%     795.201ms         1.38%     795.203ms      88.356ms             9
                   aten::relu         0.00%      50.000us         0.76%     440.434ms      73.406ms             6
              aten::clamp_min         0.76%     440.384ms         0.76%     440.384ms      73.397ms             6
                   aten::add_         0.57%     327.801ms         0.57%     327.801ms      36.422ms             9
            aten::log_softmax         0.00%      23.000us         0.10%      55.503ms      18.501ms             3
```

#### after
```
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
               aten::spmm_sum        87.35%       11.826s        87.36%       11.827s        1.314s             9
                 aten::linear         0.00%      92.000us         5.87%     794.451ms      88.272ms             9
                 aten::matmul         0.00%      62.000us         5.87%     794.208ms      88.245ms             9
                     aten::mm         5.87%     794.143ms         5.87%     794.146ms      88.238ms             9
                   aten::relu         0.00%      53.000us         3.35%     452.977ms      75.496ms             6
              aten::clamp_min         3.35%     452.924ms         3.35%     452.924ms      75.487ms             6
                   aten::add_         2.58%     348.663ms         2.58%     348.663ms      38.740ms             9
                 aten::argmax         0.42%      57.473ms         0.42%      57.475ms      14.369ms             4
            aten::log_softmax         0.00%      22.000us         0.39%      52.605ms      17.535ms             3
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83727
Approved by: https://github.com/jgong5, https://github.com/cpuhrsch, https://github.com/rusty1s, https://github.com/pearu
This commit is contained in:
mingfeima 2023-02-10 11:12:35 +08:00 committed by PyTorch MergeBot
parent 24ae50bcc7
commit c620ece726
18 changed files with 939 additions and 32 deletions

View File

@ -0,0 +1,512 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/ExpandUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/SpmmReduceKernel.h>
#include <ATen/native/cpu/ReduceUtils.h>
#include <ATen/native/cpu/utils.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_native.h>
#endif
namespace at { namespace native {
namespace {
template <typename scalar_t, typename index_t, ReductionType reduce>
void spmm_reduce_kernel_impl(
const Tensor& out,
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values,
const Tensor& other_) {
int64_t nnz = other_.numel();
if (nnz == 0) {
return;
}
auto other = other_.contiguous();
// access `crow_indices`, `col_indices` and `values` via TessorAccessor
scalar_t* out_data = out.data_ptr<scalar_t>();
auto csr_data = crow_indices.accessor<index_t, 1>();
auto col_data = col_indices.accessor<index_t, 1>();
auto val_data = values.accessor<scalar_t, 1>();
scalar_t* other_data = other.data_ptr<scalar_t>();
int64_t M = crow_indices.numel() - 1;
int64_t K = other.size(-1);
using Vec = vec::Vectorized<scalar_t>;
utils::parallel_sparse_csr(csr_data, M, nnz, [&](int64_t begin, int64_t end) {
int64_t row_start, row_end, c;
for (const auto m : c10::irange(begin, end)) {
row_start = csr_data[m];
row_end = csr_data[m + 1];
scalar_t* out_ptr = out_data + m * K;
constexpr int64_t kVecSize = Vec::size();
constexpr int64_t kVLEN = kVecSize * 4;
constexpr int64_t CHUNK_SIZE = 16;
// step 1: reinit the output row for reduce type 'amax' and 'amin'
int64_t count = row_end - row_start;
if (count != 0) {
init<scalar_t, reduce>(out_ptr, K, /*include_self*/false);
}
// step 2: reduce, do blocking on rowwise to reduce write memory bandwidth
for (int64_t e0 = row_start; e0 < row_end; e0 += CHUNK_SIZE) {
int64_t e1 = std::min(e0 + CHUNK_SIZE, row_end);
int64_t k = 0;
for (; k < K - (K % kVLEN); k += kVLEN) {
Vec out_vec0 = Vec::loadu(out_ptr + k);
Vec out_vec1 = Vec::loadu(out_ptr + k + kVecSize);
Vec out_vec2 = Vec::loadu(out_ptr + k + kVecSize * 2);
Vec out_vec3 = Vec::loadu(out_ptr + k + kVecSize * 3);
for (const auto e : c10::irange(e0, e1)) {
c = col_data[e];
scalar_t val = val_data[e];
scalar_t* other_ptr = other_data + c * K + k;
out_vec0 = update<Vec, reduce>(out_vec0, Vec::loadu(other_ptr) * Vec(val));
out_vec1 = update<Vec, reduce>(out_vec1, Vec::loadu(other_ptr + kVecSize) * Vec(val));
out_vec2 = update<Vec, reduce>(out_vec2, Vec::loadu(other_ptr + kVecSize * 2) * Vec(val));
out_vec3 = update<Vec, reduce>(out_vec3, Vec::loadu(other_ptr + kVecSize * 3) * Vec(val));
}
out_vec0.store(out_ptr + k);
out_vec1.store(out_ptr + k + kVecSize);
out_vec2.store(out_ptr + k + kVecSize * 2);
out_vec3.store(out_ptr + k + kVecSize * 3);
}
for (; k < K - (K % kVecSize); k += kVecSize) {
Vec out_vec = Vec::loadu(out_ptr + k);
for (const auto e : c10::irange(e0, e1)) {
c = col_data[e];
scalar_t val = val_data[e];
scalar_t* other_ptr = other_data + c * K;
out_vec = update<Vec, reduce>(out_vec, Vec::loadu(other_ptr + k) * Vec(val));
}
out_vec.store(out_ptr + k);
}
for (; k < K; k++) {
scalar_t out_val = out_ptr[k];
for (const auto e : c10::irange(e0, e1)) {
c = col_data[e];
scalar_t val = val_data[e];
scalar_t* other_ptr = other_data + c * K;
out_val = update<scalar_t, reduce>(out_val, other_ptr[k] * val);
}
out_ptr[k] = out_val;
}
}
// step 3: finalize
write<scalar_t, reduce>(out_ptr, count, K);
}
});
}
// update both val and arg, used for `amin` and `amax`
// it is a little troublesome to vectorize it since `scalar_t` and `index_t`
// might have different vector length, for example, each vector holds 8 floats
// and 4 int64_t.
template <typename scalar_t, typename index_t, ReductionType reduce>
inline void update_with_index(scalar_t *val, scalar_t new_val, index_t *arg, index_t new_arg) {
if ((reduce == ReductionType::MIN && new_val < *val) ||
(reduce == ReductionType::MAX && new_val > *val) ||
at::_isnan<scalar_t>(new_val)) {
*val = new_val;
*arg = new_arg;
}
}
template <typename scalar_t, typename index_t, ReductionType reduce>
void spmm_reduce_arg_kernel_impl(
const Tensor& out,
const Tensor& arg_out,
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values,
const Tensor& other_) {
TORCH_CHECK(reduce == ReductionType::MAX || reduce == ReductionType::MIN);
int64_t nnz = values.numel();
if (nnz == 0) {
return;
}
auto other = other_.contiguous();
scalar_t* out_data = out.data_ptr<scalar_t>();
index_t* arg_out_data = arg_out.data_ptr<index_t>();
auto csr_data = crow_indices.accessor<index_t, 1>();
auto col_data = col_indices.accessor<index_t, 1>();
auto val_data = values.accessor<scalar_t, 1>();
scalar_t* other_data = other.data_ptr<scalar_t>();
int64_t M = crow_indices.numel() - 1;
int64_t K = other.size(-1);
at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
int64_t row_start, row_end, c;
for (const auto m : c10::irange(begin, end)) {
row_start = csr_data[m];
row_end = csr_data[m + 1];
scalar_t* out_ptr = out_data + m * K;
index_t* arg_out_ptr = arg_out_data + m * K;
if (row_end != row_start) {
init<scalar_t, reduce>(out_ptr, K, /*include_self*/false);
for (const auto e : c10::irange(row_start, row_end)) {
c = col_data[e];
scalar_t val = val_data[e];
scalar_t* other_ptr = other_data + c * K;
for (const auto k : c10::irange(K)) {
update_with_index<scalar_t, index_t, reduce>(
&out_ptr[k], val * other_ptr[k], &arg_out_ptr[k], index_t(e));
};
}
}
}
});
}
template <typename scalar_t, typename index_t, ReductionType reduce>
void spmm_reduce_backward_input_kernel_impl(
const Tensor& grad_self,
const Tensor& grad_out_,
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& other_,
const Tensor& row_indices) {
int64_t nnz = grad_self._nnz();
if (nnz == 0) {
return;
}
auto grad_out = grad_out_.contiguous();
auto other = other_.contiguous();
auto values = grad_self.values();
auto grad_values_data = values.accessor<scalar_t, 1>();
scalar_t* grad_out_data = grad_out.data_ptr<scalar_t>();
auto crow_data = crow_indices.accessor<index_t, 1>();
auto col_data = col_indices.accessor<index_t, 1>();
scalar_t* other_data = other.data_ptr<scalar_t>();
auto row_data = row_indices.accessor<index_t, 1>();
int64_t K = grad_out.size(1);
using Vec = vec::Vectorized<vec::vec_scalar_t<scalar_t>>;
at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
index_t row = row_data[i], col = col_data[i];
scalar_t val = vec::map2_reduce_all<scalar_t>(
[](Vec x, Vec y) { return x * y; },
[](Vec x, Vec y) { return x + y; },
other_data + col * K,
grad_out_data + row * K,
K);
if (reduce == ReductionType::MEAN) {
index_t row_start = crow_data[row], row_end = crow_data[row + 1];
val /= (row_end - row_start);
}
grad_values_data[i] = val;
}
});
}
// backward for reduce type 'amax' or 'amin'
template <typename scalar_t, typename index_t>
void spmm_reduce_backward_input_arg_kernel_impl(
const Tensor& grad_self,
const Tensor& grad_out_,
const Tensor& col_indices,
const Tensor& other_,
const Tensor& arg_out_) {
int64_t nnz = grad_self._nnz();
if (nnz == 0) {
return;
}
auto grad_out = grad_out_.contiguous();
auto other = other_.contiguous();
auto arg_out = arg_out_.contiguous();
auto grad_values = grad_self.values();
auto grad_values_data = grad_values.accessor<scalar_t, 1>();
scalar_t* grad_out_data = grad_out.data_ptr<scalar_t>();
auto col_data = col_indices.accessor<index_t, 1>();
scalar_t* other_data = other.data_ptr<scalar_t>();
index_t* arg_out_data = arg_out.data_ptr<index_t>();
int64_t M = grad_out.size(0);
int64_t K = grad_out.size(1);
auto grad = at::empty({M, K}, grad_out.options());
scalar_t* grad_data = grad.data_ptr<scalar_t>();
at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
for (const auto m : c10::irange(begin, end)) {
scalar_t* grad_out_ptr = grad_out_data + m * K;
scalar_t* grad_ptr = grad_data + m * K;
index_t* arg_out_ptr = arg_out_data + m * K;
for (const auto k : c10::irange(K)) {
if (arg_out_ptr[k] == index_t(nnz)) {
grad_ptr[k] = scalar_t(0);
} else {
// collect weight at max/min indices
index_t col = col_data[arg_out_data[m * K + k]];
grad_ptr[k] = other_data[col * K + k] * grad_out_ptr[k];
}
}
}
});
// scatter_add, consider to parallel this with atomic
for (const auto i : c10::irange(M * K)) {
index_t ind = arg_out_data[i];
if (ind != index_t(nnz)) {
grad_values_data[ind] += grad_data[i];
}
}
}
template <typename scalar_t, typename index_t>
void spmm_reduce_normalize_values_kernel_impl(
const Tensor& normalized_values,
const Tensor& values,
const Tensor& crow_indices,
const Tensor& row_indices) {
int64_t nnz = values.numel();
if (nnz == 0) {
return;
}
auto normalized_values_data = normalized_values.accessor<scalar_t, 1>();
auto values_data = values.accessor<scalar_t, 1>();
auto crow_data = crow_indices.accessor<index_t, 1>();
auto row_data = row_indices.accessor<index_t, 1>();
at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
index_t row = row_data[i];
index_t row_start = crow_data[row], row_end = crow_data[row + 1];
// Note that when the row index row is listed in row_indices,
// then crow_indices[row+1] > crow_indices[row] holds
normalized_values_data[i] = values_data[i] / (row_end - row_start);
}
});
}
template <typename scalar_t, typename index_t>
void spmm_reduce_backward_other_arg_kernel_impl(
const Tensor& grad_other,
const Tensor& grad_out_,
const Tensor& col_indices,
const Tensor& values,
const Tensor& arg_out_) {
int64_t nnz = values.numel();
if (nnz == 0) {
return;
}
auto grad_out = grad_out_.contiguous();
auto arg_out = arg_out_.contiguous();
scalar_t* grad_other_data = grad_other.data_ptr<scalar_t>();
scalar_t* grad_out_data = grad_out.data_ptr<scalar_t>();
auto col_data = col_indices.accessor<index_t, 1>();
auto values_data = values.accessor<scalar_t, 1>();
index_t* arg_out_data = arg_out.data_ptr<index_t>();
int64_t M = grad_out.size(0);
int64_t K = grad_out.size(1);
auto grad = at::empty({M, K}, grad_out.options());
scalar_t* grad_data = grad.data_ptr<scalar_t>();
at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
for (const auto m : c10::irange(begin, end)) {
scalar_t* grad_out_ptr = grad_out_data + m * K;
scalar_t* grad_ptr = grad_data + m * K;
index_t* arg_out_ptr = arg_out_data + m * K;
for (const auto k : c10::irange(K)) {
if (arg_out_ptr[k] == index_t(nnz)) {
grad_ptr[k] = scalar_t(0);
} else {
grad_ptr[k] = values_data[arg_out_ptr[k]] * grad_out_ptr[k];
}
}
}
});
// scatter_add, consider to parallel this with atomic
for (const auto m : c10::irange(M)) {
for (const auto k : c10::irange(K)) {
index_t ind = arg_out_data[m * K + k];
if (ind != index_t(nnz)) {
index_t col = col_data[ind];
grad_other_data[col * K + k] += grad_data[m * K + k];
}
}
}
}
void spmm_reduce_kernel(
const Tensor& out,
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values,
const Tensor& other,
ReductionType reduce_op) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
out, crow_indices, col_indices, values, other);
});
});
});
}
void spmm_reduce_arg_kernel(
const Tensor& out,
const Tensor& arg_out,
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values,
const Tensor& other,
ReductionType reduce_op) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
spmm_reduce_arg_kernel_impl<scalar_t, index_t, reduce>(
out, arg_out, crow_indices, col_indices, values, other);
});
});
});
}
void spmm_reduce_backward_input_kernel(
const Tensor& grad_self,
const Tensor& grad_out,
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& other,
const Tensor& row_indices,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_indices", [&]() {
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
spmm_reduce_backward_input_kernel_impl<scalar_t, index_t, reduce>(
grad_self, grad_out, crow_indices, col_indices, other, row_indices);
});
});
});
}
void spmm_reduce_backward_input_arg_kernel(
const Tensor& grad_self,
const Tensor& grad_out,
const Tensor& col_indices,
const Tensor& other,
const Tensor& arg_out,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_arg_indices", [&]() {
spmm_reduce_backward_input_arg_kernel_impl<scalar_t, index_t>(
grad_self, grad_out, col_indices, other, arg_out);
});
});
}
void spmm_reduce_normalize_values_kernel(
const Tensor& normalized_values,
const Tensor& values,
const Tensor& crow_indices,
const Tensor& row_indices) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "spmm_reduce_normalize_values_indices", [&]() {
spmm_reduce_normalize_values_kernel_impl<scalar_t, index_t>(
normalized_values, values, crow_indices, row_indices);
});
});
}
void spmm_reduce_backward_other_kernel(
const Tensor& grad_other,
const Tensor& grad_out,
const Tensor& crow_indices,
const Tensor& values,
const Tensor& row_indices,
const Tensor& ccol_indices,
const Tensor& csr2csc,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
// need to permute row_indices to CSC order
auto row = row_indices.index_select(0, csr2csc);
Tensor val;
if (reduce_op == ReductionType::MEAN) {
// for reduce type "mean", need to normalize the values
// with rowcount for each of the nonzero element.
Tensor normalized_values = at::empty(values.sizes(), values.options());
spmm_reduce_normalize_values_kernel(normalized_values, values, crow_indices, row_indices);
val = normalized_values.index_select(0, csr2csc);
} else {
val = values.index_select(0, csr2csc);
}
spmm_reduce_kernel(grad_other, ccol_indices, row, val, grad_out, ReductionType::SUM);
}
void spmm_reduce_backward_other_arg_kernel(
const Tensor& grad_other,
const Tensor& grad_out,
const Tensor& col_indices,
const Tensor& values,
const Tensor& arg_out,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_other_arg_indices", [&]() {
spmm_reduce_backward_other_arg_kernel_impl<scalar_t, index_t>(
grad_other, grad_out, col_indices, values, arg_out);
});
});
}
} // anonymous namespace
REGISTER_DISPATCH(spmm_reduce_stub, &spmm_reduce_kernel);
REGISTER_DISPATCH(spmm_reduce_arg_stub, &spmm_reduce_arg_kernel);
REGISTER_DISPATCH(spmm_reduce_backward_input_stub, &spmm_reduce_backward_input_kernel);
REGISTER_DISPATCH(spmm_reduce_backward_input_arg_stub, &spmm_reduce_backward_input_arg_kernel);
REGISTER_DISPATCH(spmm_reduce_backward_other_stub, &spmm_reduce_backward_other_kernel);
REGISTER_DISPATCH(spmm_reduce_backward_other_arg_stub, &spmm_reduce_backward_other_arg_kernel);
}} // at::native

View File

@ -0,0 +1,22 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/ReductionType.h>
namespace at::native {
using spmm_reduce_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
using spmm_reduce_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
using spmm_reduce_backward_input_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
using spmm_reduce_backward_input_arg_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
using spmm_reduce_backward_other_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, const Tensor&, ReductionType op);
DECLARE_DISPATCH(spmm_reduce_fn, spmm_reduce_stub);
DECLARE_DISPATCH(spmm_reduce_arg_fn, spmm_reduce_arg_stub);
DECLARE_DISPATCH(spmm_reduce_backward_input_fn, spmm_reduce_backward_input_stub);
DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_input_arg_stub);
DECLARE_DISPATCH(spmm_reduce_backward_other_fn, spmm_reduce_backward_other_stub);
DECLARE_DISPATCH(spmm_reduce_backward_input_arg_fn, spmm_reduce_backward_other_arg_stub);
} // at::native

View File

@ -3829,6 +3829,9 @@
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
python_module: sparse
- func: _sparse_mm.reduce(Tensor sparse, Tensor dense, str reduce) -> Tensor
python_module: sparse
- func: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor
dispatch:
SparseCPU: sparse_sparse_matmul_cpu
@ -6440,6 +6443,16 @@
SparseCsrCUDA: sparse_sampled_addmm_sparse_csr_cuda
SparseCsrCPU: sparse_sampled_addmm_sparse_csr_cpu
- func: _sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor)
python_module: sparse
dispatch:
SparseCsrCPU: _sparse_mm_reduce_impl_sparse_csr_cpu
- func: _sparse_mm_reduce_impl_backward(Tensor self, Tensor grad_out, Tensor weight, str reduce, Tensor arg_out, bool[2] output_mask) -> (Tensor, Tensor)
python_module: sparse
dispatch:
SparseCsrCPU: _sparse_mm_reduce_impl_backward_sparse_csr_cpu
- func: addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:

View File

@ -20,6 +20,7 @@
#include <ATen/Operators.h>
#else
#include <ATen/ops/_conj_physical_native.h>
#include <ATen/ops/_convert_indices_from_coo_to_csr.h>
#include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
@ -51,6 +52,7 @@
#include <ATen/ops/deg2rad.h>
#include <ATen/ops/deg2rad_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/erf.h>
#include <ATen/ops/erf_native.h>
#include <ATen/ops/erfinv.h>
@ -1292,5 +1294,134 @@ Tensor _sparse_csr_prod_cpu(const Tensor& input, IntArrayRef dims_to_reduce, boo
return result;
}
std::tuple<Tensor, Tensor> _sparse_mm_reduce_impl_sparse_csr_cpu(
const Tensor& self,
const Tensor& other,
const c10::string_view reduce) {
auto layout = self.layout();
TORCH_CHECK(layout == kSparseCsr,
"sparse_mm_reduce: expect self to be SparseCsr, got ", layout);
TORCH_CHECK(self.dense_dim() == 0,
"sparse_mm_reduce: expected non-hybrid self tensor.");
TORCH_CHECK(self.dim() == 2,
"sparse_mm_reduce: expected self to be a 2-D tensor, got ", self.dim(), "-D tensor.");
sparse::impl::check_sparse_mm_reduce_impl_inputs</*train*/false>(
self, Tensor(), other);
auto op = get_reduction_enum(reduce);
TORCH_CHECK(op != ReductionType::PROD, "sparse_mm_reduce: reduce type of prod has not been enabled.")
auto crow = self.crow_indices();
auto col = self.col_indices();
auto val = self.values();
// init output to be all zeros, for `rows` that has no nonzero elements,
// the corresponding rows in the output will be zero.
auto out = at::zeros({self.size(0), other.size(1)}, other.options());
auto arg_out = at::empty({0}, col.options());
int64_t nnz = self._nnz();
if (nnz == 0) {
return std::make_tuple(out, arg_out);
}
// only need to calculate the out args
// for reduce type "amax" and "amin" for training
bool need_arg_out = at::GradMode::is_enabled()
&& (self.requires_grad() || other.requires_grad())
&& (op == ReductionType::MAX || op == ReductionType::MIN);
if (!need_arg_out) {
spmm_reduce_stub(kCPU, out, crow, col, val, other, op);
} else {
// allocate memory and init with invalid index
arg_out.resize_(out.sizes());
arg_out.fill_(nnz);
spmm_reduce_arg_stub(kCPU, out, arg_out, crow, col, val, other, op);
}
return std::make_tuple(std::move(out), std::move(arg_out));
}
std::tuple<Tensor, Tensor> _sparse_mm_reduce_impl_backward_sparse_csr_cpu(
const Tensor& self,
const Tensor& grad_out,
const Tensor& other,
const c10::string_view reduce,
const Tensor& arg_out,
std::array<bool, 2> output_mask) {
auto layout = self.layout();
TORCH_CHECK(layout == kSparseCsr,
"sparse_mm_reduce: expect self to be SparseCsr, got ", layout);
sparse::impl::check_sparse_mm_reduce_impl_inputs</*train*/true>(
self, grad_out, other);
auto op = get_reduction_enum(reduce);
auto crow = self.crow_indices();
auto col = self.col_indices();
auto val = self.values();
// `row`: row indices of COO format
// `ccol`: ccol indices of CSC format (with permute)
// `permute`: permute pattern from CSR to CSC
//
// TODO: optimize the following section,
// currently `argsort` is sequential.
Tensor row, ccol, permute;
{
bool out_int32 = crow.scalar_type() == ScalarType::Int;
Tensor coo_indices = at::_convert_indices_from_csr_to_coo(
crow,
col,
out_int32,
/*transpose*/false);
row = coo_indices.select(0, 0);
// calculte the global index for CSC
// and get the conversion permute pattern
Tensor index = col.mul(self.size(0)).add_(row);
permute = index.argsort();
ccol = at::_convert_indices_from_coo_to_csr(
/*column indices*/col.index_select(0, permute),
/*column count*/self.size(1),
out_int32);
}
Tensor grad_self, grad_other;
if (output_mask[0]) {
// grad_input has the same indices and nnz with input
grad_self = at::empty_like(self);
grad_self.values().zero_();
if (op == ReductionType::MAX || op == ReductionType::MIN) {
spmm_reduce_backward_input_arg_stub(kCPU, grad_self, grad_out, col, other, arg_out, op);
} else {
spmm_reduce_backward_input_stub(kCPU, grad_self, grad_out, crow, col, other, row, op);
}
}
if (output_mask[1]) {
grad_other = at::zeros(other.sizes(), other.options());
if (op == ReductionType::MAX || op == ReductionType::MIN) {
spmm_reduce_backward_other_arg_stub(kCPU, grad_other, grad_out, col, val, arg_out, op);
} else {
spmm_reduce_backward_other_stub(kCPU, grad_other, grad_out, crow, val, row, ccol, permute, op);
}
}
return std::make_tuple(std::move(grad_self), std::move(grad_other));
}
DEFINE_DISPATCH(spmm_reduce_stub);
DEFINE_DISPATCH(spmm_reduce_arg_stub);
DEFINE_DISPATCH(spmm_reduce_backward_input_stub);
DEFINE_DISPATCH(spmm_reduce_backward_input_arg_stub);
DEFINE_DISPATCH(spmm_reduce_backward_other_stub);
DEFINE_DISPATCH(spmm_reduce_backward_other_arg_stub);
} // namespace native
} // namespace at

View File

@ -2,6 +2,9 @@
#include <ATen/Tensor.h>
#include <ATen/core/Scalar.h>
#include <ATen/TensorUtils.h>
#include <ATen/native/ReductionType.h>
#include <ATen/native/cpu/SpmmReduceKernel.h>
namespace at {
namespace native {
@ -59,6 +62,28 @@ inline void _check_dim(const Tensor& self, int64_t target_dim, c10::string_view
" instead.");
}
template <bool train>
inline void check_sparse_mm_reduce_impl_inputs(
const Tensor& self,
const Tensor& grad_out,
const Tensor& other) {
TORCH_INTERNAL_ASSERT(self.is_sparse_csr());
const auto input_scalar_type = self.values().scalar_type();
CheckedFrom c = train ? "sparse_mm_reduce_backward" : "sparse_mm_reduce";
if (train) {
checkLayout(c, grad_out, kStrided);
checkScalarType(c, {grad_out, "grad_out", 1}, input_scalar_type);
check_dim_size(grad_out, 2, 0, self.size(0));
check_dim_size(grad_out, 2, 1, other.size(1));
}
int pos = train ? 2 : 1;
checkLayout(c, other, kStrided);
checkScalarType(c, {other, "other", pos}, input_scalar_type);
check_dim_size(other, 2, 0, self.size(1));
}
}
}
}

View File

@ -31,6 +31,8 @@
#include <ATen/ops/_sparse_sum_backward_native.h>
#include <ATen/ops/_sparse_sum_native.h>
#include <ATen/ops/_sparse_sparse_matmul.h>
#include <ATen/ops/_sparse_mm_reduce_impl.h>
#include <ATen/ops/_sparse_mm_reduce_impl_native.h>
#include <ATen/ops/add.h>
#include <ATen/ops/add_native.h>
#include <ATen/ops/addmm.h>
@ -1393,6 +1395,12 @@ SparseTensor& _sparse_mm_out(const SparseTensor& sparse,
return at::addmm_out(result, t, sparse, dense, 0, 1); // redispatch!
}
Tensor _sparse_mm(const Tensor& mat1, const Tensor& mat2, const c10::string_view reduce) {
// result: out, arg_out
auto result = at::_sparse_mm_reduce_impl(mat1, mat2, reduce);
return std::get<0>(result);
}
// --------------------------------------------------------------------
// hspmm(SparseTensor mat1, Tensor mat2)
// --------------------------------------------------------------------

View File

@ -1140,6 +1140,7 @@ aten_native_source_codegen_list = [
"aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp",
"aten/src/ATen/native/cpu/spherical_bessel_j0.cpp",
"aten/src/ATen/native/cpu/SampledAddmmKernel.cpp",
"aten/src/ATen/native/cpu/SpmmReduceKernel.cpp",
"aten/src/ATen/native/cpu/SparseFactories.cpp",
"aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp",
]

View File

@ -428,6 +428,7 @@ dtensor_fails = {
xfail("select_scatter"),
xfail("sort"),
xfail("sparse.sampled_addmm"),
xfail("sparse.mm", "reduce"),
xfail("special.airy_ai"),
xfail("special.bessel_j0"),
xfail("special.bessel_j1"),

View File

@ -457,6 +457,8 @@ aten::_sparse_log_softmax
aten::_sparse_log_softmax.out
aten::_sparse_log_softmax_backward_data
aten::_sparse_log_softmax_backward_data.out
aten::_sparse_mm_reduce_impl
aten::_sparse_mm_reduce_impl_backward
aten::_sparse_softmax
aten::_sparse_softmax.out
aten::_sparse_softmax_backward_data

View File

@ -2237,6 +2237,7 @@ aot_autograd_failures = {
xfail('cov'),
xfail('chalf'), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
xfail('sparse.sampled_addmm'),
xfail('sparse.mm', 'reduce'),
skip('nn.functional.binary_cross_entropy_with_logits'), # seems to fail sometimes?
skip('nn.functional.margin_ranking_loss'), # seems flaky
skip('linalg.lu_solve'), # flaky

View File

@ -370,6 +370,7 @@ class TestOperators(TestCase):
@skipOps('TestOperators', 'test_grad', vjp_fail.union({
xfail('chalf', '', device_type='cpu'), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
xfail('sparse.sampled_addmm', ''), # RuntimeError: Sparse CSR tensors do not have strides
xfail('sparse.mm', 'reduce'), # RuntimeError: Sparse CSR tensors do not have strides
# Non-contiguous Bugs
#
@ -567,6 +568,7 @@ class TestOperators(TestCase):
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@skipOps('TestOperators', 'test_vjp', vjp_fail.union({
xfail('sparse.sampled_addmm', ''),
xfail('sparse.mm', 'reduce'),
# ---- Non-Contiguous Failures ----
# This is expected to fail as the operator
@ -645,6 +647,7 @@ class TestOperators(TestCase):
xfail('nn.functional.ctc_loss'), # Not Implemented
xfail('native_layer_norm', ''), # Expected a proper Tensor but got None for argument #1 'other'
xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides
xfail('sparse.mm', 'reduce'), # sparse tensors have no strides
skip('nn.functional.scaled_dot_product_attention', device_type='cuda'),
# AssertionError: Tensor-likes are not close!
# Mismatched elements: 1 / 15 (6.7%)
@ -768,6 +771,7 @@ class TestOperators(TestCase):
xfail("quantile", device_type='cpu'), # Batching rule not implemented for `at::equal`
xfail("scatter_reduce", "prod"), # vmap (looks like you are calling item/data-dependent)
xfail("sparse.sampled_addmm"), # RuntimeError: Sparse CSR tensors do not have strides
xfail("sparse.mm", "reduce"), # RuntimeError: Sparse CSR tensors do not have strides
xfail("svd_lowrank"), # calls random op
xfail("take"), # vmap: inplace into a regular tensor
xfail("to"), # rank 4 tensor for channels_last
@ -894,6 +898,7 @@ class TestOperators(TestCase):
xfail('nn.functional.max_unpool2d', 'grad'),
xfail('sparse.sampled_addmm', ''),
xfail('sparse.mm', 'reduce'),
xfail('as_strided_scatter', ''), # calls as_strided
xfail('index_reduce', ''), # .item() call
# ---------------------------------------------------------------------
@ -1179,6 +1184,7 @@ class TestOperators(TestCase):
xfail('_segment_reduce', 'offsets'),
xfail('_segment_reduce', 'lengths'),
xfail('sparse.sampled_addmm', ''),
xfail('sparse.mm', 'reduce'),
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
xfail("native_dropout_backward"),
@ -1252,6 +1258,7 @@ class TestOperators(TestCase):
xfail('nn.functional.dropout3d', ''),
xfail('as_strided_scatter', ''),
xfail('sparse.sampled_addmm', ''),
xfail('sparse.mm', 'reduce'),
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
xfail('as_strided', 'partial_views'),
@ -1350,6 +1357,7 @@ class TestOperators(TestCase):
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
xfail('_segment_reduce', 'offsets'), # NYI: forward-AD for _segment_reduce
xfail('sparse.mm', 'reduce'), # Sparse tensors have no strides
xfail('index_reduce', ''), # NYI: forward-AD for index_reduce
xfail('_segment_reduce', 'lengths'), # NYI: forward-AD for _segment_reduce
xfail('native_dropout_backward'), # NYI
@ -1505,6 +1513,7 @@ class TestOperators(TestCase):
xfail('_segment_reduce', 'lengths'), # Forward AD not implemented and no decomposition
xfail('_segment_reduce', 'offsets'), # Forward AD not implemented and no decomposition
xfail('sparse.sampled_addmm'), # RuntimeError: Sparse CSR tensors do not have strides
xfail('sparse.mm', 'reduce'), # RuntimeError: Sparse CSR tensors do not have strides
xfail('svd_lowrank'), # calls random op
xfail('take'), # vmap: inplace into regular tensor
xfail('to'), # RuntimeError: required rank 4 tensor to use channels_last format
@ -1753,6 +1762,7 @@ class TestOperators(TestCase):
skip('linalg.lu_factor_ex', dtypes=(torch.float32,), device_type='cuda'), # fails on all but windows
skip('linalg.multi_dot', '', device_type='cpu'),
skip('sparse.sampled_addmm', ''),
skip('sparse.mm', 'reduce'),
skip('native_layer_norm', '', device_type='cpu'),
})
@opsToleranceOverride('TestOperators', 'test_vmap_autograd_grad', (

View File

@ -3475,6 +3475,7 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('pca_lowrank', ''), # random operation
xfail('svd_lowrank', ''), # random operation
xfail('sparse.sampled_addmm'), # sparse
xfail('sparse.mm', 'reduce'), # sparse
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable autograd.Function
skip('_softmax_backward_data'),
skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test
@ -3701,6 +3702,7 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('clamp_min', ''),
xfail('special.bessel_j0'),
xfail('sparse.sampled_addmm'),
xfail('sparse.mm', 'reduce'),
xfail('special.bessel_y0'),
xfail('special.chebyshev_polynomial_u'),
xfail('special.modified_bessel_k1'),

View File

@ -250,6 +250,7 @@ inductor_expected_failures_single_sample["cpu"] = {
"scatter_reduce.prod": {f16, f32, f64},
"_segment_reduce.lengths": {f16, f32, f64},
"sparse.sampled_addmm": {f32, f64},
"sparse.mm.reduce": {bf16, f32, f64},
"stft": {f32, f64},
"tensor_split": {b8, f16, f32, f64, i32, i64},
"to_sparse": {f32, f64},

View File

@ -1189,6 +1189,7 @@ make_fx_failures = {
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
xfail('sparse.mm', 'reduce'),
# proxy tensor doesn't support sparse correctly right now
skip('to_sparse'),

View File

@ -2394,6 +2394,100 @@ class TestSparseCSR(TestCase):
with self.assertRaisesRegex(RuntimeError, r"Expected mat2 to have strided layout"):
torch.sparse.sampled_addmm(a_sparse, a, a_sparse)
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
def test_sparse_mm_reduce_sum(self, device, dtype):
def run_test(m, n, k, nnz, train):
sparse = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=torch.int64)
dense = sparse.to_dense()
mat = torch.randn(k, n, dtype=dtype)
ref_mat = mat.clone()
if train:
sparse.requires_grad_()
mat.requires_grad_()
dense.requires_grad_()
ref_mat.requires_grad_()
ref_out = torch.mm(dense, ref_mat)
out = torch.sparse.mm(sparse, mat, 'sum')
self.assertEqual(out, ref_out)
if train:
ref_out.sum().backward()
out.sum().backward()
grad_input = sparse.grad
ref_grad_input = dense.grad
grad_mat = mat.grad
ref_grad_mat = ref_mat.grad
self.assertEqual(grad_input.to_dense(), ref_grad_input)
self.assertEqual(grad_mat, ref_grad_mat)
run_test(4, 5, 4, 10, False)
run_test(4, 4, 4, 16, True)
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
def test_sparse_mm_reduce(self, device, dtype):
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
mat = torch.randn(n, k, dtype=dtype)
ref_mat = mat.clone()
ref_values = csr.values().clone()
out_int32 = index_dtype == torch.int32
coo_indices = torch._convert_indices_from_csr_to_coo(
csr.crow_indices(),
csr.col_indices(),
out_int32=out_int32)
row, col = coo_indices[0], coo_indices[1]
def ref(row, col, val, mat):
out = torch.zeros([m, k], dtype=dtype)
weight = mat.index_select(0, col)
src = weight.mul(val.view(-1, 1))
index = row.view(-1, 1).expand_as(weight)
index = index.to(dtype=torch.int64)
# scatter_reduce expect index to be int64
out.scatter_reduce_(0, index, src, reduce=reduce_type, include_self=False)
return out
if train:
csr.requires_grad_()
mat.requires_grad_()
ref_values.requires_grad_()
ref_mat.requires_grad_()
ref_out = ref(row, col, ref_values, ref_mat)
out = torch.sparse.mm(csr, mat, reduce_type)
self.assertEqual(out, ref_out)
if train and dtype is not torch.bfloat16:
ref_out.sum().backward()
out.sum().backward()
grad_values = csr.grad.values()
grad_weight = mat.grad
ref_grad_values = ref_values.grad
ref_grad_weight = ref_mat.grad
self.assertEqual(grad_values, ref_grad_values)
self.assertEqual(grad_weight, ref_grad_weight)
for train in [False, True]:
for index_dtype in [torch.int32, torch.int64]:
for reduce_type in ["sum", "mean", "amax", "amin"]:
# by setting nnz < M, create empty rows
run_test(3, 4, 11, 1, reduce_type, index_dtype, train)
run_test(3, 4, 11, 6, reduce_type, index_dtype, train)
run_test(3, 4, 11, 12, reduce_type, index_dtype, train)
# we are doing blocking with 4x vector length in the kernel,
# so need to test when K > 4x vector length
run_test(4, 7, 33, 13, reduce_type, index_dtype, train)
@skipMeta
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_coo_csr_conversion(self, device, dtype):

View File

@ -2415,6 +2415,10 @@
mat1: maybe_multiply(grad.sparse_mask(self).mm(mat2.mH()), alpha.conj())
mat2: maybe_multiply(mat1.mH().mm(grad.sparse_mask(self)), alpha.conj())
- name: _sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor)
output_differentiability: [True, False]
self, other: "grad.defined() ? _sparse_mm_reduce_impl_backward(self, grad, other, reduce, result1, grad_input_mask) : std::tuple<Tensor, Tensor>()"
- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor
grad_output: smooth_l1_loss_backward(grad, self, target, reduction, beta)
self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta)

View File

@ -61,41 +61,65 @@ mm = _add_docstr(_sparse._sparse_mm, r"""
.. note::
This function doesn't support computing derivaties with respect to CSR matrices.
Args:
mat1 (Tensor): the first sparse matrix to be multiplied
mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense
This function also additionally accepts an optional :attr:`reduce` argument that allows
specification of an optional reduction operation, mathematically performs the following operation:
Shape:
The format of the output tensor of this function follows:
- sparse x sparse -> sparse
- sparse x dense -> dense
.. math::
Example::
z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj}
>>> a = torch.randn(2, 3).to_sparse().requires_grad_(True)
>>> a
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
[0, 1, 2, 0, 1, 2]]),
values=tensor([ 1.5901, 0.0183, -0.6146, 1.8061, -0.0112, 0.6302]),
size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True)
where :math:`\bigoplus` defines the reduce operator. :attr:`reduce` is implemented only for
CSR storage format on CPU device.
>>> b = torch.randn(3, 2, requires_grad=True)
>>> b
tensor([[-0.6479, 0.7874],
[-1.2056, 0.5641],
[-1.1716, -0.9923]], requires_grad=True)
Args:
mat1 (Tensor): the first sparse matrix to be multiplied
mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense
reduce (str, optional): the reduction operation to apply for non-unique indices
(:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`). Default :obj:`"sum"`.
>>> y = torch.sparse.mm(a, b)
>>> y
tensor([[-0.3323, 1.8723],
[-1.8951, 0.7904]], grad_fn=<SparseAddmmBackward>)
>>> y.sum().backward()
>>> a.grad
tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
[0, 1, 2, 0, 1, 2]]),
values=tensor([ 0.1394, -0.6415, -2.1639, 0.1394, -0.6415, -2.1639]),
size=(2, 3), nnz=6, layout=torch.sparse_coo)
""")
Shape:
The format of the output tensor of this function follows:
- sparse x sparse -> sparse
- sparse x dense -> dense
Example::
>>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_()
>>> a
tensor(indices=tensor([[0, 0, 1],
[0, 2, 1]]),
values=tensor([1., 2., 3.]),
size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True)
>>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True)
>>> b
tensor([[0., 1.],
[2., 0.],
[0., 0.]], requires_grad=True)
>>> y = torch.sparse.mm(a, b)
>>> y
tensor([[0., 1.],
[6., 0.]], grad_fn=<SparseAddmmBackward0>)
>>> y.sum().backward()
>>> a.grad
tensor(indices=tensor([[0, 0, 1],
[0, 2, 1]]),
values=tensor([1., 0., 2.]),
size=(2, 3), nnz=3, layout=torch.sparse_coo)
>>> c = a.detach().to_sparse_csr()
>>> c
tensor(crow_indices=tensor([0, 2, 3]),
col_indices=tensor([0, 2, 1]),
values=tensor([1., 2., 3.]), size=(2, 3), nnz=3,
layout=torch.sparse_csr)
>>> y1 = torch.sparse.mm(c, b, 'sum')
>>> y1
tensor([[0., 1.],
[6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
>>> y2 = torch.sparse.mm(c, b, 'max')
>>> y2
tensor([[0., 1.],
[6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
""")
sampled_addmm = _add_docstr(_sparse.sparse_sampled_addmm, r"""
@ -149,7 +173,6 @@ Examples::
size=(3, 3), nnz=3, layout=torch.sparse_csr)
""")
def sum(input: Tensor, dim: DimOrDims = None,
dtype: Optional[DType] = None) -> Tensor:
r"""

View File

@ -21,7 +21,7 @@ from torch.testing._internal.common_dtype import (
all_types, empty_types, complex_types_and, integral_types
)
from torch.testing._internal.common_device_type import \
(onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride,
skipCPUIfNoMklSparse,
toleranceOverride, tol)
@ -1050,6 +1050,21 @@ def sample_inputs_sparse_sampled_addmm(op_info, device, dtype, requires_grad, **
beta=beta,
)
def sample_inputs_sparse_mm_reduce(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
reductions = ["sum", "mean", "amax", "amin"]
for m, k, reduce in product([5, 7], [3, 11], reductions):
yield SampleInput(
torch.eye(m, m)
.to(device=device, dtype=dtype)
.to_sparse_csr()
.requires_grad_(requires_grad),
make_arg((m, k)),
reduce,
)
def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
yield SampleInput(make_arg(S, M), make_arg(M))
@ -10392,6 +10407,47 @@ op_db: List[OpInfo] = [
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
)),
OpInfo('sparse.mm',
dtypes=floating_types_and(torch.bfloat16),
variant_test_name='reduce',
supports_autograd=True,
supports_out=False,
supports_gradgrad=False,
supports_forward_ad=False,
sample_inputs_func=sample_inputs_sparse_mm_reduce,
decorators=[onlyCPU],
skips=(
# NotImplementedError: Tensors of type SparseCsrTensorImpl do not have is_contiguous
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
# RuntimeError: Sparse CSR tensors do not have strides.
DecorateInfo(unittest.skip("Skipped!"), 'TestTags', 'test_tags'),
# RuntimeError: Sparse CSR tensors do not have strides
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
# RuntimeError: Sparse CSR tensors do not have strides
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_operator'),
# RuntimeError: Sparse CSR tensors do not have strides
DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_backward'),
# RuntimeError: Sparse CSR tensors do not have strides
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
# RuntimeError: Sparse CSR tensors do not have strides
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
# RuntimeError: Sparse CSR tensors do not have strides
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
# RuntimeError: Sparse CSR tensors do not have strides
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# RuntimeError: unsupported memory format option Preserve
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_fn_fwgrad_bwgrad'),
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'),
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
# GradcheckError: gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'),
)),
UnaryUfuncInfo('i0',
ref=np_unary_ufunc_integer_promotion_wrapper(
scipy.special.i0) if TEST_SCIPY else None,