mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CUTLASS] [CUDA] SM100 GroupMM (#156203)
Closes https://github.com/pytorch/pytorch/issues/156202 PR adds blackwell support for GroupMM Most of the code that is used for SM90 can be reused, kernel schedule has to be changed in accordance with https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html Did some preliminary benchmarking of H200 vs B200 Script ```py import torch print(torch.__file__) device = torch.device("cuda") dtype = torch.bfloat16 shapes = [ (16, 128000, 7168, 7168), (128, 1, 2048, 7168) ] for batch, M, N, K in shapes: a = torch.randn(batch, M, K, device=device, dtype=dtype) b = torch.randn(batch, N, K, device=device, dtype=dtype) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) for i in range(5): c = torch._grouped_mm(a, b) num_iter = 50 start_event.record() for i in range(num_iter): c = torch._grouped_mm(a, b) end_event.record() torch.cuda.synchronize() elapsed_time_ms = start_event.elapsed_time(end_event) avg_time_ms = elapsed_time_ms / num_iter print(f"batch: {batch}\tM: {M}\tN: {N}\tK: {K}") print(f"Time per Iteration:\t {avg_time_ms:.4f} ms") ``` On H200 ``` batch: 16 M: 128000 N: 7168 K: 7168 Time per Iteration: 298.6668 ms batch: 128 M: 1 N: 2048 K: 7168 Time per Iteration: 4.1462 ms ``` B200 ``` batch: 16 M: 128000 N: 7168 K: 7168 Time per Iteration: 190.7458 ms batch: 128 M: 1 N: 2048 K: 7168 Time per Iteration: 3.0680 ms ``` nsys nvprof ``` root@16930b42ffc6:/workspace/pytorch# nsys nvprof python gemm_test.py WARNING: python and any of its children processes will be profiled. Collecting data... batch: 16 M: 128000 N: 7168 K: 7168 Time per Iteration: 192.6420 ms batch: 128 M: 1 N: 2048 K: 7168 Time per Iteration: 1.2255 ms Generating '/tmp/nsys-report-6a53.qdstrm' [1/7] [========================100%] report1.nsys-rep [2/7] [========================100%] report1.sqlite [3/7] Executing 'nvtx_sum' stats report SKIPPED: /workspace/pytorch/report1.sqlite does not contain NV Tools Extension (NVTX) data. [4/7] Executing 'cuda_api_sum' stats report Time (%) Total Time (ns) Num Calls Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- ------------ ------------ -------- ----------- ------------ --------------------------------- 98.9 10586895744 2 5293447872.0 5293447872.0 73786464 10513109280 7381715954.2 cudaDeviceSynchronize 1.0 104084608 5 20816921.6 33552480.0 100800 34786208 18048125.3 cudaMalloc 0.1 5694304 4 1423576.0 1416656.0 1258560 1602432 181668.1 cudaGetDeviceProperties_v2_v12000 0.1 5430496 130 41773.0 4560.0 2496 3854368 345761.8 cudaLaunchKernel 0.0 587584 110 5341.7 4992.0 4224 16992 1482.0 cudaLaunchKernelExC_v11060 0.0 119200 660 180.6 128.0 96 4128 206.7 cudaGetDriverEntryPoint_v11030 0.0 68352 660 103.6 64.0 32 4928 224.6 cuTensorMapEncodeTiled 0.0 34976 49 713.8 224.0 160 6720 1343.4 cudaStreamIsCapturing_v10000 0.0 32992 4 8248.0 7456.0 4128 13952 4804.4 cudaEventRecord 0.0 16928 4 4232.0 3600.0 1728 8000 2764.7 cudaEventQuery 0.0 16288 4 4072.0 3568.0 1952 7200 2396.1 cudaEventCreateWithFlags 0.0 13632 4 3408.0 2672.0 544 7744 3408.7 cudaEventDestroy 0.0 1056 1 1056.0 1056.0 1056 1056 0.0 cuModuleGetLoadingMode [5/7] Executing 'cuda_gpu_kern_sum' stats report Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- ----------- ----------- --------- --------- ----------- ---------------------------------------------------------------------------------------------------- 99.0 10549232845 55 191804233.5 192944479.0 165746368 203645313 5353204.3 void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm… 0.6 67327135 55 1224129.7 1330656.0 924320 1364928 182180.4 void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm… 0.3 34854783 20 1742739.1 1597856.0 10080 3899616 818421.2 void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat… 0.0 354880 110 3226.2 3296.0 1920 4160 554.4 void at::cuda::detail::prepare_grouped_gemm_data<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass:… ``` The kernel names are too long to be shown via nvprof, I pasted this from nsight systems ``` small kernel 1SM 100.0% 1.286 ms 1 1.286 ms 1.286 ms 1.286 ms 1.286 ms 0 ns void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)3, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)128, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params) large kernel 2SM 100.0% 194.178 ms 1 194.178 ms 194.178 ms 194.178 ms 194.178 ms 0 ns void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)5, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)256>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_2x1SM_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)256, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156203 Approved by: https://github.com/syed-ahmed, https://github.com/drisspg
This commit is contained in:
parent
996206e66f
commit
772d590415
|
|
@ -1044,7 +1044,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
|
|||
return _int_mm_out_cuda(self, mat2, result);
|
||||
}
|
||||
|
||||
static bool _scaled_mm_allowed_device(bool sm90_only=false) {
|
||||
static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) {
|
||||
#ifdef USE_ROCM
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx942",
|
||||
|
|
@ -1058,8 +1058,9 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false) {
|
|||
return at::detail::getCUDAHooks().isGPUArch(archs);
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (sm90_only) {
|
||||
return dprops->major == 9;
|
||||
|
||||
if (sm90_only || sm100_only) {
|
||||
return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10);
|
||||
} else {
|
||||
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
|
||||
}
|
||||
|
|
@ -1675,8 +1676,8 @@ const std::optional<at::Tensor>& offs,
|
|||
const std::optional<at::Tensor>& bias,
|
||||
std::optional<c10::ScalarType> out_dtype) {
|
||||
#ifndef USE_ROCM
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true);
|
||||
TORCH_CHECK(allowed_device, "torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0");
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK(allowed_device, "torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0, 10.0");
|
||||
|
||||
TORCH_CHECK(mat_a.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK(mat_b.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_b.scalar_type());
|
||||
|
|
|
|||
|
|
@ -8,9 +8,10 @@
|
|||
#include <c10/util/irange.h>
|
||||
|
||||
|
||||
// Two warninngs in Cutlass included header files
|
||||
// Three warninngs in Cutlass included header files
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable")
|
||||
|
||||
// Determine if the architecture supports rowwise scaled mm
|
||||
// Currently failing on windows with:
|
||||
|
|
@ -43,11 +44,14 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
|||
#include <cutlass/gemm/dispatch_policy.hpp>
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
|
||||
#include <ATen/native/cuda/cutlass_common.cuh>
|
||||
|
||||
namespace {
|
||||
using Strides = at::cuda::detail::Strides; // std::array<int64_t, 3>;
|
||||
|
||||
template <bool PONG, typename TB_M, typename TB_N, typename TB_K>
|
||||
template <typename ArchTag, bool PONGOr2SM, typename TB_M, typename TB_N, typename TB_K>
|
||||
struct Schedule {
|
||||
// SM90
|
||||
using CooperativeSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
|
||||
using PongSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
|
||||
|
|
@ -55,10 +59,19 @@ struct Schedule {
|
|||
cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using PongEpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using KernelSchedule =
|
||||
cute::conditional_t<PONG, PongSchedule, CooperativeSchedule>;
|
||||
using EpilogueSchedule = cute::
|
||||
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;
|
||||
// SM100
|
||||
using MMA1SMKernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
|
||||
using MMA1SMEpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using MMA2SMKernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;
|
||||
using MMA2SMEpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
|
||||
|
||||
using KernelSchedule = cute::conditional_t<std::is_same_v<ArchTag, cutlass::arch::Sm100>,
|
||||
cute::conditional_t<PONGOr2SM, MMA2SMKernelSchedule, MMA1SMKernelSchedule>,
|
||||
cute::conditional_t<PONGOr2SM, PongSchedule, CooperativeSchedule>>;
|
||||
using EpilogueSchedule = cute::conditional_t<std::is_same_v<ArchTag, cutlass::arch::Sm100>,
|
||||
cute::conditional_t<PONGOr2SM, MMA2SMEpilogueSchedule, MMA1SMEpilogueSchedule>,
|
||||
cute::conditional_t<PONGOr2SM, PongEpilogueSchedule, CooperativeEpilogueSchedule>>;
|
||||
|
||||
};
|
||||
|
||||
int ceildiv(int a, int b) {
|
||||
|
|
@ -70,13 +83,14 @@ int round_up_to_nearest_multiple(int a, int b) {
|
|||
}
|
||||
|
||||
template <
|
||||
typename ArchTag,
|
||||
bool a_row_major,
|
||||
bool b_row_major,
|
||||
bool Pong,
|
||||
bool PONGOr2SM,
|
||||
typename TB_M,
|
||||
typename TB_N,
|
||||
typename TB_K>
|
||||
void bf16bf16_grouped_gemm_impl_sm90(
|
||||
void bf16bf16_grouped_gemm_impl_sm90_sm100(
|
||||
at::Tensor mat_a, // bf16
|
||||
at::Tensor mat_b, // bf16
|
||||
std::optional<at::Tensor> offs,
|
||||
|
|
@ -99,14 +113,13 @@ void bf16bf16_grouped_gemm_impl_sm90(
|
|||
constexpr int AlignmentB = 16 / sizeof(DtypeB);
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using TileShape = cute::Shape<TB_M, TB_N, TB_K>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using KernelSchedule =
|
||||
typename Schedule<Pong, TB_M, TB_N, TB_K>::KernelSchedule;
|
||||
typename Schedule<ArchTag, PONGOr2SM, TB_M, TB_N, TB_K>::KernelSchedule;
|
||||
using EpilogueSchedule =
|
||||
typename Schedule<Pong, TB_M, TB_N, TB_K>::EpilogueSchedule;
|
||||
typename Schedule<ArchTag, PONGOr2SM, TB_M, TB_N, TB_K>::EpilogueSchedule;
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<
|
||||
cute::Shape<int32_t, int32_t, int32_t>>; // <M,N,K> per
|
||||
// group
|
||||
|
|
@ -146,8 +159,16 @@ void bf16bf16_grouped_gemm_impl_sm90(
|
|||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
using GemmKernel = cutlass::gemm::kernel::
|
||||
GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
|
||||
|
||||
using GemmKernelBase = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using GemmKernel = std::conditional_t<
|
||||
std::is_same_v<ArchTag, cutlass::arch::Sm100>,
|
||||
at::cuda::detail::enable_3x_kernel_for_sm10<GemmKernelBase>,
|
||||
at::cuda::detail::enable_3x_kernel_for_sm9x<GemmKernelBase>>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
|
|
@ -319,22 +340,49 @@ void dispatch_bf16_grouped_kernel_on_tile_size(
|
|||
// ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) ||
|
||||
// (K >= 2048 && N >= 2048));
|
||||
bool small = (M <= 128 || N <= 128);
|
||||
if (small) {
|
||||
bf16bf16_grouped_gemm_impl_sm90<
|
||||
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
|
||||
const bool sm10x = properties != nullptr && properties->major == 10;
|
||||
|
||||
if (sm10x) {
|
||||
if (small){
|
||||
bf16bf16_grouped_gemm_impl_sm90_sm100<
|
||||
cutlass::arch::Sm100,
|
||||
a_row_major,
|
||||
b_row_major,
|
||||
/*Pong*/ true,
|
||||
/*PONGOr2SM*/ false,
|
||||
cute::_128,
|
||||
cute::_256,
|
||||
cute::_64>(mat_a, mat_b, offs, bias, out); // Tile shape taken from CUTLASS examples, 64 = 128/sizeof(bfloat16)
|
||||
} else {
|
||||
bf16bf16_grouped_gemm_impl_sm90_sm100<
|
||||
cutlass::arch::Sm100,
|
||||
a_row_major,
|
||||
b_row_major,
|
||||
/*PONGOr2SM*/ true,
|
||||
cute::_256,
|
||||
cute::_256,
|
||||
cute::_64>(mat_a, mat_b, offs, bias, out); // Same as above ^
|
||||
}
|
||||
} else {
|
||||
if(small) {
|
||||
bf16bf16_grouped_gemm_impl_sm90_sm100<
|
||||
cutlass::arch::Sm90,
|
||||
a_row_major,
|
||||
b_row_major,
|
||||
/*PONGOr2SM*/ true,
|
||||
cute::_64,
|
||||
cute::_128,
|
||||
cute::_128>(mat_a, mat_b, offs, bias, out);
|
||||
} else {
|
||||
bf16bf16_grouped_gemm_impl_sm90<
|
||||
} else {
|
||||
bf16bf16_grouped_gemm_impl_sm90_sm100<
|
||||
cutlass::arch::Sm90,
|
||||
a_row_major,
|
||||
b_row_major,
|
||||
/*Pong*/ false,
|
||||
/*PONGOr2SM*/ false,
|
||||
cute::_128,
|
||||
cute::_256,
|
||||
cute::_64>(mat_a, mat_b, offs, bias, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,16 @@ struct enable_3x_kernel_for_sm9x : Kernel {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_3x_kernel_for_sm10 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_3x_kernel_for_sm10_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||
"90a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu"
|
||||
"90a")
|
||||
"90a;100a")
|
||||
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from torch.testing._internal.common_cuda import (
|
|||
SM89OrLater,
|
||||
SM90OrLater,
|
||||
xfailIfSM100OrLater,
|
||||
xfailIfSM120OrLater,
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_MX_GEMM,
|
||||
|
|
@ -306,8 +307,8 @@ class TestMatmulCuda(TestCase):
|
|||
self.assertEqual(bgrad, b.grad)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@xfailIfSM100OrLater
|
||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
||||
@xfailIfSM120OrLater
|
||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported only on SM90 and SM100")
|
||||
@parametrize("strided", [False, True])
|
||||
@parametrize("a_row_major", [False, True])
|
||||
@parametrize("b_row_major", [False, True])
|
||||
|
|
@ -345,8 +346,8 @@ class TestMatmulCuda(TestCase):
|
|||
self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@xfailIfSM100OrLater
|
||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
||||
@xfailIfSM120OrLater
|
||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported only on SM90 and SM100")
|
||||
@parametrize("strided", [False, True])
|
||||
@parametrize("a_row_major", [False, True])
|
||||
@parametrize("b_row_major", [False, True])
|
||||
|
|
@ -402,8 +403,8 @@ class TestMatmulCuda(TestCase):
|
|||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@xfailIfSM100OrLater
|
||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
||||
@xfailIfSM120OrLater
|
||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported only on SM90 and SM100")
|
||||
@parametrize("strided", [False, True])
|
||||
@parametrize("a_row_major", [False, True])
|
||||
@parametrize("b_row_major", [False, True])
|
||||
|
|
@ -437,8 +438,8 @@ class TestMatmulCuda(TestCase):
|
|||
self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@xfailIfSM100OrLater
|
||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
||||
@xfailIfSM120OrLater
|
||||
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported only on SM90 and SM100")
|
||||
@parametrize("strided", [False, True])
|
||||
@parametrize("a_row_major", [False, True])
|
||||
@parametrize("b_row_major", [False, True])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user