mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Conversion between SparseBsr and Strided (#78025)
Adds conversion between the strided and SparseBsr layout [Based on code by @bhosmer!](https://colab.research.google.com/drive/1NHWti04TU269dzbRjLfxGxVlzZWo1XLo?usp=sharing) Pull Request resolved: https://github.com/pytorch/pytorch/pull/78025 Approved by: https://github.com/pearu, https://github.com/jbschlosser
This commit is contained in:
parent
2679aa4789
commit
b9fb940dec
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -326,22 +327,25 @@ Tensor to_dense(const Tensor& tensor, c10::optional<c10::ScalarType> dtype) {
|
|||
if (tensor.layout() == c10::kSparse) {
|
||||
return tensor._to_dense(dtype);
|
||||
}
|
||||
if (tensor.layout() == c10::kSparseCsr || tensor.layout() == c10::kSparseCsc) {
|
||||
if (tensor.layout() == c10::kSparseCsr ||
|
||||
tensor.layout() == c10::kSparseCsc ||
|
||||
tensor.layout() == c10::kSparseBsr) {
|
||||
return tensor._to_dense(dtype);
|
||||
}
|
||||
if (tensor.layout() == c10::kMkldnn) {
|
||||
return tensor._to_dense(dtype);
|
||||
}
|
||||
TORCH_CHECK(tensor.layout() == c10::kStrided, "to_dense does not support layout ", tensor.layout());
|
||||
TORCH_CHECK(
|
||||
tensor.layout() == c10::kStrided,
|
||||
"to_dense does not support layout ",
|
||||
tensor.layout());
|
||||
if (dtype) {
|
||||
return tensor.to(*dtype);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
Tensor sparse_to_dense(
|
||||
const Tensor& self,
|
||||
c10::optional<ScalarType> dtype) {
|
||||
Tensor sparse_to_dense(const Tensor& self, c10::optional<ScalarType> dtype) {
|
||||
TORCH_CHECK(
|
||||
!dtype.has_value(), "dtype argument is not supported by sparse_to_dense");
|
||||
Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
|
||||
|
|
@ -352,11 +356,34 @@ Tensor sparse_compressed_to_dense(
|
|||
const Tensor& self,
|
||||
c10::optional<ScalarType> dtype) {
|
||||
TORCH_CHECK(
|
||||
!dtype.has_value(), "dtype argument is not supported by sparse_csr_to_dense");
|
||||
!dtype.has_value(),
|
||||
"dtype argument is not supported by sparse_csr_to_dense");
|
||||
if (self.layout() == kSparseCsr) {
|
||||
Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
|
||||
return dst.add_(self);
|
||||
}
|
||||
if (self.layout() == kSparseBsr) {
|
||||
TORCH_CHECK(self.dim() == 2, "Can only convert 2D SparseBsr to Strided.");
|
||||
Tensor indices = at::_convert_indices_from_csr_to_coo(
|
||||
self.crow_indices(), self.col_indices(), false, false);
|
||||
auto values = self.values();
|
||||
int64_t blocksize[2] = {values.size(-2), values.size(-1)};
|
||||
DimVector expanded_size(
|
||||
{self.size(0) / blocksize[0],
|
||||
self.size(1) / blocksize[1],
|
||||
blocksize[0],
|
||||
blocksize[1]});
|
||||
// We make use of COO dense dimensions here to use the COO to dense format
|
||||
// conversion.
|
||||
auto self_coo =
|
||||
at::native::_sparse_coo_tensor_unsafe(indices, values, expanded_size)
|
||||
.coalesce();
|
||||
auto dense = self_coo.to_dense();
|
||||
// Here we are untiling the result.
|
||||
dense = dense.transpose(1, 2);
|
||||
dense = dense.reshape({self.size(0), self.size(1)});
|
||||
return dense;
|
||||
}
|
||||
return self.to_sparse().to_dense();
|
||||
}
|
||||
|
||||
|
|
@ -489,10 +516,97 @@ Tensor dense_to_sparse_csc(const Tensor& self) {
|
|||
return self.to_sparse().to_sparse_csc();
|
||||
}
|
||||
|
||||
Tensor _tile_tensor(const Tensor& self, IntArrayRef blocksize) {
|
||||
// This code turns a matrix into a sequence of blocks
|
||||
//
|
||||
// Given matrix
|
||||
//
|
||||
// 1 2 3 4
|
||||
// 5 6 7 8
|
||||
// 9 10 11 12
|
||||
// 14 15 16 17
|
||||
//
|
||||
// _tile_tensor(matrix, {2, 2}) will yield the following 2 by 2 blocks
|
||||
//
|
||||
// 1 2 | 3 4 | 9 10 | 11 12
|
||||
// 5 6 | 7 8 | 14 15 | 16 17
|
||||
//
|
||||
// via a 4D Tensor of shape (2, 2, 2, 2)
|
||||
//
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[0] > 0);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[1] > 0);
|
||||
auto block_size_0 = self.size(0) / blocksize[0];
|
||||
auto block_size_1 = self.size(1) / blocksize[1];
|
||||
return self.reshape({block_size_0, blocksize[0], block_size_1, blocksize[1]})
|
||||
.transpose(1, 2)
|
||||
.contiguous();
|
||||
}
|
||||
|
||||
std::pair<Tensor, Tensor> _not_zero_mask_to_col_row_indices(
|
||||
Tensor not_zero_mask,
|
||||
ScalarType index_dtype,
|
||||
Device index_device) {
|
||||
auto col_indices =
|
||||
at::native::arange(not_zero_mask.size(-1), index_dtype, kStrided, index_device)
|
||||
.view({1, not_zero_mask.size(-1)})
|
||||
.expand_as(not_zero_mask)
|
||||
.masked_select(not_zero_mask);
|
||||
auto row_indices =
|
||||
at::native::arange(not_zero_mask.size(-2), index_dtype, kStrided, index_device)
|
||||
.view({not_zero_mask.size(-2), 1})
|
||||
.expand_as(not_zero_mask)
|
||||
.masked_select(not_zero_mask);
|
||||
return std::pair<Tensor, Tensor>(col_indices, row_indices);
|
||||
}
|
||||
|
||||
Tensor dense_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize) {
|
||||
AT_ERROR(
|
||||
"Conversion from ", self.layout(), " to SparseBsr is currently not supported.");
|
||||
return self;
|
||||
TORCH_CHECK(self.dim() == 2, "Can only covert 2D Tensor to BSR.");
|
||||
TORCH_CHECK(
|
||||
blocksize[0] > 0 && blocksize[1] > 0,
|
||||
"blocksize needs to be non zero, but got ",
|
||||
blocksize);
|
||||
TORCH_CHECK(
|
||||
self.size(0) % blocksize[0] == 0,
|
||||
"Tensor size(0) ",
|
||||
self.size(0),
|
||||
" needs to be divisible by blocksize[0] ",
|
||||
blocksize[0]);
|
||||
TORCH_CHECK(
|
||||
self.size(1) % blocksize[1] == 0,
|
||||
"Tensor size(1) ",
|
||||
self.size(1),
|
||||
" needs to be divisible by blocksize[1] ",
|
||||
blocksize[1]);
|
||||
auto block_size_0 = self.size(0) / blocksize[0];
|
||||
|
||||
auto values = _tile_tensor(self, blocksize);
|
||||
auto not_zero_mask = _tile_tensor((self != 0), blocksize);
|
||||
// Find tiles that have at least 1 non-zero value in them.
|
||||
not_zero_mask = not_zero_mask.any(-1).any(-1);
|
||||
Tensor col_indices;
|
||||
Tensor row_indices;
|
||||
std::tie(col_indices, row_indices) =
|
||||
_not_zero_mask_to_col_row_indices(not_zero_mask, at::kLong, not_zero_mask.device());
|
||||
Tensor crow_indices = at::_convert_indices_from_coo_to_csr(
|
||||
row_indices.view({-1}), block_size_0, false /* out_int32 */);
|
||||
values = values.reshape({-1, values.size(-2), values.size(-1)});
|
||||
not_zero_mask = not_zero_mask.reshape({-1});
|
||||
// TODO: masked_select does not support some form of broadcasting, so we're
|
||||
// using the mask to construct indices that are then passed into index_select.
|
||||
// This isn't ideal.
|
||||
values = values.index_select(
|
||||
0,
|
||||
at::native::arange(not_zero_mask.numel(), at::kLong, kStrided, not_zero_mask.device())
|
||||
.masked_select(not_zero_mask));
|
||||
|
||||
return at::native::_sparse_bsr_tensor_unsafe(
|
||||
crow_indices,
|
||||
col_indices,
|
||||
values,
|
||||
self.sizes(),
|
||||
values.scalar_type(),
|
||||
c10::kSparseBsr,
|
||||
values.device());
|
||||
}
|
||||
|
||||
Tensor dense_to_sparse_bsc(const Tensor& self, IntArrayRef blocksize) {
|
||||
|
|
|
|||
|
|
@ -2098,7 +2098,7 @@ class TestSparseCSR(TestCase):
|
|||
detached_inp = inp.detach()
|
||||
self.assertEqual(inp, detached_inp)
|
||||
|
||||
def _convert_to_layout(self, a, target_layout):
|
||||
def _convert_to_layout(self, a, target_layout, blocksize=(2, 2)):
|
||||
"""
|
||||
Helper function to call the correct layout conversion
|
||||
with reasonable defaults for the block size. Clearly there
|
||||
|
|
@ -2109,12 +2109,12 @@ class TestSparseCSR(TestCase):
|
|||
if target_layout is torch.sparse_csc:
|
||||
return a.to_sparse_csc()
|
||||
if target_layout is torch.sparse_bsr:
|
||||
return a.to_sparse_bsr((2, 2))
|
||||
return a.to_sparse_bsr(blocksize)
|
||||
if target_layout is torch.sparse_bsc:
|
||||
return a.to_sparse_bsc((2, 2))
|
||||
return a.to_sparse_bsc(blocksize)
|
||||
raise NotImplementedError(repr(a))
|
||||
|
||||
def _construct_sp_matrix(self, tensor, layout):
|
||||
def _construct_sp_matrix(self, tensor, layout, blocksize=(2, 2)):
|
||||
if tensor.layout in [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.strided]:
|
||||
tensor = tensor.to_dense()
|
||||
else:
|
||||
|
|
@ -2124,7 +2124,7 @@ class TestSparseCSR(TestCase):
|
|||
if layout is torch.sparse_csc:
|
||||
return sp.csc_matrix(tensor.cpu().numpy())
|
||||
if layout is torch.sparse_bsr:
|
||||
return sp.bsr_matrix(tensor.cpu().numpy())
|
||||
return sp.bsr_matrix(tensor.cpu().numpy(), blocksize=blocksize).sorted_indices()
|
||||
# No native scipy BSC support?
|
||||
raise NotImplementedError(repr(tensor))
|
||||
|
||||
|
|
@ -2173,25 +2173,20 @@ class TestSparseCSR(TestCase):
|
|||
if layout is torch.sparse_bsc:
|
||||
# TODO: Remove this once support has been enabled
|
||||
return
|
||||
if layout is torch.sparse_bsr:
|
||||
# TODO: Remove this once support has been enabled
|
||||
return
|
||||
|
||||
for shape in [(0, 10), (6, 0), (6, 10), (0, 0)]:
|
||||
shapes = [(6, 10), (0, 10), (6, 0), (0, 0)]
|
||||
|
||||
blocksizes = [(2, 2)]
|
||||
if layout is torch.sparse_bsr:
|
||||
blocksizes += [(3, 5), (6, 10)]
|
||||
|
||||
for shape, blocksize in itertools.product(shapes, blocksizes):
|
||||
dense = make_tensor(shape, dtype=torch.float, device=device)
|
||||
dense = dense.relu() # Introduce some sparsity
|
||||
sp_matrix = self._construct_sp_matrix(dense, layout)
|
||||
pt_matrix = self._convert_to_layout(dense, layout)
|
||||
sp_matrix = self._construct_sp_matrix(dense, layout, blocksize=blocksize)
|
||||
pt_matrix = self._convert_to_layout(dense, layout, blocksize=blocksize)
|
||||
|
||||
compressed_indices_mth = {
|
||||
torch.sparse_csr: torch.Tensor.crow_indices,
|
||||
torch.sparse_csc: torch.Tensor.ccol_indices,
|
||||
}[layout]
|
||||
|
||||
plain_indices_mth = {
|
||||
torch.sparse_csr: torch.Tensor.col_indices,
|
||||
torch.sparse_csc: torch.Tensor.row_indices,
|
||||
}[layout]
|
||||
compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
|
||||
|
||||
self.assertEqual(layout, pt_matrix.layout)
|
||||
self.assertEqual(sp_matrix.shape, pt_matrix.shape)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user