mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add BFloat16 sparse operators on CPU: sparse_mask, add_out, addmm
This commit is contained in:
parent
4267e6e55e
commit
55428dfb0b
|
|
@ -770,7 +770,7 @@ SparseTensor& sparse_mask_out_cpu(
|
||||||
// TODO: Re-audit this; it used to be an indexSelect directly into r_values
|
// TODO: Re-audit this; it used to be an indexSelect directly into r_values
|
||||||
at::index_select_out(r_values, t_view, 0, indices);
|
at::index_select_out(r_values, t_view, 0, indices);
|
||||||
} else {
|
} else {
|
||||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(r_values.scalar_type(), "sparse_mask", [&] {
|
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, r_values.scalar_type(), "sparse_mask", [&] {
|
||||||
sparse_mask_out_cpu_kernel<scalar_t>(
|
sparse_mask_out_cpu_kernel<scalar_t>(
|
||||||
r_values, t, r_nnz, sparse_dim, mask_indices);
|
r_values, t, r_nnz, sparse_dim, mask_indices);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -474,7 +474,7 @@ SparseTensor& add_out_sparse_contiguous(SparseTensor& r, const SparseTensor& t,
|
||||||
auto r_indices_accessor = r_indices.accessor<int64_t, 2>();
|
auto r_indices_accessor = r_indices.accessor<int64_t, 2>();
|
||||||
auto src_indices_accessor = src_indices.accessor<int64_t, 2>();
|
auto src_indices_accessor = src_indices.accessor<int64_t, 2>();
|
||||||
|
|
||||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
|
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
|
||||||
commonDtype, "cadd_sparse", [&] {
|
commonDtype, "cadd_sparse", [&] {
|
||||||
scalar_t* t_values_ptr = t_values.data_ptr<scalar_t>();
|
scalar_t* t_values_ptr = t_values.data_ptr<scalar_t>();
|
||||||
scalar_t* s_values_ptr = s_values.data_ptr<scalar_t>();
|
scalar_t* s_values_ptr = s_values.data_ptr<scalar_t>();
|
||||||
|
|
@ -899,7 +899,7 @@ Tensor& s_addmm_out_sparse_dense_cpu(
|
||||||
Tensor indices = sparse_._indices();
|
Tensor indices = sparse_._indices();
|
||||||
Tensor values = sparse_._values();
|
Tensor values = sparse_._values();
|
||||||
|
|
||||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
|
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
|
||||||
values.scalar_type(), "addmm_sparse_dense", [&] {
|
values.scalar_type(), "addmm_sparse_dense", [&] {
|
||||||
s_addmm_out_sparse_dense_worker<scalar_t>(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, indices, values, dense);
|
s_addmm_out_sparse_dense_worker<scalar_t>(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, indices, values, dense);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from torch.testing import get_all_complex_dtypes, get_all_fp_dtypes
|
||||||
from torch.testing._internal.common_cuda import \
|
from torch.testing._internal.common_cuda import \
|
||||||
(SM53OrLater, SM80OrLater, CUDA11OrLater)
|
(SM53OrLater, SM80OrLater, CUDA11OrLater)
|
||||||
from torch.testing._internal.common_device_type import \
|
from torch.testing._internal.common_device_type import \
|
||||||
(instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
|
(instantiate_device_type_tests, ops, dtypes, dtypesIfCPU, dtypesIfCUDA, onlyCPU, onlyCUDA, precisionOverride,
|
||||||
deviceCountAtLeast, OpDTypes)
|
deviceCountAtLeast, OpDTypes)
|
||||||
from torch.testing._internal.common_methods_invocations import \
|
from torch.testing._internal.common_methods_invocations import \
|
||||||
(sparse_unary_ufuncs)
|
(sparse_unary_ufuncs)
|
||||||
|
|
@ -189,6 +189,7 @@ class TestSparse(TestCase):
|
||||||
|
|
||||||
@coalescedonoff
|
@coalescedonoff
|
||||||
@dtypes(torch.double, torch.cdouble)
|
@dtypes(torch.double, torch.cdouble)
|
||||||
|
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
|
||||||
def test_coalesce(self, device, dtype, coalesced):
|
def test_coalesce(self, device, dtype, coalesced):
|
||||||
|
|
||||||
def _test_coalesce(t):
|
def _test_coalesce(t):
|
||||||
|
|
@ -663,6 +664,7 @@ class TestSparse(TestCase):
|
||||||
|
|
||||||
@coalescedonoff
|
@coalescedonoff
|
||||||
@dtypes(torch.double, torch.cdouble)
|
@dtypes(torch.double, torch.cdouble)
|
||||||
|
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
|
||||||
def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced):
|
def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced):
|
||||||
# This is for testing torch.copy_(SparseTensor, SparseTensor)
|
# This is for testing torch.copy_(SparseTensor, SparseTensor)
|
||||||
sparse_dims = 3
|
sparse_dims = 3
|
||||||
|
|
@ -1240,6 +1242,8 @@ class TestSparse(TestCase):
|
||||||
|
|
||||||
@coalescedonoff
|
@coalescedonoff
|
||||||
@dtypes(torch.double, torch.cdouble)
|
@dtypes(torch.double, torch.cdouble)
|
||||||
|
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
|
||||||
|
@precisionOverride({torch.bfloat16: 2e-1})
|
||||||
def test_sparse_addmm(self, device, dtype, coalesced):
|
def test_sparse_addmm(self, device, dtype, coalesced):
|
||||||
def test_shape(m, n, p, nnz, broadcast, alpha_beta=None):
|
def test_shape(m, n, p, nnz, broadcast, alpha_beta=None):
|
||||||
if alpha_beta is None:
|
if alpha_beta is None:
|
||||||
|
|
@ -1261,7 +1265,8 @@ class TestSparse(TestCase):
|
||||||
|
|
||||||
def fn(S, D1, D2, beta=beta, alpha=alpha):
|
def fn(S, D1, D2, beta=beta, alpha=alpha):
|
||||||
return torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
|
return torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
|
||||||
gradcheck(fn, (S, D1, D2), check_sparse_nnz=True)
|
if dtype == torch.double or dtype == torch.cdouble:
|
||||||
|
gradcheck(fn, (S, D1, D2), check_sparse_nnz=True)
|
||||||
|
|
||||||
test_shape(7, 8, 9, 20, False, None)
|
test_shape(7, 8, 9, 20, False, None)
|
||||||
test_shape(7, 8, 9, 20, True, None)
|
test_shape(7, 8, 9, 20, True, None)
|
||||||
|
|
@ -1401,15 +1406,17 @@ class TestSparse(TestCase):
|
||||||
_test_spadd()
|
_test_spadd()
|
||||||
_test_spadd_hybrid()
|
_test_spadd_hybrid()
|
||||||
|
|
||||||
@onlyCUDA
|
|
||||||
@coalescedonoff
|
@coalescedonoff
|
||||||
@dtypes(torch.double, torch.cdouble)
|
@dtypes(torch.double, torch.cdouble)
|
||||||
def test_sparse_add_out_bfloat16(self, device, dtype, coalesced):
|
def test_sparse_add_out_bfloat16(self, device, dtype, coalesced):
|
||||||
# fp32
|
# fp32
|
||||||
x, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
|
x, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
|
||||||
y, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
|
y, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
|
||||||
x = x.float().cuda()
|
x = x.float()
|
||||||
y = y.float().cuda()
|
y = y.float()
|
||||||
|
if device == 'cuda':
|
||||||
|
x = x.cuda()
|
||||||
|
y = y.cuda()
|
||||||
res_fp32 = torch.add(x, y)
|
res_fp32 = torch.add(x, y)
|
||||||
|
|
||||||
# bfloat16
|
# bfloat16
|
||||||
|
|
@ -1628,6 +1635,7 @@ class TestSparse(TestCase):
|
||||||
_test_basic_ops_hybrid()
|
_test_basic_ops_hybrid()
|
||||||
|
|
||||||
@dtypes(torch.double, torch.cdouble)
|
@dtypes(torch.double, torch.cdouble)
|
||||||
|
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
|
||||||
def test_add_dense_sparse_mismatch(self, device, dtype):
|
def test_add_dense_sparse_mismatch(self, device, dtype):
|
||||||
def test_shape(dense_size, sparse_dims_shape, dense_dims_shape, sparse_size):
|
def test_shape(dense_size, sparse_dims_shape, dense_dims_shape, sparse_size):
|
||||||
x = torch.zeros(dense_size, dtype=dtype, device=device)
|
x = torch.zeros(dense_size, dtype=dtype, device=device)
|
||||||
|
|
@ -1666,6 +1674,7 @@ class TestSparse(TestCase):
|
||||||
|
|
||||||
@coalescedonoff
|
@coalescedonoff
|
||||||
@dtypes(torch.double, torch.cdouble)
|
@dtypes(torch.double, torch.cdouble)
|
||||||
|
@dtypesIfCPU(torch.double, torch.cdouble, torch.bfloat16)
|
||||||
def test_sparse_mask(self, device, dtype, coalesced):
|
def test_sparse_mask(self, device, dtype, coalesced):
|
||||||
def _test_sparse_mask_fixed():
|
def _test_sparse_mask_fixed():
|
||||||
i = self.index_tensor([
|
i = self.index_tensor([
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user