mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Refactor row-wise scaled MM (#149978)
1. Add config selection for SM89. 2. Only build kernels if compiling for given arch. 3. Factor out CMake code to enforce compiling for needed archs for individual files into a function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149978 Approved by: https://github.com/drisspg
This commit is contained in:
parent
6aca002d82
commit
43cc954f88
|
|
@ -41,6 +41,8 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
|||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
#include <ATen/native/cuda/cutlass_common.cuh>
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
|
|
@ -221,10 +223,11 @@ void f8f8bf16_rowwise_impl(
|
|||
typename Schedule<large_tile, FastAccum::value>::type>::
|
||||
CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
using GemmKernel = at::cuda::detail::enable_3x_kernel_for_sm9x<
|
||||
cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
|
@ -402,10 +405,11 @@ void f8f8bf16_rowwise_impl_sm100(
|
|||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduleType>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
using GemmKernel = at::cuda::detail::enable_3x_kernel_for_sm10_or_later<
|
||||
cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
|
|
@ -480,6 +484,9 @@ void f8f8bf16_rowwise_impl_sm100(
|
|||
|
||||
// Cutlass rowwise kernel for SM89
|
||||
template <
|
||||
typename ThreadblockShape,
|
||||
typename WarpShape,
|
||||
int NumStages,
|
||||
typename FastAccum,
|
||||
typename DtypeA,
|
||||
typename DtypeB,
|
||||
|
|
@ -511,12 +518,7 @@ void f8f8bf16_rowwise_impl_sm89(
|
|||
using ThreadblockSwizzle =
|
||||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK;
|
||||
|
||||
// TODO: instead of fixing these values, implement logic alike to
|
||||
// what is used for SM90+.
|
||||
using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
constexpr auto NumStages = 4;
|
||||
|
||||
using Operator = std::conditional_t<
|
||||
FastAccum::value,
|
||||
|
|
@ -586,23 +588,23 @@ void f8f8bf16_rowwise_impl_sm89(
|
|||
Output,
|
||||
EVTApplyBias>;
|
||||
|
||||
using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
DtypeA, LayoutInputA, cutlass::ComplexTransform::kNone, AlignmentInputA,
|
||||
DtypeB, LayoutInputB, cutlass::ComplexTransform::kNone, AlignmentInputB,
|
||||
DtypeOutput, LayoutOutput, AlignmentOutput,
|
||||
DtypeAccum,
|
||||
DtypeEpilogue,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EVTOutput,
|
||||
ThreadblockSwizzle,
|
||||
NumStages,
|
||||
Operator,
|
||||
NumEVTEpilogueStages
|
||||
>::GemmKernel;
|
||||
using EVTKernel = at::cuda::detail::enable_2x_kernel_for_sm89<
|
||||
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
DtypeA, LayoutInputA, cutlass::ComplexTransform::kNone, AlignmentInputA,
|
||||
DtypeB, LayoutInputB, cutlass::ComplexTransform::kNone, AlignmentInputB,
|
||||
DtypeOutput, LayoutOutput, AlignmentOutput,
|
||||
DtypeAccum,
|
||||
DtypeEpilogue,
|
||||
OperatorClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
EVTOutput,
|
||||
ThreadblockSwizzle,
|
||||
NumStages,
|
||||
Operator,
|
||||
NumEVTEpilogueStages>::GemmKernel>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
|
||||
|
||||
|
|
@ -880,6 +882,49 @@ void dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename... Types>
|
||||
void dispatch_fp8_rowwise_kernel_sm89(
|
||||
at::Tensor XQ,
|
||||
at::Tensor WQ,
|
||||
at::Tensor x_scale,
|
||||
at::Tensor w_scale,
|
||||
std::optional<at::Tensor> bias,
|
||||
at::Tensor out) {
|
||||
int M = XQ.size(0);
|
||||
|
||||
if (M <= 16) {
|
||||
return f8f8bf16_rowwise_impl_sm89<
|
||||
/*ThreadblockShape=*/cutlass::gemm::GemmShape<16, 64, 128>,
|
||||
/*WarpShape=*/cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
/*NumStages=*/5,
|
||||
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else if (M <= 32) {
|
||||
return f8f8bf16_rowwise_impl_sm89<
|
||||
/*ThreadblockShape=*/cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
/*WarpShape=*/cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
/*NumStages=*/5,
|
||||
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else if (M <= 64) {
|
||||
return f8f8bf16_rowwise_impl_sm89<
|
||||
/*ThreadblockShape=*/cutlass::gemm::GemmShape<64, 64, 128>,
|
||||
/*WarpShape=*/cutlass::gemm::GemmShape<32, 64, 64>,
|
||||
/*NumStages=*/5,
|
||||
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else if (M <= 256) {
|
||||
return f8f8bf16_rowwise_impl_sm89<
|
||||
/*ThreadblockShape=*/cutlass::gemm::GemmShape<64, 128, 128>,
|
||||
/*WarpShape=*/cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
/*NumStages=*/3,
|
||||
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else {
|
||||
return f8f8bf16_rowwise_impl_sm89<
|
||||
/*ThreadblockShape=*/cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
/*WarpShape=*/cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
/*NumStages=*/5,
|
||||
Types...>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Types>
|
||||
void dispatch_fp8_rowwise_kernel_on_sm(
|
||||
at::Tensor XQ,
|
||||
|
|
@ -900,7 +945,7 @@ void dispatch_fp8_rowwise_kernel_on_sm(
|
|||
if (sm9x || sm10x) {
|
||||
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<Types...>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
} else {
|
||||
f8f8bf16_rowwise_impl_sm89<Types...>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
dispatch_fp8_rowwise_kernel_sm89<Types...>(XQ, WQ, x_scale, w_scale, bias, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
38
aten/src/ATen/native/cuda/cutlass_common.cuh
Normal file
38
aten/src/ATen/native/cuda/cutlass_common.cuh
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
namespace at::cuda::detail {
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_2x_kernel_for_sm89 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 890
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_3x_kernel_for_sm9x : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_3x_kernel_for_sm10_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace at::cuda::detail
|
||||
|
|
@ -76,48 +76,51 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||
|
||||
file(GLOB_RECURSE all_python "${CMAKE_CURRENT_LIST_DIR}/../torchgen/*.py")
|
||||
|
||||
# RowwiseScaled.cu requires sm89/sm90a flags
|
||||
# Handle files that may need sm89/sm90a/sm100a flags (stable/nightly
|
||||
# builds are not built for these archs).
|
||||
if(USE_CUDA)
|
||||
set(ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu")
|
||||
# The stable/nightly builds do not enable some SM architectures,
|
||||
# like 89/90a/100a. Still, some files need to be built for these
|
||||
# architecturs specifically. This function makes it possible to
|
||||
# enable building given file for a specific such architecture, in
|
||||
# case if PyTorch is built for corresponding other architecture;
|
||||
# for example, it will enable building for SM 90a in case PyTorch
|
||||
# built for SM 90, etc. For examples of how to use the function,
|
||||
# see below the function itself.
|
||||
function(_BUILD_FOR_ADDITIONAL_ARCHS file archs)
|
||||
torch_cuda_get_nvcc_gencode_flag(_existing_arch_flags)
|
||||
|
||||
# Get existing arch flags
|
||||
torch_cuda_get_nvcc_gencode_flag(EXISTING_ARCH_FLAGS)
|
||||
|
||||
# Check NVCC version and existing arch flags
|
||||
set(ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "")
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_86.*")
|
||||
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_89,code=sm_89")
|
||||
set(_file_compile_flags "")
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
foreach(_arch ${archs})
|
||||
if("${_arch}" STREQUAL "89")
|
||||
if(_existing_arch_flags MATCHES ".*compute_86.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_89,code=sm_89")
|
||||
endif()
|
||||
endif()
|
||||
if("${_arch}" STREQUAL "90a")
|
||||
if(_existing_arch_flags MATCHES ".*compute_90.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_90a,code=sm_90a")
|
||||
endif()
|
||||
endif()
|
||||
if("${_arch}" STREQUAL "100a")
|
||||
if(_existing_arch_flags MATCHES ".*compute_100.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_100a,code=sm_100a")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
|
||||
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
|
||||
endif()
|
||||
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_100.*")
|
||||
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_100a,code=sm_100a")
|
||||
endif()
|
||||
endif()
|
||||
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS)
|
||||
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}")
|
||||
|
||||
set(ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu")
|
||||
|
||||
# Get existing arch flags
|
||||
torch_cuda_get_nvcc_gencode_flag(EXISTING_ARCH_FLAGS)
|
||||
|
||||
# Check NVCC version and existing arch flags
|
||||
set(ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "")
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0)
|
||||
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_86.*")
|
||||
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_89,code=sm_89")
|
||||
endif()
|
||||
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
|
||||
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
|
||||
endif()
|
||||
endif()
|
||||
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS)
|
||||
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}")
|
||||
list(JOIN _file_compile_flags " " _file_compile_flags)
|
||||
|
||||
set_source_files_properties(${file} PROPERTIES COMPILE_FLAGS "${_file_compile_flags}")
|
||||
endfunction()
|
||||
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
|
||||
"89;90a;100a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
|
||||
"89;90a")
|
||||
endif()
|
||||
|
||||
set(GEN_ROCM_FLAG)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user