mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - [ ] Only supports power of two sequence lengths. - [ ] No support for varlen APIs. - [ ] Only support head dimension 16,32,64,128. - [ ] Performance is still being optimized. Fixes https://github.com/pytorch/pytorch/issues/112997 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114309 Approved by: https://github.com/jeffdaily, https://github.com/malfet --------- Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com> |
||
|---|---|---|
| .. | ||
| _internal | ||
| __init__.py | ||
| _comparison.py | ||
| _creation.py | ||