From ac99fc7e5753d2c37c61f5517f9d4a3631bd2034 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 4 Mar 2025 05:23:38 +0000 Subject: [PATCH] Updates to build rowwise scaled mm kernel on SM10.0a (#148274) ## Summary Update cmake files and RowwiseScaledMM.cu to build on SM10.0a arch. **NOTE**: performance optimization will be done in separate follow up PRs ## Steps to verify build 1. Access devgpu/machine with B200 GPUs, verify B200s are visible w/ `nvidia-smi` 2. Install CUDA tookit 12.8 - e.g. see [Nvidia docs](https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&Distribution=Rocky&target_version=9&target_type=rpm_local) 3. Verify CUDA toolkit installation - e.g. `nvcc --version` should have `... Cuda compilation tools, release 12.8 ... ` in output 4. Set env var `TORCH_CUDA_ARCH_LIST=10.0a` 4. Build pytorch from source with this PR ([steps](https://github.com/pytorch/pytorch#from-source)) 5. Uninstall `pytorch-triton` with `pip uninstall pytorch-triton` 6. Build and install triton from source: https://github.com/triton-lang/triton?tab=readme-ov-file#install-from-source 7. Run tests shown in test plan below **NOTE**: performance optimization will be done in a separate PR. The goal of this PR is just to ensure it builds correctly. ## Test plan - `python test/distributed/tensor/test_matrix_ops.py -k scaled_mm`: OK - `python test/test_matmul_cuda.py -k rowwise`: OK - `python test/test_flop_counter.py -k scaled_mm`: OK - `python test/inductor/test_aot_inductor.py -k fp8`: OK - `python test/inductor/test_fp8.py`: OK Pull Request resolved: https://github.com/pytorch/pytorch/pull/148274 Approved by: https://github.com/drisspg --- aten/src/ATen/native/cuda/RowwiseScaledMM.cu | 3 ++- cmake/Codegen.cmake | 3 +++ test/inductor/test_fp8.py | 8 ++++---- torch/_inductor/kernel/mm_scaled.py | 17 +++++++++++++++++ 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 12803f3b9c1..8cfd7661281 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -693,7 +693,8 @@ void dispatch_fp8_rowwise_kernel_on_sm( cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9; const bool sm9x = properties != nullptr && properties->major == 9; - if (!(sm89 || sm9x)) { + const bool sm10x = properties != nullptr && properties->major == 10; + if (!(sm89 || sm9x || sm10x)) { TORCH_CHECK( false, "Rowwise scaling is not currently supported on your device"); } diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index e04d5c8830d..e643c5228bc 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -92,6 +92,9 @@ if(INTERN_BUILD_ATEN_OPS) if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*") list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a") endif() + if(EXISTING_ARCH_FLAGS MATCHES ".*compute_100a.*") + list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_100a,code=sm_100a") + endif() endif() list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS) set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}") diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index 9d71bb6a8f7..64086e5071c 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -416,7 +416,7 @@ class TestFP8Lowering(TestCase): @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("dtype", (torch.bfloat16, torch.float32)) - @parametrize("shape", ("16,16,32", "1024,1024,512")) + @parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512")) @parametrize("has_bias", (False, True)) @parametrize("use_fast_accum", (False, True)) @parametrize( @@ -493,7 +493,7 @@ class TestFP8Lowering(TestCase): @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @parametrize("shape", ("16,16,32", "1024,1024,512")) + @parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512")) @parametrize("has_bias", (False, True)) @parametrize("use_fast_accum", (False, True)) @parametrize( @@ -560,7 +560,7 @@ class TestFP8Lowering(TestCase): @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("M", (1, 3, 33, 257, 1024)) - @parametrize("K", (16, 1024)) + @parametrize("K", (16, 32, 1024)) @parametrize("N", (16, 2048)) @parametrize( "persistent_matmul", [False, True] if has_triton_tma_device() else [False] @@ -618,7 +618,7 @@ class TestFP8Lowering(TestCase): @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("M", (1, 3, 33, 257, 1024)) - @parametrize("K", (16, 1024)) + @parametrize("K", (16, 32, 1024)) @parametrize("N", (16, 2048)) @parametrize( "persistent_matmul", [False, True] if has_triton_tma_device() else [False] diff --git a/torch/_inductor/kernel/mm_scaled.py b/torch/_inductor/kernel/mm_scaled.py index 79aea44407e..fd51b95556e 100644 --- a/torch/_inductor/kernel/mm_scaled.py +++ b/torch/_inductor/kernel/mm_scaled.py @@ -1,3 +1,4 @@ +import functools import logging from collections.abc import Sequence from typing import Any, Optional @@ -553,6 +554,12 @@ def tuned_scaled_mm( for config in scaled_mm_configs(m, n, k): if k == 16 and config.kwargs["BLOCK_M"] >= 64: continue # Triton crashes in this case + + # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid + # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape + if using_b200() and k < 32: + continue + kwargs = scaled_mm_options( config, m, n, k, layout, scale_a, scale_b, use_fast_accum ) @@ -584,3 +591,13 @@ def tuned_scaled_mm( "All choices for scaled_mm were invalid, using ATen backend as fallback" ) return aten_choice.output_node() + + +@functools.lru_cache +def using_b200() -> bool: + """Returns true if the device is a NVIDIA B200, otherwise returns false.""" + if not torch.cuda.is_available(): + return False + # compute capability 10.0 or 10.0a is NVIDIA B200 + device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) + return device_properties.major == 10