diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 05691d2998d..74588f8506e 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -6,6 +6,7 @@ #include #include +#include namespace at { namespace native { @@ -326,22 +327,25 @@ Tensor to_dense(const Tensor& tensor, c10::optional dtype) { if (tensor.layout() == c10::kSparse) { return tensor._to_dense(dtype); } - if (tensor.layout() == c10::kSparseCsr || tensor.layout() == c10::kSparseCsc) { + if (tensor.layout() == c10::kSparseCsr || + tensor.layout() == c10::kSparseCsc || + tensor.layout() == c10::kSparseBsr) { return tensor._to_dense(dtype); } if (tensor.layout() == c10::kMkldnn) { return tensor._to_dense(dtype); } - TORCH_CHECK(tensor.layout() == c10::kStrided, "to_dense does not support layout ", tensor.layout()); + TORCH_CHECK( + tensor.layout() == c10::kStrided, + "to_dense does not support layout ", + tensor.layout()); if (dtype) { return tensor.to(*dtype); } return tensor; } -Tensor sparse_to_dense( - const Tensor& self, - c10::optional dtype) { +Tensor sparse_to_dense(const Tensor& self, c10::optional dtype) { TORCH_CHECK( !dtype.has_value(), "dtype argument is not supported by sparse_to_dense"); Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided)); @@ -352,11 +356,34 @@ Tensor sparse_compressed_to_dense( const Tensor& self, c10::optional dtype) { TORCH_CHECK( - !dtype.has_value(), "dtype argument is not supported by sparse_csr_to_dense"); + !dtype.has_value(), + "dtype argument is not supported by sparse_csr_to_dense"); if (self.layout() == kSparseCsr) { Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided)); return dst.add_(self); } + if (self.layout() == kSparseBsr) { + TORCH_CHECK(self.dim() == 2, "Can only convert 2D SparseBsr to Strided."); + Tensor indices = at::_convert_indices_from_csr_to_coo( + self.crow_indices(), self.col_indices(), false, false); + auto values = self.values(); + int64_t blocksize[2] = {values.size(-2), values.size(-1)}; + DimVector expanded_size( + {self.size(0) / blocksize[0], + self.size(1) / blocksize[1], + blocksize[0], + blocksize[1]}); + // We make use of COO dense dimensions here to use the COO to dense format + // conversion. + auto self_coo = + at::native::_sparse_coo_tensor_unsafe(indices, values, expanded_size) + .coalesce(); + auto dense = self_coo.to_dense(); + // Here we are untiling the result. + dense = dense.transpose(1, 2); + dense = dense.reshape({self.size(0), self.size(1)}); + return dense; + } return self.to_sparse().to_dense(); } @@ -489,10 +516,97 @@ Tensor dense_to_sparse_csc(const Tensor& self) { return self.to_sparse().to_sparse_csc(); } +Tensor _tile_tensor(const Tensor& self, IntArrayRef blocksize) { + // This code turns a matrix into a sequence of blocks + // + // Given matrix + // + // 1 2 3 4 + // 5 6 7 8 + // 9 10 11 12 + // 14 15 16 17 + // + // _tile_tensor(matrix, {2, 2}) will yield the following 2 by 2 blocks + // + // 1 2 | 3 4 | 9 10 | 11 12 + // 5 6 | 7 8 | 14 15 | 16 17 + // + // via a 4D Tensor of shape (2, 2, 2, 2) + // + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[0] > 0); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[1] > 0); + auto block_size_0 = self.size(0) / blocksize[0]; + auto block_size_1 = self.size(1) / blocksize[1]; + return self.reshape({block_size_0, blocksize[0], block_size_1, blocksize[1]}) + .transpose(1, 2) + .contiguous(); +} + +std::pair _not_zero_mask_to_col_row_indices( + Tensor not_zero_mask, + ScalarType index_dtype, + Device index_device) { + auto col_indices = + at::native::arange(not_zero_mask.size(-1), index_dtype, kStrided, index_device) + .view({1, not_zero_mask.size(-1)}) + .expand_as(not_zero_mask) + .masked_select(not_zero_mask); + auto row_indices = + at::native::arange(not_zero_mask.size(-2), index_dtype, kStrided, index_device) + .view({not_zero_mask.size(-2), 1}) + .expand_as(not_zero_mask) + .masked_select(not_zero_mask); + return std::pair(col_indices, row_indices); +} + Tensor dense_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize) { - AT_ERROR( - "Conversion from ", self.layout(), " to SparseBsr is currently not supported."); - return self; + TORCH_CHECK(self.dim() == 2, "Can only covert 2D Tensor to BSR."); + TORCH_CHECK( + blocksize[0] > 0 && blocksize[1] > 0, + "blocksize needs to be non zero, but got ", + blocksize); + TORCH_CHECK( + self.size(0) % blocksize[0] == 0, + "Tensor size(0) ", + self.size(0), + " needs to be divisible by blocksize[0] ", + blocksize[0]); + TORCH_CHECK( + self.size(1) % blocksize[1] == 0, + "Tensor size(1) ", + self.size(1), + " needs to be divisible by blocksize[1] ", + blocksize[1]); + auto block_size_0 = self.size(0) / blocksize[0]; + + auto values = _tile_tensor(self, blocksize); + auto not_zero_mask = _tile_tensor((self != 0), blocksize); + // Find tiles that have at least 1 non-zero value in them. + not_zero_mask = not_zero_mask.any(-1).any(-1); + Tensor col_indices; + Tensor row_indices; + std::tie(col_indices, row_indices) = + _not_zero_mask_to_col_row_indices(not_zero_mask, at::kLong, not_zero_mask.device()); + Tensor crow_indices = at::_convert_indices_from_coo_to_csr( + row_indices.view({-1}), block_size_0, false /* out_int32 */); + values = values.reshape({-1, values.size(-2), values.size(-1)}); + not_zero_mask = not_zero_mask.reshape({-1}); + // TODO: masked_select does not support some form of broadcasting, so we're + // using the mask to construct indices that are then passed into index_select. + // This isn't ideal. + values = values.index_select( + 0, + at::native::arange(not_zero_mask.numel(), at::kLong, kStrided, not_zero_mask.device()) + .masked_select(not_zero_mask)); + + return at::native::_sparse_bsr_tensor_unsafe( + crow_indices, + col_indices, + values, + self.sizes(), + values.scalar_type(), + c10::kSparseBsr, + values.device()); } Tensor dense_to_sparse_bsc(const Tensor& self, IntArrayRef blocksize) { diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 3a2dddf398b..376564bdb0a 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -2098,7 +2098,7 @@ class TestSparseCSR(TestCase): detached_inp = inp.detach() self.assertEqual(inp, detached_inp) - def _convert_to_layout(self, a, target_layout): + def _convert_to_layout(self, a, target_layout, blocksize=(2, 2)): """ Helper function to call the correct layout conversion with reasonable defaults for the block size. Clearly there @@ -2109,12 +2109,12 @@ class TestSparseCSR(TestCase): if target_layout is torch.sparse_csc: return a.to_sparse_csc() if target_layout is torch.sparse_bsr: - return a.to_sparse_bsr((2, 2)) + return a.to_sparse_bsr(blocksize) if target_layout is torch.sparse_bsc: - return a.to_sparse_bsc((2, 2)) + return a.to_sparse_bsc(blocksize) raise NotImplementedError(repr(a)) - def _construct_sp_matrix(self, tensor, layout): + def _construct_sp_matrix(self, tensor, layout, blocksize=(2, 2)): if tensor.layout in [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.strided]: tensor = tensor.to_dense() else: @@ -2124,7 +2124,7 @@ class TestSparseCSR(TestCase): if layout is torch.sparse_csc: return sp.csc_matrix(tensor.cpu().numpy()) if layout is torch.sparse_bsr: - return sp.bsr_matrix(tensor.cpu().numpy()) + return sp.bsr_matrix(tensor.cpu().numpy(), blocksize=blocksize).sorted_indices() # No native scipy BSC support? raise NotImplementedError(repr(tensor)) @@ -2173,25 +2173,20 @@ class TestSparseCSR(TestCase): if layout is torch.sparse_bsc: # TODO: Remove this once support has been enabled return - if layout is torch.sparse_bsr: - # TODO: Remove this once support has been enabled - return - for shape in [(0, 10), (6, 0), (6, 10), (0, 0)]: + shapes = [(6, 10), (0, 10), (6, 0), (0, 0)] + + blocksizes = [(2, 2)] + if layout is torch.sparse_bsr: + blocksizes += [(3, 5), (6, 10)] + + for shape, blocksize in itertools.product(shapes, blocksizes): dense = make_tensor(shape, dtype=torch.float, device=device) dense = dense.relu() # Introduce some sparsity - sp_matrix = self._construct_sp_matrix(dense, layout) - pt_matrix = self._convert_to_layout(dense, layout) + sp_matrix = self._construct_sp_matrix(dense, layout, blocksize=blocksize) + pt_matrix = self._convert_to_layout(dense, layout, blocksize=blocksize) - compressed_indices_mth = { - torch.sparse_csr: torch.Tensor.crow_indices, - torch.sparse_csc: torch.Tensor.ccol_indices, - }[layout] - - plain_indices_mth = { - torch.sparse_csr: torch.Tensor.col_indices, - torch.sparse_csc: torch.Tensor.row_indices, - }[layout] + compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout] self.assertEqual(layout, pt_matrix.layout) self.assertEqual(sp_matrix.shape, pt_matrix.shape)