Beta value is ignored for sparse torch.addmm with non-MKL build (#72430)

Summary:
When PyTorch is not built with MKL or on Windows there's a native implementation of `torch.addmm` for tensors on CPU. There was a bug that `beta` value was ignored, causing new tests to fail (see https://github.com/pytorch/pytorch/pull/71949#issuecomment-1024639741).

In addition, I also enabled complex numbers support for this code path.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72430

Reviewed By: davidberard98

Differential Revision: D34045670

Pulled By: cpuhrsch

fbshipit-source-id: b2b63f22ba3eea895a31c5c2925b0fb1555d2c6f
(cherry picked from commit ac0a2080bb)
This commit is contained in:
Ivan Yashchuk 2022-02-08 16:26:28 -08:00 committed by PyTorch MergeBot
parent 8bb1d06702
commit ad5a5a9794
2 changed files with 2 additions and 3 deletions

View File

@ -346,7 +346,7 @@ void addmm_out_sparse_csr_native_cpu(const Tensor& sparse, const Tensor& dense,
auto values = sparse.values(); auto values = sparse.values();
scalar_t cast_alpha = alpha.to<scalar_t>(); scalar_t cast_alpha = alpha.to<scalar_t>();
scalar_t cast_beta = beta.to<scalar_t>(); r.mul_(beta);
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_mm_crow_indices", [&]() { AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_mm_crow_indices", [&]() {
auto csr_accessor = csr.accessor<index_t, 1>(); auto csr_accessor = csr.accessor<index_t, 1>();
auto col_indices_accessor = col_indices.accessor<index_t, 1>(); auto col_indices_accessor = col_indices.accessor<index_t, 1>();
@ -470,7 +470,7 @@ Tensor& addmm_out_sparse_csr_cpu(
"Please use PyTorch built with MKL on Linux."); "Please use PyTorch built with MKL on Linux.");
} }
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.layout() == kStrided); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.layout() == kStrided);
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "addmm_sparse_dense", [&] { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "addmm_sparse_dense", [&] {
addmm_out_sparse_csr_native_cpu<scalar_t>(mat1, mat2, result, alpha, beta); addmm_out_sparse_csr_native_cpu<scalar_t>(mat1, mat2, result, alpha, beta);
}); });
#else #else

View File

@ -826,7 +826,6 @@ class TestSparseCSR(TestCase):
test_shape(7, 8, 9, 20, False, index_dtype) test_shape(7, 8, 9, 20, False, index_dtype)
test_shape(7, 8, 9, 20, True, index_dtype) test_shape(7, 8, 9, 20, True, index_dtype)
@skipCPUIfNoMklSparse
@dtypes(*floating_and_complex_types()) @dtypes(*floating_and_complex_types())
@dtypesIfCUDA(*get_all_complex_dtypes(), @dtypesIfCUDA(*get_all_complex_dtypes(),
*get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC, *get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC,