Commit Graph

95 Commits

Author SHA1 Message Date
Aidyn-A
086e2c2399 [TEST][ATen][CUDA] Skip row-wise scaled matrix mmultiplication tests on sm_120+ (#152814)
The float8 row-wise scaled matmuls are not supported on Blackwell yet. This PR adds skips to those tests to decrease the noise on `sm_120+` machines.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152814
Approved by: https://github.com/eqy, https://github.com/Skylion007
2025-05-08 19:34:20 +00:00
Fuzzkatt
c6d3b8f861 add xfail for distributed tests on Jetson (#152224)
We are hitting distributed import failures on Jetson in test/export/test_export.py tests in NVIDIA internal testing with the recent additions of https://github.com/pytorch/pytorch/pull/146050 and https://github.com/pytorch/pytorch/pull/147417. Instead of simply skipping these tests for Jetson, we are introducing an xfailIfDistributedNotSupported to get better signaling for this kind of failure in the long run.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152224
Approved by: https://github.com/nWEIdia, https://github.com/eqy
2025-04-29 23:48:40 +00:00
Jagadish Krishnamoorthy
0d99b4e9e2 ROCm: Enable tf32 testing on test_nn (#148945)
Add tf32 support for ROCm tests.
test command: python test/test_nn.py -v

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148945
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-04-28 23:01:04 +00:00
Prachi Gupta
b8f4dc5a9f [ROCm] opportunistic fastatomics for ReduceAdd operations for MI300 GPUs (#146264)
In this approach, we are catching any lane within a wave that is doing fastatomics to the same destination address and computing the sum on the CU. This is leading to 3x improvement in scatter_add performance and 2x improvement in index_select.

scatter_add performance on MI300x:
dtype|Baseline (before optimizations)|opportunistic fastatomics
-------|----------------------------------|----------------------------------
f32|1.389425039|0.430447996
fp16|2.195472956|0.779729486
bf16|2.194051027|0.784599513

Using the following reproducer
```
import torch
import triton

def main():
    dtype = torch.float32
    dim = 1305301
    a = torch.rand(100, device="cuda", dtype=dtype)
    index = torch.randint(0, 100, (dim,), device="cuda")
    src = torch.rand(dim, device="cuda", dtype=dtype)

    print("=" * 20)
    print(
        triton.testing.do_bench(
            lambda: a.scatter_add(0, index, src),
            return_mode="median",
        )
    )
    print("=" * 20)

if __name__ == "__main__":
    main()
```

co-authored by: @amd-hhashemi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146264
Approved by: https://github.com/jeffdaily, https://github.com/mxz297

Co-authored-by: Hashem Hashemi <hashem.hashemi@amd.com>
2025-04-22 21:55:40 +00:00
Nichols A. Romero
a24a9c42fb [ROCm] Improve behavior of get_torch_rocm_version helper function on non-ROCm systems. (#151040)
Fixes #150041

Return a zero tuple when ROCm is _not_ supported, similar to what is done for the CUDA version of this function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151040
Approved by: https://github.com/jeffdaily
2025-04-14 22:50:07 +00:00
Wei Wang
f6e9e064a7 [CI][CUDA] xfail grouped gemm unit tests on blackwell (#150982)
On SM100OrLater, Expect failures like:

RuntimeError: torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0

To execute this test, run the following from the base repo dir:
    python test/test_matmul_cuda.py TestMatmulCudaCUDA.test_grouped_gemm_3d_2d_strided_False_a_row_major_True_b_row_major_False_cuda

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

`
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_False_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0005s] (Issue with numpy versi...) [  2%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_False_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [  4%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_False_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [  6%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_False_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [  8%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_True_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 10%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_True_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 12%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_True_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 14%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_True_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version ...) [ 16%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_False_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versi...) [ 18%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_False_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 20%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_False_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 22%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_False_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 25%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_True_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 27%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_True_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 29%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_True_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 31%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_True_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version ...) [ 33%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_False_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0002s] (Issue with numpy versi...) [ 35%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_False_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 37%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_False_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 39%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_False_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 41%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_True_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 43%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_True_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 45%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_True_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 47%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_True_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version ...) [ 50%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_False_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versi...) [ 52%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_False_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 54%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_False_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 56%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_False_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 58%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_True_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 60%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_True_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 62%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_True_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 64%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_True_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version ...) [ 66%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_2d_fast_accum_False_strided_False_cuda XFAIL [0.8166s]                                        [ 68%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_2d_fast_accum_False_strided_True_cuda XFAIL [0.0017s]                                         [ 70%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_2d_fast_accum_True_strided_False_cuda XFAIL [0.0012s]                                         [ 72%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_2d_fast_accum_True_strided_True_cuda XFAIL [0.0012s]                                          [ 75%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_3d_fast_accum_False_strided_False_cuda XFAIL [0.0033s]                                        [ 77%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_3d_fast_accum_False_strided_True_cuda XFAIL [0.0012s]                                         [ 79%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_3d_fast_accum_True_strided_False_cuda XFAIL [0.0015s]                                         [ 81%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_3d_fast_accum_True_strided_True_cuda XFAIL [0.0012s]                                          [ 83%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_2d_fast_accum_False_strided_False_cuda XFAIL [0.0012s]                                        [ 85%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_2d_fast_accum_False_strided_True_cuda XFAIL [0.0012s]                                         [ 87%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_2d_fast_accum_True_strided_False_cuda XFAIL [0.0011s]                                         [ 89%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_2d_fast_accum_True_strided_True_cuda XFAIL [0.0012s]                                          [ 91%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_3d_fast_accum_False_strided_False_cuda XFAIL [0.0014s]                                        [ 93%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_3d_fast_accum_False_strided_True_cuda XFAIL [0.0012s]                                         [ 95%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_3d_fast_accum_True_strided_False_cuda XFAIL [0.0011s]                                         [ 97%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_3d_fast_accum_True_strided_True_cuda XFAIL [0.0011s]                                          [100%]
`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150982
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-04-12 01:53:12 +00:00
Aaron Gokaslan
a0ac63cbd9 [BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149257
Approved by: https://github.com/jansel
2025-03-18 00:46:07 +00:00
PyTorch MergeBot
24cfeec2c7 Revert "[BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257)"
This reverts commit bfee141666.

Reverted https://github.com/pytorch/pytorch/pull/149257 on behalf of https://github.com/malfet due to Let's see if it helps restore compiler benchmark sanity, see 8bc7bd94a5/1 ([comment](https://github.com/pytorch/pytorch/pull/149257#issuecomment-2731133812))
2025-03-17 22:57:00 +00:00
Xinya Zhang
2a011ca904 [ROCm] testing: enable MEFF/FA unittests for gfx1100 (#148911)
Include gfx1100, and optionally enable gfx1201/gfx950 according to env var TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148911
Approved by: https://github.com/jeffdaily
2025-03-17 16:41:15 +00:00
Aaron Gokaslan
bfee141666 [BE]: Apply ruff PERF403 to use dict comprehensions more often (#149257)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149257
Approved by: https://github.com/jansel
2025-03-16 23:52:58 +00:00
drisspg
07b7b3ed4e torch._scaled_mm with MXFP8 (#147548)
# summary

Add blockwise MXFP8 support to `torch._scaled_mm` on CUDA capability 10.0 and higher devices.  If the scales for A and B are of dtype `torch.float8_e8m0fnu`, we dispatch to the blockwise kernel from cuBLAS.

This is a skeleton PR where we test basic functionality (numerics of various simple matrices, as well as one end to end quantization + gemm).

- Scales are flipped based on transpose_result
- Handles boundary conditions

Note that MXFP4 is not added in this PR - we can tackle that in a future PR.

This PR was created by taking https://github.com/pytorch/pytorch/pull/145562, switching e8m0 to in-core dtype, removing fp4 for now, and adding test cases.

# test plan

```
pytest test/test_matmul_cuda.py -k blockwise_mxfp8 -s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147548
Approved by: https://github.com/drisspg

Co-authored-by: drisspg <drisspguessous@gmail.com>
2025-02-27 02:44:39 +00:00
PyTorch MergeBot
a84db75e1b Revert "torch._scaled_mm with MXFP8 (#147548)"
This reverts commit 12b9674cb6.

Reverted https://github.com/pytorch/pytorch/pull/147548 on behalf of https://github.com/wdvr due to failing internal build - similar to previous, see below ([comment](https://github.com/pytorch/pytorch/pull/147548#issuecomment-2684134336))
2025-02-26 07:17:24 +00:00
vasiliy
12b9674cb6 torch._scaled_mm with MXFP8 (#147548)
# summary

Add blockwise MXFP8 support to `torch._scaled_mm` on CUDA capability 10.0 and higher devices.  If the scales for A and B are of dtype `torch.float8_e8m0fnu`, we dispatch to the blockwise kernel from cuBLAS.

This is a skeleton PR where we test basic functionality (numerics of various simple matrices, as well as one end to end quantization + gemm).

- Scales are flipped based on transpose_result
- Handles boundary conditions

Note that MXFP4 is not added in this PR - we can tackle that in a future PR.

This PR was created by taking https://github.com/pytorch/pytorch/pull/145562, switching e8m0 to in-core dtype, removing fp4 for now, and adding test cases.

# test plan

```
pytest test/test_matmul_cuda.py -k blockwise_mxfp8 -s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147548
Approved by: https://github.com/drisspg

Co-authored-by: drisspg <drisspguessous@gmail.com>
2025-02-26 05:21:26 +00:00
PyTorch MergeBot
c82c1411c6 Revert "torch._scaled_mm with MXFP8 (#147548)"
This reverts commit e34c15a05b.

Reverted https://github.com/pytorch/pytorch/pull/147548 on behalf of https://github.com/wdvr due to failing internal build - discussed with author ([comment](https://github.com/pytorch/pytorch/pull/147548#issuecomment-2683517851))
2025-02-25 23:28:15 +00:00
vasiliy
e34c15a05b torch._scaled_mm with MXFP8 (#147548)
# summary

Add blockwise MXFP8 support to `torch._scaled_mm` on CUDA capability 10.0 and higher devices.  If the scales for A and B are of dtype `torch.float8_e8m0fnu`, we dispatch to the blockwise kernel from cuBLAS.

This is a skeleton PR where we test basic functionality (numerics of various simple matrices, as well as one end to end quantization + gemm).

- Scales are flipped based on transpose_result
- Handles boundary conditions

Note that MXFP4 is not added in this PR - we can tackle that in a future PR.

This PR was created by taking https://github.com/pytorch/pytorch/pull/145562, switching e8m0 to in-core dtype, removing fp4 for now, and adding test cases.

# test plan

```
pytest test/test_matmul_cuda.py -k blockwise_mxfp8 -s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147548
Approved by: https://github.com/drisspg

Co-authored-by: drisspg <drisspguessous@gmail.com>
2025-02-25 03:32:22 +00:00
Peter Yeh
81dccd706b [ROCm] OCP FP8 Support for new GPUs (#146632)
TLDR: Follow up/ Build on top of https://github.com/pytorch/pytorch/pull/144476. add OCP FP8 support for gfx950
refer to https://github.com/pytorch/ao/pull/1677

This pull request includes several changes to improve compatibility and support for new GPU architectures and data types, particularly for ROCm. The key updates involve adding support for new ROCm versions and GPU architectures, updating data type handling, and removing outdated checks.

### Improvements to GPU Architecture and ROCm Version Support:
* [`aten/src/ATen/Context.cpp`](diffhunk://#diff-33de472d304acbe57d693c8567370c638068bedc1aa0ce8e9dc115dad05a7810L323-R326): Added support for new GPU architectures `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks.
* [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199): Updated architecture support in multiple functions to include `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL865-R876)

### Updates to Data Type Handling:
* [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L81-L98): Enhanced data type conversion to include new float8 types for both CUDA and ROCm environments.
* [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fL29-R80): Updated `HipDataTypeFor` template to handle new float8 types and added hard-coded enum values for ROCm versions prior to 6.3.

### Removal of Outdated Checks:
* [`cmake/public/LoadHIP.cmake`](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197): Removed the check for `HIP_NEW_TYPE_ENUMS` as it is no longer necessary with the updated ROCm versions. [[1]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197) [[2]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L211-R182)

These changes ensure better compatibility and performance on newer hardware and software environments, particularly for users leveraging ROCm and CUDA for deep learning and scientific computing tasks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146632
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-02-24 22:47:52 +00:00
PyTorch MergeBot
3e2d9d079e Revert "[ROCm] OCP FP8 Support for new GPUs (#146632)"
This reverts commit f95ab46797.

Reverted https://github.com/pytorch/pytorch/pull/146632 on behalf of https://github.com/jeanschmidt due to Breaking internal builds, I'll find someone to help merge this PR back to main ([comment](https://github.com/pytorch/pytorch/pull/146632#issuecomment-2676823614))
2025-02-23 12:04:50 +00:00
Peter Yeh
f95ab46797 [ROCm] OCP FP8 Support for new GPUs (#146632)
TLDR: Follow up/ Build on top of https://github.com/pytorch/pytorch/pull/144476. add OCP FP8 support for gfx950
refer to https://github.com/pytorch/ao/pull/1677

This pull request includes several changes to improve compatibility and support for new GPU architectures and data types, particularly for ROCm. The key updates involve adding support for new ROCm versions and GPU architectures, updating data type handling, and removing outdated checks.

### Improvements to GPU Architecture and ROCm Version Support:
* [`aten/src/ATen/Context.cpp`](diffhunk://#diff-33de472d304acbe57d693c8567370c638068bedc1aa0ce8e9dc115dad05a7810L323-R326): Added support for new GPU architectures `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks.
* [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199): Updated architecture support in multiple functions to include `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL865-R876)

### Updates to Data Type Handling:
* [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L81-L98): Enhanced data type conversion to include new float8 types for both CUDA and ROCm environments.
* [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fL29-R80): Updated `HipDataTypeFor` template to handle new float8 types and added hard-coded enum values for ROCm versions prior to 6.3.

### Removal of Outdated Checks:
* [`cmake/public/LoadHIP.cmake`](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197): Removed the check for `HIP_NEW_TYPE_ENUMS` as it is no longer necessary with the updated ROCm versions. [[1]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197) [[2]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L211-R182)

These changes ensure better compatibility and performance on newer hardware and software environments, particularly for users leveraging ROCm and CUDA for deep learning and scientific computing tasks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146632
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-02-21 23:44:08 +00:00
Nichols A. Romero
fd8ae1aa04 [ROCm] gfx940 and gfx941 cleanup (#147394)
Removing gfx architectures not supported by ROCm.

NOTE: For users wanting to build PyTorch for gfx archs that are *not* supported by the official wheels on download.pytorch.org, you can build PyTorch from source for your desired gfx arch [using the PYTORCH_ROCM_ARCH env var](https://github.com/pytorch/pytorch/blob/main/README.md#amd-rocm-support).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147394
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
2025-02-21 19:42:12 +00:00
Benjamin Glass
5aa5a5763e [inductor triton] Disable incorrect TF32 usage on CUDA capability < 8 (#145684)
Triton 2.2 and greater have a bug where allowing TF32 generation for a GPU that does not support TF32 will cause code generation errors. Patch around this problem by:

1. Adding a function to `torch.cuda` that determines whether CUDA hardware is capable of using the TF32 format.
2. Using that function to explicitly disable TF32 generation when calling Triton, where needed.

To demonstrate that this fix works, try running `test/inductor/test_max_autotune.py` on a GPU with CUDA compute capability < 8 (e.g. any NVIDIA consumer GPU) without this fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145684
Approved by: https://github.com/eqy
2025-01-28 22:01:08 +00:00
Fuzzkatt
7c7bcb1e33 update IS_JETSON check (#144725)
update IS_JETSON check to include the latest SM

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144725
Approved by: https://github.com/eqy
2025-01-16 22:34:48 +00:00
Nikita Shulga
38bbe37187 Enable CI on SM89 (#140305)
Using EC2 G6 instance, based on NVIDIA L4, added to scale config in https://github.com/pytorch/test-infra/pull/5376

To enable more balanced sharding, had to push 148ae19935

Added `@xfailIfSM89` to the following tests:
 - test_fp8_pattern_2
 - test_original_aten_preserved_split_addmm
 - test_sparse_semi_structured_scaled_mm
 - test_sparse_semi_structured_scaled_mm_fp8
 - test_sparse_fp8fp8_mm

Increased tolerance to 2e-4 for `RNNTest.BidirectionalMultilayerGRU_CPU_vs_CUDA`

Skipped following inductor tests (that either flaky OOMs or timeouts):
 - test_reduction_fn_std_float64
 - test_reduction_fn_var_mean_float64
 - test_multi_output_unbacked_custom_op

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140305
Approved by: https://github.com/wdvr, https://github.com/ZainRizvi
2024-12-03 04:49:46 +00:00
drisspg
80c7c7178e Make sure all SDPA tests are ran with tensor cores enabled (#135592)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135592
Approved by: https://github.com/eqy
2024-10-29 20:53:10 +00:00
Tom Ritchford
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
Aaron Orenstein
524fe784ec BundledAutotuneCache (take 2) (#137902)
Summary:
Add a cache to combine individual autotune caches into a single cached bundle.  We still rely on the individual autotune caches - on a cache hit we copy the individual results into the local caches so they can retrieved later.

Attempt 2 of #134959 (D60677499).

Various configs:
env: TORCHINDUCTOR_BUNDLED_AUTOTUNE_REMOTE_CACHE
config: bundled_autotune_remote_cache
jk: pytorch/remote_cache:bundled_autotune_remote_cache_version

Test Plan:
unit tests

Manually tested w/ EMU:
```
cd fbcode/accelerators/workloads/models/emu_flash/v1p4
make build_benchmark_model && make save_model_to_path
make test_pt2_latency
```

- on a cold run we got 0 hits and 40 misses. On a warm run it got 40 hits and 0 miss.
- perf seems a little better - for 8 runs:
  - no bundled cache averaged 14m11s
  - bundled cache averaged 14m6s
  - 125ms saved per cache entry seems reasonable

Cache Metrics for an sample run:
no bundled cache:
```
INFO: Cache Metrics:
  FbMemcacheRemoteKernelCache: {hit: 2256, miss: 0, put: 0, exception: 0}
  FbRemoteAutotuneCache: {hit: 0, miss: 0, put: 7, exception: 0}
  FbRemoteFxGraphCache: {hit: 40, miss: 0, put: 0, exception: 0}
  LocalAutotuneCache: {hit: 878, miss: 0, put: 7, exception: 0}
  backend:MemcacheCache: {hit: 2256, miss: 0, put: 7, exception: 0}
  backend:_LocalAutotuneCacheBackend: {hit: 878, miss: 0, put: 7, exception: 0}
  backend:_ManifoldCache: {hit: 40, miss: 0, put: 0, exception: 0}
```
bundled cache:
```
INFO: Cache Metrics:
  FbMemcacheRemoteKernelCache: {hit: 2258, miss: 0, put: 0, exception: 0}
  FbRemoteAutotuneCache: {hit: 0, miss: 0, put: 8, exception: 0}
  FbRemoteBundledAutotuneCache: {hit: 40, miss: 0, put: 0, exception: 0} <<<<<<
  FbRemoteFxGraphCache: {hit: 40, miss: 0, put: 0, exception: 0}
  LocalAutotuneCache: {hit: 878, miss: 0, put: 886, exception: 0}
  backend:MemcacheCache: {hit: 2258, miss: 0, put: 8, exception: 0}
  backend:_LocalAutotuneCacheBackend: {hit: 878, miss: 0, put: 886, exception: 0}
  backend:_ManifoldCache: {hit: 80, miss: 0, put: 0, exception: 0}
```

Differential Revision: D64336043

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137902
Approved by: https://github.com/oulgen
2024-10-15 18:39:47 +00:00
Wei Feng
14b4099521 [FSDP2] support torch._foreach_copy_(float8) for fully_shard(Float8Linear) (#135955)
this PR unblocks unit test with single Float8Linear module. It fixes following error
```
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
[rank0]:E0913 13:44:29.829000 2179476 torch/testing/_internal/common_distributed.py:671] RuntimeError: "foreach_tensor_copy" not implemented for 'Float8_e4m3fn'
```

Differential Revision: [D63961071](https://our.internmc.facebook.com/intern/diff/D63961071)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135955
Approved by: https://github.com/vkuzo, https://github.com/eqy
2024-10-07 16:36:31 +00:00
PyTorch MergeBot
6cf493158e Revert "Enable FlashAttention on Windows (#131906)"
This reverts commit b90bc66766.

Reverted https://github.com/pytorch/pytorch/pull/131906 on behalf of https://github.com/atalman due to Windows nightly failures ([comment](https://github.com/pytorch/pytorch/pull/131906#issuecomment-2256421183))
2024-07-29 16:49:23 +00:00
Luca Wehrstedt
b90bc66766 Enable FlashAttention on Windows (#131906)
Let's just give this a try.

Reland of https://github.com/pytorch/pytorch/pull/131875.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131906
Approved by: https://github.com/drisspg
2024-07-26 21:41:56 +00:00
Jerry Mannil
42f647219a [ROCm] Add int4 support (#129710)
- Add AMD support for int4 kernel
  - Only supports CDNA2 and CDNA3 gpus for now
  - Uses `mfma_f32_16x16x16bf16` instruction for matrix multiply
  - Uses `v_and_or_b32` instruction and `__hfma2` instrinsic for unpacking bf16 values
  - Enable hipify for `__nv_bfloat16` and `__nv_bfloat162` data types
- Enable int4 unit tests for CDNA2 and CDNA3 AMD gpus
- Fix torchscript issues due to hipify for `__nv_bfloat16` type
  - TorchScript has its own implementation for bfloat16 type
    - Implemented in `__nv_bloat16` structure at [resource_strings.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h)
    - So, we shouldn't hipify any reference of `__nv_bfloat16` in the torchscript implementation
    - Hence moved the `__nv_bfloat16` direct references in `codegen.cpp` and `cuda_codegen.cpp` to `resource_strings.h` which is already exempted from hipify

Fixes #124699
Fixes pytorch-labs/gpt-fast/issues/154

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129710
Approved by: https://github.com/malfet
2024-07-09 19:49:12 +00:00
PyTorch MergeBot
d7b7f8b79f Revert "[ROCm] Add int4 support (#129710)"
This reverts commit d0ad13fa42.

Reverted https://github.com/pytorch/pytorch/pull/129710 on behalf of https://github.com/jeffdaily due to original ROCm PR did not have ciflow/rocm, missed signal ([comment](https://github.com/pytorch/pytorch/pull/129710#issuecomment-2214558368))
2024-07-08 16:07:53 +00:00
Jerry Mannil
d0ad13fa42 [ROCm] Add int4 support (#129710)
Add AMD support for int4 kernel using mfma_f32_16x16x16bf16 instruction.
Only supports CDNA2 and CDNA3 gpus for now.
Fixes #124699

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129710
Approved by: https://github.com/malfet
2024-07-07 23:54:22 +00:00
eqy
f845a7a91a [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.

What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...

Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-30 19:22:16 +00:00
PyTorch MergeBot
999eec8dea Revert "[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)"
This reverts commit b7e7a4cb01.

Reverted https://github.com/pytorch/pytorch/pull/125343 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to break some test_transformer running on internal A100 and V100 ([comment](https://github.com/pytorch/pytorch/pull/125343#issuecomment-2196202003))
2024-06-28 06:03:54 +00:00
Andres Lugo
b9a1c2c991 [ROCm] Enable F8 Inductor Unit tests (#128353)
First batch of inductor unit test enablement on ROCm for the fnuz f8 variant on MI300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128353
Approved by: https://github.com/jansel, https://github.com/eellison
2024-06-26 18:30:43 +00:00
Eddie Yan
b7e7a4cb01 [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.

What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...

Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-26 00:49:18 +00:00
PyTorch MergeBot
817ce6835b Revert "[cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)"
This reverts commit 4c971932e8.

Reverted https://github.com/pytorch/pytorch/pull/125343 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/125343#issuecomment-2163690162))
2024-06-12 18:47:52 +00:00
eqy
4c971932e8 [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.

What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...

Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-09 06:53:34 +00:00
Xinya Zhang
d34075e0bd Add Efficient Attention support on ROCM (#124885)
This patch implements `with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):` by reusing AOTriton's accelerated SDPA implementation

Known limitations:
- Only supports MI200/MI300X GPUs
- Does not support varlen
- Does not support `CausalVariant`
- Optional arguments `causal_diagonal` and `seqlen_k` in `_efficient_attention_forward/backward` must be null
- Does not work well with inductor's SDPA rewriter. The rewriter has been updated to only use math and flash attention on ROCM.

This PR also uses a different approach of installing AOTriton binary instead of building it from source in the base docker image. More details on motivation: https://github.com/pytorch/pytorch/pull/124885#issuecomment-2153229129

`PYTORCH_TEST_WITH_ROCM=1 PYTORCH_TESTING_DEVICE_ONLY_FOR="cuda" python test/test_transformers.py` yields "55028 passed, 20784 skipped" results with this change.  [Previous result](https://hud.pytorch.org/pr/127528) of `test_transformers.py` was 0 error, 0 failure, 55229 skipped out of 75517 tests in total (the XML report does not contain total number of passed tests).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124885
Approved by: https://github.com/malfet
2024-06-08 22:41:05 +00:00
Fuzzkatt
1cf62e86a4 skip various unit tests for Jetson (#122531)
skip multiprocessing, cuda expandable segments, mem eff and flash attention tests on Jetson due to hanging / sigkill issues from nvidia internal testing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122531
Approved by: https://github.com/eqy, https://github.com/malfet
2024-04-16 01:26:26 +00:00
drisspg
f4e2a226aa ScoreMod API (#121845)
# Summary

This PR adds a new higher-order_op: `templated_attention`.  This op is designed to extend the functionality of torch.nn.fucntional.scaled_dot_product_attention.  PyTorch has efficient pre-written fused-attention kernels. However, users want to modify how scores are computed (a substep inside attention) -- this traditionally requires the user to write their own attention kernel. One such modification to attention scores that is not currently supported by the top level SDPA op is:[ Attention with Linear Biases (ALiBi](https://arxiv.org/abs/2108.12409)).

This higher-order op will instead accept a callable( 'score_mod') function that is through torch.compile will be used to create an efficient attention kernel instantiation.

### Details

This HOP utilizes the existing fx and HOP infra to capture and convert the User `score-mod` function and convert to an FX graph module. Inductor then consumes this HOP that has a `ir.Subgraph` input. It will inline this lowered subgraph into a triton kernel which performs fused attention with the modification to the scores matrix inlined.

### API

The API for a score_mod function should be as follows:

```Python
def score_mod(score: torch.Tensor, batch: torch.Tensor, head: torch.Tensor, token_1: torch.Tensor, token_kv: torch.Tensor) -> torch.Tensor
```

This function receives five parameters:

- `score`: A scalar tensor representing the attention score, with the same data type and device as the query, key, and value tensors.
- `batch`, `head`, `seq_len_q`, `seq_len_kv`: Scalar tensors indicating the batch index, head index, query index, and key/value index, respectively, with torch.int data type and located on the same device as the score tensor.

Consider inputs query, key, value of shapes (2, 4, 16, 8), leading to an intermediate attention score matrix of shape (2, 4, 16, 16)

The score_mod function will be vectorized over each element of this matrix. For instance, modifying the score at the position corresponding to the 0th batch, 2nd head, between the 8th query and the 9th key element, would be invoked as:

```Python
score_mod(score[0,2,8,9], torch.tensor(0), torch.tensor(2), torch.tensor(8), torch.tensor(9))
```

### Examples
```Python
import torch
from torch.nn.attention.templated_attention import templated_attention

torch.manual_seed(0)

# Lets create some input tensors
# The input tensor has shape (batch_size, num_heads, seq_len, head_dim)
query = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
key = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)
value = torch.randn(8, 8, 2048, 64, device="cuda", dtype=torch.float32)

# Lets create a fun new score_modification! I will call this
# Checkerboard. It will reduce the score for neighboring tokens (1 step apart)
# in the sequence. And increase the score for tokens 2 steps apart. For everything
# else, the score will remain the same.

def checkerboard(score, batch, head, token_q, token_kv):
    score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score)
    score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score)
    return score

# Lets call templated_attention with this new score modification
output = templated_attention(query, key, value, score_mod=checkerboard)

compiled_templated_attention = torch.compile(templated_attention)
out_compiled = compiled_templated_attention(query, key, value, score_mod=checkerboard)

torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2)
```

### Future Work
- This PR is currently only forward only. However the triton kernel for backwards where score_modifications to not rely on external buffers has been explored here: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/flash/flash_attention.py
- Kernel Improvements; There are has been some larger updates to the fused attention implementation that Triton uses in its tutorials. The implementation of this kernel is based on a prior version and should be updated.
- We may want to unify this API under the top level SDPA API and leave that as a follow up once this is more stable
- Should we error on CPU?
- There are some issues with dynamic shapes
- Capturing of free variables and lifting to inputs to the subgraph is not working correctly today

### Performance
Comparisons generated by this benchmark:

| Type    |   Speedup |   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |
|---------|-----------|--------------|-------------|-------------|-------------|------------|---------------|----------------|
| Average |     5.412 |              |             |             |             |            |               |                |
| Max     |     8.882 |           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |
| Min     |     3.645 |            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |
| Min     |     0.345 |            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |

For reference

| Configuration                                 | Forward Time (µ seconds) | Backend          | Speedup |
|-----------------------------------------------|--------------------------|------------------|---------|
| Fastest Config in Sweep (`8 16 4096 4096 64 relative_bias torch.bfloat16`) | 3608                   | Templated Attention                | 1.0  |
| Compiled SDPA (No Mask)                       | 9928                   | Math             | 2.75x   |
| Compiled SDPA (With Mask)                     | 11898                    | Math             | 3.29x   |
| Compiled SDPA (With Mask) | 8704                      | Memory Efficient Attention | 2.42x   |
| Compiled SDPA (No Mask) | 2548                     | FlashAttention2 | 0.706x   |

The speedups are measuring compiled templated attention speed versus different calls to torch.nn.functional.sdpa

<details>

<summary> FULL PERFORMANCE SWEEP NUMBERS </summary>

|   batch_size |   num_heads |   q_seq_len |   k_seq_len |   head_dim | score_mod     | dtype          |   eager_time |   compiled_time |   speedup |
|--------------|-------------|-------------|-------------|------------|---------------|----------------|--------------|-----------------|-----------|
|            1 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      331.444 |          67.221 |     4.931 |
|            1 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      335.300 |          64.187 |     5.224 |
|            1 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      352.039 |          63.806 |     5.517 |
|            1 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      371.699 |         711.349 |     0.523 |
|            1 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |      333.488 |          86.455 |     3.857 |
|            1 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |      322.363 |          82.469 |     3.909 |
|            1 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |      349.967 |          82.233 |     4.256 |
|            1 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |      486.359 |        1412.453 |     0.344 |
|            1 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |     2794.597 |         551.188 |     5.070 |
|            1 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |     3965.150 |         513.101 |     7.728 |
|            1 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |     2408.013 |         504.759 |     4.771 |
|            1 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |     6850.531 |       16733.675 |     0.409 |
|            8 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      441.939 |         123.576 |     3.576 |
|            8 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |      560.379 |         116.710 |     4.801 |
|            8 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      421.172 |         115.825 |     3.636 |
|            8 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |      994.492 |        2132.806 |     0.466 |
|            8 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     1436.430 |         309.495 |     4.641 |
|            8 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     1892.216 |         290.186 |     6.521 |
|            8 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     1360.665 |         282.956 |     4.809 |
|            8 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     3525.532 |        8359.702 |     0.422 |
|            8 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    22026.839 |        3864.604 |     5.700 |
|            8 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    31262.746 |        3609.551 |     8.661 |
|            8 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    20219.079 |        3480.402 |     5.809 |
|            8 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |    54654.647 |      116652.357 |     0.469 |
|           16 |          16 |         512 |         512 |         64 | causal_mask   | torch.bfloat16 |      820.606 |         188.683 |     4.349 |
|           16 |          16 |         512 |         512 |         64 | relative_bias | torch.bfloat16 |     1058.362 |         179.295 |     5.903 |
|           16 |          16 |         512 |         512 |         64 | head_bias     | torch.bfloat16 |      784.372 |         175.714 |     4.464 |
|           16 |          16 |         512 |         512 |         64 | pathological  | torch.bfloat16 |     1890.792 |        4212.877 |     0.449 |
|           16 |          16 |        1024 |        1024 |         64 | causal_mask   | torch.bfloat16 |     2781.830 |         557.017 |     4.994 |
|           16 |          16 |        1024 |        1024 |         64 | relative_bias | torch.bfloat16 |     3694.050 |         525.249 |     7.033 |
|           16 |          16 |        1024 |        1024 |         64 | head_bias     | torch.bfloat16 |     2634.164 |         507.613 |     5.189 |
|           16 |          16 |        1024 |        1024 |         64 | pathological  | torch.bfloat16 |     6959.917 |       15331.116 |     0.454 |
|           16 |          16 |        4096 |        4096 |         64 | causal_mask   | torch.bfloat16 |    43889.096 |        7582.018 |     5.789 |
|           16 |          16 |        4096 |        4096 |         64 | relative_bias | torch.bfloat16 |    62784.293 |        7075.846 |     8.873 |
|           16 |          16 |        4096 |        4096 |         64 | head_bias     | torch.bfloat16 |    40308.606 |        6829.587 |     5.902 |
|           16 |          16 |        4096 |        4096 |         64 | pathological  | torch.bfloat16 |   108892.137 |      233090.953 |     0.467 |
</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121845
Approved by: https://github.com/Chillee, https://github.com/zou3519
2024-04-06 01:10:44 +00:00
Xinya Zhang
12116aee68 Add Flash Attention support on ROCM (#121561)
This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in future release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/huydhn
2024-03-28 00:27:38 +00:00
PyTorch MergeBot
764eae9c4e Revert "Add Flash Attention support on ROCM (#121561)"
This reverts commit a37e22de70.

Reverted https://github.com/pytorch/pytorch/pull/121561 on behalf of https://github.com/huydhn due to Sorry for reverting your change but this needs more work to be able to land in fbcode because https://github.com/ROCm/aotriton is not available there atm.  We are working to reland this change before 2.3 release ([comment](https://github.com/pytorch/pytorch/pull/121561#issuecomment-2007717091))
2024-03-19 17:14:28 +00:00
Xinya Zhang
a37e22de70 Add Flash Attention support on ROCM (#121561)
This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton)

- [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`).
    * MI300X is supported. More architectures will be added once Triton support them.
- [x] Only supports power of two sequence lengths.
    * Now it support arbitrary sequence length
- [ ] No support for varlen APIs.
    * varlen API will be supported in the next release of AOTriton
- [x] Only support head dimension 16,32,64,128.
    * Now it support arbitrary head dimension <= 256
- [x] Performance is still being optimized.
    * Kernel is selected according to autotune information from Triton.

Other improvements from AOTriton include
* Allow more flexible Tensor storage layout
* More flexible API

This is a more extensive fix to #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561
Approved by: https://github.com/malfet, https://github.com/atalman
2024-03-12 01:16:53 +00:00
Eddie Yan
cd380c794f [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)
#113713

Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with `TORCH_CUDNN_MHA_ENABLED=1`.

CC @drisspg @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115663
Approved by: https://github.com/drisspg
2024-02-14 22:02:06 +00:00
CaoE
113138aa55 add test cases for GradScaler on CPU (#109994)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109994
Approved by: https://github.com/jgong5, https://github.com/ezyang
2024-02-02 21:49:07 +00:00
Edward Z. Yang
9bce208dfb Replace follow_imports = silent with normal (#118414)
This is a lot of files changed! Don't panic! Here's how it works:

* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.

In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.

The codemod was done with this script authored by GPT-4:

```
import glob

exclude_patterns = [
    ...
]

for pattern in exclude_patterns:
    for filepath in glob.glob(pattern, recursive=True):
        if filepath.endswith('.py'):
            with open(filepath, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('# mypy: ignore-errors\n\n' + content)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
2024-01-27 02:44:11 +00:00
PyTorch MergeBot
2f84a9d37c Revert "[CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)"
This reverts commit 5aa92b5090.

Reverted https://github.com/pytorch/pytorch/pull/115663 on behalf of https://github.com/PaliC due to Unfortunately, this pr breaks cuda builds internally ([comment](https://github.com/pytorch/pytorch/pull/115663#issuecomment-1899388813))
2024-01-18 23:40:30 +00:00
Eddie Yan
5aa92b5090 [CUDNN][SDPA] Experimental cuDNN Flash Attention v2 Inference (#115663)
#113713

Going to clean up some of the checks and will remove draft status after.
Can be tested on SM80+ with `TORCH_CUDNN_MHA_ENABLED=1`.

CC @drisspg @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115663
Approved by: https://github.com/drisspg
2024-01-18 01:20:36 +00:00
Xinya Zhang
e3ca7346ce Re-add initial Flash Attention support on ROCM (#115981)
Note about the Updates:

This PR:
1. skips more flash attention related UTs on MI200
2. Fix additional ATen compiling errors after hipification
3. Fix the author "root" of a specific commit
4. Includes the patch from Nikita in favor of block level static initialization.

CAVEAT: This revised PR has a commit that modifies the CI to force its running on MI200 nodes. That specific commit must be reverted before merge.

Original PR (https://github.com/pytorch/pytorch/pull/114309) Note:

This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project.

Know limitations:

- Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`.
- Only supports power of two sequence lengths.
- No support for varlen APIs.
- Only support head dimension 16,32,64,128.
- Performance is still being optimized.

Fixes #112997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115981
Approved by: https://github.com/malfet
2024-01-04 22:21:31 +00:00
Jeff Daily
e3aefe2970 Revert "Initial Flash Attention support on ROCM (#114309)" (#115975)
This reverts commit 5bddbed399.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115975
Approved by: https://github.com/atalman, https://github.com/malfet
2023-12-16 03:40:14 +00:00