add BFloat16 sparse operators on CPU: sparse_mask, add_out, addmm

This commit is contained in:
Jiayi Sun 2022-01-18 16:53:31 +08:00
parent 4267e6e55e
commit 55428dfb0b
3 changed files with 17 additions and 8 deletions

View File

@ -770,7 +770,7 @@ SparseTensor& sparse_mask_out_cpu(
// TODO: Re-audit this; it used to be an indexSelect directly into r_values
at::index_select_out(r_values, t_view, 0, indices);
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(r_values.scalar_type(), "sparse_mask", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, r_values.scalar_type(), "sparse_mask", [&] {
sparse_mask_out_cpu_kernel<scalar_t>(
r_values, t, r_nnz, sparse_dim, mask_indices);
});

View File

@ -474,7 +474,7 @@ SparseTensor& add_out_sparse_contiguous(SparseTensor& r, const SparseTensor& t,
auto r_indices_accessor = r_indices.accessor<int64_t, 2>();
auto src_indices_accessor = src_indices.accessor<int64_t, 2>();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
commonDtype, "cadd_sparse", [&] {
scalar_t* t_values_ptr = t_values.data_ptr<scalar_t>();
scalar_t* s_values_ptr = s_values.data_ptr<scalar_t>();
@ -899,7 +899,7 @@ Tensor& s_addmm_out_sparse_dense_cpu(
Tensor indices = sparse_._indices();
Tensor values = sparse_._values();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
values.scalar_type(), "addmm_sparse_dense", [&] {
s_addmm_out_sparse_dense_worker<scalar_t>(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, indices, values, dense);
}

View File

@ -21,7 +21,7 @@ from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes
from torch.testing._internal.common_cuda import \
(SM53OrLater, SM80OrLater, CUDA11OrLater)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
(instantiate_device_type_tests, ops, dtypes, dtypesIfCPU, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
deviceCountAtLeast, OpDTypes)
from torch.testing._internal.common_methods_invocations import \
(sparse_unary_ufuncs)
@ -189,6 +189,7 @@ class TestSparse(TestCase):
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
def test_coalesce(self, device, dtype, coalesced):
def _test_coalesce(t):
@ -663,6 +664,7 @@ class TestSparse(TestCase):
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced):
# This is for testing torch.copy_(SparseTensor, SparseTensor)
sparse_dims = 3
@ -1240,6 +1242,8 @@ class TestSparse(TestCase):
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
@precisionOverride({torch.bfloat16: 2e-1})
def test_sparse_addmm(self, device, dtype, coalesced):
def test_shape(m, n, p, nnz, broadcast, alpha_beta=None):
if alpha_beta is None:
@ -1261,7 +1265,8 @@ class TestSparse(TestCase):
def fn(S, D1, D2, beta=beta, alpha=alpha):
return torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
gradcheck(fn, (S, D1, D2), check_sparse_nnz=True)
if dtype == torch.double or dtype == torch.cdouble:
gradcheck(fn, (S, D1, D2), check_sparse_nnz=True)
test_shape(7, 8, 9, 20, False, None)
test_shape(7, 8, 9, 20, True, None)
@ -1401,15 +1406,17 @@ class TestSparse(TestCase):
_test_spadd()
_test_spadd_hybrid()
@onlyCUDA
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
def test_sparse_add_out_bfloat16(self, device, dtype, coalesced):
# fp32
x, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
y, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
x = x.float().cuda()
y = y.float().cuda()
x = x.float()
y = y.float()
if device == 'cuda':
x = x.cuda()
y = y.cuda()
res_fp32 = torch.add(x, y)
# bfloat16
@ -1628,6 +1635,7 @@ class TestSparse(TestCase):
_test_basic_ops_hybrid()
@dtypes(torch.double, torch.cdouble)
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
def test_add_dense_sparse_mismatch(self, device, dtype):
def test_shape(dense_size, sparse_dims_shape, dense_dims_shape, sparse_size):
x = torch.zeros(dense_size, dtype=dtype, device=device)
@ -1666,6 +1674,7 @@ class TestSparse(TestCase):
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
def test_sparse_mask(self, device, dtype, coalesced):
def _test_sparse_mask_fixed():
i = self.index_tensor([