mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Note about the Updates: This PR: 1. skips more flash attention related UTs on MI200 2. Fix additional ATen compiling errors after hipification 3. Fix the author "root" of a specific commit 4. Includes the patch from Nikita in favor of block level static initialization. CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge. Original PR (https://github.com/pytorch/pytorch/pull/114309) Note: This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - Only supports power of two sequence lengths. - No support for varlen APIs. - Only support head dimension 16,32,64,128. - Performance is still being optimized. Fixes #112997 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115981 Approved by: https://github.com/malfet |
||
|---|---|---|
| .. | ||
| EigenBLAS.cmake | ||
| nccl.cmake | ||
| nnpack.cmake | ||
| oort.cmake | ||
| rccl.cmake | ||
| ucc.cmake | ||