Commit Graph

121 Commits

Author SHA1 Message Date
Xuehai Pan
fc0376e8b1 [BE][2/6] fix typos in test/ (test/test_*.py) (#157636)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157636
Approved by: https://github.com/yewentao256, https://github.com/mlazos
ghstack dependencies: #156311, #156609
2025-07-09 11:02:23 +00:00
Jeff Daily
210632fae1 [ROCm] support experimental CU carveout (#149466)
Fixes #149280.  Follow up to #147966, but now available for ROCm.

Since hipblaslt does not support HIPBLASLT_MATMUL_DESC_CU_COUNT_TARGET, we instead create a hipStream that has a CU mask applied.  We pass this masked stream to hipblaslt instead of pytorch's current stream.  We ensure stream ordering between streams using hipEvents and stream synchronization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149466
Approved by: https://github.com/malfet, https://github.com/atalman
2025-07-01 08:54:52 +00:00
AaronWang04
772d590415 [CUTLASS] [CUDA] SM100 GroupMM (#156203)
Closes https://github.com/pytorch/pytorch/issues/156202

PR adds blackwell support for GroupMM

Most of the code that is used for SM90 can be reused, kernel schedule has to be changed in accordance with https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html

Did some preliminary benchmarking of H200 vs B200

Script
```py
import torch
print(torch.__file__)
device = torch.device("cuda")
dtype = torch.bfloat16

shapes = [
    (16, 128000, 7168, 7168),
    (128, 1, 2048, 7168)
]

for batch, M, N, K in shapes:
    a = torch.randn(batch, M, K, device=device, dtype=dtype)
    b = torch.randn(batch, N, K, device=device, dtype=dtype)

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    for i in range(5): c = torch._grouped_mm(a, b)

    num_iter = 50
    start_event.record()

    for i in range(num_iter): c = torch._grouped_mm(a, b)
    end_event.record()

    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    avg_time_ms = elapsed_time_ms / num_iter
    print(f"batch: {batch}\tM: {M}\tN: {N}\tK: {K}")
    print(f"Time per Iteration:\t {avg_time_ms:.4f} ms")
```

On H200
```
batch: 16	M: 128000	N: 7168	K: 7168
Time per Iteration:	 298.6668 ms
batch: 128	M: 1	N: 2048	K: 7168
Time per Iteration:	 4.1462 ms
```

B200
```
batch: 16       M: 128000       N: 7168 K: 7168
Time per Iteration:      190.7458 ms
batch: 128      M: 1    N: 2048 K: 7168
Time per Iteration:      3.0680 ms
```
nsys nvprof
```
root@16930b42ffc6:/workspace/pytorch# nsys nvprof python gemm_test.py
WARNING: python and any of its children processes will be profiled.

Collecting data...
batch: 16	M: 128000	N: 7168	K: 7168
Time per Iteration:	 192.6420 ms
batch: 128	M: 1	N: 2048	K: 7168
Time per Iteration:	 1.2255 ms
Generating '/tmp/nsys-report-6a53.qdstrm'
[1/7] [========================100%] report1.nsys-rep
[2/7] [========================100%] report1.sqlite
[3/7] Executing 'nvtx_sum' stats report
SKIPPED: /workspace/pytorch/report1.sqlite does not contain NV Tools Extension (NVTX) data.
[4/7] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)      Med (ns)    Min (ns)   Max (ns)    StdDev (ns)                 Name
 --------  ---------------  ---------  ------------  ------------  --------  -----------  ------------  ---------------------------------
     98.9      10586895744          2  5293447872.0  5293447872.0  73786464  10513109280  7381715954.2  cudaDeviceSynchronize
      1.0        104084608          5    20816921.6    33552480.0    100800     34786208    18048125.3  cudaMalloc
      0.1          5694304          4     1423576.0     1416656.0   1258560      1602432      181668.1  cudaGetDeviceProperties_v2_v12000
      0.1          5430496        130       41773.0        4560.0      2496      3854368      345761.8  cudaLaunchKernel
      0.0           587584        110        5341.7        4992.0      4224        16992        1482.0  cudaLaunchKernelExC_v11060
      0.0           119200        660         180.6         128.0        96         4128         206.7  cudaGetDriverEntryPoint_v11030
      0.0            68352        660         103.6          64.0        32         4928         224.6  cuTensorMapEncodeTiled
      0.0            34976         49         713.8         224.0       160         6720        1343.4  cudaStreamIsCapturing_v10000
      0.0            32992          4        8248.0        7456.0      4128        13952        4804.4  cudaEventRecord
      0.0            16928          4        4232.0        3600.0      1728         8000        2764.7  cudaEventQuery
      0.0            16288          4        4072.0        3568.0      1952         7200        2396.1  cudaEventCreateWithFlags
      0.0            13632          4        3408.0        2672.0       544         7744        3408.7  cudaEventDestroy
      0.0             1056          1        1056.0        1056.0      1056         1056           0.0  cuModuleGetLoadingMode

[5/7] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)                                                  Name
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  ----------------------------------------------------------------------------------------------------
     99.0      10549232845         55  191804233.5  192944479.0  165746368  203645313    5353204.3  void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm…
      0.6         67327135         55    1224129.7    1330656.0     924320    1364928     182180.4  void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm…
      0.3         34854783         20    1742739.1    1597856.0      10080    3899616     818421.2  void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
      0.0           354880        110       3226.2       3296.0       1920       4160        554.4  void at::cuda::detail::prepare_grouped_gemm_data<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass:…
```

The kernel names are too long to be shown via nvprof, I pasted this from nsight systems
```
small kernel 1SM
100.0%	1.286 ms	1	1.286 ms	1.286 ms	1.286 ms	1.286 ms	0 ns	void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)3, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)128, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params)

large kernel 2SM
100.0%	194.178 ms	1	194.178 ms	194.178 ms	194.178 ms	194.178 ms	0 ns	void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)5, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)256>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_2x1SM_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)256, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156203
Approved by: https://github.com/syed-ahmed, https://github.com/drisspg
2025-06-28 23:02:00 +00:00
Aleksandar Samardžić
6ed85bfe6a Refine alignment check along dynamic dimension for grouped MMs (#155466)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155466
Approved by: https://github.com/ngimel
2025-06-20 19:42:57 +00:00
PyTorch MergeBot
0b62465b99 Revert "Refine alignment check along dynamic dimension for grouped MMs (#155466)"
This reverts commit 830a335a7d.

Reverted https://github.com/pytorch/pytorch/pull/155466 on behalf of https://github.com/atalman due to breaks internal builds ([comment](https://github.com/pytorch/pytorch/pull/155466#issuecomment-2988285117))
2025-06-19 14:25:38 +00:00
Aleksandar Samardžić
830a335a7d Refine alignment check along dynamic dimension for grouped MMs (#155466)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155466
Approved by: https://github.com/ngimel
2025-06-18 15:15:05 +00:00
Thien Tran
fc177801af Enable FP8 row-wise scaled-mm for sm12x (#155991)
## Update using Cutlass 3.x (2025/06/15)

Following @alexsamardzic's advice, I tried out Cutlass 3.x API and it's impressive (rated specs is 419 TFLOPS)

 M | N | K | TFLOPS
---|---|---|--------
16|4096|4096|17.56
64|4096|4096|69.63
256|4096|4096|266.57
1024|4096|4096|339.28
4096|4096|4096|388.91

This uses the same SM100 template. The only difference is
- Cluster size is fixed to `<1,1,1>` since sm120 does not have multicast feature
- ~~Tile size is fixed to `<128,128,128>` due to default kernel schedule does not support `<64,128,128>`. I will work a bit on improve perf for small M.~~ Fixed. Use `KernelTmaWarpSpecializedPingpong` when TileShape.M == 64

Perf for small M is still bad since it seems like Cutlass does not support TileShape.M < 64 for this kernel. It's possible to boost perf a bit by using TileShape `<64,64,128>`.

## Original using SM89

I tried using cutlass FP8 row-wise scaled-mm for sm89 on sm120 (5090) and it works. I guess it makes sense because sm120 matmul uses the standard sm80 PTX instructions (`cp.async`+`mma` and friends).

Simple benchmark script

```python
import torch
from torch._inductor.utils import do_bench_using_profiling

N, K = 4096, 4096
for M in [16, 64, 256, 1024, 4096]:
    A = torch.randn(M, K, device="cuda").to(torch.float8_e4m3fn)
    B = torch.randn(N, K, device="cuda").to(torch.float8_e4m3fn).T
    scale_A = torch.ones(M, 1).cuda()
    scale_B = torch.ones(1, N).cuda()

    out = torch._scaled_mm(A, B, scale_A, scale_B, out_dtype=torch.bfloat16)
    out_ref = ((A.float() @ B.float()) * scale_A * scale_B).bfloat16()
    torch.testing.assert_close(out, out_ref)

    latency_us = do_bench_using_profiling(lambda: torch._scaled_mm(A, B, scale_A, scale_B, out_dtype=torch.bfloat16))
    tflops = (2 * M * N * K) / latency_us / 1e9
    print(f"{M=}\t{N=}\t{K=}\t{tflops:.2f} TFLOPS")
```

M | N | K | TFLOPS
---|---|---|---
16 | 4096 | 4096 | 25.73 TFLOPS
64 | 4096 | 4096 | 71.84 TFLOPS
256 | 4096 | 4096 | 86.40 TFLOPS
1024 | 4096 | 4096 | 112.12 TFLOPS
4096 | 4096 | 4096 | 121.24 TFLOPS

Accodring to [RTX Blackwell Whitepaper](https://images.nvidia.com/aem-dam/Solutions/geforce/blackwell/nvidia-rtx-blackwell-gpu-architecture.pdf), FP8 MMA with FP32 accumulate is 419 TFLOPS. So the result is quite bad here...

However, if I change `ThreadblockSwizzle` to `cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>`

 M | N | K | TFLOPS
---|---|---|--------
16|4096|4096|27.13 TFLOPS
64|4096|4096|84.84 TFLOPS
256|4096|4096|96.75 TFLOPS
1024|4096|4096|110.21 TFLOPS
4096|4096|4096|122.98 TFLOPS

Small M slightly improves, but large M is still bad.

If I further change `ThreadBlockShape=<128, 64, 128>, WarpShape=<64, 32, 128>, NumStages=3` for M>256, which is taken from [cutlass example 58](https://github.com/NVIDIA/cutlass/blob/v3.9.2/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu), I get the following results

 M | N | K | TFLOPS
---|---|---|--------
1024|4096|4096|313.28
4096|4096|4096|376.73

Which is much closer to hardware limit. And it also means this kernel is sufficient to get the most perf out of sm120. Only need better tuned configs.

To make sure this high perf is only obtainable with `GemmIdentityThreadblockSwizzle<1>` + `ThreadBlockShape=<128, 64, 128>, WarpShape=<64, 32, 128>, NumStages=3`, I also try using `ThreadblockSwizzleStreamK` + `ThreadBlockShape=<128, 64, 128>, WarpShape=<64, 32, 128>, NumStages=3`

 M | N | K | TFLOPS
---|---|---|--------
1024|4096|4096|144.03
4096|4096|4096|156.86

A bit better than current configs, but still very far away from hardware limit.

@alexsamardzic I noticed you chose this configs in #149978. Do you have any numbers how the current configs perform on sm89?

Update: Using triton codegen-ed from inductor `compiled_scaled_mm = torch.compile(torch._scaled_mm, dynamic=False, mode="max-autotune-no-cudagraphs")`

 M | N | K | TFLOPS
---|---|---|--------
16|4096|4096|25.60
64|4096|4096|71.74
256|4096|4096|161.64
1024|4096|4096|185.89
4096|4096|4096|215.53

Better than default configs, but still far away from the config above for compute-bound

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155991
Approved by: https://github.com/drisspg, https://github.com/eqy
2025-06-17 18:52:44 +00:00
Aleksandar Samardžić
62fa3f5aeb Support tuning of _grouped_mm (#153953)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153953
Approved by: https://github.com/ngimel
2025-06-12 15:39:35 +00:00
Aleksandar Samardžić
f8baec8984 Update auto-tuning support for _scaled_grouped_mm (#150944)
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Updated meta registration
7. Update synthetic offsets creation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150944
Approved by: https://github.com/ngimel, https://github.com/davidberard98
2025-06-11 19:12:52 +00:00
PyTorch MergeBot
e12597090c Revert "Update auto-tuning support for _scaled_grouped_mm (#150944)"
This reverts commit 09328eb02f.

Reverted https://github.com/pytorch/pytorch/pull/150944 on behalf of https://github.com/davidberard98 due to breaks internal usage & complicates triton pin update - more details in https://github.com/pytorch/pytorch/pull/150944#issuecomment-2957246463 ([comment](https://github.com/pytorch/pytorch/pull/150944#issuecomment-2957248841))
2025-06-09 23:12:56 +00:00
Aleksandar Samardžić
09328eb02f Update auto-tuning support for _scaled_grouped_mm (#150944)
1. Enable strided inputs
2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs
3. Fix non-TMA load variant
4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor
5. Fix cases when group size along K dimension is not multiple of block size along K
6. Updated meta registration
7. Update synthetic offsets creation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150944
Approved by: https://github.com/ngimel
2025-06-08 10:18:13 +00:00
Eddie Yan
5163bf0069 [CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests (#153655)
Some tests may not set the preferred backend, which leads to unexpected behavior when multiple tests are run vs. standalone

Tests that should exercise both backends should explicitly parametrize this setting

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153655
Approved by: https://github.com/ngimel
2025-05-20 16:18:35 +00:00
eqy
6ae0c42278 [CUDA][cuBLASLt] Respect allow[FP16/BF16]ReductionCuBLAS in cuBLASLt (#153095)
cuBLASLt matmuls have been silently allowing all reduction types, which meant that e.g., `allow_fp16_reduced_precision_reduction = False` had no effect.

In practice split-K with reduced precision reductions were unlikely to happen as the default `CUBLASLT_WORKSPACE_SIZE` of 1MiB tends to prevent this.

However this isn't guaranteed and we are on the path to increasing the default workspace size following #151163

This setting is effectively already tested in e.g., `test_cublas_addmm_size_100_cuda_float16` and `test_cublas_addmm_size_100_cuda_bfloat16` but the backend selection is not deterministic. Running the full `test_matmul_cuda.py` seems to exercise the Lt interface, but running a standalone test does not (apparently due to spurious alignment differences).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153095
Approved by: https://github.com/cyyever, https://github.com/Skylion007
2025-05-19 20:05:37 +00:00
PyTorch MergeBot
40339c1e99 Revert "[CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests (#153655)"
This reverts commit 3bde364996.

Reverted https://github.com/pytorch/pytorch/pull/153655 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to fail a test in trunk ([comment](https://github.com/pytorch/pytorch/pull/153655#issuecomment-2888212597))
2025-05-17 08:11:54 +00:00
Eddie Yan
3bde364996 [CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests (#153655)
Some tests may not set the preferred backend, which leads to unexpected behavior when multiple tests are run vs. standalone

Tests that should exercise both backends should explicitly parametrize this setting

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153655
Approved by: https://github.com/ngimel
2025-05-16 21:31:13 +00:00
Eddie Yan
d965fa2c4b [CUDA][cuBLAS] Remove IS_ARM64 skip in test_matmul_cuda.py (#153660)
Original skip seems stale and the test appears to run fine on Grace + Hopper and Grace + Blackwell

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153660
Approved by: https://github.com/Skylion007
2025-05-16 07:31:16 +00:00
Natalia Gimelshein
9c99ea2991 error out on negative offs or on K=0 in group gemm (#153226)
Error out if K=0 in one of the grouped gemms to avoid hangs in #152668
Also, adds meta function for _scaled_grouped_mm (TODO: do the same for _grouped_mm, unless it's done already)

One weird thing I'm seeing, when running all grouped_gemm tests, I'm erroring out with
```
  File "/data/users/ngimel/pytorch/torch/_inductor/graph.py", line 1246, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
  File "/data/users/ngimel/pytorch/torch/_inductor/lowering.py", line 445, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/data/users/ngimel/pytorch/torch/_inductor/kernel/mm_scaled_grouped.py", line 444, in tuned_scaled_grouped_mm
    if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias):
  File "/data/users/ngimel/pytorch/torch/_inductor/kernel/mm_scaled_grouped.py", line 375, in can_use_triton_kernel
    offs is not None
  File "/home/ngimel/.conda/envs/pytorch_monarch/lib/python3.10/site-packages/sympy/core/relational.py", line 516, in __bool__
    raise TypeError("cannot determine truth value of Relational")
torch._inductor.exc.InductorError: LoweringException: TypeError: cannot determine truth value of Relational
```
which is weird, there's no relational that sympy has to evaluate in `offs is not None`, and when running this test separately (`test_scaled_grouped_gemm_2d_3d_fast_accum_True_strided_False_use_torch_compile_True_cuda`) it passes. I suspect some autotuning cache has to be reset between runs, but don't know what to look for.
Edit: that error is "fixed" by setting `dynamic=False`, now with correct meat function something's wrong with dynamic shapes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153226
Approved by: https://github.com/kwen2501
2025-05-10 01:13:18 +00:00
eqy
b30d276abc [CUDA][cuBLASLt] Fix scale setting for allowFP16AccumulationCuBLAS true case (#153083)
Also add some missing `@onlyCUDA` / support check decorators in `test_matmul_cuda.py`
Should help resolve #151890

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153083
Approved by: https://github.com/janeyx99
2025-05-09 02:27:17 +00:00
Aidyn-A
086e2c2399 [TEST][ATen][CUDA] Skip row-wise scaled matrix mmultiplication tests on sm_120+ (#152814)
The float8 row-wise scaled matmuls are not supported on Blackwell yet. This PR adds skips to those tests to decrease the noise on `sm_120+` machines.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152814
Approved by: https://github.com/eqy, https://github.com/Skylion007
2025-05-08 19:34:20 +00:00
drisspg
14f8066910 Ensure mxfp8 scaled_mm works w/ max-autotune (#152744)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152744
Approved by: https://github.com/Skylion007
2025-05-06 01:16:57 +00:00
eqy
cc072af74a [CUDA][MXFP8] bump tolerances for test_blockwise_mxfp8_nvfp4_numerics (#151811)
got a slightly lower sqnr on a smaller GPU

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151811
Approved by: https://github.com/albanD
2025-04-30 01:12:51 +00:00
PaulZhang12
3ed5f1fb77 [CUDA][cuBLAS] Aten GEMM overload for FP32 output from FP16/BF16 inputs (#150812)
Enable FP32 output from FP16/BF16 GEMMs in aten with cuBLAS. Accumulation for these GEMMs are generally already done in FP32. Adds the functionality to the following aten operators:
* mm
* bmm
* addmm
* baddmm

Follow up of customer issue: https://github.com/pytorch/pytorch/issues/146241#issuecomment-2781889390

Differential Revision: [D73126191](https://our.internmc.facebook.com/intern/diff/D73126191)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150812
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-04-18 01:53:26 +00:00
Joel Schlosser
ae53510b9e Fix setUpClass() / tearDownClass() for device-specific tests (#151129)
Finishes up the work started in #121686 + adds test

Update: this was not as straightforward as I originally imagined. Context below.

**TL;DR:** `TestFoo{CPU, CUDA}` now actually derive from `TestFoo`! Also, `{CPU, CUDA}TestBase` setup / teardown logic is now always called (it is required to set the primary device), regardless of whether `super().setUpClass()` / `super().tearDownClass()` are called or not.

**Background:** The typical way to get device-specific tests is to write a generic `TestFoo` and call `instantiate_device_type_tests(TestFoo, locals())` to get `TestFooCPU`, `TestFooCUDA`, etc. After this, generic tests (e.g. `TestFoo.test_bar()`) become `TestFooCPU.test_bar_cpu()` / `TestFooCUDA.test_bar_cuda()`.

Behind the scenes, this was historically accomplished by creating a `TestFooCUDA` that derives from both a `CUDATestBase` and an *empty class* called `TestFoo_base`. This `TestFoo_base` has the same bases as `TestFoo`, but none of the test functions (e.g. `test_bar()`). The documented reason for this is to avoid things like a derived `TestFooCUDA.test_bar()` being discovered in addition to the real device-specific test `TestFooCUDA.test_bar_cuda()`.

(1) A reason this matters is because it should be possible to call e.g. `super().setUpClass()` from a custom setup / teardown classmethod. If the generated TestFooCUDA does not derive from TestFoo, but instead derives from the empty class described above, this syntax does not work; in fact there is no way to form a proper `super()` call that works across the device-specific test variants. Here's an example that breaks in the OpInfo tests:

070f389745/test/test_ops.py (L218-L221)

(2) Further, there is some precedent within a custom `setUpClass()` impl for storing things on the `cls` object to be accessed at test time. This must be the device-specific test class (`TestFooCUDA`) and not `TestFoo` for this to work. As an example, the open device registration tests load a module during setup and use it in the test logic:

070f389745/test/test_cpp_extensions_open_device_registration.py (L63-L77)

070f389745/test/test_cpp_extensions_open_device_registration.py (L79-L80)

To accomplish both (1) and (2) at the same time, I decided to revisit the idea of utilizing a proper inheritance hierarchy for `TestFoo` -> `{TestFooCPU, TestFooCUDA}`. That is: have TestFooCPU / TestFooCUDA **actually** derive from `TestFoo`. This achieves both (1) and (2). The only thing left is to make sure the generic tests (e.g. `TestFoo.test_bar()`) are not discoverable, as was the stated reason for diverging from this in the first place. It turns out we can simply `delattr()` these generic tests from `TestFoo` once `TestFooCPU` / `TestFooCUDA` have been setup with the device-specific variants, and all works well. The `instantiate_device_type_tests(...)` logic already deletes `TestFoo` from scope, so I don't see a problem with deleting generic tests from this base class as well (CI will prove me right or wrong ofc).

**Side note:** I was encountering a weird race condition where sometimes the custom `setUpClass()` / `tearDownClass()` defined & swapped in [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L940-L955)) would be used, and sometimes it wouldn't. This non-deterministic behavior was called out previously by @ngimel here:
4a47dd9b3f/test/inductor/test_torchinductor_dynamic_shapes.py (L128-L130)

To address this, I moved this block of logic to before the first call to `instantiate_test()`, as that method queries for the primary device, and the primary device identification logic may manually invoke `setUpClass()` (see [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L381-L384))). Goal: define the `setUpClass()` / `tearDownClass()` we want for correctness before they're ever called. This seems to work and the behavior is deterministic now AFAICT.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151129
Approved by: https://github.com/janeyx99, https://github.com/masnesral, https://github.com/malfet
2025-04-16 02:18:42 +00:00
PyTorch MergeBot
98b1e82ba8 Revert "Fix setUpClass() / tearDownClass() for device-specific tests (#151129)"
This reverts commit bd4cf30e31.

Reverted https://github.com/pytorch/pytorch/pull/151129 on behalf of https://github.com/jbschlosser due to flex attention tests failing ([comment](https://github.com/pytorch/pytorch/pull/151129#issuecomment-2807632119))
2025-04-15 22:07:25 +00:00
Joel Schlosser
bd4cf30e31 Fix setUpClass() / tearDownClass() for device-specific tests (#151129)
Finishes up the work started in #121686 + adds test

Update: this was not as straightforward as I originally imagined. Context below.

**TL;DR:** `TestFoo{CPU, CUDA}` now actually derive from `TestFoo`! Also, `{CPU, CUDA}TestBase` setup / teardown logic is now always called (it is required to set the primary device), regardless of whether `super().setUpClass()` / `super().tearDownClass()` are called or not.

**Background:** The typical way to get device-specific tests is to write a generic `TestFoo` and call `instantiate_device_type_tests(TestFoo, locals())` to get `TestFooCPU`, `TestFooCUDA`, etc. After this, generic tests (e.g. `TestFoo.test_bar()`) become `TestFooCPU.test_bar_cpu()` / `TestFooCUDA.test_bar_cuda()`.

Behind the scenes, this was historically accomplished by creating a `TestFooCUDA` that derives from both a `CUDATestBase` and an *empty class* called `TestFoo_base`. This `TestFoo_base` has the same bases as `TestFoo`, but none of the test functions (e.g. `test_bar()`). The documented reason for this is to avoid things like a derived `TestFooCUDA.test_bar()` being discovered in addition to the real device-specific test `TestFooCUDA.test_bar_cuda()`.

(1) A reason this matters is because it should be possible to call e.g. `super().setUpClass()` from a custom setup / teardown classmethod. If the generated TestFooCUDA does not derive from TestFoo, but instead derives from the empty class described above, this syntax does not work; in fact there is no way to form a proper `super()` call that works across the device-specific test variants. Here's an example that breaks in the OpInfo tests:

070f389745/test/test_ops.py (L218-L221)

(2) Further, there is some precedent within a custom `setUpClass()` impl for storing things on the `cls` object to be accessed at test time. This must be the device-specific test class (`TestFooCUDA`) and not `TestFoo` for this to work. As an example, the open device registration tests load a module during setup and use it in the test logic:

070f389745/test/test_cpp_extensions_open_device_registration.py (L63-L77)

070f389745/test/test_cpp_extensions_open_device_registration.py (L79-L80)

To accomplish both (1) and (2) at the same time, I decided to revisit the idea of utilizing a proper inheritance hierarchy for `TestFoo` -> `{TestFooCPU, TestFooCUDA}`. That is: have TestFooCPU / TestFooCUDA **actually** derive from `TestFoo`. This achieves both (1) and (2). The only thing left is to make sure the generic tests (e.g. `TestFoo.test_bar()`) are not discoverable, as was the stated reason for diverging from this in the first place. It turns out we can simply `delattr()` these generic tests from `TestFoo` once `TestFooCPU` / `TestFooCUDA` have been setup with the device-specific variants, and all works well. The `instantiate_device_type_tests(...)` logic already deletes `TestFoo` from scope, so I don't see a problem with deleting generic tests from this base class as well (CI will prove me right or wrong ofc).

**Side note:** I was encountering a weird race condition where sometimes the custom `setUpClass()` / `tearDownClass()` defined & swapped in [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L940-L955)) would be used, and sometimes it wouldn't. This non-deterministic behavior was called out previously by @ngimel here:
4a47dd9b3f/test/inductor/test_torchinductor_dynamic_shapes.py (L128-L130)

To address this, I moved this block of logic to before the first call to `instantiate_test()`, as that method queries for the primary device, and the primary device identification logic may manually invoke `setUpClass()` (see [here](4a47dd9b3f/torch/testing/_internal/common_device_type.py (L381-L384))). Goal: define the `setUpClass()` / `tearDownClass()` we want for correctness before they're ever called. This seems to work and the behavior is deterministic now AFAICT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151129
Approved by: https://github.com/janeyx99, https://github.com/masnesral, https://github.com/malfet
2025-04-15 20:13:26 +00:00
Wei Wang
f6e9e064a7 [CI][CUDA] xfail grouped gemm unit tests on blackwell (#150982)
On SM100OrLater, Expect failures like:

RuntimeError: torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0

To execute this test, run the following from the base repo dir:
    python test/test_matmul_cuda.py TestMatmulCudaCUDA.test_grouped_gemm_3d_2d_strided_False_a_row_major_True_b_row_major_False_cuda

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

`
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_False_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0005s] (Issue with numpy versi...) [  2%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_False_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [  4%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_False_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [  6%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_False_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [  8%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_True_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 10%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_True_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 12%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_True_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 14%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_2d_strided_True_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version ...) [ 16%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_False_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versi...) [ 18%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_False_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 20%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_False_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 22%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_False_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 25%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_True_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 27%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_True_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 29%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_True_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 31%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_2d_3d_strided_True_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version ...) [ 33%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_False_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0002s] (Issue with numpy versi...) [ 35%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_False_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 37%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_False_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 39%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_False_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 41%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_True_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 43%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_True_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 45%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_True_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 47%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_2d_strided_True_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version ...) [ 50%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_False_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versi...) [ 52%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_False_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 54%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_False_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 56%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_False_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 58%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_True_a_row_major_False_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy versio...) [ 60%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_True_a_row_major_False_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 62%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_True_a_row_major_True_b_row_major_False_cuda SKIPPED [0.0001s] (Issue with numpy version...) [ 64%]
test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_grouped_gemm_3d_3d_strided_True_a_row_major_True_b_row_major_True_cuda SKIPPED [0.0001s] (Issue with numpy version ...) [ 66%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_2d_fast_accum_False_strided_False_cuda XFAIL [0.8166s]                                        [ 68%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_2d_fast_accum_False_strided_True_cuda XFAIL [0.0017s]                                         [ 70%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_2d_fast_accum_True_strided_False_cuda XFAIL [0.0012s]                                         [ 72%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_2d_fast_accum_True_strided_True_cuda XFAIL [0.0012s]                                          [ 75%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_3d_fast_accum_False_strided_False_cuda XFAIL [0.0033s]                                        [ 77%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_3d_fast_accum_False_strided_True_cuda XFAIL [0.0012s]                                         [ 79%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_3d_fast_accum_True_strided_False_cuda XFAIL [0.0015s]                                         [ 81%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_2d_3d_fast_accum_True_strided_True_cuda XFAIL [0.0012s]                                          [ 83%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_2d_fast_accum_False_strided_False_cuda XFAIL [0.0012s]                                        [ 85%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_2d_fast_accum_False_strided_True_cuda XFAIL [0.0012s]                                         [ 87%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_2d_fast_accum_True_strided_False_cuda XFAIL [0.0011s]                                         [ 89%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_2d_fast_accum_True_strided_True_cuda XFAIL [0.0012s]                                          [ 91%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_3d_fast_accum_False_strided_False_cuda XFAIL [0.0014s]                                        [ 93%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_3d_fast_accum_False_strided_True_cuda XFAIL [0.0012s]                                         [ 95%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_3d_fast_accum_True_strided_False_cuda XFAIL [0.0011s]                                         [ 97%]
test/test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_scaled_grouped_gemm_3d_3d_fast_accum_True_strided_True_cuda XFAIL [0.0011s]                                          [100%]
`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150982
Approved by: https://github.com/ngimel, https://github.com/eqy
2025-04-12 01:53:12 +00:00
Bert Maher
2d187bf7e6 Support tuning of _scaled_grouped_mm (#150421)
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150421
Approved by: https://github.com/ngimel
2025-04-11 23:03:49 +00:00
Jiang, Yanbing
1e92579126 Add torch._scaled_mm for CPU (#150410)
This PR is the duplicated one for https://github.com/pytorch/pytorch/pull/139975.

This PR is to add torch._scaled_mm for CPU backend.

_scaled_mm_out_cpu and _scaled_mm_cpu are new added and included in torch._scaled_mm CPU dispatch. We also add _scaled_mm_out_cpu_emulated as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150410
Approved by: https://github.com/atalman
2025-04-11 02:23:03 +00:00
PyTorch MergeBot
6a65f2c4fe Revert "Support tuning of _scaled_grouped_mm (#150421)"
This reverts commit 8efcf21fff.

Reverted https://github.com/pytorch/pytorch/pull/150421 on behalf of https://github.com/malfet due to Looks like it broke lint, see a0ab243c3a/1 ([comment](https://github.com/pytorch/pytorch/pull/150421#issuecomment-2795218547))
2025-04-10 21:36:41 +00:00
Bert Maher
8efcf21fff Support tuning of _scaled_grouped_mm (#150421)
This includes the default aten implementation, as well as a Triton
implementation imported from FBGEMM
(https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150421
Approved by: https://github.com/ngimel
2025-04-10 20:34:16 +00:00
Natalia Gimelshein
55e62ff74a bf16 grouped gemm (#150374)
Enabled bf16 grouped gemm with an API similar to _scaled_group_gemm, except without scale and fast accum arguments. All transpose variants are enabled, unlike scaled gemm. Ideally we'd factor out a lot more code from scaled gemm, currently there's a lot of repetition between scaled and non-scaled versions. I factored out only a helper kernel that prepares arguments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150374
Approved by: https://github.com/drisspg
2025-04-06 04:53:24 +00:00
PyTorch MergeBot
4854926aeb Revert "Add torch._scaled_mm for CPU (#150410)"
This reverts commit 3b02f795c5.

Reverted https://github.com/pytorch/pytorch/pull/150410 on behalf of https://github.com/malfet due to It breaks ROCM tests ([comment](https://github.com/pytorch/pytorch/pull/150410#issuecomment-2777704212))
2025-04-04 06:52:54 +00:00
Jiang, Yanbing
3b02f795c5 Add torch._scaled_mm for CPU (#150410)
This PR is the duplicated one for https://github.com/pytorch/pytorch/pull/139975.

This PR is to add torch._scaled_mm for CPU backend.

_scaled_mm_out_cpu and _scaled_mm_cpu are new added and included in torch._scaled_mm CPU dispatch. We also add _scaled_mm_out_cpu_emulated as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150410
Approved by: https://github.com/atalman
2025-04-03 19:43:45 +00:00
vasiliy
c974b5322a enable torch.compile for torch._scaled_mm nvfp4 recipe (#150462)
Summary:

Updates the meta registration for `torch._scaled_mm` to work for the
nvfp4 recipe.

Test Plan:

```bash
pytest test/test_matmul_cuda.py -s -k test_blockwise_nvfp4
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150462
Approved by: https://github.com/eellison
2025-04-02 01:08:40 +00:00
Prachi Gupta
47cdad2995 [ROCm] Enable several fsdp related UTs (#149369)
Enabling 26 UTs for ROCm in the following files:

-  distributed._shard.sharded_optim.test_sharded_optim - 2 UTs
-  distributed._shard.sharded_tensor.ops.test_binary_cmp - 4 UTs
-  distributed._shard.sharded_tensor.ops.test_init - 3 UTs
-  distributed._shard.sharded_tensor.ops.test_embedding - 2 UTs
-  distributed._shard.sharded_tensor.ops.test_embedding_bag - 2 UTs
-  distributed._composable.test_replicate_with_compiler - 4 UTs
-  distributed._composable.fsdp.test_fully_shard_grad_scaler - 1 UTs
-  distributed.tensor.test_attention - 4 UTs
-  distributed.tensor.test_matrix_ops - 1 UTs
-  distributed.tensor.test_tensor_ops - 1 UTs
-  distributed.fsdp.test_fsdp_grad_acc - 2 UTs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149369
Approved by: https://github.com/jeffdaily
2025-03-31 16:15:57 +00:00
vasiliy
01cb3519b3 wire torch._scaled_mm with fp4 operands to the cublas nvfp4 kernel (#148792)
Summary:

When `a` and `b` have dtype `torch.float4_e2m1fn_x2` and `a_scale` and `b_scale` have dtype `torch.float8_e4m3fn`, makes

```python
c = torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=torch.bfloat16)
```

call the cuBLAS fp4 gemm kernel, as specified in https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-for-fp8-and-fp4-data-types

note: output scale (`scale_in_D` from the cuBLAS docs) is not tested in this PR - we can enable in a follow-up.

Test Plan:

```bash
pytest test/test_matmul_cuda.py -s -k mxfp8_nvfp4
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148792
Approved by: https://github.com/eqy
ghstack dependencies: #148791
2025-03-27 17:32:20 +00:00
vasiliy
dad0854d48 meta registration for torch._scaled_mm with mxfp8 (#148461)
Summary:

Adds the meta registration logic for torch.compile to work with
`torch._scaled_mm` with mxfp8.  Thanks to @eellison  for the pointer to make inductor work with this.

Test Plan:

```
pytest test/test_matmul_cuda.py -k test_blockwise_mxfp8_compile -s
```

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148461
Approved by: https://github.com/drisspg, https://github.com/eellison
2025-03-27 02:32:40 +00:00
Natalia Gimelshein
53a1a022a9 [WIP] Initial implementation of Grouped Gemm API (#148531)
This PR provides initial cutlass implementation of grouped gemm api as described in this [document](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9). Any combination of 2d and 3d inputs is supported, with 2d input being jagged, and the offsets of the jagged input being given by device tensor `offs`. Only H100 is supported, and only fp8_e4m3 with bf16 output and rowwise scaling. All the dimensions of each individual gemm have to be multiple of 16, that's cutlass limitation.
I'll need to add those checks, for dynamic dimensions unfortunately the checks will have to be a device assert.
I had to copy-paste cutlass's `Sm90RowBroadcast` and `Sm90ColBroadcast` structs with minor changes to enable scales given as pointer arrays, ideally those should be part of cutlass itself.
I copied the schedules from the similar grouped gemm in FBGEMM, but there's a lot of room to improve perf, especially for `fast_accum=False`.
Next steps would be perf tuning and increasing coverage to B100, I don't know how cutlass grouped gemm example handles blockwise scaling on B100.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148531
Approved by: https://github.com/drisspg
2025-03-11 21:49:46 +00:00
PyTorch MergeBot
c983e1124c Revert "[WIP] Initial implementation of Grouped Gemm API (#148531)"
This reverts commit ff29791ed8.

Reverted https://github.com/pytorch/pytorch/pull/148531 on behalf of https://github.com/janeyx99 due to Sorry but this broke ROCm jobs on trunk ([comment](https://github.com/pytorch/pytorch/pull/148531#issuecomment-2714577498))
2025-03-11 14:40:58 +00:00
Natalia Gimelshein
ff29791ed8 [WIP] Initial implementation of Grouped Gemm API (#148531)
This PR provides initial cutlass implementation of grouped gemm api as described in this [document](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9). Any combination of 2d and 3d inputs is supported, with 2d input being jagged, and the offsets of the jagged input being given by device tensor `offs`. Only H100 is supported, and only fp8_e4m3 with bf16 output and rowwise scaling. All the dimensions of each individual gemm have to be multiple of 16, that's cutlass limitation.
I'll need to add those checks, for dynamic dimensions unfortunately the checks will have to be a device assert.
I had to copy-paste cutlass's `Sm90RowBroadcast` and `Sm90ColBroadcast` structs with minor changes to enable scales given as pointer arrays, ideally those should be part of cutlass itself.
I copied the schedules from the similar grouped gemm in FBGEMM, but there's a lot of room to improve perf, especially for `fast_accum=False`.
Next steps would be perf tuning and increasing coverage to B100, I don't know how cutlass grouped gemm example handles blockwise scaling on B100.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148531
Approved by: https://github.com/drisspg
2025-03-11 02:41:09 +00:00
drisspg
07b7b3ed4e torch._scaled_mm with MXFP8 (#147548)
# summary

Add blockwise MXFP8 support to `torch._scaled_mm` on CUDA capability 10.0 and higher devices.  If the scales for A and B are of dtype `torch.float8_e8m0fnu`, we dispatch to the blockwise kernel from cuBLAS.

This is a skeleton PR where we test basic functionality (numerics of various simple matrices, as well as one end to end quantization + gemm).

- Scales are flipped based on transpose_result
- Handles boundary conditions

Note that MXFP4 is not added in this PR - we can tackle that in a future PR.

This PR was created by taking https://github.com/pytorch/pytorch/pull/145562, switching e8m0 to in-core dtype, removing fp4 for now, and adding test cases.

# test plan

```
pytest test/test_matmul_cuda.py -k blockwise_mxfp8 -s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147548
Approved by: https://github.com/drisspg

Co-authored-by: drisspg <drisspguessous@gmail.com>
2025-02-27 02:44:39 +00:00
Luca Wehrstedt
60d94ea22b Add option to limit number of SMs used by matmul kernels (#147966)
Resubmission of #144974 which was reverted for unrelated reasons.

Newer matmul kernels, e.g. those targeting Hopper GPUs, sometime use a "persistent" schedule which consists in launching as many CUDA blocks as there are SMs on the GPU, with each such block then working on multiple output tiles in a row. This allows to eliminate the overhead of starting and finishing each tile, effectively doing cross-tile pipelining. In previous generations these latencies could be hidden by having multiple CUDA blocks per SM but, with blocks becoming larger, only one can run at a time per SM and thus this needs to be taken care of in software.

Persistent kernels become an issue when other kernels are running concurrently. The classical example is a NCCL communication kernel running in the background. In such cases the matmul expects to be able to use all the SMs but is prevented from doing so because some of the are busy. This can lead to its blocks being scheduled as two separate waves on the available SMs. This "wave quantization" can double the latency of the matmul kernels.

While we wait for smarter solutions, such as automatic load balancing among the blocks, an easy way to unblock ourselves is to tell the matmuls to only use a subset of the GPU's SMs. For this, I am introducing a global `sm_carveout` flag which can be used to specify how many SMs should be left available for other kernels.

For now I only change the cuBLAS kernels and the scaled-mm CUTLASS kernel. More kernels can be opted-in later.

I tested this change manually, by using the Kineto profiler to look up the grid size of a scaled-mm kernel with different values of `sm_carveout`, and making sure it changed. Suggestions are welcome for a more automated test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147966
Approved by: https://github.com/danthe3rd
2025-02-26 12:01:12 +00:00
PyTorch MergeBot
a84db75e1b Revert "torch._scaled_mm with MXFP8 (#147548)"
This reverts commit 12b9674cb6.

Reverted https://github.com/pytorch/pytorch/pull/147548 on behalf of https://github.com/wdvr due to failing internal build - similar to previous, see below ([comment](https://github.com/pytorch/pytorch/pull/147548#issuecomment-2684134336))
2025-02-26 07:17:24 +00:00
vasiliy
12b9674cb6 torch._scaled_mm with MXFP8 (#147548)
# summary

Add blockwise MXFP8 support to `torch._scaled_mm` on CUDA capability 10.0 and higher devices.  If the scales for A and B are of dtype `torch.float8_e8m0fnu`, we dispatch to the blockwise kernel from cuBLAS.

This is a skeleton PR where we test basic functionality (numerics of various simple matrices, as well as one end to end quantization + gemm).

- Scales are flipped based on transpose_result
- Handles boundary conditions

Note that MXFP4 is not added in this PR - we can tackle that in a future PR.

This PR was created by taking https://github.com/pytorch/pytorch/pull/145562, switching e8m0 to in-core dtype, removing fp4 for now, and adding test cases.

# test plan

```
pytest test/test_matmul_cuda.py -k blockwise_mxfp8 -s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147548
Approved by: https://github.com/drisspg

Co-authored-by: drisspg <drisspguessous@gmail.com>
2025-02-26 05:21:26 +00:00
PyTorch MergeBot
c82c1411c6 Revert "torch._scaled_mm with MXFP8 (#147548)"
This reverts commit e34c15a05b.

Reverted https://github.com/pytorch/pytorch/pull/147548 on behalf of https://github.com/wdvr due to failing internal build - discussed with author ([comment](https://github.com/pytorch/pytorch/pull/147548#issuecomment-2683517851))
2025-02-25 23:28:15 +00:00
PyTorch MergeBot
1e894d2635 Revert "Add option to limit number of SMs used by matmul kernels (#144974)"
This reverts commit af2d63637e.

Reverted https://github.com/pytorch/pytorch/pull/144974 on behalf of https://github.com/wdvr due to reverting in order to revert #147548 that causes a merge conflict ([comment](https://github.com/pytorch/pytorch/pull/144974#issuecomment-2683461733))
2025-02-25 22:46:38 +00:00
Luca Wehrstedt
af2d63637e Add option to limit number of SMs used by matmul kernels (#144974)
Newer matmul kernels, e.g. those targeting Hopper GPUs, sometime use a "persistent" schedule which consists in launching as many CUDA blocks as there are SMs on the GPU, with each such block then working on multiple output tiles in a row. This allows to eliminate the overhead of starting and finishing each tile, effectively doing cross-tile pipelining. In previous generations these latencies could be hidden by having multiple CUDA blocks per SM but, with blocks becoming larger, only one can run at a time per SM and thus this needs to be taken care of in software.

Persistent kernels become an issue when other kernels are running concurrently. The classical example is a NCCL communication kernel running in the background. In such cases the matmul expects to be able to use all the SMs but is prevented from doing so because some of the are busy. This can lead to its blocks being scheduled as two separate waves on the available SMs. This "wave quantization" can double the latency of the matmul kernels.

While we wait for smarter solutions, such as automatic load balancing among the blocks, an easy way to unblock ourselves is to tell the matmuls to only use a subset of the GPU's SMs. For this, I am introducing a global `sm_carveout` flag which can be used to specify how many SMs should be left available for other kernels.

For now I only change the cuBLAS kernels and the scaled-mm CUTLASS kernel. More kernels can be opted-in later.

I tested this change manually, by using the Kineto profiler to look up the grid size of a scaled-mm kernel with different values of `sm_carveout`, and making sure it changed. Suggestions are welcome for a more automated test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144974
Approved by: https://github.com/eqy, https://github.com/albanD
2025-02-25 10:19:19 +00:00
vasiliy
e34c15a05b torch._scaled_mm with MXFP8 (#147548)
# summary

Add blockwise MXFP8 support to `torch._scaled_mm` on CUDA capability 10.0 and higher devices.  If the scales for A and B are of dtype `torch.float8_e8m0fnu`, we dispatch to the blockwise kernel from cuBLAS.

This is a skeleton PR where we test basic functionality (numerics of various simple matrices, as well as one end to end quantization + gemm).

- Scales are flipped based on transpose_result
- Handles boundary conditions

Note that MXFP4 is not added in this PR - we can tackle that in a future PR.

This PR was created by taking https://github.com/pytorch/pytorch/pull/145562, switching e8m0 to in-core dtype, removing fp4 for now, and adding test cases.

# test plan

```
pytest test/test_matmul_cuda.py -k blockwise_mxfp8 -s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147548
Approved by: https://github.com/drisspg

Co-authored-by: drisspg <drisspguessous@gmail.com>
2025-02-25 03:32:22 +00:00
Peter Yeh
81dccd706b [ROCm] OCP FP8 Support for new GPUs (#146632)
TLDR: Follow up/ Build on top of https://github.com/pytorch/pytorch/pull/144476. add OCP FP8 support for gfx950
refer to https://github.com/pytorch/ao/pull/1677

This pull request includes several changes to improve compatibility and support for new GPU architectures and data types, particularly for ROCm. The key updates involve adding support for new ROCm versions and GPU architectures, updating data type handling, and removing outdated checks.

### Improvements to GPU Architecture and ROCm Version Support:
* [`aten/src/ATen/Context.cpp`](diffhunk://#diff-33de472d304acbe57d693c8567370c638068bedc1aa0ce8e9dc115dad05a7810L323-R326): Added support for new GPU architectures `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks.
* [`aten/src/ATen/native/cuda/Blas.cpp`](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199): Updated architecture support in multiple functions to include `gfx1200`, `gfx1201`, and `gfx950` based on ROCm version checks. [[1]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL196-R199) [[2]](diffhunk://#diff-e8a569efee1e650172f120a0fdcda024fe3e4703a4ee3336425c8f685af6b3abL865-R876)

### Updates to Data Type Handling:
* [`aten/src/ATen/cuda/CUDADataType.h`](diffhunk://#diff-9188bb13b1a49f459141f5f9b875593d1c5ce2beb5ad711fdbaf5bc7089ec015L81-L98): Enhanced data type conversion to include new float8 types for both CUDA and ROCm environments.
* [`aten/src/ATen/cuda/tunable/GemmHipblaslt.h`](diffhunk://#diff-bfa1a3b5d4bef1892bf50338775f3b0fd8cd31fc1868148f3968b98aefb68e3fL29-R80): Updated `HipDataTypeFor` template to handle new float8 types and added hard-coded enum values for ROCm versions prior to 6.3.

### Removal of Outdated Checks:
* [`cmake/public/LoadHIP.cmake`](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197): Removed the check for `HIP_NEW_TYPE_ENUMS` as it is no longer necessary with the updated ROCm versions. [[1]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L169-L197) [[2]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5L211-R182)

These changes ensure better compatibility and performance on newer hardware and software environments, particularly for users leveraging ROCm and CUDA for deep learning and scientific computing tasks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146632
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-02-24 22:47:52 +00:00
PyTorch MergeBot
3e2d9d079e Revert "[ROCm] OCP FP8 Support for new GPUs (#146632)"
This reverts commit f95ab46797.

Reverted https://github.com/pytorch/pytorch/pull/146632 on behalf of https://github.com/jeanschmidt due to Breaking internal builds, I'll find someone to help merge this PR back to main ([comment](https://github.com/pytorch/pytorch/pull/146632#issuecomment-2676823614))
2025-02-23 12:04:50 +00:00