pytorch/docs/source/sparse.rst
Wei Yang be7c618fd7 torch.sparse.sum() (#12430)
Summary:
- to fix #12241
- add `_sparse_sum()` to ATen, and expose as `torch.sparse.sum()`, not support `SparseTensor.sum()` currently
- this PR depends on #11253, and will need to be updated upon it lands
- [x] implement forward
- [x] implement backward
- performance [benchmark script](https://gist.github.com/weiyangfb/f4c55c88b6092ef8f7e348f6b9ad8946#file-sparse_sum_benchmark-py):
  - sum all dims is fastest for sparse tensor
  - when input is sparse enough nnz = 0.1%, sum of sparse tensor is faster than dense in CPU, but not necessary in CUDA
  - CUDA backward is comparable (<2x) between `sum several dims` vs `sum all dims` in sparse
  - CPU backward uses binary search is still slow in sparse, takes `5x` time in `sum [0, 2, 3] dims` vs `sum all dims`
    - optimize CUDA backward for now
      - using thrust for sort and binary search, but runtime not improved
  - both of CPU and CUDA forward are slow in sparse (`sum several dims` vs `sum all dims`), at most `20x` slower in CPU, and `10x` in CUDA
    - improve CPU and CUDA forward kernels

(nnz, sizes, sum_dims, keepdim, sum all or dims, bk=backward) | CPU (sparse vs dense) | CUDA(sparse vs dense)
-- | -- | --
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumAll) | 8.77 µs vs 72.9 µs | 42.5 µs vs 108 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumD) | 112 µs vs 4.47 ms | 484 µs vs 407 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 141 µs vs 148 µs | 647 µs vs 231 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 235 µs vs 1.23 ms | 781 µs vs 213 µs
(1000,   [1000, 1000, 2, 2], [2, 3], False, sumD) | 48.5 µs vs 360 µs | 160 µs vs 2.03 ms
(1000,   [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 258 µs vs 1.22 ms | 798 µs vs 224 µs
(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 204 µs vs 882 µs | 443 µs vs 133 µs
(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 709 µs vs 1.15 ms | 893 µs vs 202 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumAll) | 39.8 µs vs 81 µs | 42.4 µs vs 113 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumD) | 747 µs vs 4.7 ms | 2.4 ms vs 414 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 1.04 ms vs 126 µs | 5.03 ms vs 231 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 1.12 ms vs 1.24 ms | 5.99 ms vs 213 µs
(10000,   [1000, 1000, 2, 2], [2, 3], False, sumD) | 133 µs vs 366 µs | 463 µs vs 2.03 ms
(10000,   [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 1.56 ms vs 1.22 ms | 6.11 ms vs 229 µs
(10000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 1.53 ms vs 799 µs | 824 µs vs 134 µs
(10000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 5.15 ms vs 1.09 ms | 7.02 ms vs 205 µs

- after improving CPU and CUDA forward kernels
  - in `(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD)` forward, CPU takes ~~`171 µs`~~, in which `130 µs` is spent on `coalesce()`, for CUDA, total time is ~~`331 µs`~~, in which `141 µs` is spent on `coalesce()`, we need to reduce time at other places outside `coalesce()`.
  - after a few simple tweaks, now in the forward, it is at most `10x` slower in CPU, and `7x` in CUDA. And time takes in `sum dense dims only [2, 3]` is `~2x` of `sum all dims`. Speed of `sum all sparse dims [0, 1]` is on bar with `sum all dims`

(nnz,   sizes, sum_dims, keepdim, sum all or dims, bk=backward) | CPU (sparse vs dense) | CUDA(sparse vs dense)
-- | -- | --
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumAll) | 7 µs vs 69.5 µs | 31.5 µs vs 61.6 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumD) | 11.3 µs vs 4.72 ms | 35.2 µs vs 285 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 197 µs vs 124 µs | 857 µs vs 134 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 124 µs vs 833 µs | 796 µs vs 106 µs
(1000,   [1000, 1000, 2, 2], [2, 3], False, sumD) | 20.5 µs vs 213 µs | 39.4 µs vs 1.24 ms
(1000,   [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 131 µs vs 830 µs | 881 µs vs 132 µs
(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 95.8 µs vs 409 µs | 246 µs vs 87.2 µs
(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 624 µs vs 820 µs | 953 µs vs 124 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumAll) | 45.3 µs vs 72.9 µs | 33.9 µs vs 57.2 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumD) | 81.4 µs vs 4.49 ms | 39.7 µs vs 280 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 984 µs vs 111 µs | 6.41 ms vs 121 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 1.45 ms vs 828 µs | 6.77 ms vs 113 µs
(10000,   [1000, 1000, 2, 2], [2, 3], False, sumD) | 74.9 µs vs 209 µs | 37.7 µs vs 1.23 ms
(10000,   [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 1.48 ms vs 845 µs | 6.96 ms vs 132 µs
(10000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 1.14 ms vs 411 µs | 252 µs vs 87.8 µs
(10000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 4.53 ms vs 851 µs | 7.12 ms vs 128 µs

- time takes in CUDA backward of sparse is super long with large variance (in case of nnz=10000, it normally takes 6-7ms). To improve backward of sparse ops, we will need to debug at places other than CUDA kernels. here is a benchmark of `torch.copy_()`:
```
>>> d = [1000, 1000, 2, 2]
>>> nnz = 10000
>>> I = torch.cat([torch.randint(0, d[0], size=(nnz,)),
               torch.randint(0, d[1], size=(nnz,))], 0).reshape(2, nnz)
>>> V = torch.randn(nnz, d[2], d[3])
>>> size = torch.Size(d)
>>> S = torch.sparse_coo_tensor(I, V, size).coalesce().cuda()
>>> S2 = torch.sparse_coo_tensor(I, V, size).coalesce().cuda().requires_grad_()
>>> data = S2.clone()
>>> S.copy_(S2)
>>> y = S * 2
>>> torch.cuda.synchronize()
>>> %timeit y.backward(data, retain_graph=True); torch.cuda.synchronize()
7.07 ms ± 3.06 ms per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12430

Differential Revision: D12878313

Pulled By: weiyangfb

fbshipit-source-id: e16dc7681ba41fdabf4838cf05e491ca9108c6fe
2018-11-28 02:19:12 -08:00

145 lines
5.0 KiB
ReStructuredText

.. currentmodule:: torch.sparse
.. _sparse-docs:
torch.sparse
============
.. warning::
This API is currently experimental and may change in the near future.
Torch supports sparse tensors in COO(rdinate) format, which can
efficiently store and process tensors for which the majority of elements
are zeros.
A sparse tensor is represented as a pair of dense tensors: a tensor
of values and a 2D tensor of indices. A sparse tensor can be constructed
by providing these two tensors, as well as the size of the sparse tensor
(which cannot be inferred from these tensors!) Suppose we want to define
a sparse tensor with the entry 3 at location (0, 2), entry 4 at
location (1, 0), and entry 5 at location (1, 2). We would then write:
>>> i = torch.LongTensor([[0, 1, 1],
[2, 0, 2]])
>>> v = torch.FloatTensor([3, 4, 5])
>>> torch.sparse.FloatTensor(i, v, torch.Size([2,3])).to_dense()
0 0 3
4 0 5
[torch.FloatTensor of size 2x3]
Note that the input to LongTensor is NOT a list of index tuples. If you want
to write your indices this way, you should transpose before passing them to
the sparse constructor:
>>> i = torch.LongTensor([[0, 2], [1, 0], [1, 2]])
>>> v = torch.FloatTensor([3, 4, 5 ])
>>> torch.sparse.FloatTensor(i.t(), v, torch.Size([2,3])).to_dense()
0 0 3
4 0 5
[torch.FloatTensor of size 2x3]
You can also construct hybrid sparse tensors, where only the first n
dimensions are sparse, and the rest of the dimensions are dense.
>>> i = torch.LongTensor([[2, 4]])
>>> v = torch.FloatTensor([[1, 3], [5, 7]])
>>> torch.sparse.FloatTensor(i, v).to_dense()
0 0
0 0
1 3
0 0
5 7
[torch.FloatTensor of size 5x2]
An empty sparse tensor can be constructed by specifying its size:
>>> torch.sparse.FloatTensor(2, 3)
SparseFloatTensor of size 2x3 with indices:
[torch.LongTensor with no dimension]
and values:
[torch.FloatTensor with no dimension]
SparseTensor has the following invariants:
1. sparse_dim + dense_dim = len(SparseTensor.shape)
2. SparseTensor._indices().shape = (sparse_dim, nnz)
3. SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:])
Since SparseTensor._indices() is always a 2D tensor, the smallest sparse_dim = 1.
Therefore, representation of a SparseTensor of sparse_dim = 0 is simply a dense tensor.
.. note::
Our sparse tensor format permits *uncoalesced* sparse tensors, where
there may be duplicate coordinates in the indices; in this case,
the interpretation is that the value at that index is the sum of all
duplicate value entries. Uncoalesced tensors permit us to implement
certain operators more efficiently.
For the most part, you shouldn't have to care whether or not a
sparse tensor is coalesced or not, as most operations will work
identically given a coalesced or uncoalesced sparse tensor.
However, there are two cases in which you may need to care.
First, if you repeatedly perform an operation that can produce
duplicate entries (e.g., :func:`torch.sparse.FloatTensor.add`), you
should occasionally coalesce your sparse tensors to prevent
them from growing too large.
Second, some operators will produce different values depending on
whether or not they are coalesced or not (e.g.,
:func:`torch.sparse.FloatTensor._values` and
:func:`torch.sparse.FloatTensor._indices`, as well as
:func:`torch.Tensor.sparse_mask`). These operators are
prefixed by an underscore to indicate that they reveal internal
implementation details and should be used with care, since code
that works with coalesced sparse tensors may not work with
uncoalesced sparse tensors; generally speaking, it is safest
to explicitly coalesce before working with these operators.
For example, suppose that we wanted to implement an operator
by operating directly on :func:`torch.sparse.FloatTensor._values`.
Multiplication by a scalar can be implemented in the obvious way,
as multiplication distributes over addition; however, square root
cannot be implemented directly, since ``sqrt(a + b) != sqrt(a) +
sqrt(b)`` (which is what would be computed if you were given an
uncoalesced tensor.)
.. class:: FloatTensor()
.. method:: add
.. method:: add_
.. method:: clone
.. method:: dim
.. method:: div
.. method:: div_
.. method:: get_device
.. method:: hspmm
.. method:: mm
.. method:: mul
.. method:: mul_
.. method:: narrow_copy
.. method:: resizeAs_
.. method:: size
.. method:: spadd
.. method:: spmm
.. method:: sspaddmm
.. method:: sspmm
.. method:: sub
.. method:: sub_
.. method:: t_
.. method:: toDense
.. method:: transpose
.. method:: transpose_
.. method:: zero_
.. method:: coalesce
.. method:: is_coalesced
.. method:: _indices
.. method:: _values
.. method:: _nnz
Functions
----------------------------------
.. autofunction:: torch.sparse.addmm
.. autofunction:: torch.sparse.sum