mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
sparse.mm.backward: fix for non-contiguous grad values on CPU (#106127)
Fixes https://github.com/pytorch/pytorch/issues/102493. The problem was that the backward implementation assumed inputs to be contiguous. This might supersede https://github.com/pytorch/pytorch/pull/104520. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106127 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
93b2036bef
commit
01069ad4be
|
|
@ -50,17 +50,14 @@ void csr_to_coo(const int64_t n_row, const int64_t Ap[], int64_t Bi[]) {
|
|||
}
|
||||
}
|
||||
|
||||
template<typename index_t_ptr = int64_t*>
|
||||
int64_t _csr_matmult_maxnnz(
|
||||
const int64_t n_row,
|
||||
const int64_t n_col,
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
const int64_t Ap[],
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
const int64_t Aj[],
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
const int64_t Bp[],
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
const int64_t Bj[]) {
|
||||
const index_t_ptr Ap,
|
||||
const index_t_ptr Aj,
|
||||
const index_t_ptr Bp,
|
||||
const index_t_ptr Bj) {
|
||||
/*
|
||||
Compute needed buffer size for matrix `C` in `C = A@B` operation.
|
||||
|
||||
|
|
@ -88,28 +85,22 @@ int64_t _csr_matmult_maxnnz(
|
|||
return nnz;
|
||||
}
|
||||
|
||||
template<class scalar_t>
|
||||
template<typename index_t_ptr, typename scalar_t_ptr>
|
||||
void _csr_matmult(
|
||||
const int64_t n_row,
|
||||
const int64_t n_col,
|
||||
const index_t_ptr Ap,
|
||||
const index_t_ptr Aj,
|
||||
const scalar_t_ptr Ax,
|
||||
const index_t_ptr Bp,
|
||||
const index_t_ptr Bj,
|
||||
const scalar_t_ptr Bx,
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
const int64_t Ap[],
|
||||
typename index_t_ptr::value_type Cp[],
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
const int64_t Aj[],
|
||||
typename index_t_ptr::value_type Cj[],
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
const scalar_t Ax[],
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
const int64_t Bp[],
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
const int64_t Bj[],
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
const scalar_t Bx[],
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
int64_t Cp[],
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
int64_t Cj[],
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||
scalar_t Cx[]) {
|
||||
typename scalar_t_ptr::value_type Cx[]) {
|
||||
/*
|
||||
Compute CSR entries for matrix C = A@B.
|
||||
|
||||
|
|
@ -133,7 +124,10 @@ void _csr_matmult(
|
|||
Note:
|
||||
Output arrays Cp, Cj, and Cx must be preallocated
|
||||
*/
|
||||
std::vector<int64_t> next(n_col, -1);
|
||||
using index_t = typename index_t_ptr::value_type;
|
||||
using scalar_t = typename scalar_t_ptr::value_type;
|
||||
|
||||
std::vector<index_t> next(n_col, -1);
|
||||
std::vector<scalar_t> sums(n_col, 0);
|
||||
|
||||
int64_t nnz = 0;
|
||||
|
|
@ -141,19 +135,19 @@ void _csr_matmult(
|
|||
Cp[0] = 0;
|
||||
|
||||
for (const auto i : c10::irange(n_row)) {
|
||||
int64_t head = -2;
|
||||
int64_t length = 0;
|
||||
index_t head = -2;
|
||||
index_t length = 0;
|
||||
|
||||
int64_t jj_start = Ap[i];
|
||||
int64_t jj_end = Ap[i + 1];
|
||||
index_t jj_start = Ap[i];
|
||||
index_t jj_end = Ap[i + 1];
|
||||
for (const auto jj : c10::irange(jj_start, jj_end)) {
|
||||
int64_t j = Aj[jj];
|
||||
index_t j = Aj[jj];
|
||||
scalar_t v = Ax[jj];
|
||||
|
||||
int64_t kk_start = Bp[j];
|
||||
int64_t kk_end = Bp[j + 1];
|
||||
index_t kk_start = Bp[j];
|
||||
index_t kk_end = Bp[j + 1];
|
||||
for (const auto kk : c10::irange(kk_start, kk_end)) {
|
||||
int64_t k = Bj[kk];
|
||||
index_t k = Bj[kk];
|
||||
|
||||
sums[k] += v * Bx[kk];
|
||||
|
||||
|
|
@ -174,7 +168,7 @@ void _csr_matmult(
|
|||
Cx[nnz] = sums[head];
|
||||
nnz++;
|
||||
|
||||
int64_t temp = head;
|
||||
index_t temp = head;
|
||||
head = next[head];
|
||||
|
||||
next[temp] = -1; // clear arrays
|
||||
|
|
@ -183,6 +177,7 @@ void _csr_matmult(
|
|||
|
||||
// Make sure that col indices are sorted.
|
||||
// TODO: a better approach is to implement a CSR @ CSC kernel.
|
||||
// NOTE: Cx arrays are expected to be contiguous!
|
||||
auto col_indices_accessor = StridedRandomAccessor<int64_t>(Cj + nnz - length, 1);
|
||||
auto val_accessor = StridedRandomAccessor<scalar_t>(Cx + nnz - length, 1);
|
||||
auto kv_accessor = CompositeRandomAccessorCPU<
|
||||
|
|
@ -212,13 +207,32 @@ void sparse_matmul_kernel(
|
|||
const auto mat1_csr = mat1.to_sparse_csr();
|
||||
const auto mat2_csr = mat2.to_sparse_csr();
|
||||
|
||||
auto mat1_crow_indices_ptr = StridedRandomAccessor<int64_t>(
|
||||
mat1_csr.crow_indices().data_ptr<int64_t>(),
|
||||
mat1_csr.crow_indices().stride(-1));
|
||||
auto mat1_col_indices_ptr = StridedRandomAccessor<int64_t>(
|
||||
mat1_csr.col_indices().data_ptr<int64_t>(),
|
||||
mat1_csr.col_indices().stride(-1));
|
||||
auto mat1_values_ptr = StridedRandomAccessor<scalar_t>(
|
||||
mat1_csr.values().data_ptr<scalar_t>(),
|
||||
mat1_csr.values().stride(-1));
|
||||
auto mat2_crow_indices_ptr = StridedRandomAccessor<int64_t>(
|
||||
mat2_csr.crow_indices().data_ptr<int64_t>(),
|
||||
mat2_csr.crow_indices().stride(-1));
|
||||
auto mat2_col_indices_ptr = StridedRandomAccessor<int64_t>(
|
||||
mat2_csr.col_indices().data_ptr<int64_t>(),
|
||||
mat2_csr.col_indices().stride(-1));
|
||||
auto mat2_values_ptr = StridedRandomAccessor<scalar_t>(
|
||||
mat2_csr.values().data_ptr<scalar_t>(),
|
||||
mat2_csr.values().stride(-1));
|
||||
|
||||
const auto nnz = _csr_matmult_maxnnz(
|
||||
M,
|
||||
N,
|
||||
mat1_csr.crow_indices().data_ptr<int64_t>(),
|
||||
mat1_csr.col_indices().data_ptr<int64_t>(),
|
||||
mat2_csr.crow_indices().data_ptr<int64_t>(),
|
||||
mat2_csr.col_indices().data_ptr<int64_t>());
|
||||
mat1_crow_indices_ptr,
|
||||
mat1_col_indices_ptr,
|
||||
mat2_crow_indices_ptr,
|
||||
mat2_col_indices_ptr);
|
||||
|
||||
auto output_indices = output._indices();
|
||||
auto output_values = output._values();
|
||||
|
|
@ -234,12 +248,12 @@ void sparse_matmul_kernel(
|
|||
_csr_matmult(
|
||||
M,
|
||||
N,
|
||||
mat1_csr.crow_indices().data_ptr<int64_t>(),
|
||||
mat1_csr.col_indices().data_ptr<int64_t>(),
|
||||
mat1_csr.values().data_ptr<scalar_t>(),
|
||||
mat2_csr.crow_indices().data_ptr<int64_t>(),
|
||||
mat2_csr.col_indices().data_ptr<int64_t>(),
|
||||
mat2_csr.values().data_ptr<scalar_t>(),
|
||||
mat1_crow_indices_ptr,
|
||||
mat1_col_indices_ptr,
|
||||
mat1_values_ptr,
|
||||
mat2_crow_indices_ptr,
|
||||
mat2_col_indices_ptr,
|
||||
mat2_values_ptr,
|
||||
output_indptr.data_ptr<int64_t>(),
|
||||
output_col_indices.data_ptr<int64_t>(),
|
||||
output_values.data_ptr<scalar_t>());
|
||||
|
|
|
|||
|
|
@ -3700,6 +3700,17 @@ class TestSparse(TestSparseBase):
|
|||
|
||||
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)
|
||||
|
||||
def test_backward_noncontiguous():
|
||||
# Sparse.mm backward used to wrong with non-contiguous grads,
|
||||
# see https://github.com/pytorch/pytorch/issues/102493.
|
||||
n_reps = 7
|
||||
for _ in range(n_reps):
|
||||
A = torch.eye(5).to_sparse().requires_grad_(True)
|
||||
B = torch.eye(5).to_sparse()
|
||||
out = torch.sparse.mm(A, B)
|
||||
out.coalesce().values().sum().backward()
|
||||
self.assertEqual(A.grad, A)
|
||||
|
||||
for n in range(2, 5):
|
||||
for m in range(2, 8):
|
||||
for p in range(2, 8):
|
||||
|
|
@ -3708,6 +3719,7 @@ class TestSparse(TestSparseBase):
|
|||
test_sparse_matmul(2, 0, [0, 0], [0, 0])
|
||||
test_sparse_matmul(2, 0, [0, 10], [10, 0])
|
||||
test_error_cases()
|
||||
test_backward_noncontiguous()
|
||||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user