Updates to build rowwise scaled mm kernel on SM10.0a (#148274)

## Summary
Update cmake files and RowwiseScaledMM.cu to build on SM10.0a arch.

**NOTE**: performance optimization will be done in separate follow up PRs

## Steps to verify build
1. Access devgpu/machine with B200 GPUs, verify B200s are visible w/ `nvidia-smi`
2. Install CUDA tookit 12.8
    - e.g. see [Nvidia docs](https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&Distribution=Rocky&target_version=9&target_type=rpm_local)
3. Verify CUDA toolkit installation
    - e.g. `nvcc --version` should have `... Cuda compilation tools, release 12.8 ... ` in output
4. Set env var `TORCH_CUDA_ARCH_LIST=10.0a`
4. Build pytorch from source with this PR ([steps](https://github.com/pytorch/pytorch#from-source))
5. Uninstall `pytorch-triton` with `pip uninstall pytorch-triton`
6. Build and install triton from source: https://github.com/triton-lang/triton?tab=readme-ov-file#install-from-source
7. Run tests shown in test plan below

**NOTE**: performance optimization will be done in a separate PR. The goal of this PR is just to ensure it builds correctly.

## Test plan
- `python test/distributed/tensor/test_matrix_ops.py  -k scaled_mm`: OK
- `python test/test_matmul_cuda.py -k rowwise`: OK
- `python test/test_flop_counter.py -k scaled_mm`: OK
- `python test/inductor/test_aot_inductor.py -k fp8`: OK
- `python test/inductor/test_fp8.py`: OK

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148274
Approved by: https://github.com/drisspg
This commit is contained in:
Daniel Vega-Myhre 2025-03-04 05:23:38 +00:00 committed by PyTorch MergeBot
parent 7ab6749ec7
commit ac99fc7e57
4 changed files with 26 additions and 5 deletions

View File

@ -693,7 +693,8 @@ void dispatch_fp8_rowwise_kernel_on_sm(
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9;
const bool sm9x = properties != nullptr && properties->major == 9;
if (!(sm89 || sm9x)) {
const bool sm10x = properties != nullptr && properties->major == 10;
if (!(sm89 || sm9x || sm10x)) {
TORCH_CHECK(
false, "Rowwise scaling is not currently supported on your device");
}

View File

@ -92,6 +92,9 @@ if(INTERN_BUILD_ATEN_OPS)
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
endif()
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_100a.*")
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_100a,code=sm_100a")
endif()
endif()
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS)
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}")

View File

@ -416,7 +416,7 @@ class TestFP8Lowering(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("dtype", (torch.bfloat16, torch.float32))
@parametrize("shape", ("16,16,32", "1024,1024,512"))
@parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512"))
@parametrize("has_bias", (False, True))
@parametrize("use_fast_accum", (False, True))
@parametrize(
@ -493,7 +493,7 @@ class TestFP8Lowering(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("shape", ("16,16,32", "1024,1024,512"))
@parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512"))
@parametrize("has_bias", (False, True))
@parametrize("use_fast_accum", (False, True))
@parametrize(
@ -560,7 +560,7 @@ class TestFP8Lowering(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("M", (1, 3, 33, 257, 1024))
@parametrize("K", (16, 1024))
@parametrize("K", (16, 32, 1024))
@parametrize("N", (16, 2048))
@parametrize(
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
@ -618,7 +618,7 @@ class TestFP8Lowering(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("M", (1, 3, 33, 257, 1024))
@parametrize("K", (16, 1024))
@parametrize("K", (16, 32, 1024))
@parametrize("N", (16, 2048))
@parametrize(
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]

View File

@ -1,3 +1,4 @@
import functools
import logging
from collections.abc import Sequence
from typing import Any, Optional
@ -553,6 +554,12 @@ def tuned_scaled_mm(
for config in scaled_mm_configs(m, n, k):
if k == 16 and config.kwargs["BLOCK_M"] >= 64:
continue # Triton crashes in this case
# On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid
# source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape
if using_b200() and k < 32:
continue
kwargs = scaled_mm_options(
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
)
@ -584,3 +591,13 @@ def tuned_scaled_mm(
"All choices for scaled_mm were invalid, using ATen backend as fallback"
)
return aten_choice.output_node()
@functools.lru_cache
def using_b200() -> bool:
"""Returns true if the device is a NVIDIA B200, otherwise returns false."""
if not torch.cuda.is_available():
return False
# compute capability 10.0 or 10.0a is NVIDIA B200
device_properties = torch.cuda.get_device_properties(torch.cuda.current_device())
return device_properties.major == 10