Build RowwiseScaledMM.cu for SM89 (#145676)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145676
Approved by: https://github.com/drisspg, https://github.com/malfet, https://github.com/eqy
This commit is contained in:
Aleksandar Samardžić 2025-01-31 23:48:47 +01:00 committed by PyTorch MergeBot
parent f40e013787
commit 2b00d211f0
10 changed files with 57 additions and 67 deletions

View File

@ -315,9 +315,6 @@ void f8f8bf16_rowwise_impl_sm89(
using LayoutInputB = cutlass::layout::ColumnMajor; using LayoutInputB = cutlass::layout::ColumnMajor;
constexpr int AlignmentInputB = 16 / sizeof(DtypeB); constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
constexpr int AlignmentScale = 16 / sizeof(DtypeScale);
constexpr int AlignmentBias = 16 / sizeof(DtypeBias);
using LayoutOutput = cutlass::layout::RowMajor; using LayoutOutput = cutlass::layout::RowMajor;
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
@ -330,8 +327,8 @@ void f8f8bf16_rowwise_impl_sm89(
// TODO: instead of fixing these values, implement logic alike to // TODO: instead of fixing these values, implement logic alike to
// what is used for SM90+. // what is used for SM90+.
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
constexpr auto NumStages = 4; constexpr auto NumStages = 4;
@ -341,20 +338,6 @@ void f8f8bf16_rowwise_impl_sm89(
cutlass::arch::OpMultiplyAdd>; cutlass::arch::OpMultiplyAdd>;
constexpr auto NumEVTEpilogueStages = 1; constexpr auto NumEVTEpilogueStages = 1;
using ScaleTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
DtypeScale,
AlignmentScale,
NumEVTEpilogueStages>;
using BiasTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
DtypeBias,
AlignmentBias,
NumEVTEpilogueStages>;
using OutputTileThreadMap = using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout< cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape, ThreadblockShape,
@ -365,25 +348,19 @@ void f8f8bf16_rowwise_impl_sm89(
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using XScale = using XScale = cutlass::epilogue::threadblock::VisitorColBroadcast<
cutlass::epilogue::threadblock::VisitorColBroadcast< OutputTileThreadMap, DtypeScale,
ScaleTileThreadMap, cute::Stride<cute::_1, cute::_0, int64_t>>;
DtypeScale,
cute::Stride<cute::_1, cute::_0, int64_t>>;
using XScaleArguments = typename XScale::Arguments; using XScaleArguments = typename XScale::Arguments;
using WScale = using WScale = cutlass::epilogue::threadblock::VisitorRowBroadcast<
cutlass::epilogue::threadblock::VisitorRowBroadcast< OutputTileThreadMap, DtypeScale,
ScaleTileThreadMap, cute::Stride<cute::_0, cute::_1, int64_t>>;
DtypeScale,
cute::Stride<cute::_0, cute::_1, int64_t>>;
using WScaleArguments = typename WScale::Arguments; using WScaleArguments = typename WScale::Arguments;
using Bias = using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
cutlass::epilogue::threadblock::VisitorRowBroadcast< OutputTileThreadMap, DtypeBias,
BiasTileThreadMap, cute::Stride<cute::_0, cute::_1, int64_t>>;
DtypeBias,
cute::Stride<cute::_0, cute::_1, int32_t>>;
using BiasArguments = typename Bias::Arguments; using BiasArguments = typename Bias::Arguments;
using ApplyXScale = cutlass::epilogue::threadblock::VisitorCompute< using ApplyXScale = cutlass::epilogue::threadblock::VisitorCompute<
@ -423,8 +400,7 @@ void f8f8bf16_rowwise_impl_sm89(
Output, Output,
EVTApplyBias>; EVTApplyBias>;
using EVTKernel = using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
DtypeA, LayoutInputA, cutlass::ComplexTransform::kNone, AlignmentInputA, DtypeA, LayoutInputA, cutlass::ComplexTransform::kNone, AlignmentInputA,
DtypeB, LayoutInputB, cutlass::ComplexTransform::kNone, AlignmentInputB, DtypeB, LayoutInputB, cutlass::ComplexTransform::kNone, AlignmentInputB,
DtypeOutput, LayoutOutput, AlignmentOutput, DtypeOutput, LayoutOutput, AlignmentOutput,
@ -442,7 +418,7 @@ void f8f8bf16_rowwise_impl_sm89(
NumEVTEpilogueStages NumEVTEpilogueStages
>::GemmKernel; >::GemmKernel;
using Gemm = cutlass::gemm::device::GemmUniversalBase<EVTKernel>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
cutlass::gemm::GemmCoord problem_size(M, N, K); cutlass::gemm::GemmCoord problem_size(M, N, K);
constexpr auto SplitKFactor = 1; constexpr auto SplitKFactor = 1;
@ -475,14 +451,13 @@ void f8f8bf16_rowwise_impl_sm89(
{} // ApplyXScale {} // ApplyXScale
}, // EVTApplyXScale }, // EVTApplyXScale
w_scale_arguments, // WScale w_scale_arguments, // WScale
{}, // ApplyWScale {} // ApplyWScale
}, // EVTApplyWScale }, // EVTApplyWScale
bias_arguments, // Bias bias_arguments, // Bias
{} // ApplyBias {} // ApplyBias
}, // EVTApplyBias }, // EVTApplyBias
output_arguments // Output output_arguments // Output
}; // EVTOutput }; // EVTOutput
constexpr auto AvailSms = -1;
typename Gemm::Arguments arguments( typename Gemm::Arguments arguments(
cutlass::gemm::GemmUniversalMode::kGemm, cutlass::gemm::GemmUniversalMode::kGemm,
@ -500,8 +475,7 @@ void f8f8bf16_rowwise_impl_sm89(
problem_size.k(), // stride A problem_size.k(), // stride A
problem_size.k(), // stride B problem_size.k(), // stride B
0, // stride C (unused) 0, // stride C (unused)
0, // stride D (unused) 0); // stride D (unused)
AvailSms);
Gemm gemm; Gemm gemm;

View File

@ -76,7 +76,7 @@ if(INTERN_BUILD_ATEN_OPS)
file(GLOB_RECURSE all_python "${CMAKE_CURRENT_LIST_DIR}/../torchgen/*.py") file(GLOB_RECURSE all_python "${CMAKE_CURRENT_LIST_DIR}/../torchgen/*.py")
# RowwiseScaled.cu requires sm90a flags # RowwiseScaled.cu requires sm89/sm90a flags
if(USE_CUDA) if(USE_CUDA)
set(ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu") set(ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu")
@ -84,11 +84,17 @@ if(INTERN_BUILD_ATEN_OPS)
torch_cuda_get_nvcc_gencode_flag(EXISTING_ARCH_FLAGS) torch_cuda_get_nvcc_gencode_flag(EXISTING_ARCH_FLAGS)
# Check NVCC version and existing arch flags # Check NVCC version and existing arch flags
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND set(ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "")
EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*") if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0)
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} if(EXISTING_ARCH_FLAGS MATCHES ".*compute_86.*")
PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_89,code=sm_89")
endif()
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
endif()
endif() endif()
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS)
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}")
endif() endif()
set(GEN_ROCM_FLAG) set(GEN_ROCM_FLAG)

View File

@ -136,7 +136,10 @@ class DistMatrixOpsTest(DTensorTestBase):
@with_comms @with_comms
@skip_unless_torch_gpu @skip_unless_torch_gpu
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "torch._scaled_mm requires H100+") @unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
def test_scaled_mm(self): def test_scaled_mm(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shrd0 = Shard(0) shrd0 = Shard(0)

View File

@ -708,8 +708,8 @@ class AOTInductorTestsTemplate:
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
@unittest.skipIf( @unittest.skipIf(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0), not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+ and sm_89 and MI300+ devices", "FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
) )
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform @skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
@skipIfXpu @skipIfXpu
@ -756,8 +756,8 @@ class AOTInductorTestsTemplate:
) )
@unittest.skipIf( @unittest.skipIf(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0), not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+", "FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
) )
@skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform @skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform
@skipIfXpu @skipIfXpu
@ -3324,7 +3324,7 @@ class AOTInductorTestsTemplate:
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf( @unittest.skipIf(
not PLATFORM_SUPPORTS_FP8, not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+ and sm_89 and MI300+ devices", "FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
) )
def test_runtime_checks_fp8(self): def test_runtime_checks_fp8(self):
# cuda only # cuda only

View File

@ -21,7 +21,7 @@ from torch.utils._triton import has_triton_tma_device
torch.set_float32_matmul_precision("high") torch.set_float32_matmul_precision("high")
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
# define the e4m3/e5m2 constants # define the e4m3/e5m2 constants
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max

View File

@ -1259,7 +1259,7 @@ class TestPrologueFusion(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
@unittest.skipIf( @unittest.skipIf(
not PLATFORM_SUPPORTS_FP8, not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+ and sm_89 and MI300+ devices", "FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
) )
def test_low_precision(self): def test_low_precision(self):
M = K = N = 128 M = K = N = 128

View File

@ -839,7 +839,7 @@ class TestFlopCounter(TestCase):
@unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf(not HAS_CUDA, "CUDA not available")
@unittest.skipIf( @unittest.skipIf(
not PLATFORM_SUPPORTS_FP8, not PLATFORM_SUPPORTS_FP8,
"Does not support fp8 (pre-SM90 hardware on CUDA)", "FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
) )
def test_scaled_mm(self): def test_scaled_mm(self):
dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn

View File

@ -17,6 +17,7 @@ from torch.quantization._quantized_conversions import (
from torch.testing import make_tensor from torch.testing import make_tensor
from torch.testing._internal.common_cuda import ( from torch.testing._internal.common_cuda import (
SM53OrLater, SM53OrLater,
SM89OrLater,
_get_torch_cuda_version, _get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8 PLATFORM_SUPPORTS_FP8
) )
@ -42,10 +43,8 @@ from torch.testing._internal.common_utils import (
) )
_IS_SM8X = False _IS_SM8X = False
_IS_SM9X = False
if TEST_CUDA: if TEST_CUDA:
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
_IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
# Protects against includes accidentally setting the default dtype # Protects against includes accidentally setting the default dtype
assert torch.get_default_dtype() is torch.float32 assert torch.get_default_dtype() is torch.float32
@ -213,7 +212,7 @@ class TestMatmulCuda(TestCase):
self.assertEqual(out1_gpu, out2_gpu[0]) self.assertEqual(out1_gpu, out2_gpu[0])
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
if torch.version.hip: if torch.version.hip:
e4m3_type = torch.float8_e4m3fnuz e4m3_type = torch.float8_e4m3fnuz
@ -538,8 +537,7 @@ class TestFP8MatmulCuda(TestCase):
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
) )
@unittest.skipIf(PLATFORM_SUPPORTS_FP8, @unittest.skipIf(PLATFORM_SUPPORTS_FP8, f8_msg)
"This test is only for devices with compute capability < 8.9")
def test_error_message_fp8_pre_sm89(self, device) -> None: def test_error_message_fp8_pre_sm89(self, device) -> None:
(k, l, m) = (16, 48, 32) (k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(e4m3_type) x = torch.rand((k, l), device=device).to(e4m3_type)
@ -567,7 +565,7 @@ class TestFP8MatmulCuda(TestCase):
self.assertEqual(out_fp8, out_fp8_s) self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific") @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
@parametrize("use_fast_accum", [True, False]) @parametrize("use_fast_accum", [True, False])
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
M, K, N = (1024, 512, 2048) M, K, N = (1024, 512, 2048)
@ -673,7 +671,7 @@ class TestFP8MatmulCuda(TestCase):
) )
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific") @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
@parametrize("base_dtype", [torch.bfloat16]) @parametrize("base_dtype", [torch.bfloat16])
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
torch.manual_seed(42) torch.manual_seed(42)

View File

@ -1047,7 +1047,10 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS: if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
self.skipTest('cuSPARSELt not enabled') self.skipTest('cuSPARSELt not enabled')
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices") @unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
@xfailIfSM89 @xfailIfSM89
@parametrize("dense_input_shape", [(256, 128)]) @parametrize("dense_input_shape", [(256, 128)])
def test_sparse_fp8fp8_mm(self, dense_input_shape, device): def test_sparse_fp8fp8_mm(self, dense_input_shape, device):
@ -1067,7 +1070,10 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
): ):
dense_result = torch.mm(A_fp8_sparse, B_fp8) dense_result = torch.mm(A_fp8_sparse, B_fp8)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices") @unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
@xfailIfSM89 @xfailIfSM89
def test_sparse_semi_structured_scaled_mm_fp8(self, device) -> None: def test_sparse_semi_structured_scaled_mm_fp8(self, device) -> None:
(k, l, m) = (32, 64, 32) (k, l, m) = (32, 64, 32)
@ -1084,7 +1090,10 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
out_fp32_sparse = out_fp8_sparse.to(torch.float32) out_fp32_sparse = out_fp8_sparse.to(torch.float32)
torch.testing.assert_close(out_fp32, out_fp32_sparse, rtol=1e-1, atol=1e-1) torch.testing.assert_close(out_fp32, out_fp32_sparse, rtol=1e-1, atol=1e-1)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices") @unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
)
@xfailIfSM89 @xfailIfSM89
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32]) @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32])
@parametrize("dense_input_shape", [(256, 128)]) @parametrize("dense_input_shape", [(256, 128)])

View File

@ -31,7 +31,7 @@ from torch.testing._internal.common_device_type import \
toleranceOverride, tol) toleranceOverride, tol)
from torch.testing._internal.common_cuda import ( from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
SM53OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version,
_get_torch_rocm_version, _get_torch_rocm_version,
) )
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
@ -16211,7 +16211,7 @@ op_db: list[OpInfo] = [
supports_out=True, supports_out=True,
supports_forward_ad=False, supports_forward_ad=False,
supports_autograd=False, supports_autograd=False,
decorators=[skipCUDAIf(not SM90OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 9.0')], decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')],
skips=( skips=(
# Sample inputs isn't really parametrized on dtype # Sample inputs isn't really parametrized on dtype
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',