Fixes#121965
This PR hopes to add support complex numbers in the scatter/gather related kernels. For brevity, I will only include `complex<float>` for now as `complex<double>`, for example, will be more complicated.
C++ unit tests are currently passing alongside tests in `test_scatter_gather_ops.py`. Python test suites also seem to be passing.
Please keep the following in mind:
1) I think this is my first time using Pytorch.
2) This is my first contribution to Pytorch.
Environment:
3080 & WSL 2. `nvcc` is at 12.4.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124809
Approved by: https://github.com/mikaylagawarecki
Fixes#121965
This PR hopes to add support complex numbers in the scatter/gather related kernels. For brevity, I will only include `complex<float>` for now as `complex<double>`, for example, will be more complicated.
C++ unit tests are currently passing alongside tests in `test_scatter_gather_ops.py`. Python test suites also seem to be passing.
Please keep the following in mind:
1) I think this is my first time using Pytorch.
2) This is my first contribution to Pytorch.
Environment:
3080 & WSL 2. `nvcc` is at 12.4.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124809
Approved by: https://github.com/eqy, https://github.com/mikaylagawarecki
using the existing deterministic implementation via `index_put` which has a deterministic implementation based on sorting indices.
With the `accumulate` arg in `index_put`, this can work for both scatter and scatter_reduce with sum/mean reduction mode.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98060
Approved by: https://github.com/mikaylagawarecki
On classic pyg user case for message passing, `gather` has `index` tensor in a broadcasted shape, e.g. with shape `5000, 128` and stride `[1, 0]`. That indicated gather is done on each row of the self tensor. The current implementation will try to parallel on the inner dimension which is bad performance for CPU and unable to be vectorized.
This PR addressed this use case and optimize in a similar manner to index_select, parallel on outer dimension of `index` and do vectorized copy on inner dimension.
Performance benchmarking on Xeon Icelake single socket on `GCN`: the `gather` reduced from `150.787ms` to `10.926ms`, after this optimization, `gather` will no longer be the major bottleneck for training of GNN models when `EdgeIndex` is in COO format.
for more details, please refer to https://github.com/pyg-team/pytorch_geometric/issues/4891#issuecomment-1288423705
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87586
Approved by: https://github.com/rusty1s, https://github.com/malfet
### Motivation of this PR
This PR is targeting at improving performance of `scatter_add` for GNN usage scenarios on PyG. Currently only CPU optimizations is covered.
`Message Passing` is the major step in GNN learning which means exchanging/aggregating info between nodes. And from the perf point of view, if the `EdgeIndex` is stored as [2, num_edges], `scatter_reduce` would be a major perf hotspot on current pytorch implementations.
To be more specific, in the process of message passing, `scatter_add` is used in a very similar way as `index_select`, except that the `self` tensor is written into while `index_select` is only reading. Therefore, the `index` tensor passed to `scatter_add` is an expanded tensor on dim0, which means all the rest of dims would end up with the same value.
### Algorithm
Current impl on scatter would do parallel on the inner dims for such case which would cause bad perf: non-contiguous memory access pattern and non-vectorized.
This PR did sorting on the `index` to solve the write conflicts if we directly parallel on dim0. The algorithm is equivalent to:
* convert memory format from `COO` to `CSR`
* do spmm reduce
### Perf improvement
The benchmark comes from https://github.com/pyg-team/pytorch_geometric/tree/master/examples, `python reddit.py` which runs model SAGE on dataset reddit.
CPU type: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
` aten::scatter_add_` has been reduced from **37.797s** to **5.989s**:
* breakdown before
```
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::scatter_add_ 49.00% 37.797s 49.00% 37.797s 41.445ms 912
aten::index_select 19.74% 15.223s 19.74% 15.227s 6.678ms 2280
aten::linear 0.01% 5.706ms 15.04% 11.602s 12.721ms 912
aten::addmm 6.62% 5.108s 7.92% 6.112s 13.403ms 456
aten::matmul 0.00% 2.339ms 7.10% 5.475s 12.006ms 456
```
* breakdown after
```
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
aten::index_select 32.41% 14.677s 32.42% 14.681s 6.439ms 2280
aten::linear 0.01% 6.665ms 26.43% 11.968s 13.123ms 912
aten::addmm 11.76% 5.328s 13.76% 6.232s 13.667ms 456
aten::scatter_add_ 13.22% 5.989s 13.22% 5.989s 6.566ms 912
aten::matmul 0.01% 2.303ms 12.63% 5.720s 12.543ms 456
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82703
Approved by: https://github.com/jgong5, https://github.com/ezyang
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74226
Update signature of `scatter_reduce_` to match `scatter_/scatter_add_`
`Tensor.scatter_reduce_(int64 dim, Tensor index, Tensor src, str reduce)`
- Add new reduction options in ScatterGatherKernel.cpp and update `scatter_reduce` to call into the cpu kernel for `scatter.reduce`
- `scatter_reduce` now has the same shape constraints as `scatter_` and `scatter_add_`
- Migrate `test/test_torch.py:test_scatter_reduce` to `test/test_scatter_gather_ops.py`
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D35222842
Pulled By: mikaylagawarecki
fbshipit-source-id: 84930add2ad30baf872c495251373313cb7428bd
(cherry picked from commit 1b45139482e22eb0dc8b6aec2a7b25a4b58e31df)