Conversion between SparseBsr and Strided (#78025)

Adds conversion between the strided and SparseBsr layout

[Based on code by @bhosmer!](https://colab.research.google.com/drive/1NHWti04TU269dzbRjLfxGxVlzZWo1XLo?usp=sharing)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78025
Approved by: https://github.com/pearu, https://github.com/jbschlosser
This commit is contained in:
Christian Puhrsch 2022-05-25 15:03:35 +00:00 committed by PyTorch MergeBot
parent 2679aa4789
commit b9fb940dec
2 changed files with 138 additions and 29 deletions

View File

@ -6,6 +6,7 @@
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/native/IndexingUtils.h>
namespace at {
namespace native {
@ -326,22 +327,25 @@ Tensor to_dense(const Tensor& tensor, c10::optional<c10::ScalarType> dtype) {
if (tensor.layout() == c10::kSparse) {
return tensor._to_dense(dtype);
}
if (tensor.layout() == c10::kSparseCsr || tensor.layout() == c10::kSparseCsc) {
if (tensor.layout() == c10::kSparseCsr ||
tensor.layout() == c10::kSparseCsc ||
tensor.layout() == c10::kSparseBsr) {
return tensor._to_dense(dtype);
}
if (tensor.layout() == c10::kMkldnn) {
return tensor._to_dense(dtype);
}
TORCH_CHECK(tensor.layout() == c10::kStrided, "to_dense does not support layout ", tensor.layout());
TORCH_CHECK(
tensor.layout() == c10::kStrided,
"to_dense does not support layout ",
tensor.layout());
if (dtype) {
return tensor.to(*dtype);
}
return tensor;
}
Tensor sparse_to_dense(
const Tensor& self,
c10::optional<ScalarType> dtype) {
Tensor sparse_to_dense(const Tensor& self, c10::optional<ScalarType> dtype) {
TORCH_CHECK(
!dtype.has_value(), "dtype argument is not supported by sparse_to_dense");
Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
@ -352,11 +356,34 @@ Tensor sparse_compressed_to_dense(
const Tensor& self,
c10::optional<ScalarType> dtype) {
TORCH_CHECK(
!dtype.has_value(), "dtype argument is not supported by sparse_csr_to_dense");
!dtype.has_value(),
"dtype argument is not supported by sparse_csr_to_dense");
if (self.layout() == kSparseCsr) {
Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
return dst.add_(self);
}
if (self.layout() == kSparseBsr) {
TORCH_CHECK(self.dim() == 2, "Can only convert 2D SparseBsr to Strided.");
Tensor indices = at::_convert_indices_from_csr_to_coo(
self.crow_indices(), self.col_indices(), false, false);
auto values = self.values();
int64_t blocksize[2] = {values.size(-2), values.size(-1)};
DimVector expanded_size(
{self.size(0) / blocksize[0],
self.size(1) / blocksize[1],
blocksize[0],
blocksize[1]});
// We make use of COO dense dimensions here to use the COO to dense format
// conversion.
auto self_coo =
at::native::_sparse_coo_tensor_unsafe(indices, values, expanded_size)
.coalesce();
auto dense = self_coo.to_dense();
// Here we are untiling the result.
dense = dense.transpose(1, 2);
dense = dense.reshape({self.size(0), self.size(1)});
return dense;
}
return self.to_sparse().to_dense();
}
@ -489,10 +516,97 @@ Tensor dense_to_sparse_csc(const Tensor& self) {
return self.to_sparse().to_sparse_csc();
}
Tensor _tile_tensor(const Tensor& self, IntArrayRef blocksize) {
// This code turns a matrix into a sequence of blocks
//
// Given matrix
//
// 1 2 3 4
// 5 6 7 8
// 9 10 11 12
// 14 15 16 17
//
// _tile_tensor(matrix, {2, 2}) will yield the following 2 by 2 blocks
//
// 1 2 | 3 4 | 9 10 | 11 12
// 5 6 | 7 8 | 14 15 | 16 17
//
// via a 4D Tensor of shape (2, 2, 2, 2)
//
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[0] > 0);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[1] > 0);
auto block_size_0 = self.size(0) / blocksize[0];
auto block_size_1 = self.size(1) / blocksize[1];
return self.reshape({block_size_0, blocksize[0], block_size_1, blocksize[1]})
.transpose(1, 2)
.contiguous();
}
std::pair<Tensor, Tensor> _not_zero_mask_to_col_row_indices(
Tensor not_zero_mask,
ScalarType index_dtype,
Device index_device) {
auto col_indices =
at::native::arange(not_zero_mask.size(-1), index_dtype, kStrided, index_device)
.view({1, not_zero_mask.size(-1)})
.expand_as(not_zero_mask)
.masked_select(not_zero_mask);
auto row_indices =
at::native::arange(not_zero_mask.size(-2), index_dtype, kStrided, index_device)
.view({not_zero_mask.size(-2), 1})
.expand_as(not_zero_mask)
.masked_select(not_zero_mask);
return std::pair<Tensor, Tensor>(col_indices, row_indices);
}
Tensor dense_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize) {
AT_ERROR(
"Conversion from ", self.layout(), " to SparseBsr is currently not supported.");
return self;
TORCH_CHECK(self.dim() == 2, "Can only covert 2D Tensor to BSR.");
TORCH_CHECK(
blocksize[0] > 0 && blocksize[1] > 0,
"blocksize needs to be non zero, but got ",
blocksize);
TORCH_CHECK(
self.size(0) % blocksize[0] == 0,
"Tensor size(0) ",
self.size(0),
" needs to be divisible by blocksize[0] ",
blocksize[0]);
TORCH_CHECK(
self.size(1) % blocksize[1] == 0,
"Tensor size(1) ",
self.size(1),
" needs to be divisible by blocksize[1] ",
blocksize[1]);
auto block_size_0 = self.size(0) / blocksize[0];
auto values = _tile_tensor(self, blocksize);
auto not_zero_mask = _tile_tensor((self != 0), blocksize);
// Find tiles that have at least 1 non-zero value in them.
not_zero_mask = not_zero_mask.any(-1).any(-1);
Tensor col_indices;
Tensor row_indices;
std::tie(col_indices, row_indices) =
_not_zero_mask_to_col_row_indices(not_zero_mask, at::kLong, not_zero_mask.device());
Tensor crow_indices = at::_convert_indices_from_coo_to_csr(
row_indices.view({-1}), block_size_0, false /* out_int32 */);
values = values.reshape({-1, values.size(-2), values.size(-1)});
not_zero_mask = not_zero_mask.reshape({-1});
// TODO: masked_select does not support some form of broadcasting, so we're
// using the mask to construct indices that are then passed into index_select.
// This isn't ideal.
values = values.index_select(
0,
at::native::arange(not_zero_mask.numel(), at::kLong, kStrided, not_zero_mask.device())
.masked_select(not_zero_mask));
return at::native::_sparse_bsr_tensor_unsafe(
crow_indices,
col_indices,
values,
self.sizes(),
values.scalar_type(),
c10::kSparseBsr,
values.device());
}
Tensor dense_to_sparse_bsc(const Tensor& self, IntArrayRef blocksize) {

View File

@ -2098,7 +2098,7 @@ class TestSparseCSR(TestCase):
detached_inp = inp.detach()
self.assertEqual(inp, detached_inp)
def _convert_to_layout(self, a, target_layout):
def _convert_to_layout(self, a, target_layout, blocksize=(2, 2)):
"""
Helper function to call the correct layout conversion
with reasonable defaults for the block size. Clearly there
@ -2109,12 +2109,12 @@ class TestSparseCSR(TestCase):
if target_layout is torch.sparse_csc:
return a.to_sparse_csc()
if target_layout is torch.sparse_bsr:
return a.to_sparse_bsr((2, 2))
return a.to_sparse_bsr(blocksize)
if target_layout is torch.sparse_bsc:
return a.to_sparse_bsc((2, 2))
return a.to_sparse_bsc(blocksize)
raise NotImplementedError(repr(a))
def _construct_sp_matrix(self, tensor, layout):
def _construct_sp_matrix(self, tensor, layout, blocksize=(2, 2)):
if tensor.layout in [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.strided]:
tensor = tensor.to_dense()
else:
@ -2124,7 +2124,7 @@ class TestSparseCSR(TestCase):
if layout is torch.sparse_csc:
return sp.csc_matrix(tensor.cpu().numpy())
if layout is torch.sparse_bsr:
return sp.bsr_matrix(tensor.cpu().numpy())
return sp.bsr_matrix(tensor.cpu().numpy(), blocksize=blocksize).sorted_indices()
# No native scipy BSC support?
raise NotImplementedError(repr(tensor))
@ -2173,25 +2173,20 @@ class TestSparseCSR(TestCase):
if layout is torch.sparse_bsc:
# TODO: Remove this once support has been enabled
return
if layout is torch.sparse_bsr:
# TODO: Remove this once support has been enabled
return
for shape in [(0, 10), (6, 0), (6, 10), (0, 0)]:
shapes = [(6, 10), (0, 10), (6, 0), (0, 0)]
blocksizes = [(2, 2)]
if layout is torch.sparse_bsr:
blocksizes += [(3, 5), (6, 10)]
for shape, blocksize in itertools.product(shapes, blocksizes):
dense = make_tensor(shape, dtype=torch.float, device=device)
dense = dense.relu() # Introduce some sparsity
sp_matrix = self._construct_sp_matrix(dense, layout)
pt_matrix = self._convert_to_layout(dense, layout)
sp_matrix = self._construct_sp_matrix(dense, layout, blocksize=blocksize)
pt_matrix = self._convert_to_layout(dense, layout, blocksize=blocksize)
compressed_indices_mth = {
torch.sparse_csr: torch.Tensor.crow_indices,
torch.sparse_csc: torch.Tensor.ccol_indices,
}[layout]
plain_indices_mth = {
torch.sparse_csr: torch.Tensor.col_indices,
torch.sparse_csc: torch.Tensor.row_indices,
}[layout]
compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
self.assertEqual(layout, pt_matrix.layout)
self.assertEqual(sp_matrix.shape, pt_matrix.shape)