mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Beta value is ignored for sparse torch.addmm with non-MKL build (#72430)
Summary:
When PyTorch is not built with MKL or on Windows there's a native implementation of `torch.addmm` for tensors on CPU. There was a bug that `beta` value was ignored, causing new tests to fail (see https://github.com/pytorch/pytorch/pull/71949#issuecomment-1024639741).
In addition, I also enabled complex numbers support for this code path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72430
Reviewed By: davidberard98
Differential Revision: D34045670
Pulled By: cpuhrsch
fbshipit-source-id: b2b63f22ba3eea895a31c5c2925b0fb1555d2c6f
(cherry picked from commit ac0a2080bb)
This commit is contained in:
parent
8bb1d06702
commit
ad5a5a9794
|
|
@ -346,7 +346,7 @@ void addmm_out_sparse_csr_native_cpu(const Tensor& sparse, const Tensor& dense,
|
||||||
auto values = sparse.values();
|
auto values = sparse.values();
|
||||||
|
|
||||||
scalar_t cast_alpha = alpha.to<scalar_t>();
|
scalar_t cast_alpha = alpha.to<scalar_t>();
|
||||||
scalar_t cast_beta = beta.to<scalar_t>();
|
r.mul_(beta);
|
||||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_mm_crow_indices", [&]() {
|
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_mm_crow_indices", [&]() {
|
||||||
auto csr_accessor = csr.accessor<index_t, 1>();
|
auto csr_accessor = csr.accessor<index_t, 1>();
|
||||||
auto col_indices_accessor = col_indices.accessor<index_t, 1>();
|
auto col_indices_accessor = col_indices.accessor<index_t, 1>();
|
||||||
|
|
@ -470,7 +470,7 @@ Tensor& addmm_out_sparse_csr_cpu(
|
||||||
"Please use PyTorch built with MKL on Linux.");
|
"Please use PyTorch built with MKL on Linux.");
|
||||||
}
|
}
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.layout() == kStrided);
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.layout() == kStrided);
|
||||||
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "addmm_sparse_dense", [&] {
|
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "addmm_sparse_dense", [&] {
|
||||||
addmm_out_sparse_csr_native_cpu<scalar_t>(mat1, mat2, result, alpha, beta);
|
addmm_out_sparse_csr_native_cpu<scalar_t>(mat1, mat2, result, alpha, beta);
|
||||||
});
|
});
|
||||||
#else
|
#else
|
||||||
|
|
|
||||||
|
|
@ -826,7 +826,6 @@ class TestSparseCSR(TestCase):
|
||||||
test_shape(7, 8, 9, 20, False, index_dtype)
|
test_shape(7, 8, 9, 20, False, index_dtype)
|
||||||
test_shape(7, 8, 9, 20, True, index_dtype)
|
test_shape(7, 8, 9, 20, True, index_dtype)
|
||||||
|
|
||||||
@skipCPUIfNoMklSparse
|
|
||||||
@dtypes(*floating_and_complex_types())
|
@dtypes(*floating_and_complex_types())
|
||||||
@dtypesIfCUDA(*get_all_complex_dtypes(),
|
@dtypesIfCUDA(*get_all_complex_dtypes(),
|
||||||
*get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC,
|
*get_all_fp_dtypes(include_half=SM53OrLater and TEST_CUSPARSE_GENERIC,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user