mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Beta value is ignored for sparse torch.addmm with non-MKL build (#72430)
Summary:
When PyTorch is not built with MKL or on Windows there's a native implementation of `torch.addmm` for tensors on CPU. There was a bug that `beta` value was ignored, causing new tests to fail (see https://github.com/pytorch/pytorch/pull/71949#issuecomment-1024639741).
In addition, I also enabled complex numbers support for this code path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72430
Reviewed By: davidberard98
Differential Revision: D34045670
Pulled By: cpuhrsch
fbshipit-source-id: b2b63f22ba3eea895a31c5c2925b0fb1555d2c6f
(cherry picked from commit ac0a2080bb)
This commit is contained in:
parent
8bb1d06702
commit
ad5a5a9794
|
|
@ -346,7 +346,7 @@ void addmm_out_sparse_csr_native_cpu(const Tensor& sparse, const Tensor& dense,
|
|||
auto values = sparse.values();
|
||||
|
||||
scalar_t cast_alpha = alpha.to<scalar_t>();
|
||||
scalar_t cast_beta = beta.to<scalar_t>();
|
||||
r.mul_(beta);
|
||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_mm_crow_indices", [&]() {
|
||||
auto csr_accessor = csr.accessor<index_t, 1>();
|
||||
auto col_indices_accessor = col_indices.accessor<index_t, 1>();
|
||||
|
|
@ -470,7 +470,7 @@ Tensor& addmm_out_sparse_csr_cpu(
|
|||
"Please use PyTorch built with MKL on Linux.");
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.layout() == kStrided);
|
||||
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "addmm_sparse_dense", [&] {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "addmm_sparse_dense", [&] {
|
||||
addmm_out_sparse_csr_native_cpu<scalar_t>(mat1, mat2, result, alpha, beta);
|
||||
});
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -826,7 +826,6 @@ class TestSparseCSR(TestCase):
|
|||
test_shape(7, 8, 9, 20, False, index_dtype)
|
||||
test_shape(7, 8, 9, 20, True, index_dtype)
|
||||
|
||||
@skipCPUIfNoMklSparse
|
||||
@dtypes(*floating_and_complex_types())
|
||||
@dtypesIfCUDA(*get_all_complex_dtypes(),
|
||||
*get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user