mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
sparse.mm.backward: fix for non-contiguous grad values on CPU (#106127)
Fixes https://github.com/pytorch/pytorch/issues/102493. The problem was that the backward implementation assumed inputs to be contiguous. This might supersede https://github.com/pytorch/pytorch/pull/104520. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106127 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
93b2036bef
commit
01069ad4be
|
|
@ -50,17 +50,14 @@ void csr_to_coo(const int64_t n_row, const int64_t Ap[], int64_t Bi[]) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename index_t_ptr = int64_t*>
|
||||||
int64_t _csr_matmult_maxnnz(
|
int64_t _csr_matmult_maxnnz(
|
||||||
const int64_t n_row,
|
const int64_t n_row,
|
||||||
const int64_t n_col,
|
const int64_t n_col,
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
const index_t_ptr Ap,
|
||||||
const int64_t Ap[],
|
const index_t_ptr Aj,
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
const index_t_ptr Bp,
|
||||||
const int64_t Aj[],
|
const index_t_ptr Bj) {
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
||||||
const int64_t Bp[],
|
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
||||||
const int64_t Bj[]) {
|
|
||||||
/*
|
/*
|
||||||
Compute needed buffer size for matrix `C` in `C = A@B` operation.
|
Compute needed buffer size for matrix `C` in `C = A@B` operation.
|
||||||
|
|
||||||
|
|
@ -88,28 +85,22 @@ int64_t _csr_matmult_maxnnz(
|
||||||
return nnz;
|
return nnz;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class scalar_t>
|
template<typename index_t_ptr, typename scalar_t_ptr>
|
||||||
void _csr_matmult(
|
void _csr_matmult(
|
||||||
const int64_t n_row,
|
const int64_t n_row,
|
||||||
const int64_t n_col,
|
const int64_t n_col,
|
||||||
|
const index_t_ptr Ap,
|
||||||
|
const index_t_ptr Aj,
|
||||||
|
const scalar_t_ptr Ax,
|
||||||
|
const index_t_ptr Bp,
|
||||||
|
const index_t_ptr Bj,
|
||||||
|
const scalar_t_ptr Bx,
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||||
const int64_t Ap[],
|
typename index_t_ptr::value_type Cp[],
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||||
const int64_t Aj[],
|
typename index_t_ptr::value_type Cj[],
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
||||||
const scalar_t Ax[],
|
typename scalar_t_ptr::value_type Cx[]) {
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
||||||
const int64_t Bp[],
|
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
||||||
const int64_t Bj[],
|
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
||||||
const scalar_t Bx[],
|
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
||||||
int64_t Cp[],
|
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
||||||
int64_t Cj[],
|
|
||||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
|
|
||||||
scalar_t Cx[]) {
|
|
||||||
/*
|
/*
|
||||||
Compute CSR entries for matrix C = A@B.
|
Compute CSR entries for matrix C = A@B.
|
||||||
|
|
||||||
|
|
@ -133,7 +124,10 @@ void _csr_matmult(
|
||||||
Note:
|
Note:
|
||||||
Output arrays Cp, Cj, and Cx must be preallocated
|
Output arrays Cp, Cj, and Cx must be preallocated
|
||||||
*/
|
*/
|
||||||
std::vector<int64_t> next(n_col, -1);
|
using index_t = typename index_t_ptr::value_type;
|
||||||
|
using scalar_t = typename scalar_t_ptr::value_type;
|
||||||
|
|
||||||
|
std::vector<index_t> next(n_col, -1);
|
||||||
std::vector<scalar_t> sums(n_col, 0);
|
std::vector<scalar_t> sums(n_col, 0);
|
||||||
|
|
||||||
int64_t nnz = 0;
|
int64_t nnz = 0;
|
||||||
|
|
@ -141,19 +135,19 @@ void _csr_matmult(
|
||||||
Cp[0] = 0;
|
Cp[0] = 0;
|
||||||
|
|
||||||
for (const auto i : c10::irange(n_row)) {
|
for (const auto i : c10::irange(n_row)) {
|
||||||
int64_t head = -2;
|
index_t head = -2;
|
||||||
int64_t length = 0;
|
index_t length = 0;
|
||||||
|
|
||||||
int64_t jj_start = Ap[i];
|
index_t jj_start = Ap[i];
|
||||||
int64_t jj_end = Ap[i + 1];
|
index_t jj_end = Ap[i + 1];
|
||||||
for (const auto jj : c10::irange(jj_start, jj_end)) {
|
for (const auto jj : c10::irange(jj_start, jj_end)) {
|
||||||
int64_t j = Aj[jj];
|
index_t j = Aj[jj];
|
||||||
scalar_t v = Ax[jj];
|
scalar_t v = Ax[jj];
|
||||||
|
|
||||||
int64_t kk_start = Bp[j];
|
index_t kk_start = Bp[j];
|
||||||
int64_t kk_end = Bp[j + 1];
|
index_t kk_end = Bp[j + 1];
|
||||||
for (const auto kk : c10::irange(kk_start, kk_end)) {
|
for (const auto kk : c10::irange(kk_start, kk_end)) {
|
||||||
int64_t k = Bj[kk];
|
index_t k = Bj[kk];
|
||||||
|
|
||||||
sums[k] += v * Bx[kk];
|
sums[k] += v * Bx[kk];
|
||||||
|
|
||||||
|
|
@ -174,7 +168,7 @@ void _csr_matmult(
|
||||||
Cx[nnz] = sums[head];
|
Cx[nnz] = sums[head];
|
||||||
nnz++;
|
nnz++;
|
||||||
|
|
||||||
int64_t temp = head;
|
index_t temp = head;
|
||||||
head = next[head];
|
head = next[head];
|
||||||
|
|
||||||
next[temp] = -1; // clear arrays
|
next[temp] = -1; // clear arrays
|
||||||
|
|
@ -183,6 +177,7 @@ void _csr_matmult(
|
||||||
|
|
||||||
// Make sure that col indices are sorted.
|
// Make sure that col indices are sorted.
|
||||||
// TODO: a better approach is to implement a CSR @ CSC kernel.
|
// TODO: a better approach is to implement a CSR @ CSC kernel.
|
||||||
|
// NOTE: Cx arrays are expected to be contiguous!
|
||||||
auto col_indices_accessor = StridedRandomAccessor<int64_t>(Cj + nnz - length, 1);
|
auto col_indices_accessor = StridedRandomAccessor<int64_t>(Cj + nnz - length, 1);
|
||||||
auto val_accessor = StridedRandomAccessor<scalar_t>(Cx + nnz - length, 1);
|
auto val_accessor = StridedRandomAccessor<scalar_t>(Cx + nnz - length, 1);
|
||||||
auto kv_accessor = CompositeRandomAccessorCPU<
|
auto kv_accessor = CompositeRandomAccessorCPU<
|
||||||
|
|
@ -212,13 +207,32 @@ void sparse_matmul_kernel(
|
||||||
const auto mat1_csr = mat1.to_sparse_csr();
|
const auto mat1_csr = mat1.to_sparse_csr();
|
||||||
const auto mat2_csr = mat2.to_sparse_csr();
|
const auto mat2_csr = mat2.to_sparse_csr();
|
||||||
|
|
||||||
|
auto mat1_crow_indices_ptr = StridedRandomAccessor<int64_t>(
|
||||||
|
mat1_csr.crow_indices().data_ptr<int64_t>(),
|
||||||
|
mat1_csr.crow_indices().stride(-1));
|
||||||
|
auto mat1_col_indices_ptr = StridedRandomAccessor<int64_t>(
|
||||||
|
mat1_csr.col_indices().data_ptr<int64_t>(),
|
||||||
|
mat1_csr.col_indices().stride(-1));
|
||||||
|
auto mat1_values_ptr = StridedRandomAccessor<scalar_t>(
|
||||||
|
mat1_csr.values().data_ptr<scalar_t>(),
|
||||||
|
mat1_csr.values().stride(-1));
|
||||||
|
auto mat2_crow_indices_ptr = StridedRandomAccessor<int64_t>(
|
||||||
|
mat2_csr.crow_indices().data_ptr<int64_t>(),
|
||||||
|
mat2_csr.crow_indices().stride(-1));
|
||||||
|
auto mat2_col_indices_ptr = StridedRandomAccessor<int64_t>(
|
||||||
|
mat2_csr.col_indices().data_ptr<int64_t>(),
|
||||||
|
mat2_csr.col_indices().stride(-1));
|
||||||
|
auto mat2_values_ptr = StridedRandomAccessor<scalar_t>(
|
||||||
|
mat2_csr.values().data_ptr<scalar_t>(),
|
||||||
|
mat2_csr.values().stride(-1));
|
||||||
|
|
||||||
const auto nnz = _csr_matmult_maxnnz(
|
const auto nnz = _csr_matmult_maxnnz(
|
||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
mat1_csr.crow_indices().data_ptr<int64_t>(),
|
mat1_crow_indices_ptr,
|
||||||
mat1_csr.col_indices().data_ptr<int64_t>(),
|
mat1_col_indices_ptr,
|
||||||
mat2_csr.crow_indices().data_ptr<int64_t>(),
|
mat2_crow_indices_ptr,
|
||||||
mat2_csr.col_indices().data_ptr<int64_t>());
|
mat2_col_indices_ptr);
|
||||||
|
|
||||||
auto output_indices = output._indices();
|
auto output_indices = output._indices();
|
||||||
auto output_values = output._values();
|
auto output_values = output._values();
|
||||||
|
|
@ -234,12 +248,12 @@ void sparse_matmul_kernel(
|
||||||
_csr_matmult(
|
_csr_matmult(
|
||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
mat1_csr.crow_indices().data_ptr<int64_t>(),
|
mat1_crow_indices_ptr,
|
||||||
mat1_csr.col_indices().data_ptr<int64_t>(),
|
mat1_col_indices_ptr,
|
||||||
mat1_csr.values().data_ptr<scalar_t>(),
|
mat1_values_ptr,
|
||||||
mat2_csr.crow_indices().data_ptr<int64_t>(),
|
mat2_crow_indices_ptr,
|
||||||
mat2_csr.col_indices().data_ptr<int64_t>(),
|
mat2_col_indices_ptr,
|
||||||
mat2_csr.values().data_ptr<scalar_t>(),
|
mat2_values_ptr,
|
||||||
output_indptr.data_ptr<int64_t>(),
|
output_indptr.data_ptr<int64_t>(),
|
||||||
output_col_indices.data_ptr<int64_t>(),
|
output_col_indices.data_ptr<int64_t>(),
|
||||||
output_values.data_ptr<scalar_t>());
|
output_values.data_ptr<scalar_t>());
|
||||||
|
|
|
||||||
|
|
@ -3700,6 +3700,17 @@ class TestSparse(TestSparseBase):
|
||||||
|
|
||||||
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)
|
self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes)
|
||||||
|
|
||||||
|
def test_backward_noncontiguous():
|
||||||
|
# Sparse.mm backward used to wrong with non-contiguous grads,
|
||||||
|
# see https://github.com/pytorch/pytorch/issues/102493.
|
||||||
|
n_reps = 7
|
||||||
|
for _ in range(n_reps):
|
||||||
|
A = torch.eye(5).to_sparse().requires_grad_(True)
|
||||||
|
B = torch.eye(5).to_sparse()
|
||||||
|
out = torch.sparse.mm(A, B)
|
||||||
|
out.coalesce().values().sum().backward()
|
||||||
|
self.assertEqual(A.grad, A)
|
||||||
|
|
||||||
for n in range(2, 5):
|
for n in range(2, 5):
|
||||||
for m in range(2, 8):
|
for m in range(2, 8):
|
||||||
for p in range(2, 8):
|
for p in range(2, 8):
|
||||||
|
|
@ -3708,6 +3719,7 @@ class TestSparse(TestSparseBase):
|
||||||
test_sparse_matmul(2, 0, [0, 0], [0, 0])
|
test_sparse_matmul(2, 0, [0, 0], [0, 0])
|
||||||
test_sparse_matmul(2, 0, [0, 10], [10, 0])
|
test_sparse_matmul(2, 0, [0, 10], [10, 0])
|
||||||
test_error_cases()
|
test_error_cases()
|
||||||
|
test_backward_noncontiguous()
|
||||||
|
|
||||||
@coalescedonoff
|
@coalescedonoff
|
||||||
@dtypes(torch.double)
|
@dtypes(torch.double)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user