mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Updates to build rowwise scaled mm kernel on SM10.0a (#148274)
## Summary
Update cmake files and RowwiseScaledMM.cu to build on SM10.0a arch.
**NOTE**: performance optimization will be done in separate follow up PRs
## Steps to verify build
1. Access devgpu/machine with B200 GPUs, verify B200s are visible w/ `nvidia-smi`
2. Install CUDA tookit 12.8
- e.g. see [Nvidia docs](https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&Distribution=Rocky&target_version=9&target_type=rpm_local)
3. Verify CUDA toolkit installation
- e.g. `nvcc --version` should have `... Cuda compilation tools, release 12.8 ... ` in output
4. Set env var `TORCH_CUDA_ARCH_LIST=10.0a`
4. Build pytorch from source with this PR ([steps](https://github.com/pytorch/pytorch#from-source))
5. Uninstall `pytorch-triton` with `pip uninstall pytorch-triton`
6. Build and install triton from source: https://github.com/triton-lang/triton?tab=readme-ov-file#install-from-source
7. Run tests shown in test plan below
**NOTE**: performance optimization will be done in a separate PR. The goal of this PR is just to ensure it builds correctly.
## Test plan
- `python test/distributed/tensor/test_matrix_ops.py -k scaled_mm`: OK
- `python test/test_matmul_cuda.py -k rowwise`: OK
- `python test/test_flop_counter.py -k scaled_mm`: OK
- `python test/inductor/test_aot_inductor.py -k fp8`: OK
- `python test/inductor/test_fp8.py`: OK
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148274
Approved by: https://github.com/drisspg
This commit is contained in:
parent
7ab6749ec7
commit
ac99fc7e57
|
|
@ -693,7 +693,8 @@ void dispatch_fp8_rowwise_kernel_on_sm(
|
|||
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
|
||||
const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9;
|
||||
const bool sm9x = properties != nullptr && properties->major == 9;
|
||||
if (!(sm89 || sm9x)) {
|
||||
const bool sm10x = properties != nullptr && properties->major == 10;
|
||||
if (!(sm89 || sm9x || sm10x)) {
|
||||
TORCH_CHECK(
|
||||
false, "Rowwise scaling is not currently supported on your device");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -92,6 +92,9 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
|
||||
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
|
||||
endif()
|
||||
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_100a.*")
|
||||
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_100a,code=sm_100a")
|
||||
endif()
|
||||
endif()
|
||||
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS)
|
||||
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}")
|
||||
|
|
|
|||
|
|
@ -416,7 +416,7 @@ class TestFP8Lowering(TestCase):
|
|||
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("dtype", (torch.bfloat16, torch.float32))
|
||||
@parametrize("shape", ("16,16,32", "1024,1024,512"))
|
||||
@parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512"))
|
||||
@parametrize("has_bias", (False, True))
|
||||
@parametrize("use_fast_accum", (False, True))
|
||||
@parametrize(
|
||||
|
|
@ -493,7 +493,7 @@ class TestFP8Lowering(TestCase):
|
|||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("shape", ("16,16,32", "1024,1024,512"))
|
||||
@parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512"))
|
||||
@parametrize("has_bias", (False, True))
|
||||
@parametrize("use_fast_accum", (False, True))
|
||||
@parametrize(
|
||||
|
|
@ -560,7 +560,7 @@ class TestFP8Lowering(TestCase):
|
|||
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("M", (1, 3, 33, 257, 1024))
|
||||
@parametrize("K", (16, 1024))
|
||||
@parametrize("K", (16, 32, 1024))
|
||||
@parametrize("N", (16, 2048))
|
||||
@parametrize(
|
||||
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
|
||||
|
|
@ -618,7 +618,7 @@ class TestFP8Lowering(TestCase):
|
|||
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("M", (1, 3, 33, 257, 1024))
|
||||
@parametrize("K", (16, 1024))
|
||||
@parametrize("K", (16, 32, 1024))
|
||||
@parametrize("N", (16, 2048))
|
||||
@parametrize(
|
||||
"persistent_matmul", [False, True] if has_triton_tma_device() else [False]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import functools
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
|
@ -553,6 +554,12 @@ def tuned_scaled_mm(
|
|||
for config in scaled_mm_configs(m, n, k):
|
||||
if k == 16 and config.kwargs["BLOCK_M"] >= 64:
|
||||
continue # Triton crashes in this case
|
||||
|
||||
# On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid
|
||||
# source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape
|
||||
if using_b200() and k < 32:
|
||||
continue
|
||||
|
||||
kwargs = scaled_mm_options(
|
||||
config, m, n, k, layout, scale_a, scale_b, use_fast_accum
|
||||
)
|
||||
|
|
@ -584,3 +591,13 @@ def tuned_scaled_mm(
|
|||
"All choices for scaled_mm were invalid, using ATen backend as fallback"
|
||||
)
|
||||
return aten_choice.output_node()
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def using_b200() -> bool:
|
||||
"""Returns true if the device is a NVIDIA B200, otherwise returns false."""
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
# compute capability 10.0 or 10.0a is NVIDIA B200
|
||||
device_properties = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
return device_properties.major == 10
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user