Refactor row-wise scaled MM (#149978)

1. Add config selection for SM89.
2. Only build kernels if compiling for given arch.
3. Factor out CMake code to enforce compiling for needed archs for individual files into a function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149978
Approved by: https://github.com/drisspg
This commit is contained in:
Aleksandar Samardžić 2025-03-26 19:33:19 +01:00 committed by PyTorch MergeBot
parent 6aca002d82
commit 43cc954f88
3 changed files with 154 additions and 68 deletions

View File

@ -41,6 +41,8 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include <ATen/native/cuda/cutlass_common.cuh>
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()
@ -221,10 +223,11 @@ void f8f8bf16_rowwise_impl(
typename Schedule<large_tile, FastAccum::value>::type>::
CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;
using GemmKernel = at::cuda::detail::enable_3x_kernel_for_sm9x<
cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
@ -402,10 +405,11 @@ void f8f8bf16_rowwise_impl_sm100(
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;
using GemmKernel = at::cuda::detail::enable_3x_kernel_for_sm10_or_later<
cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
@ -480,6 +484,9 @@ void f8f8bf16_rowwise_impl_sm100(
// Cutlass rowwise kernel for SM89
template <
typename ThreadblockShape,
typename WarpShape,
int NumStages,
typename FastAccum,
typename DtypeA,
typename DtypeB,
@ -511,12 +518,7 @@ void f8f8bf16_rowwise_impl_sm89(
using ThreadblockSwizzle =
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK;
// TODO: instead of fixing these values, implement logic alike to
// what is used for SM90+.
using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
constexpr auto NumStages = 4;
using Operator = std::conditional_t<
FastAccum::value,
@ -586,23 +588,23 @@ void f8f8bf16_rowwise_impl_sm89(
Output,
EVTApplyBias>;
using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
DtypeA, LayoutInputA, cutlass::ComplexTransform::kNone, AlignmentInputA,
DtypeB, LayoutInputB, cutlass::ComplexTransform::kNone, AlignmentInputB,
DtypeOutput, LayoutOutput, AlignmentOutput,
DtypeAccum,
DtypeEpilogue,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTOutput,
ThreadblockSwizzle,
NumStages,
Operator,
NumEVTEpilogueStages
>::GemmKernel;
using EVTKernel = at::cuda::detail::enable_2x_kernel_for_sm89<
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
DtypeA, LayoutInputA, cutlass::ComplexTransform::kNone, AlignmentInputA,
DtypeB, LayoutInputB, cutlass::ComplexTransform::kNone, AlignmentInputB,
DtypeOutput, LayoutOutput, AlignmentOutput,
DtypeAccum,
DtypeEpilogue,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTOutput,
ThreadblockSwizzle,
NumStages,
Operator,
NumEVTEpilogueStages>::GemmKernel>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
@ -880,6 +882,49 @@ void dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose(
}
}
template <typename... Types>
void dispatch_fp8_rowwise_kernel_sm89(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
at::Tensor out) {
int M = XQ.size(0);
if (M <= 16) {
return f8f8bf16_rowwise_impl_sm89<
/*ThreadblockShape=*/cutlass::gemm::GemmShape<16, 64, 128>,
/*WarpShape=*/cutlass::gemm::GemmShape<16, 64, 64>,
/*NumStages=*/5,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
} else if (M <= 32) {
return f8f8bf16_rowwise_impl_sm89<
/*ThreadblockShape=*/cutlass::gemm::GemmShape<32, 64, 128>,
/*WarpShape=*/cutlass::gemm::GemmShape<16, 64, 64>,
/*NumStages=*/5,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
} else if (M <= 64) {
return f8f8bf16_rowwise_impl_sm89<
/*ThreadblockShape=*/cutlass::gemm::GemmShape<64, 64, 128>,
/*WarpShape=*/cutlass::gemm::GemmShape<32, 64, 64>,
/*NumStages=*/5,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
} else if (M <= 256) {
return f8f8bf16_rowwise_impl_sm89<
/*ThreadblockShape=*/cutlass::gemm::GemmShape<64, 128, 128>,
/*WarpShape=*/cutlass::gemm::GemmShape<64, 64, 64>,
/*NumStages=*/3,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
} else {
return f8f8bf16_rowwise_impl_sm89<
/*ThreadblockShape=*/cutlass::gemm::GemmShape<128, 128, 64>,
/*WarpShape=*/cutlass::gemm::GemmShape<64, 64, 64>,
/*NumStages=*/5,
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
}
}
template <typename... Types>
void dispatch_fp8_rowwise_kernel_on_sm(
at::Tensor XQ,
@ -900,7 +945,7 @@ void dispatch_fp8_rowwise_kernel_on_sm(
if (sm9x || sm10x) {
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<Types...>(XQ, WQ, x_scale, w_scale, bias, out);
} else {
f8f8bf16_rowwise_impl_sm89<Types...>(XQ, WQ, x_scale, w_scale, bias, out);
dispatch_fp8_rowwise_kernel_sm89<Types...>(XQ, WQ, x_scale, w_scale, bias, out);
}
}

View File

@ -0,0 +1,38 @@
#pragma once
#include <c10/util/Exception.h>
#include <cutlass/cutlass.h>
namespace at::cuda::detail {
template <typename Kernel>
struct enable_2x_kernel_for_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_3x_kernel_for_sm9x : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_3x_kernel_for_sm10_or_later : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
} // namespace at::cuda::detail

View File

@ -76,48 +76,51 @@ if(INTERN_BUILD_ATEN_OPS)
file(GLOB_RECURSE all_python "${CMAKE_CURRENT_LIST_DIR}/../torchgen/*.py")
# RowwiseScaled.cu requires sm89/sm90a flags
# Handle files that may need sm89/sm90a/sm100a flags (stable/nightly
# builds are not built for these archs).
if(USE_CUDA)
set(ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu")
# The stable/nightly builds do not enable some SM architectures,
# like 89/90a/100a. Still, some files need to be built for these
# architecturs specifically. This function makes it possible to
# enable building given file for a specific such architecture, in
# case if PyTorch is built for corresponding other architecture;
# for example, it will enable building for SM 90a in case PyTorch
# built for SM 90, etc. For examples of how to use the function,
# see below the function itself.
function(_BUILD_FOR_ADDITIONAL_ARCHS file archs)
torch_cuda_get_nvcc_gencode_flag(_existing_arch_flags)
# Get existing arch flags
torch_cuda_get_nvcc_gencode_flag(EXISTING_ARCH_FLAGS)
# Check NVCC version and existing arch flags
set(ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "")
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0)
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_86.*")
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_89,code=sm_89")
set(_file_compile_flags "")
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0)
foreach(_arch ${archs})
if("${_arch}" STREQUAL "89")
if(_existing_arch_flags MATCHES ".*compute_86.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_89,code=sm_89")
endif()
endif()
if("${_arch}" STREQUAL "90a")
if(_existing_arch_flags MATCHES ".*compute_90.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_90a,code=sm_90a")
endif()
endif()
if("${_arch}" STREQUAL "100a")
if(_existing_arch_flags MATCHES ".*compute_100.*")
list(APPEND _file_compile_flags "-gencode;arch=compute_100a,code=sm_100a")
endif()
endif()
endforeach()
endif()
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
endif()
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_100.*")
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_100a,code=sm_100a")
endif()
endif()
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS)
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}")
set(ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu")
# Get existing arch flags
torch_cuda_get_nvcc_gencode_flag(EXISTING_ARCH_FLAGS)
# Check NVCC version and existing arch flags
set(ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "")
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0)
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_86.*")
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_89,code=sm_89")
endif()
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
endif()
endif()
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS)
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}")
list(JOIN _file_compile_flags " " _file_compile_flags)
set_source_files_properties(${file} PROPERTIES COMPILE_FLAGS "${_file_compile_flags}")
endfunction()
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
"89;90a;100a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
"89;90a")
endif()
set(GEN_ROCM_FLAG)