pytorch/torch/csrc/distributed
Tristan Rice ddd0ed1b43 distributed: templated ring attention (#124215)
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR.

This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way.

Misc changes:

* Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test
* Adds compile support to the ring attention implementations (required some tweaks to process groups)

Test plan:

```
pytest test/distributed/_tensor/test_attention.py
pytest test/distributed/test_functional_api.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124215
Approved by: https://github.com/wanchaol
2024-04-19 00:57:08 +00:00
..
autograd Fix ouput typos (#120870) 2024-02-29 08:29:14 +00:00
c10d distributed: templated ring attention (#124215) 2024-04-19 00:57:08 +00:00
rpc Remove unnecessary const_casts (#121225) 2024-03-05 17:34:24 +00:00