Support sparse.sum on empty sparse tensor (#71091)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71091

Fixes https://github.com/pytorch/pytorch/issues/65394

The masked sum on a full input tensor (of any layout) with an all-true mask is the same as the sum on the strided input tensor (after applying `to_dense` to sparse inputs).
Since masked sum uses `torch.sparse.sum` then, for the simplicity of masked reductions implementations, its reduction behavior ought to be defined by the behavior of the `torch.sum`. This PR implements the behavioral connection with respect to the directional summation of empty sparse tensors that correspond to all-zero strided tensors.

cc nikitaved pearu cpuhrsch

Test Plan: Imported from OSS

Reviewed By: davidberard98

Differential Revision: D33651750

Pulled By: cpuhrsch

fbshipit-source-id: 703891bff88c8da6270b4272f5d2da81688db67d
(cherry picked from commit 53f97e80f7)
This commit is contained in:
Pearu Peterson 2022-01-19 10:52:42 -08:00 committed by PyTorch MergeBot
parent 3b589c3497
commit 214f4bf2ff
2 changed files with 1 additions and 4 deletions

View File

@ -1178,8 +1178,6 @@ Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum, ScalarTyp
}
Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) {
TORCH_CHECK(input._nnz() > 0, "_sparse_sum: sparse tensor input._nnz() == 0, please call torch.sparse.sum(input) instead.")
const int64_t input_dim = input.dim();
auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim);
auto dims_to_sum_v = dims_to_sum.vec();
@ -1189,7 +1187,6 @@ Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) {
Tensor values = input._values();
IntArrayRef sizes = input.sizes();
const int64_t sparse_dim = input.sparse_dim();
// const int64_t dense_dim = input.dense_dim();
auto dims_to_keep_v = std::vector<int64_t>();
auto dense_dims_to_sum_v = std::vector<int64_t>();

View File

@ -1516,7 +1516,7 @@ class TestSparse(TestCase):
# sum an empty tensor
empty_S = torch.sparse_coo_tensor(size=with_size, dtype=dtype, device=device)
self.assertRaises(RuntimeError, lambda: torch.sparse.sum(empty_S, [0]))
self.assertEqual(torch.sparse.sum(empty_S, [0]).to_dense(), torch.sum(empty_S.to_dense(), [0]))
self.assertEqual(torch.sparse.sum(empty_S), torch.tensor(0, dtype=dtype, device=device))
empty_S.requires_grad_(True)
empty_S_sum = torch.sparse.sum(empty_S)