Commit Graph

180 Commits

Author SHA1 Message Date
Luca Wehrstedt
5ab0eb28f7 Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)
cuBLAS added support for them in CUDA 12.9. It's rather easy to call into them, the hardest thing is allowing the lhs and rhs operands to have different scaling types, as that changes the whole callstack.

The scaling format is still detected from the sizes of the scale tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158037
Approved by: https://github.com/eqy, https://github.com/drisspg
2025-07-24 20:10:51 +00:00
Eddie Yan
55ff4f85e9 [FP8][CUTLASS] xFail honor_sm_carveout on sm100 (#152378)
CUTLASS only supports SM carveout via green contexts on `sm100`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152378
Approved by: https://github.com/Skylion007, https://github.com/albanD, https://github.com/nWEIdia
2025-07-22 18:39:50 +00:00
PyTorch MergeBot
32aade9d8d Revert "Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)"
This reverts commit 39ac189808.

Reverted https://github.com/pytorch/pytorch/pull/158037 on behalf of https://github.com/jithunnair-amd due to Ignored ROCm failures while ROCm was unstable, but HUD clearly shows this PR introduced failures on trunk ([comment](https://github.com/pytorch/pytorch/pull/158037#issuecomment-3087982975))
2025-07-18 07:47:46 +00:00
Luca Wehrstedt
39ac189808 Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)
cuBLAS added support for them in CUDA 12.9. It's rather easy to call into them, the hardest thing is allowing the lhs and rhs operands to have different scaling types, as that changes the whole callstack.

The scaling format is still detected from the sizes of the scale tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158037
Approved by: https://github.com/eqy, https://github.com/drisspg
2025-07-17 08:26:27 +00:00
PyTorch MergeBot
9513b9d03f Revert "Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)"
This reverts commit bc65253369.

Reverted https://github.com/pytorch/pytorch/pull/158037 on behalf of https://github.com/lw due to OSX failures are real ([comment](https://github.com/pytorch/pytorch/pull/158037#issuecomment-3079042171))
2025-07-16 15:04:10 +00:00
Luca Wehrstedt
bc65253369 Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)
cuBLAS added support for them in CUDA 12.9. It's rather easy to call into them, the hardest thing is allowing the lhs and rhs operands to have different scaling types, as that changes the whole callstack.

The scaling format is still detected from the sizes of the scale tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158037
Approved by: https://github.com/eqy, https://github.com/drisspg
2025-07-16 13:54:09 +00:00
Aleksandar Samardžić
90618581e9 Fix grouped MM output strides when compiled but not max-autotuned (#158143)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158143
Approved by: https://github.com/ngimel
2025-07-15 11:53:13 +00:00
Luca Wehrstedt
27c50799c1 Use new cuBLAS row-wise fp8 matmul for scaled-mm (#157905)
Most of the work had already been done by @jeffdaily in #154680, but there was one remaining check that needed to be modified in order for `torch._scaled_mm` to use cuBLAS over CUTLASS when available.

I tested this change by rebuilding PyTorch locally with CUDA 12.9 and ran `torch._scaled_mm` under the profiler, and observed that the kernel being launched is called `nvjet_qqtst_128x128_128x6_1x1_h_bz_coopA_algo2_ovscale_TNT` (where `ovscale` stands for "outer vector scaling", I believe, which is how cuBLAS calls this scaling mode).

I then benchmarked the new kernels against the old CUTLASS ones on a standard 700W H100 GPU. I used the same approach as in #134781, and obtained these speed-ups:
![image](https://github.com/user-attachments/assets/43dfb816-9ccf-40c5-8b2a-571ce9cb511d)
![image](https://github.com/user-attachments/assets/be7ac6f2-e16c-479b-ad5c-f8039caba4b1)

We see that the two kernels perform very closely (I'm surprised, I would have expected cuBLAS to outperform CUTLASS across the board), with some thin/skewed shapes becoming worse but some very large shapes becoming better.

I guess the questions are whether we consider this a net-zero change (given that there's improvements _and_ degradations), and how large we consider the burden of maintaining our own CUTLASS kernels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157905
Approved by: https://github.com/eqy, https://github.com/Skylion007, https://github.com/drisspg
2025-07-11 16:11:55 +00:00
Aleksandar Samardžić
a3ec6d64b2 Update test after CUTLASS upgrade (#157903)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157903
Approved by: https://github.com/ngimel
2025-07-10 20:10:20 +00:00
Xuehai Pan
fc0376e8b1 [BE][2/6] fix typos in test/ (test/test_*.py) (#157636)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157636
Approved by: https://github.com/yewentao256, https://github.com/mlazos
ghstack dependencies: #156311, #156609
2025-07-09 11:02:23 +00:00
Jeff Daily
210632fae1 [ROCm] support experimental CU carveout (#149466)
Fixes #149280.  Follow up to #147966, but now available for ROCm.

Since hipblaslt does not support HIPBLASLT_MATMUL_DESC_CU_COUNT_TARGET, we instead create a hipStream that has a CU mask applied.  We pass this masked stream to hipblaslt instead of pytorch's current stream.  We ensure stream ordering between streams using hipEvents and stream synchronization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149466
Approved by: https://github.com/malfet, https://github.com/atalman
2025-07-01 08:54:52 +00:00
AaronWang04
772d590415 [CUTLASS] [CUDA] SM100 GroupMM (#156203)
Closes https://github.com/pytorch/pytorch/issues/156202

PR adds blackwell support for GroupMM

Most of the code that is used for SM90 can be reused, kernel schedule has to be changed in accordance with https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html

Did some preliminary benchmarking of H200 vs B200

Script
```py
import torch
print(torch.__file__)
device = torch.device("cuda")
dtype = torch.bfloat16

shapes = [
    (16, 128000, 7168, 7168),
    (128, 1, 2048, 7168)
]

for batch, M, N, K in shapes:
    a = torch.randn(batch, M, K, device=device, dtype=dtype)
    b = torch.randn(batch, N, K, device=device, dtype=dtype)

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    for i in range(5): c = torch._grouped_mm(a, b)

    num_iter = 50
    start_event.record()

    for i in range(num_iter): c = torch._grouped_mm(a, b)
    end_event.record()

    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    avg_time_ms = elapsed_time_ms / num_iter
    print(f"batch: {batch}\tM: {M}\tN: {N}\tK: {K}")
    print(f"Time per Iteration:\t {avg_time_ms:.4f} ms")
```

On H200
```
batch: 16	M: 128000	N: 7168	K: 7168
Time per Iteration:	 298.6668 ms
batch: 128	M: 1	N: 2048	K: 7168
Time per Iteration:	 4.1462 ms
```

B200
```
batch: 16       M: 128000       N: 7168 K: 7168
Time per Iteration:      190.7458 ms
batch: 128      M: 1    N: 2048 K: 7168
Time per Iteration:      3.0680 ms
```
nsys nvprof
```
root@16930b42ffc6:/workspace/pytorch# nsys nvprof python gemm_test.py
WARNING: python and any of its children processes will be profiled.

Collecting data...
batch: 16	M: 128000	N: 7168	K: 7168
Time per Iteration:	 192.6420 ms
batch: 128	M: 1	N: 2048	K: 7168
Time per Iteration:	 1.2255 ms
Generating '/tmp/nsys-report-6a53.qdstrm'
[1/7] [========================100%] report1.nsys-rep
[2/7] [========================100%] report1.sqlite
[3/7] Executing 'nvtx_sum' stats report
SKIPPED: /workspace/pytorch/report1.sqlite does not contain NV Tools Extension (NVTX) data.
[4/7] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)      Med (ns)    Min (ns)   Max (ns)    StdDev (ns)                 Name
 --------  ---------------  ---------  ------------  ------------  --------  -----------  ------------  ---------------------------------
     98.9      10586895744          2  5293447872.0  5293447872.0  73786464  10513109280  7381715954.2  cudaDeviceSynchronize
      1.0        104084608          5    20816921.6    33552480.0    100800     34786208    18048125.3  cudaMalloc
      0.1          5694304          4     1423576.0     1416656.0   1258560      1602432      181668.1  cudaGetDeviceProperties_v2_v12000
      0.1          5430496        130       41773.0        4560.0      2496      3854368      345761.8  cudaLaunchKernel
      0.0           587584        110        5341.7        4992.0      4224        16992        1482.0  cudaLaunchKernelExC_v11060
      0.0           119200        660         180.6         128.0        96         4128         206.7  cudaGetDriverEntryPoint_v11030
      0.0            68352        660         103.6          64.0        32         4928         224.6  cuTensorMapEncodeTiled
      0.0            34976         49         713.8         224.0       160         6720        1343.4  cudaStreamIsCapturing_v10000
      0.0            32992          4        8248.0        7456.0      4128        13952        4804.4  cudaEventRecord
      0.0            16928          4        4232.0        3600.0      1728         8000        2764.7  cudaEventQuery
      0.0            16288          4        4072.0        3568.0      1952         7200        2396.1  cudaEventCreateWithFlags
      0.0            13632          4        3408.0        2672.0       544         7744        3408.7  cudaEventDestroy
      0.0             1056          1        1056.0        1056.0      1056         1056           0.0  cuModuleGetLoadingMode

[5/7] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)                                                  Name
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  ----------------------------------------------------------------------------------------------------
     99.0      10549232845         55  191804233.5  192944479.0  165746368  203645313    5353204.3  void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm…
      0.6         67327135         55    1224129.7    1330656.0     924320    1364928     182180.4  void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm…
      0.3         34854783         20    1742739.1    1597856.0      10080    3899616     818421.2  void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
      0.0           354880        110       3226.2       3296.0       1920       4160        554.4  void at::cuda::detail::prepare_grouped_gemm_data<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass:…
```

The kernel names are too long to be shown via nvprof, I pasted this from nsight systems
```
small kernel 1SM
100.0%	1.286 ms	1	1.286 ms	1.286 ms	1.286 ms	1.286 ms	0 ns	void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)3, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)128, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params)

large kernel 2SM
100.0%	194.178 ms	1	194.178 ms	194.178 ms	194.178 ms	194.178 ms	0 ns	void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)5, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)256>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_2x1SM_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)256, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156203
Approved by: https://github.com/syed-ahmed, https://github.com/drisspg
2025-06-28 23:02:00 +00:00
Aleksandar Samardžić
6ed85bfe6a Refine alignment check along dynamic dimension for grouped MMs (#155466)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155466
Approved by: https://github.com/ngimel
2025-06-20 19:42:57 +00:00
PyTorch MergeBot
0b62465b99 Revert "Refine alignment check along dynamic dimension for grouped MMs (#155466)"
This reverts commit 830a335a7d.

Reverted https://github.com/pytorch/pytorch/pull/155466 on behalf of https://github.com/atalman due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/155466#issuecomment-2988285117))
2025-06-19 14:25:38 +00:00
Aleksandar Samardžić
830a335a7d Refine alignment check along dynamic dimension for grouped MMs (#155466)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155466
Approved by: https://github.com/ngimel
2025-06-18 15:15:05 +00:00
Thien Tran
fc177801af Enable FP8 row-wise scaled-mm for sm12x (#155991)
## Update using Cutlass 3.x (2025/06/15)

Following @alexsamardzic's advice, I tried out Cutlass 3.x API and it's impressive (rated specs is 419 TFLOPS)

 M | N | K | TFLOPS
---|---|---|--------
16|4096|4096|17.56
64|4096|4096|69.63
256|4096|4096|266.57
1024|4096|4096|339.28
4096|4096|4096|388.91

This uses the same SM100 template. The only difference is
- Cluster size is fixed to `<1,1,1>` since sm120 does not have multicast feature
- ~~Tile size is fixed to `<128,128,128>` due to default kernel schedule does not support `<64,128,128>`. I will work a bit on improve perf for small M.~~ Fixed. Use `KernelTmaWarpSpecializedPingpong` when TileShape.M == 64

Perf for small M is still bad since it seems like Cutlass does not support TileShape.M < 64 for this kernel. It's possible to boost perf a bit by using TileShape `<64,64,128>`.

## Original using SM89

I tried using cutlass FP8 row-wise scaled-mm for sm89 on sm120 (5090) and it works. I guess it makes sense because sm120 matmul uses the standard sm80 PTX instructions (`cp.async`+`mma` and friends).

Simple benchmark script

```python
import torch
from torch._inductor.utils import do_bench_using_profiling

N, K = 4096, 4096
for M in [16, 64, 256, 1024, 4096]:
    A = torch.randn(M, K, device="cuda").to(torch.float8_e4m3fn)
    B = torch.randn(N, K, device="cuda").to(torch.float8_e4m3fn).T
    scale_A = torch.ones(M, 1).cuda()
    scale_B = torch.ones(1, N).cuda()

    out = torch._scaled_mm(A, B, scale_A, scale_B, out_dtype=torch.bfloat16)
    out_ref = ((A.float() @ B.float()) * scale_A * scale_B).bfloat16()
    torch.testing.assert_close(out, out_ref)

    latency_us = do_bench_using_profiling(lambda: torch._scaled_mm(A, B, scale_A, scale_B, out_dtype=torch.bfloat16))
    tflops = (2 * M * N * K) / latency_us / 1e9
    print(f"{M=}\t{N=}\t{K=}\t{tflops:.2f} TFLOPS")
```

M | N | K | TFLOPS
---|---|---|---
16 | 4096 | 4096 | 25.73 TFLOPS
64 | 4096 | 4096 | 71.84 TFLOPS
256 | 4096 | 4096 | 86.40 TFLOPS
1024 | 4096 | 4096 | 112.12 TFLOPS
4096 | 4096 | 4096 | 121.24 TFLOPS

Accodring to [RTX Blackwell Whitepaper](https://images.nvidia.com/aem-dam/Solutions/geforce/blackwell/nvidia-rtx-blackwell-gpu-architecture.pdf), FP8 MMA with FP32 accumulate is 419 TFLOPS. So the result is quite bad here...

However, if I change `ThreadblockSwizzle` to `cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>`

 M | N | K | TFLOPS
---|---|---|--------
16|4096|4096|27.13 TFLOPS
64|4096|4096|84.84 TFLOPS
256|4096|4096|96.75 TFLOPS
1024|4096|4096|110.21 TFLOPS
4096|4096|4096|122.98 TFLOPS

Small M slightly improves, but large M is still bad.

If I further change `ThreadBlockShape=<128, 64, 128>, WarpShape=<64, 32, 128>, NumStages=3` for M>256, which is taken from [cutlass example 58](https://github.com/NVIDIA/cutlass/blob/v3.9.2/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu), I get the following results

 M | N | K | TFLOPS
---|---|---|--------
1024|4096|4096|313.28
4096|4096|4096|376.73

Which is much closer to hardware limit. And it also means this kernel is sufficient to get the most perf out of sm120. Only need better tuned configs.

To make sure this high perf is only obtainable with `GemmIdentityThreadblockSwizzle<1>` + `ThreadBlockShape=<128, 64, 128>, WarpShape=<64, 32, 128>, NumStages=3`, I also try using `ThreadblockSwizzleStreamK` + `ThreadBlockShape=<128, 64, 128>, WarpShape=<64, 32, 128>, NumStages=3`

 M | N | K | TFLOPS
---|---|---|--------
1024|4096|4096|144.03
4096|4096|4096|156.86

A bit better than current configs, but still very far away from hardware limit.

@alexsamardzic I noticed you chose this configs in #149978. Do you have any numbers how the current configs perform on sm89?

Update: Using triton codegen-ed from inductor `compiled_scaled_mm = torch.compile(torch._scaled_mm, dynamic=False, mode="max-autotune-no-cudagraphs")`

 M | N | K | TFLOPS
---|---|---|--------
16|4096|4096|25.60
64|4096|4096|71.74
256|4096|4096|161.64
1024|4096|4096|185.89
4096|4096|4096|215.53

Better than default configs, but still far away from the config above for compute-bound

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155991
Approved by: https://github.com/drisspg, https://github.com/eqy
2025-06-17 18:52:44 +00:00
Aleksandar Samardžić
62fa3f5aeb Support tuning of _grouped_mm (#153953)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153953
Approved by: https://github.com/ngimel
2025-06-12 15:39:35 +00:00
Aleksandar Samardžić
f8baec8984 Update auto-tuning support for _scaled_grouped_mm (#150944)
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Updated meta registration
7. Update synthetic offsets creation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150944
Approved by: https://github.com/ngimel, https://github.com/davidberard98
2025-06-11 19:12:52 +00:00
PyTorch MergeBot
e12597090c Revert "Update auto-tuning support for _scaled_grouped_mm (#150944)"
This reverts commit 09328eb02f.

Reverted https://github.com/pytorch/pytorch/pull/150944 on behalf of https://github.com/davidberard98 due to breaks internal usage & complicates triton pin update - more details in https://github.com/pytorch/pytorch/pull/150944#issuecomment-2957246463 ([comment](https://github.com/pytorch/pytorch/pull/150944#issuecomment-2957248841))
2025-06-09 23:12:56 +00:00
Aleksandar Samardžić
09328eb02f Update auto-tuning support for _scaled_grouped_mm (#150944)
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Updated meta registration
7. Update synthetic offsets creation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150944
Approved by: https://github.com/ngimel
2025-06-08 10:18:13 +00:00
Eddie Yan
5163bf0069 [CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests (#153655)
Some tests may not set the preferred backend, which leads to unexpected behavior when multiple tests are run vs. standalone

Tests that should exercise both backends should explicitly parametrize this setting

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153655
Approved by: https://github.com/ngimel
2025-05-20 16:18:35 +00:00
eqy
6ae0c42278 [CUDA][cuBLASLt] Respect allow[FP16/BF16]ReductionCuBLAS in cuBLASLt (#153095)
cuBLASLt matmuls have been silently allowing all reduction types, which meant that e.g., `allow_fp16_reduced_precision_reduction = False` had no effect.

In practice split-K with reduced precision reductions were unlikely to happen as the default `CUBLASLT_WORKSPACE_SIZE` of 1MiB tends to prevent this.

However this isn't guaranteed and we are on the path to increasing the default workspace size following #151163

This setting is effectively already tested in e.g., `test_cublas_addmm_size_100_cuda_float16` and `test_cublas_addmm_size_100_cuda_bfloat16` but the backend selection is not deterministic. Running the full `test_matmul_cuda.py` seems to exercise the Lt interface, but running a standalone test does not (apparently due to spurious alignment differences).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153095
Approved by: https://github.com/cyyever, https://github.com/Skylion007
2025-05-19 20:05:37 +00:00
PyTorch MergeBot
40339c1e99 Revert "[CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests (#153655)"
This reverts commit 3bde364996.

Reverted https://github.com/pytorch/pytorch/pull/153655 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to fail a test in trunk ([comment](https://github.com/pytorch/pytorch/pull/153655#issuecomment-2888212597))
2025-05-17 08:11:54 +00:00
Eddie Yan
3bde364996 [CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests (#153655)
Some tests may not set the preferred backend, which leads to unexpected behavior when multiple tests are run vs. standalone

Tests that should exercise both backends should explicitly parametrize this setting

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153655
Approved by: https://github.com/ngimel
2025-05-16 21:31:13 +00:00
Eddie Yan
d965fa2c4b [CUDA][cuBLAS] Remove IS_ARM64 skip in test_matmul_cuda.py (#153660)
Original skip seems stale and the test appears to run fine on Grace + Hopper and Grace + Blackwell

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153660
Approved by: https://github.com/Skylion007
2025-05-16 07:31:16 +00:00
Natalia Gimelshein
9c99ea2991 error out on negative offs or on K=0 in group gemm (#153226)
Error out if K=0 in one of the grouped gemms to avoid hangs in #152668
Also, adds meta function for _scaled_grouped_mm (TODO: do the same for _grouped_mm, unless it's done already)

One weird thing I'm seeing, when running all grouped_gemm tests, I'm erroring out with
```
  File "/data/users/ngimel/pytorch/torch/_inductor/graph.py", line 1246, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
  File "/data/users/ngimel/pytorch/torch/_inductor/lowering.py", line 445, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/data/users/ngimel/pytorch/torch/_inductor/kernel/mm_scaled_grouped.py", line 444, in tuned_scaled_grouped_mm
    if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias):
  File "/data/users/ngimel/pytorch/torch/_inductor/kernel/mm_scaled_grouped.py", line 375, in can_use_triton_kernel
    offs is not None
  File "/home/ngimel/.conda/envs/pytorch_monarch/lib/python3.10/site-packages/sympy/core/relational.py", line 516, in __bool__
    raise TypeError("cannot determine truth value of Relational")
torch._inductor.exc.InductorError: LoweringException: TypeError: cannot determine truth value of Relational
```
which is weird, there's no relational that sympy has to evaluate in `offs is not None`, and when running this test separately (`test_scaled_grouped_gemm_2d_3d_fast_accum_True_strided_False_use_torch_compile_True_cuda`) it passes. I suspect some autotuning cache has to be reset between runs, but don't know what to look for.
Edit: that error is "fixed" by setting `dynamic=False`, now with correct meat function something's wrong with dynamic shapes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153226
Approved by: https://github.com/kwen2501
2025-05-10 01:13:18 +00:00
eqy
b30d276abc [CUDA][cuBLASLt] Fix scale setting for allowFP16AccumulationCuBLAS true case (#153083)
Also add some missing `@onlyCUDA` / support check decorators in `test_matmul_cuda.py`
Should help resolve #151890

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153083
Approved by: https://github.com/janeyx99
2025-05-09 02:27:17 +00:00
Aidyn-A
086e2c2399 [TEST][ATen][CUDA] Skip row-wise scaled matrix mmultiplication tests on sm_120+ (#152814)
The float8 row-wise scaled matmuls are not supported on Blackwell yet. This PR adds skips to those tests to decrease the noise on `sm_120+` machines.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152814
Approved by: https://github.com/eqy, https://github.com/Skylion007
2025-05-08 19:34:20 +00:00
drisspg
14f8066910 Ensure mxfp8 scaled_mm works w/ max-autotune (#152744)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152744
Approved by: https://github.com/Skylion007
2025-05-06 01:16:57 +00:00
eqy
cc072af74a [CUDA][MXFP8] bump tolerances for test_blockwise_mxfp8_nvfp4_numerics (#151811)
got a slightly lower sqnr on a smaller GPU

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151811
Approved by: https://github.com/albanD
2025-04-30 01:12:51 +00:00
PaulZhang12
3ed5f1fb77 [CUDA][cuBLAS] Aten GEMM overload for FP32 output from FP16/BF16 inputs (#150812)
Enable FP32 output from FP16/BF16 GEMMs in aten with cuBLAS. Accumulation for these GEMMs are generally already done in FP32. Adds the functionality to the following aten operators:
* mm
* bmm
* addmm
* baddmm

Follow up of customer issue: https://github.com/pytorch/pytorch/issues/146241#issuecomment-2781889390

Differential Revision: [D73126191](https://our.internmc.facebook.com/intern/diff/D73126191)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150812
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-04-18 01:53:26 +00:00
Joel Schlosser
ae53510b9e Fix setUpClass() / tearDownClass() for device-specific tests (#151129)
Finishes up the work started in #121686 + adds test

Update: this was not as straightforward as I originally imagined. Context below.

**TL;DR:** `TestFoo{CPU, CUDA}` now actually derive from `TestFoo`! Also, `{CPU, CUDA}TestBase` setup / teardown logic is now always called (it is required to set the primary device), regardless of whether `super().setUpClass()` / `super().tearDownClass()` are called or not.

**Background:** The typical way to get device-specific tests is to write a generic `TestFoo` and call `instantiate_device_type_tests(TestFoo, locals())` to get `TestFooCPU`, `TestFooCUDA`, etc. After this, generic tests (e.g. `TestFoo.test_bar()`) become `TestFooCPU.test_bar_cpu()` / `TestFooCUDA.test_bar_cuda()`.

Behind the scenes, this was historically accomplished by creating a `TestFooCUDA` that derives from both a `CUDATestBase` and an *empty class* called `TestFoo_base`. This `TestFoo_base` has the same bases as `TestFoo`, but none of the test functions (e.g. `test_bar()`). The documented reason for this is to avoid things like a derived `TestFooCUDA.test_bar()` being discovered in addition to the real device-specific test `TestFooCUDA.test_bar_cuda()`.

(1) A reason this matters is because it should be possible to call e.g. `super().setUpClass()` from a custom setup / teardown classmethod. If the generated TestFooCUDA does not derive from TestFoo, but instead derives from the empty class described above, this syntax does not work; in fact there is no way to form a proper `super()` call that works across the device-specific test variants. Here's an example that breaks in the OpInfo tests:

070f389745/test/test_ops.py (L218-L221)

(2) Further, there is some precedent within a custom `setUpClass()` impl for storing things on the `cls` object to be accessed at test time. This must be the device-specific test class (`TestFooCUDA`) and not `TestFoo` for this to work. As an example, the open device registration tests load a module during setup and use it in the test logic:

070f389745/test/test_cpp_extensions_open_device_registration.py (L63-L77)

070f389745/test/test_cpp_extensions_open_device_registration.py (L79-L80)

To accomplish both (1) and (2) at the same time, I decided to revisit the idea of utilizing a proper inheritance hierarchy for `TestFoo` -> `{TestFooCPU, TestFooCUDA}`. That is: have TestFooCPU / TestFooCUDA **actually** derive from `TestFoo`. This achieves both (1) and (2). The only thing left is to make sure the generic tests (e.g. `TestFoo.test_bar()`) are not discoverable, as was the stated reason for diverging from this in the first place. It turns out we can simply `delattr()` these generic tests from `TestFoo` once `TestFooCPU` / `TestFooCUDA` have been setup with the device-specific variants, and all works well. The `instantiate_device_type_tests(...)` logic already deletes `TestFoo` from scope, so I don't see a problem with deleting generic tests from this base class as well (CI will prove me right or wrong ofc).

**Side note:** I was encountering a weird race condition where sometimes the custom `setUpClass()` / `tearDownClass()` defined & swapped in [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L940-L955)) would be used, and sometimes it wouldn't. This non-deterministic behavior was called out previously by @ngimel here:
4a47dd9b3f/test/inductor/test_torchinductor_dynamic_shapes.py (L128-L130)

To address this, I moved this block of logic to before the first call to `instantiate_test()`, as that method queries for the primary device, and the primary device identification logic may manually invoke `setUpClass()` (see [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L381-L384))). Goal: define the `setUpClass()` / `tearDownClass()` we want for correctness before they're ever called. This seems to work and the behavior is deterministic now AFAICT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151129
Approved by: https://github.com/janeyx99, https://github.com/masnesral, https://github.com/malfet
2025-04-16 02:18:42 +00:00
PyTorch MergeBot
98b1e82ba8 Revert "Fix setUpClass() / tearDownClass() for device-specific tests (#151129)"
This reverts commit bd4cf30e31.

Reverted https://github.com/pytorch/pytorch/pull/151129 on behalf of https://github.com/jbschlosser due to flex attention tests failing ([comment](https://github.com/pytorch/pytorch/pull/151129#issuecomment-2807632119))
2025-04-15 22:07:25 +00:00
Joel Schlosser
bd4cf30e31 Fix setUpClass() / tearDownClass() for device-specific tests (#151129)
Finishes up the work started in #121686 + adds test

Update: this was not as straightforward as I originally imagined. Context below.

**TL;DR:** `TestFoo{CPU, CUDA}` now actually derive from `TestFoo`! Also, `{CPU, CUDA}TestBase` setup / teardown logic is now always called (it is required to set the primary device), regardless of whether `super().setUpClass()` / `super().tearDownClass()` are called or not.

**Background:** The typical way to get device-specific tests is to write a generic `TestFoo` and call `instantiate_device_type_tests(TestFoo, locals())` to get `TestFooCPU`, `TestFooCUDA`, etc. After this, generic tests (e.g. `TestFoo.test_bar()`) become `TestFooCPU.test_bar_cpu()` / `TestFooCUDA.test_bar_cuda()`.

Behind the scenes, this was historically accomplished by creating a `TestFooCUDA` that derives from both a `CUDATestBase` and an *empty class* called `TestFoo_base`. This `TestFoo_base` has the same bases as `TestFoo`, but none of the test functions (e.g. `test_bar()`). The documented reason for this is to avoid things like a derived `TestFooCUDA.test_bar()` being discovered in addition to the real device-specific test `TestFooCUDA.test_bar_cuda()`.

(1) A reason this matters is because it should be possible to call e.g. `super().setUpClass()` from a custom setup / teardown classmethod. If the generated TestFooCUDA does not derive from TestFoo, but instead derives from the empty class described above, this syntax does not work; in fact there is no way to form a proper `super()` call that works across the device-specific test variants. Here's an example that breaks in the OpInfo tests:

070f389745/test/test_ops.py (L218-L221)

(2) Further, there is some precedent within a custom `setUpClass()` impl for storing things on the `cls` object to be accessed at test time. This must be the device-specific test class (`TestFooCUDA`) and not `TestFoo` for this to work. As an example, the open device registration tests load a module during setup and use it in the test logic:

070f389745/test/test_cpp_extensions_open_device_registration.py (L63-L77)

070f389745/test/test_cpp_extensions_open_device_registration.py (L79-L80)

To accomplish both (1) and (2) at the same time, I decided to revisit the idea of utilizing a proper inheritance hierarchy for `TestFoo` -> `{TestFooCPU, TestFooCUDA}`. That is: have TestFooCPU / TestFooCUDA **actually** derive from `TestFoo`. This achieves both (1) and (2). The only thing left is to make sure the generic tests (e.g. `TestFoo.test_bar()`) are not discoverable, as was the stated reason for diverging from this in the first place. It turns out we can simply `delattr()` these generic tests from `TestFoo` once `TestFooCPU` / `TestFooCUDA` have been setup with the device-specific variants, and all works well. The `instantiate_device_type_tests(...)` logic already deletes `TestFoo` from scope, so I don't see a problem with deleting generic tests from this base class as well (CI will prove me right or wrong ofc).

**Side note:** I was encountering a weird race condition where sometimes the custom `setUpClass()` / `tearDownClass()` defined & swapped in [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L940-L955)) would be used, and sometimes it wouldn't. This non-deterministic behavior was called out previously by @ngimel here:
4a47dd9b3f/test/inductor/test_torchinductor_dynamic_shapes.py (L128-L130)

To address this, I moved this block of logic to before the first call to `instantiate_test()`, as that method queries for the primary device, and the primary device identification logic may manually invoke `setUpClass()` (see [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L381-L384))). Goal: define the `setUpClass()` / `tearDownClass()` we want for correctness before they're ever called. This seems to work and the behavior is deterministic now AFAICT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151129
Approved by: https://github.com/janeyx99, https://github.com/masnesral, https://github.com/malfet
2025-04-15 20:13:26 +00:00
Wei Wang
f6e9e064a7 [CI][CUDA] xfail grouped gemm unit tests on blackwell (#150982)
On SM100OrLater, Expect failures like:

RuntimeError: torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0

To execute this test, run the following from the base repo dir:
    python test/test_matmul_cuda.py TestMatmulCudaCUDA.test_grouped_gemm_3d_2d_strided_False_a_row_major_True_b_row_major_False_cuda

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

`
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_False_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0005s] (Issue with numpy versi...) [  2%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_False_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [  4%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_False_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [  6%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_False_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [  8%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_True_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 10%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_True_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 12%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_True_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 14%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_True_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version ...) [ 16%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_False_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versi...) [ 18%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_False_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 20%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_False_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 22%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_False_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 25%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_True_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 27%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_True_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 29%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_True_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 31%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_True_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version ...) [ 33%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_False_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0002s] (Issue with numpy versi...) [ 35%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_False_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 37%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_False_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 39%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_False_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 41%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_True_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 43%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_True_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 45%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_True_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 47%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_True_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version ...) [ 50%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_False_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versi...) [ 52%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_False_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 54%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_False_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 56%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_False_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 58%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_True_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 60%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_True_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 62%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_True_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 64%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_True_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version ...) [ 66%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_2d_fast_accum_False_strided_False_cuda XFAIL [0.8166s]                                        [ 68%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_2d_fast_accum_False_strided_True_cuda XFAIL [0.0017s]                                         [ 70%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_2d_fast_accum_True_strided_False_cuda XFAIL [0.0012s]                                         [ 72%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_2d_fast_accum_True_strided_True_cuda XFAIL [0.0012s]                                          [ 75%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_3d_fast_accum_False_strided_False_cuda XFAIL [0.0033s]                                        [ 77%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_3d_fast_accum_False_strided_True_cuda XFAIL [0.0012s]                                         [ 79%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_3d_fast_accum_True_strided_False_cuda XFAIL [0.0015s]                                         [ 81%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_3d_fast_accum_True_strided_True_cuda XFAIL [0.0012s]                                          [ 83%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_2d_fast_accum_False_strided_False_cuda XFAIL [0.0012s]                                        [ 85%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_2d_fast_accum_False_strided_True_cuda XFAIL [0.0012s]                                         [ 87%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_2d_fast_accum_True_strided_False_cuda XFAIL [0.0011s]                                         [ 89%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_2d_fast_accum_True_strided_True_cuda XFAIL [0.0012s]                                          [ 91%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_3d_fast_accum_False_strided_False_cuda XFAIL [0.0014s]                                        [ 93%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_3d_fast_accum_False_strided_True_cuda XFAIL [0.0012s]                                         [ 95%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_3d_fast_accum_True_strided_False_cuda XFAIL [0.0011s]                                         [ 97%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_3d_fast_accum_True_strided_True_cuda XFAIL [0.0011s]                                          [100%]
`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150982
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-04-12 01:53:12 +00:00
Bert Maher
2d187bf7e6 Support tuning of _scaled_grouped_mm (#150421)
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150421
Approved by: https://github.com/ngimel
2025-04-11 23:03:49 +00:00
Jiang, Yanbing
1e92579126 Add torch._scaled_mm for CPU (#150410)
This PR is the duplicated one for https://github.com/pytorch/pytorch/pull/139975.

This PR is to add torch._scaled_mm for CPU backend.

_scaled_mm_out_cpu and _scaled_mm_cpu are new added and included in torch._scaled_mm CPU dispatch. We also add _scaled_mm_out_cpu_emulated as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150410
Approved by: https://github.com/atalman
2025-04-11 02:23:03 +00:00
PyTorch MergeBot
6a65f2c4fe Revert "Support tuning of _scaled_grouped_mm (#150421)"
This reverts commit 8efcf21fff.

Reverted https://github.com/pytorch/pytorch/pull/150421 on behalf of https://github.com/malfet due to Looks like it broke lint, see a0ab243c3a/1 ([comment](https://github.com/pytorch/pytorch/pull/150421#issuecomment-2795218547))
2025-04-10 21:36:41 +00:00
Bert Maher
8efcf21fff Support tuning of _scaled_grouped_mm (#150421)
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150421
Approved by: https://github.com/ngimel
2025-04-10 20:34:16 +00:00
Natalia Gimelshein
55e62ff74a bf16 grouped gemm (#150374)
Enabled bf16 grouped gemm with an API similar to _scaled_group_gemm, except without scale and fast accum arguments. All transpose variants are enabled, unlike scaled gemm. Ideally we'd factor out a lot more code from scaled gemm, currently there's a lot of repetition between scaled and non-scaled versions. I factored out only a helper kernel that prepares arguments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150374
Approved by: https://github.com/drisspg
2025-04-06 04:53:24 +00:00
PyTorch MergeBot
4854926aeb Revert "Add torch._scaled_mm for CPU (#150410)"
This reverts commit 3b02f795c5.

Reverted https://github.com/pytorch/pytorch/pull/150410 on behalf of https://github.com/malfet due to It breaks ROCM tests ([comment](https://github.com/pytorch/pytorch/pull/150410#issuecomment-2777704212))
2025-04-04 06:52:54 +00:00
Jiang, Yanbing
3b02f795c5 Add torch._scaled_mm for CPU (#150410)
This PR is the duplicated one for https://github.com/pytorch/pytorch/pull/139975.

This PR is to add torch._scaled_mm for CPU backend.

_scaled_mm_out_cpu and _scaled_mm_cpu are new added and included in torch._scaled_mm CPU dispatch. We also add _scaled_mm_out_cpu_emulated as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150410
Approved by: https://github.com/atalman
2025-04-03 19:43:45 +00:00
vasiliy
c974b5322a enable torch.compile for torch._scaled_mm nvfp4 recipe (#150462)
Summary:

Updates the meta registration for `torch._scaled_mm` to work for the
nvfp4 recipe.

Test Plan:

```bash
pytest test/test_matmul_cuda.py -s -k test_blockwise_nvfp4
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150462
Approved by: https://github.com/eellison
2025-04-02 01:08:40 +00:00
Prachi Gupta
47cdad2995 [ROCm] Enable several fsdp related UTs (#149369)
Enabling 26 UTs for ROCm in the following files:

-  distributed._shard.sharded_optim.test_sharded_optim - 2 UTs
-  distributed._shard.sharded_tensor.ops.test_binary_cmp - 4 UTs
-  distributed._shard.sharded_tensor.ops.test_init - 3 UTs
-  distributed._shard.sharded_tensor.ops.test_embedding - 2 UTs
-  distributed._shard.sharded_tensor.ops.test_embedding_bag - 2 UTs
-  distributed._composable.test_replicate_with_compiler - 4 UTs
-  distributed._composable.fsdp.test_fully_shard_grad_scaler - 1 UTs
-  distributed.tensor.test_attention - 4 UTs
-  distributed.tensor.test_matrix_ops - 1 UTs
-  distributed.tensor.test_tensor_ops - 1 UTs
-  distributed.fsdp.test_fsdp_grad_acc - 2 UTs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149369
Approved by: https://github.com/jeffdaily
2025-03-31 16:15:57 +00:00
vasiliy
01cb3519b3 wire torch._scaled_mm with fp4 operands to the cublas nvfp4 kernel (#148792)
Summary:

When `a` and `b` have dtype `torch.float4_e2m1fn_x2` and `a_scale` and `b_scale` have dtype `torch.float8_e4m3fn`, makes

```python
c = torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=torch.bfloat16)
```

call the cuBLAS fp4 gemm kernel, as specified in https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-for-fp8-and-fp4-data-types

note: output scale (`scale_in_D` from the cuBLAS docs) is not tested in this PR - we can enable in a follow-up.

Test Plan:

```bash
pytest test/test_matmul_cuda.py -s -k mxfp8_nvfp4
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148792
Approved by: https://github.com/eqy
ghstack dependencies: #148791
2025-03-27 17:32:20 +00:00
vasiliy
dad0854d48 meta registration for torch._scaled_mm with mxfp8 (#148461)
Summary:

Adds the meta registration logic for torch.compile to work with
`torch._scaled_mm` with mxfp8.  Thanks to @eellison  for the pointer to make inductor work with this.

Test Plan:

```
pytest test/test_matmul_cuda.py -k test_blockwise_mxfp8_compile -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148461
Approved by: https://github.com/drisspg, https://github.com/eellison
2025-03-27 02:32:40 +00:00
Natalia Gimelshein
53a1a022a9 [WIP] Initial implementation of Grouped Gemm API (#148531)
This PR provides initial cutlass implementation of grouped gemm api as described in this [document](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9). Any combination of 2d and 3d inputs is supported, with 2d input being jagged, and the offsets of the jagged input being given by device tensor `offs`. Only H100 is supported, and only fp8_e4m3 with bf16 output and rowwise scaling. All the dimensions of each individual gemm have to be multiple of 16, that's cutlass limitation.
I'll need to add those checks, for dynamic dimensions unfortunately the checks will have to be a device assert.
I had to copy-paste cutlass's `Sm90RowBroadcast` and `Sm90ColBroadcast` structs with minor changes to enable scales given as pointer arrays, ideally those should be part of cutlass itself.
I copied the schedules from the similar grouped gemm in FBGEMM, but there's a lot of room to improve perf, especially for `fast_accum=False`.
Next steps would be perf tuning and increasing coverage to B100, I don't know how cutlass grouped gemm example handles blockwise scaling on B100.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148531
Approved by: https://github.com/drisspg
2025-03-11 21:49:46 +00:00
PyTorch MergeBot
c983e1124c Revert "[WIP] Initial implementation of Grouped Gemm API (#148531)"
This reverts commit ff29791ed8.

Reverted https://github.com/pytorch/pytorch/pull/148531 on behalf of https://github.com/janeyx99 due to Sorry but this broke ROCm jobs on trunk ([comment](https://github.com/pytorch/pytorch/pull/148531#issuecomment-2714577498))
2025-03-11 14:40:58 +00:00
Natalia Gimelshein
ff29791ed8 [WIP] Initial implementation of Grouped Gemm API (#148531)
This PR provides initial cutlass implementation of grouped gemm api as described in this [document](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9). Any combination of 2d and 3d inputs is supported, with 2d input being jagged, and the offsets of the jagged input being given by device tensor `offs`. Only H100 is supported, and only fp8_e4m3 with bf16 output and rowwise scaling. All the dimensions of each individual gemm have to be multiple of 16, that's cutlass limitation.
I'll need to add those checks, for dynamic dimensions unfortunately the checks will have to be a device assert.
I had to copy-paste cutlass's `Sm90RowBroadcast` and `Sm90ColBroadcast` structs with minor changes to enable scales given as pointer arrays, ideally those should be part of cutlass itself.
I copied the schedules from the similar grouped gemm in FBGEMM, but there's a lot of room to improve perf, especially for `fast_accum=False`.
Next steps would be perf tuning and increasing coverage to B100, I don't know how cutlass grouped gemm example handles blockwise scaling on B100.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148531
Approved by: https://github.com/drisspg
2025-03-11 02:41:09 +00:00
drisspg
07b7b3ed4e torch._scaled_mm with MXFP8 (#147548)
# summary

Add blockwise MXFP8 support to `torch._scaled_mm` on CUDA capability 10.0 and higher devices.  If the scales for A and B are of dtype `torch.float8_e8m0fnu`, we dispatch to the blockwise kernel from cuBLAS.

This is a skeleton PR where we test basic functionality (numerics of various simple matrices, as well as one end to end quantization + gemm).

- Scales are flipped based on transpose_result
- Handles boundary conditions

Note that MXFP4 is not added in this PR - we can tackle that in a future PR.

This PR was created by taking https://github.com/pytorch/pytorch/pull/145562, switching e8m0 to in-core dtype, removing fp4 for now, and adding test cases.

# test plan

```
pytest test/test_matmul_cuda.py -k blockwise_mxfp8 -s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147548
Approved by: https://github.com/drisspg

Co-authored-by: drisspg <drisspguessous@gmail.com>
2025-02-27 02:44:39 +00:00
Luca Wehrstedt
60d94ea22b Add option to limit number of SMs used by matmul kernels (#147966)
Resubmission of #144974 which was reverted for unrelated reasons.

Newer matmul kernels, e.g. those targeting Hopper GPUs, sometime use a "persistent" schedule which consists in launching as many CUDA blocks as there are SMs on the GPU, with each such block then working on multiple output tiles in a row. This allows to eliminate the overhead of starting and finishing each tile, effectively doing cross-tile pipelining. In previous generations these latencies could be hidden by having multiple CUDA blocks per SM but, with blocks becoming larger, only one can run at a time per SM and thus this needs to be taken care of in software.

Persistent kernels become an issue when other kernels are running concurrently. The classical example is a NCCL communication kernel running in the background. In such cases the matmul expects to be able to use all the SMs but is prevented from doing so because some of the are busy. This can lead to its blocks being scheduled as two separate waves on the available SMs. This "wave quantization" can double the latency of the matmul kernels.

While we wait for smarter solutions, such as automatic load balancing among the blocks, an easy way to unblock ourselves is to tell the matmuls to only use a subset of the GPU's SMs. For this, I am introducing a global `sm_carveout` flag which can be used to specify how many SMs should be left available for other kernels.

For now I only change the cuBLAS kernels and the scaled-mm CUTLASS kernel. More kernels can be opted-in later.

I tested this change manually, by using the Kineto profiler to look up the grid size of a scaled-mm kernel with different values of `sm_carveout`, and making sure it changed. Suggestions are welcome for a more automated test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147966
Approved by: https://github.com/danthe3rd
2025-02-26 12:01:12 +00:00
PyTorch MergeBot
a84db75e1b Revert "torch._scaled_mm with MXFP8 (#147548)"
This reverts commit 12b9674cb6.

Reverted https://github.com/pytorch/pytorch/pull/147548 on behalf of https://github.com/wdvr due to failing internal build - similar to previous, see below ([comment](https://github.com/pytorch/pytorch/pull/147548#issuecomment-2684134336))
2025-02-26 07:17:24 +00:00
vasiliy
12b9674cb6 torch._scaled_mm with MXFP8 (#147548)
# summary

Add blockwise MXFP8 support to `torch._scaled_mm` on CUDA capability 10.0 and higher devices.  If the scales for A and B are of dtype `torch.float8_e8m0fnu`, we dispatch to the blockwise kernel from cuBLAS.

This is a skeleton PR where we test basic functionality (numerics of various simple matrices, as well as one end to end quantization + gemm).

- Scales are flipped based on transpose_result
- Handles boundary conditions

Note that MXFP4 is not added in this PR - we can tackle that in a future PR.

This PR was created by taking https://github.com/pytorch/pytorch/pull/145562, switching e8m0 to in-core dtype, removing fp4 for now, and adding test cases.

# test plan

```
pytest test/test_matmul_cuda.py -k blockwise_mxfp8 -s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147548
Approved by: https://github.com/drisspg

Co-authored-by: drisspg <drisspguessous@gmail.com>
2025-02-26 05:21:26 +00:00
PyTorch MergeBot
c82c1411c6 Revert "torch._scaled_mm with MXFP8 (#147548)"
This reverts commit e34c15a05b.

Reverted https://github.com/pytorch/pytorch/pull/147548 on behalf of https://github.com/wdvr due to failing internal build - discussed with author ([comment](https://github.com/pytorch/pytorch/pull/147548#issuecomment-2683517851))
2025-02-25 23:28:15 +00:00
PyTorch MergeBot
1e894d2635 Revert "Add option to limit number of SMs used by matmul kernels (#144974)"
This reverts commit af2d63637e.

Reverted https://github.com/pytorch/pytorch/pull/144974 on behalf of https://github.com/wdvr due to reverting in order to revert #147548 that causes a merge conflict ([comment](https://github.com/pytorch/pytorch/pull/144974#issuecomment-2683461733))
2025-02-25 22:46:38 +00:00
Luca Wehrstedt
af2d63637e Add option to limit number of SMs used by matmul kernels (#144974)
Newer matmul kernels, e.g. those targeting Hopper GPUs, sometime use a "persistent" schedule which consists in launching as many CUDA blocks as there are SMs on the GPU, with each such block then working on multiple output tiles in a row. This allows to eliminate the overhead of starting and finishing each tile, effectively doing cross-tile pipelining. In previous generations these latencies could be hidden by having multiple CUDA blocks per SM but, with blocks becoming larger, only one can run at a time per SM and thus this needs to be taken care of in software.

Persistent kernels become an issue when other kernels are running concurrently. The classical example is a NCCL communication kernel running in the background. In such cases the matmul expects to be able to use all the SMs but is prevented from doing so because some of the are busy. This can lead to its blocks being scheduled as two separate waves on the available SMs. This "wave quantization" can double the latency of the matmul kernels.

While we wait for smarter solutions, such as automatic load balancing among the blocks, an easy way to unblock ourselves is to tell the matmuls to only use a subset of the GPU's SMs. For this, I am introducing a global `sm_carveout` flag which can be used to specify how many SMs should be left available for other kernels.

For now I only change the cuBLAS kernels and the scaled-mm CUTLASS kernel. More kernels can be opted-in later.

I tested this change manually, by using the Kineto profiler to look up the grid size of a scaled-mm kernel with different values of `sm_carveout`, and making sure it changed. Suggestions are welcome for a more automated test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144974
Approved by: https://github.com/eqy, https://github.com/albanD
2025-02-25 10:19:19 +00:00
vasiliy
e34c15a05b torch._scaled_mm with MXFP8 (#147548)
# summary

Add blockwise MXFP8 support to `torch._scaled_mm` on CUDA capability 10.0 and higher devices.  If the scales for A and B are of dtype `torch.float8_e8m0fnu`, we dispatch to the blockwise kernel from cuBLAS.

This is a skeleton PR where we test basic functionality (numerics of various simple matrices, as well as one end to end quantization + gemm).

- Scales are flipped based on transpose_result
- Handles boundary conditions

Note that MXFP4 is not added in this PR - we can tackle that in a future PR.

This PR was created by taking https://github.com/pytorch/pytorch/pull/145562, switching e8m0 to in-core dtype, removing fp4 for now, and adding test cases.

# test plan

```
pytest test/test_matmul_cuda.py -k blockwise_mxfp8 -s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147548
Approved by: https://github.com/drisspg

Co-authored-by: drisspg <drisspguessous@gmail.com>
2025-02-25 03:32:22 +00:00
Peter Yeh
81dccd706b [ROCm] OCP FP8 Support for new GPUs (#146632)
TLDR: Follow up/ Build on top of https://github.com/pytorch/pytorch/pull/144476. add OCP FP8 support for gfx950
refer to https://github.com/pytorch/ao/pull/1677

This pull request includes several changes to improve compatibility and support for new GPU architectures and data types, particularly for ROCm. The key updates involve adding support for new ROCm versions and GPU architectures, updating data type handling, and removing outdated checks.

### Improvements to GPU Architecture and ROCm Version Support:
* [`aten/src/ATen/Context.cpp`](diffhunk://#diff-33de472d304acbe57d693c8567370c638068bedc1aa0ce8e9dc115dad05a7810L323-R326): Added support for new GPU architectures `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks.
* [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199): Updated architecture support in multiple functions to include `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL865-R876)

### Updates to Data Type Handling:
* [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L81-L98): Enhanced data type conversion to include new float8 types for both CUDA and ROCm environments.
* [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fL29-R80): Updated `HipDataTypeFor` template to handle new float8 types and added hard-coded enum values for ROCm versions prior to 6.3.

### Removal of Outdated Checks:
* [`cmake/public/LoadHIP.cmake`](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197): Removed the check for `HIP_NEW_TYPE_ENUMS` as it is no longer necessary with the updated ROCm versions. [[1]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197) [[2]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L211-R182)

These changes ensure better compatibility and performance on newer hardware and software environments, particularly for users leveraging ROCm and CUDA for deep learning and scientific computing tasks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146632
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-02-24 22:47:52 +00:00
PyTorch MergeBot
3e2d9d079e Revert "[ROCm] OCP FP8 Support for new GPUs (#146632)"
This reverts commit f95ab46797.

Reverted https://github.com/pytorch/pytorch/pull/146632 on behalf of https://github.com/jeanschmidt due to Breaking internal builds, I'll find someone to help merge this PR back to main ([comment](https://github.com/pytorch/pytorch/pull/146632#issuecomment-2676823614))
2025-02-23 12:04:50 +00:00
Peter Yeh
f95ab46797 [ROCm] OCP FP8 Support for new GPUs (#146632)
TLDR: Follow up/ Build on top of https://github.com/pytorch/pytorch/pull/144476. add OCP FP8 support for gfx950
refer to https://github.com/pytorch/ao/pull/1677

This pull request includes several changes to improve compatibility and support for new GPU architectures and data types, particularly for ROCm. The key updates involve adding support for new ROCm versions and GPU architectures, updating data type handling, and removing outdated checks.

### Improvements to GPU Architecture and ROCm Version Support:
* [`aten/src/ATen/Context.cpp`](diffhunk://#diff-33de472d304acbe57d693c8567370c638068bedc1aa0ce8e9dc115dad05a7810L323-R326): Added support for new GPU architectures `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks.
* [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199): Updated architecture support in multiple functions to include `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL865-R876)

### Updates to Data Type Handling:
* [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L81-L98): Enhanced data type conversion to include new float8 types for both CUDA and ROCm environments.
* [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fL29-R80): Updated `HipDataTypeFor` template to handle new float8 types and added hard-coded enum values for ROCm versions prior to 6.3.

### Removal of Outdated Checks:
* [`cmake/public/LoadHIP.cmake`](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197): Removed the check for `HIP_NEW_TYPE_ENUMS` as it is no longer necessary with the updated ROCm versions. [[1]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197) [[2]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L211-R182)

These changes ensure better compatibility and performance on newer hardware and software environments, particularly for users leveraging ROCm and CUDA for deep learning and scientific computing tasks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146632
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-02-21 23:44:08 +00:00
PyTorch MergeBot
babb2dc2af Revert "Add torch._scaled_mm for CPU (#139975)"
This reverts commit 6f7e67c43c.

Reverted https://github.com/pytorch/pytorch/pull/139975 on behalf of https://github.com/wdvr due to failing inductor mkldnn_pattern_matcher_cpu tests ([comment](https://github.com/pytorch/pytorch/pull/139975#issuecomment-2667186865))
2025-02-18 23:58:31 +00:00
Jiang, Yanbing
6f7e67c43c Add torch._scaled_mm for CPU (#139975)
This PR is to add `torch._scaled_mm` for CPU backend.

`_scaled_mm_out_cpu` and `_scaled_mm_cpu` are new added and included in `torch._scaled_mm` CPU dispatch. We also add `_scaled_mm_out_cpu_emulated` as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139975
Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/malfet
2025-02-18 18:44:26 +00:00
PyTorch MergeBot
49e8f9c965 Revert "Add torch._scaled_mm for CPU (#139975)"
This reverts commit 22fae4c5f9.

Reverted https://github.com/pytorch/pytorch/pull/139975 on behalf of https://github.com/huydhn due to third time is the charm ([comment](https://github.com/pytorch/pytorch/pull/139975#issuecomment-2664622598))
2025-02-18 05:11:32 +00:00
Jiang, Yanbing
22fae4c5f9 Add torch._scaled_mm for CPU (#139975)
This PR is to add `torch._scaled_mm` for CPU backend.

`_scaled_mm_out_cpu` and `_scaled_mm_cpu` are new added and included in `torch._scaled_mm` CPU dispatch. We also add `_scaled_mm_out_cpu_emulated` as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139975
Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/malfet
2025-02-17 18:39:10 +00:00
PyTorch MergeBot
aac5d1a289 Revert "Add torch._scaled_mm for CPU (#139975)"
This reverts commit f0bdc27f74.

Reverted https://github.com/pytorch/pytorch/pull/139975 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it looks like internal ideep version is too old to support this ([comment](https://github.com/pytorch/pytorch/pull/139975#issuecomment-2660008996))
2025-02-14 18:31:54 +00:00
Jiang, Yanbing
f0bdc27f74 Add torch._scaled_mm for CPU (#139975)
This PR is to add `torch._scaled_mm` for CPU backend.

`_scaled_mm_out_cpu` and `_scaled_mm_cpu` are new added and included in `torch._scaled_mm` CPU dispatch. We also add `_scaled_mm_out_cpu_emulated` as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139975
Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/malfet
2025-02-14 02:03:53 +00:00
Eddie Yan
9ee506bd93 [CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)
Test for `cublasGemmEx` added, still need to figure out the best way to exercise the other APIs...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144441
Approved by: https://github.com/Chillee, https://github.com/malfet
2025-02-06 19:04:50 +00:00
Aleksandar Samardžić
2b00d211f0 Build RowwiseScaledMM.cu for SM89 (#145676)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145676
Approved by: https://github.com/drisspg, https://github.com/malfet, https://github.com/eqy
2025-02-01 11:44:58 +00:00
PyTorch MergeBot
c3f71eb61b Revert "[CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)"
This reverts commit e2917245fb.

Reverted https://github.com/pytorch/pytorch/pull/144441 on behalf of https://github.com/ZainRizvi due to Sorry but this still fails internally with the same error.  @Chillee or @malfet, can you please help the change get tested? (See D68783351) ([comment](https://github.com/pytorch/pytorch/pull/144441#issuecomment-2627886999))
2025-01-31 17:43:09 +00:00
Eddie Yan
e2917245fb [CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)
Test for `cublasGemmEx` added, still need to figure out the best way to exercise the other APIs...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144441
Approved by: https://github.com/Chillee, https://github.com/malfet
2025-01-30 22:33:50 +00:00
PyTorch MergeBot
c986eba560 Revert "[CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)"
This reverts commit abf28982a8.

Reverted https://github.com/pytorch/pytorch/pull/144441 on behalf of https://github.com/ZainRizvi due to Sorry but this is failing internally. @Chillee can you please help change get remerged? See  D68720562 ([comment](https://github.com/pytorch/pytorch/pull/144441#issuecomment-2616726406))
2025-01-27 19:38:26 +00:00
Eddie Yan
abf28982a8 [CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)
Test for `cublasGemmEx` added, still need to figure out the best way to exercise the other APIs...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144441
Approved by: https://github.com/Chillee
2025-01-27 18:05:23 +00:00
PyTorch MergeBot
dad9bc3461 Revert "[CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)"
This reverts commit de945d78da.

Reverted https://github.com/pytorch/pytorch/pull/144441 on behalf of https://github.com/izaitsevfb due to unused variables again :( ([comment](https://github.com/pytorch/pytorch/pull/144441#issuecomment-2611182461))
2025-01-23 22:59:25 +00:00
Eddie Yan
de945d78da [CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)
Test for `cublasGemmEx` added, still need to figure out the best way to exercise the other APIs...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144441
Approved by: https://github.com/Chillee
2025-01-22 22:42:48 +00:00
Dmitry Nikolaev
57b2b64acf Fix always true scaled_mm test (#143912)
Looks like `out_fp8` should use matmul without scales and `out_fp8_s` with
Scales were optional arguments before PR https://github.com/pytorch/pytorch/pull/128683
Then test_float8_scale started comparing two identical results and lost its meaning
Reason of making scales required https://github.com/pytorch/pytorch/pull/128683#issuecomment-2169146402UMBER

This PR uses scale=1.0 to compare result with scaled matmul

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143912
Approved by: https://github.com/drisspg, https://github.com/malfet, https://github.com/pruthvistony
2025-01-20 16:17:46 +00:00
PyTorch MergeBot
4ea189422d Revert "[CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)"
This reverts commit a6763b7b81.

Reverted https://github.com/pytorch/pytorch/pull/144441 on behalf of https://github.com/kit1980 due to breaking internal builds: unused variable 'halpha' ([comment](https://github.com/pytorch/pytorch/pull/144441#issuecomment-2596895865))
2025-01-16 21:12:41 +00:00
eqy
a6763b7b81 [CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)
Test for `cublasGemmEx` added, still need to figure out the best way to exercise the other APIs...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144441
Approved by: https://github.com/Chillee
2025-01-15 18:37:55 +00:00
Jeff Daily
6ac0616504 [ROCm] hipblaslt rowwise f8 gemm (#144432)
hipblaslt added rowwise f8 gemm support.  Integrate with scaled_mm.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144432
Approved by: https://github.com/drisspg
2025-01-15 18:23:44 +00:00
PyTorch MergeBot
64bcf39180 Revert "[CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)"
This reverts commit 388b75edec.

Reverted https://github.com/pytorch/pytorch/pull/144441 on behalf of https://github.com/kit1980 due to breaking internal builds: unused variable 'halpha' ([comment](https://github.com/pytorch/pytorch/pull/144441#issuecomment-2588517060))
2025-01-14 00:48:28 +00:00
eqy
388b75edec [CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441)
Test for `cublasGemmEx` added, still need to figure out the best way to exercise the other APIs...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144441
Approved by: https://github.com/Chillee
2025-01-11 15:30:38 +00:00
Dmitry Nikolaev
d4871750d9 [ROCm] Enable post-merge trunk workflow on MI300 runners; skip and fix MI300 related failed tests (#143673)
This PR
* makes changes to the workflow files and scripts so we can run CI workflows on the MI300 runners
* skips and fixes several tests, failed on MI300, observed in https://github.com/pytorch/pytorch/pull/140989

Skipped due to unsupported Float8_e4m3fn data type on MI300 (need to update test code to use datatypes supported by MI300):
- distributed.tensor.parallel.test_micro_pipeline_tp.py::MicroPipelineTPTest::test_fuse_all_gather_scaled_matmul_A_dims_\*_gather_dim_\* (24 tests across inductor/distributed configs)
- distributed.tensor.parallel.test_micro_pipeline_tp.py::test_fuse_scaled_matmul_reduce_scatter_A_dims_\*_scatter_dim_\* (12 tests across inductor/distributed configs))
- inductor.test_loop_ordering::LoopOrderingTest::test_fp8_cast_and_t
- inductor.test_loop_ordering::LoopOrderingTest::test_fp8_pattern_2

Skipped due to AssertionError on MI300:
- inductor.test_mkldnn_pattern_matcher.py::test_qconv2d_int8_mixed_bf16
- distributed._tools.test_sac_ilp::TestSACILP::test_sac_ilp_case1

Skipped:
- test_cuda.py::TestCudaMallocAsync::test_clock_speed
- test_cuda.py::TestCudaMallocAsync::test_power_draw
- test_torch.py::TestTorchDeviceTypeCUDA::test_deterministic_cumsum_cuda

Skipped flaky tests on MI300:
- distributed.test_c10d_gloo.py::ProcessGroupGlooTest::test_gather_stress_cuda
- inductor.test_cpu_repro::CPUReproTests::test_lstm_packed_unbatched_False* (256 tests)

Fixed:
- test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_float8_basics_cuda

Features:
- inductor/test_fp8.py - declare a new function to convert FP8 datatypes to ROCm supported FP8 datatypes. It keeps test names for CUDA and ROCm and allows to enable Inductor FP8 tests on CPU

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143673
Approved by: https://github.com/jeffdaily, https://github.com/malfet, https://github.com/pruthvistony

Co-authored-by: saienduri <saimanas.enduri@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-01-09 05:18:57 +00:00
drisspg
9bf2a9a616 [ScaledMM] Fix NaNs in test for garbage input data (#144042)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144042
Approved by: https://github.com/janeyx99
2025-01-03 21:02:25 +00:00
PyTorch MergeBot
45a709d9ec Revert "Add torch._scaled_mm for CPU (#139975)"
This reverts commit cbc4cf3043.

Reverted https://github.com/pytorch/pytorch/pull/139975 on behalf of https://github.com/malfet due to It broke the same test, but on ROCM this time, though it was classified as flaky for some reason, see d8c3900d80/1 ([comment](https://github.com/pytorch/pytorch/pull/139975#issuecomment-2564378146))
2024-12-28 16:49:38 +00:00
Jiang, Yanbing
cbc4cf3043 Add torch._scaled_mm for CPU (#139975)
This PR is to add `torch._scaled_mm` for CPU backend.

`_scaled_mm_out_cpu` and `_scaled_mm_cpu` are new added and included in `torch._scaled_mm` CPU dispatch. We also add `_scaled_mm_out_cpu_emulated` as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139975
Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/malfet
2024-12-28 05:49:06 +00:00
PyTorch MergeBot
fca457b5db Revert "Add torch._scaled_mm for CPU (#139975)"
This reverts commit 3f80632c80.

Reverted https://github.com/pytorch/pytorch/pull/139975 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing some tests in trunk ([comment](https://github.com/pytorch/pytorch/pull/139975#issuecomment-2563331259))
2024-12-27 05:25:17 +00:00
Jiang, Yanbing
3f80632c80 Add torch._scaled_mm for CPU (#139975)
This PR is to add `torch._scaled_mm` for CPU backend.

`_scaled_mm_out_cpu` and `_scaled_mm_cpu` are new added and included in `torch._scaled_mm` CPU dispatch. We also add `_scaled_mm_out_cpu_emulated` as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139975
Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/malfet
ghstack dependencies: #139974
2024-12-26 22:22:42 +00:00
eqy
bd4071f0b0 [Matmul][CUDA][FP8] Skip rowwise scaling tests on non-sm90 (#141596)
Since the current kernel is using sm90-specific features, just pre-emptively skip the test for any non-sm90 compute capabilities

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141596
Approved by: https://github.com/drisspg
2024-12-10 23:16:19 +00:00
PyTorch MergeBot
dbd7b820dd Revert "[ROCm] port CK rowwise F8 from fbgemm (#140856)"
This reverts commit 291626fb22.

Reverted https://github.com/pytorch/pytorch/pull/140856 on behalf of https://github.com/atalman due to Failing internal build ([comment](https://github.com/pytorch/pytorch/pull/140856#issuecomment-2518911997))
2024-12-05 01:51:40 +00:00
Jeff Daily
291626fb22 [ROCm] port CK rowwise F8 from fbgemm (#140856)
This ports (copies) FBGEMM's implementation from @jwfromm.

https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140856
Approved by: https://github.com/drisspg, https://github.com/atalman
2024-12-04 00:32:24 +00:00
vasiliy
3d5fe0ce78 torch._scaled_mm: support dims of size 0 for tensorwise scaling (#140967)
Summary:

Ensures we support dims of size 0 properly in `torch._scaled_mm`. Follows the behavior from `torch.mm`.

For now only enable support for tensorwise, we can tackle rowwise in a future PR.

Test Plan:

```
python test/test_matmul_cuda.py -k test_zero_dim
```

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140967
Approved by: https://github.com/eqy, https://github.com/drisspg
2024-11-27 04:07:52 +00:00
eqy
c0e98a485b [FP8][CUDA] Fix stale expected error message (#136581)
CC @nWEIdia as I think we have seen internal failures on this

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136581
Approved by: https://github.com/mikaylagawarecki
2024-09-26 20:57:38 +00:00
eqy
2519e5a8de [CUDA][FP8] Skip rowwise scaling test on sm89 (#135718)
Same reason as #https://github.com/pytorch/pytorch/pull/133612, rowwise scaling implementation is sm90+ specific (e.g., uses TMA)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135718
Approved by: https://github.com/Skylion007
2024-09-13 15:07:20 +00:00
eqy
7ad3108ef2 [CUTLASS][FP8] Skip scaled_mm rowwise test on sm89 (#133612)
Rowwise implementation currently uses sm90-specific features incl. TMA
CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133612
Approved by: https://github.com/Skylion007
2024-08-15 23:43:30 +00:00
drisspg
4d11a9b783 [CI] Fix rowwise scaling tests on h100 (#133384)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133384
Approved by: https://github.com/malfet, https://github.com/nWEIdia
2024-08-14 05:58:33 +00:00
drisspg
9108b74bbc Updates to scaled_mm for rowwise scaling (#130059)
# Summary

This updates _scaled_mm's API to enforce that input scales are always 2 dimensional. This resolves ambiguity around scaling scheme
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130059
Approved by: https://github.com/vkuzo
2024-07-04 00:53:17 +00:00
y-sq
ff026f3d0a Fix an issue in meta_scaled_mm (#129521)
Summary:
To fix the following failure cases:

For example, when `M, K, N = 245760, 656, 6560`, fp8 with compile fails due to `RuntimeError: mat2 must be col_major`.

---------
From the inductor generated code (https://fburl.com/everpaste/epcagkrd)
```
V0625 01:38:55.551000 140329914449920 torch/_inductor/scheduler.py:1623] [0/0] scheduling ComputedBuffer(name='buf12', layout=FixedLayout('cuda', torch.float8_e4m3fn, size=[656, 6560], stride=[6656, 1]),
... ...
V0625 01:38:56.194000 140329914449920 torch/_inductor/graph.py:1680] [0/0] [__output_code]         buf12 = empty_strided_cuda((656, 6560), (6656, 1), torch.float8_e4m3fn)
... ...
V0625 01:38:56.194000 140329914449920 torch/_inductor/graph.py:1680] [0/0] [__output_code]     return (buf10, buf2, buf5, buf6, reinterpret_tensor(buf11, (245760, 656), (1, 245760), 0), reinterpret_tensor(buf12, (6560, 656), (1, 6656), 0), )
... ...
V0625 01:39:12.098000 140312968167424 torch/_inductor/graph.py:1680] [1/0_1] [__output_code]     assert_size_stride(permute_10, (6560, 656), (1, 6656))
... ...
V0625 01:39:12.098000 140312968167424 torch/_inductor/graph.py:1680] [1/0_1] [__output_code]         buf8 = aten._scaled_mm.default(buf6, permute_10, buf7, reciprocal_3, None, None, torch.bfloat16)
```

Inductor gives the mat2 (`permute_10`) a different stride (`6656`) instead of using its shape[0] (`(6560, 656)`).

Therefore, the `stride[1] == shape[0]` condition fails.

To fix the issue, simply modify the `is_col_major` check to exclude this condition as it doesn't hold for all valid cases.

Test Plan:
Run the failed case again. It works with the fix.
-----
Sandcastle / GitHub CI will make sure the existing tests could still pass.

Reviewed By: vkuzo

Differential Revision: D58994704

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129521
Approved by: https://github.com/drisspg
2024-06-27 07:03:34 +00:00
Andres Lugo
b9a1c2c991 [ROCm] Enable F8 Inductor Unit tests (#128353)
First batch of inductor unit test enablement on ROCm for the fnuz f8 variant on MI300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128353
Approved by: https://github.com/jansel, https://github.com/eellison
2024-06-26 18:30:43 +00:00
drisspg
fcf2a1378b Enable fp8 rowwise scaling kernel on cuda, TAKE 2: #125204 (#128989)
# Summary
First PR got reverted and needed a redo

This pull request introduces an fp8 row-scaling kernel as an optional implementation for `scaled_mm`. The kernel selection is based on the scaling tensors of the inputs. For inputs `x` and `y` of shape `[M, K]` and `[K, N]` respectively, the following conditions must be met:
- `x`'s scale should be a 1-dimensional tensor of length `M`.
- `y`'s scale should be a 1-dimensional tensor of length `N`.

It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for `y` are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row".

The following two PRs were required to enable local builds:
- [PR #126185](https://github.com/pytorch/pytorch/pull/126185)
- [PR #125523](https://github.com/pytorch/pytorch/pull/125523)

### Todo
We still do not build our Python wheels with this architecture.

@ptrblck @malfet, should we replace `sm_90` with `sm_90a`?

The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit:
https://github.com/pytorch/pytorch/pull/125204/files#r1586986954

#### ifdef

I tried to use : `#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \
    defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900` to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this

Kernel Credit:
@jwfromm

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128989
Approved by: https://github.com/yangsiyu007, https://github.com/vkuzo
2024-06-19 04:49:39 +00:00
drisspg
fc2913fb80 Remove amax return from _scaled_mm (#128683)
# Summary
The primary reason for the change was lack of current use case and the need to work around an two Inductor issue.
- Tensor arguments as kwarg only
- multiple outputs from triton templates

If the need for the amax return type arises we can consider either adding it, more likely creating a separate op.

In principle PyTorch is moving away from ops that bundle lots of functionality into "mega ops". We instead rely upon the compiler to generate appropriate fused kernels.

### Changes:
- This removes the amax return type from scaled_mm. We have found that the common use case is to return in "high-precision" ( a type with more precision than fp8). This is only relevant when returning in low-precision.
- We currently still allow for fp8 returns and scaled result.  Perhaps we should also ban this as well...

New signature:
```Python
def meta_scaled_mm(
    self: torch.Tensor,
    mat2: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    scale_result: Optional[torch.Tensor] = None,
    out_dtype: Optional[torch.dtype] = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128683
Approved by: https://github.com/vkuzo
2024-06-17 16:48:00 +00:00
PyTorch MergeBot
a5b86a1ec0 Revert "FP8 rowwise scaling (#125204)"
This reverts commit 5dc9128229.

Reverted https://github.com/pytorch/pytorch/pull/125204 on behalf of https://github.com/atalman due to Sorry need to revert this failing, on internal CI. I suggest to reimport this and try to land internally resolving all issues ([comment](https://github.com/pytorch/pytorch/pull/125204#issuecomment-2152905513))
2024-06-06 16:12:34 +00:00