sparse.mm.backward: fix for non-contiguous grad values on CPU (#106127)

Fixes https://github.com/pytorch/pytorch/issues/102493.
The problem was that the backward implementation assumed inputs to be contiguous.
This might supersede https://github.com/pytorch/pytorch/pull/104520.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106127
Approved by: https://github.com/cpuhrsch
This commit is contained in:
nikitaved 2023-07-27 14:11:36 +02:00 committed by PyTorch MergeBot
parent 93b2036bef
commit 01069ad4be
2 changed files with 70 additions and 44 deletions

View File

@ -50,17 +50,14 @@ void csr_to_coo(const int64_t n_row, const int64_t Ap[], int64_t Bi[]) {
}
}
template<typename index_t_ptr = int64_t*>
int64_t _csr_matmult_maxnnz(
const int64_t n_row,
const int64_t n_col,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Ap[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Aj[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Bp[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Bj[]) {
const index_t_ptr Ap,
const index_t_ptr Aj,
const index_t_ptr Bp,
const index_t_ptr Bj) {
/*
Compute needed buffer size for matrix `C` in `C = A@B` operation.
@ -88,28 +85,22 @@ int64_t _csr_matmult_maxnnz(
return nnz;
}
template<class scalar_t>
template<typename index_t_ptr, typename scalar_t_ptr>
void _csr_matmult(
const int64_t n_row,
const int64_t n_col,
const index_t_ptr Ap,
const index_t_ptr Aj,
const scalar_t_ptr Ax,
const index_t_ptr Bp,
const index_t_ptr Bj,
const scalar_t_ptr Bx,
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Ap[],
typename index_t_ptr::value_type Cp[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Aj[],
typename index_t_ptr::value_type Cj[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const scalar_t Ax[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Bp[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const int64_t Bj[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
const scalar_t Bx[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t Cp[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t Cj[],
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
scalar_t Cx[]) {
typename scalar_t_ptr::value_type Cx[]) {
/*
Compute CSR entries for matrix C = A@B.
@ -133,7 +124,10 @@ void _csr_matmult(
Note:
Output arrays Cp, Cj, and Cx must be preallocated
*/
std::vector<int64_t> next(n_col, -1);
using index_t = typename index_t_ptr::value_type;
using scalar_t = typename scalar_t_ptr::value_type;
std::vector<index_t> next(n_col, -1);
std::vector<scalar_t> sums(n_col, 0);
int64_t nnz = 0;
@ -141,19 +135,19 @@ void _csr_matmult(
Cp[0] = 0;
for (const auto i : c10::irange(n_row)) {
int64_t head = -2;
int64_t length = 0;
index_t head = -2;
index_t length = 0;
int64_t jj_start = Ap[i];
int64_t jj_end = Ap[i + 1];
index_t jj_start = Ap[i];
index_t jj_end = Ap[i + 1];
for (const auto jj : c10::irange(jj_start, jj_end)) {
int64_t j = Aj[jj];
index_t j = Aj[jj];
scalar_t v = Ax[jj];
int64_t kk_start = Bp[j];
int64_t kk_end = Bp[j + 1];
index_t kk_start = Bp[j];
index_t kk_end = Bp[j + 1];
for (const auto kk : c10::irange(kk_start, kk_end)) {
int64_t k = Bj[kk];
index_t k = Bj[kk];
sums[k] += v * Bx[kk];
@ -174,7 +168,7 @@ void _csr_matmult(
Cx[nnz] = sums[head];
nnz++;
int64_t temp = head;
index_t temp = head;
head = next[head];
next[temp] = -1; // clear arrays
@ -183,6 +177,7 @@ void _csr_matmult(
// Make sure that col indices are sorted.
// TODO: a better approach is to implement a CSR @ CSC kernel.
// NOTE: Cx arrays are expected to be contiguous!
auto col_indices_accessor = StridedRandomAccessor<int64_t>(Cj + nnz - length, 1);
auto val_accessor = StridedRandomAccessor<scalar_t>(Cx + nnz - length, 1);
auto kv_accessor = CompositeRandomAccessorCPU<
@ -212,13 +207,32 @@ void sparse_matmul_kernel(
const auto mat1_csr = mat1.to_sparse_csr();
const auto mat2_csr = mat2.to_sparse_csr();
auto mat1_crow_indices_ptr = StridedRandomAccessor<int64_t>(
mat1_csr.crow_indices().data_ptr<int64_t>(),
mat1_csr.crow_indices().stride(-1));
auto mat1_col_indices_ptr = StridedRandomAccessor<int64_t>(
mat1_csr.col_indices().data_ptr<int64_t>(),
mat1_csr.col_indices().stride(-1));
auto mat1_values_ptr = StridedRandomAccessor<scalar_t>(
mat1_csr.values().data_ptr<scalar_t>(),
mat1_csr.values().stride(-1));
auto mat2_crow_indices_ptr = StridedRandomAccessor<int64_t>(
mat2_csr.crow_indices().data_ptr<int64_t>(),
mat2_csr.crow_indices().stride(-1));
auto mat2_col_indices_ptr = StridedRandomAccessor<int64_t>(
mat2_csr.col_indices().data_ptr<int64_t>(),
mat2_csr.col_indices().stride(-1));
auto mat2_values_ptr = StridedRandomAccessor<scalar_t>(
mat2_csr.values().data_ptr<scalar_t>(),
mat2_csr.values().stride(-1));
const auto nnz = _csr_matmult_maxnnz(
M,
N,
mat1_csr.crow_indices().data_ptr<int64_t>(),
mat1_csr.col_indices().data_ptr<int64_t>(),
mat2_csr.crow_indices().data_ptr<int64_t>(),
mat2_csr.col_indices().data_ptr<int64_t>());
mat1_crow_indices_ptr,
mat1_col_indices_ptr,
mat2_crow_indices_ptr,
mat2_col_indices_ptr);
auto output_indices = output._indices();
auto output_values = output._values();
@ -234,12 +248,12 @@ void sparse_matmul_kernel(
_csr_matmult(
M,
N,
mat1_csr.crow_indices().data_ptr<int64_t>(),
mat1_csr.col_indices().data_ptr<int64_t>(),
mat1_csr.values().data_ptr<scalar_t>(),
mat2_csr.crow_indices().data_ptr<int64_t>(),
mat2_csr.col_indices().data_ptr<int64_t>(),
mat2_csr.values().data_ptr<scalar_t>(),
mat1_crow_indices_ptr,
mat1_col_indices_ptr,
mat1_values_ptr,
mat2_crow_indices_ptr,
mat2_col_indices_ptr,
mat2_values_ptr,
output_indptr.data_ptr<int64_t>(),
output_col_indices.data_ptr<int64_t>(),
output_values.data_ptr<scalar_t>());

View File

@ -3700,6 +3700,17 @@ class TestSparse(TestSparseBase):
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)
def test_backward_noncontiguous():
# Sparse.mm backward used to wrong with non-contiguous grads,
# see https://github.com/pytorch/pytorch/issues/102493.
n_reps = 7
for _ in range(n_reps):
A = torch.eye(5).to_sparse().requires_grad_(True)
B = torch.eye(5).to_sparse()
out = torch.sparse.mm(A, B)
out.coalesce().values().sum().backward()
self.assertEqual(A.grad, A)
for n in range(2, 5):
for m in range(2, 8):
for p in range(2, 8):
@ -3708,6 +3719,7 @@ class TestSparse(TestSparseBase):
test_sparse_matmul(2, 0, [0, 0], [0, 0])
test_sparse_matmul(2, 0, [0, 10], [10, 0])
test_error_cases()
test_backward_noncontiguous()
@coalescedonoff
@dtypes(torch.double)