mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
add BFloat16 sparse operators on CPU: sparse_mask, add_out, addmm
This commit is contained in:
parent
4267e6e55e
commit
55428dfb0b
|
|
@ -770,7 +770,7 @@ SparseTensor& sparse_mask_out_cpu(
|
|||
// TODO: Re-audit this; it used to be an indexSelect directly into r_values
|
||||
at::index_select_out(r_values, t_view, 0, indices);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(r_values.scalar_type(), "sparse_mask", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, r_values.scalar_type(), "sparse_mask", [&] {
|
||||
sparse_mask_out_cpu_kernel<scalar_t>(
|
||||
r_values, t, r_nnz, sparse_dim, mask_indices);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -474,7 +474,7 @@ SparseTensor& add_out_sparse_contiguous(SparseTensor& r, const SparseTensor& t,
|
|||
auto r_indices_accessor = r_indices.accessor<int64_t, 2>();
|
||||
auto src_indices_accessor = src_indices.accessor<int64_t, 2>();
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
|
||||
commonDtype, "cadd_sparse", [&] {
|
||||
scalar_t* t_values_ptr = t_values.data_ptr<scalar_t>();
|
||||
scalar_t* s_values_ptr = s_values.data_ptr<scalar_t>();
|
||||
|
|
@ -899,7 +899,7 @@ Tensor& s_addmm_out_sparse_dense_cpu(
|
|||
Tensor indices = sparse_._indices();
|
||||
Tensor values = sparse_._values();
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
|
||||
values.scalar_type(), "addmm_sparse_dense", [&] {
|
||||
s_addmm_out_sparse_dense_worker<scalar_t>(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, indices, values, dense);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes
|
|||
from torch.testing._internal.common_cuda import \
|
||||
(SM53OrLater, SM80OrLater, CUDA11OrLater)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
|
||||
(instantiate_device_type_tests, ops, dtypes, dtypesIfCPU, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
|
||||
deviceCountAtLeast, OpDTypes)
|
||||
from torch.testing._internal.common_methods_invocations import \
|
||||
(sparse_unary_ufuncs)
|
||||
|
|
@ -189,6 +189,7 @@ class TestSparse(TestCase):
|
|||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
|
||||
def test_coalesce(self, device, dtype, coalesced):
|
||||
|
||||
def _test_coalesce(t):
|
||||
|
|
@ -663,6 +664,7 @@ class TestSparse(TestCase):
|
|||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
|
||||
def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced):
|
||||
# This is for testing torch.copy_(SparseTensor, SparseTensor)
|
||||
sparse_dims = 3
|
||||
|
|
@ -1240,6 +1242,8 @@ class TestSparse(TestCase):
|
|||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
|
||||
@precisionOverride({torch.bfloat16: 2e-1})
|
||||
def test_sparse_addmm(self, device, dtype, coalesced):
|
||||
def test_shape(m, n, p, nnz, broadcast, alpha_beta=None):
|
||||
if alpha_beta is None:
|
||||
|
|
@ -1261,7 +1265,8 @@ class TestSparse(TestCase):
|
|||
|
||||
def fn(S, D1, D2, beta=beta, alpha=alpha):
|
||||
return torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
|
||||
gradcheck(fn, (S, D1, D2), check_sparse_nnz=True)
|
||||
if dtype == torch.double or dtype == torch.cdouble:
|
||||
gradcheck(fn, (S, D1, D2), check_sparse_nnz=True)
|
||||
|
||||
test_shape(7, 8, 9, 20, False, None)
|
||||
test_shape(7, 8, 9, 20, True, None)
|
||||
|
|
@ -1401,15 +1406,17 @@ class TestSparse(TestCase):
|
|||
_test_spadd()
|
||||
_test_spadd_hybrid()
|
||||
|
||||
@onlyCUDA
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
def test_sparse_add_out_bfloat16(self, device, dtype, coalesced):
|
||||
# fp32
|
||||
x, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
|
||||
y, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
|
||||
x = x.float().cuda()
|
||||
y = y.float().cuda()
|
||||
x = x.float()
|
||||
y = y.float()
|
||||
if device == 'cuda':
|
||||
x = x.cuda()
|
||||
y = y.cuda()
|
||||
res_fp32 = torch.add(x, y)
|
||||
|
||||
# bfloat16
|
||||
|
|
@ -1628,6 +1635,7 @@ class TestSparse(TestCase):
|
|||
_test_basic_ops_hybrid()
|
||||
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
|
||||
def test_add_dense_sparse_mismatch(self, device, dtype):
|
||||
def test_shape(dense_size, sparse_dims_shape, dense_dims_shape, sparse_size):
|
||||
x = torch.zeros(dense_size, dtype=dtype, device=device)
|
||||
|
|
@ -1666,6 +1674,7 @@ class TestSparse(TestCase):
|
|||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
|
||||
def test_sparse_mask(self, device, dtype, coalesced):
|
||||
def _test_sparse_mask_fixed():
|
||||
i = self.index_tensor([
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user