mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support sparse.sum on empty sparse tensor (#71091)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71091
Fixes https://github.com/pytorch/pytorch/issues/65394
The masked sum on a full input tensor (of any layout) with an all-true mask is the same as the sum on the strided input tensor (after applying `to_dense` to sparse inputs).
Since masked sum uses `torch.sparse.sum` then, for the simplicity of masked reductions implementations, its reduction behavior ought to be defined by the behavior of the `torch.sum`. This PR implements the behavioral connection with respect to the directional summation of empty sparse tensors that correspond to all-zero strided tensors.
cc nikitaved pearu cpuhrsch
Test Plan: Imported from OSS
Reviewed By: davidberard98
Differential Revision: D33651750
Pulled By: cpuhrsch
fbshipit-source-id: 703891bff88c8da6270b4272f5d2da81688db67d
(cherry picked from commit 53f97e80f7)
This commit is contained in:
parent
3b589c3497
commit
214f4bf2ff
|
|
@ -1178,8 +1178,6 @@ Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum, ScalarTyp
|
|||
}
|
||||
|
||||
Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) {
|
||||
TORCH_CHECK(input._nnz() > 0, "_sparse_sum: sparse tensor input._nnz() == 0, please call torch.sparse.sum(input) instead.")
|
||||
|
||||
const int64_t input_dim = input.dim();
|
||||
auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim);
|
||||
auto dims_to_sum_v = dims_to_sum.vec();
|
||||
|
|
@ -1189,7 +1187,6 @@ Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) {
|
|||
Tensor values = input._values();
|
||||
IntArrayRef sizes = input.sizes();
|
||||
const int64_t sparse_dim = input.sparse_dim();
|
||||
// const int64_t dense_dim = input.dense_dim();
|
||||
|
||||
auto dims_to_keep_v = std::vector<int64_t>();
|
||||
auto dense_dims_to_sum_v = std::vector<int64_t>();
|
||||
|
|
|
|||
|
|
@ -1516,7 +1516,7 @@ class TestSparse(TestCase):
|
|||
|
||||
# sum an empty tensor
|
||||
empty_S = torch.sparse_coo_tensor(size=with_size, dtype=dtype, device=device)
|
||||
self.assertRaises(RuntimeError, lambda: torch.sparse.sum(empty_S, [0]))
|
||||
self.assertEqual(torch.sparse.sum(empty_S, [0]).to_dense(), torch.sum(empty_S.to_dense(), [0]))
|
||||
self.assertEqual(torch.sparse.sum(empty_S), torch.tensor(0, dtype=dtype, device=device))
|
||||
empty_S.requires_grad_(True)
|
||||
empty_S_sum = torch.sparse.sum(empty_S)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user