[CUTLASS] [CUDA] SM100 GroupMM (#156203)

Closes https://github.com/pytorch/pytorch/issues/156202

PR adds blackwell support for GroupMM

Most of the code that is used for SM90 can be reused, kernel schedule has to be changed in accordance with https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html

Did some preliminary benchmarking of H200 vs B200

Script
```py
import torch
print(torch.__file__)
device = torch.device("cuda")
dtype = torch.bfloat16

shapes = [
    (16, 128000, 7168, 7168),
    (128, 1, 2048, 7168)
]

for batch, M, N, K in shapes:
    a = torch.randn(batch, M, K, device=device, dtype=dtype)
    b = torch.randn(batch, N, K, device=device, dtype=dtype)

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    for i in range(5): c = torch._grouped_mm(a, b)

    num_iter = 50
    start_event.record()

    for i in range(num_iter): c = torch._grouped_mm(a, b)
    end_event.record()

    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    avg_time_ms = elapsed_time_ms / num_iter
    print(f"batch: {batch}\tM: {M}\tN: {N}\tK: {K}")
    print(f"Time per Iteration:\t {avg_time_ms:.4f} ms")
```

On H200
```
batch: 16	M: 128000	N: 7168	K: 7168
Time per Iteration:	 298.6668 ms
batch: 128	M: 1	N: 2048	K: 7168
Time per Iteration:	 4.1462 ms
```

B200
```
batch: 16       M: 128000       N: 7168 K: 7168
Time per Iteration:      190.7458 ms
batch: 128      M: 1    N: 2048 K: 7168
Time per Iteration:      3.0680 ms
```
nsys nvprof
```
root@16930b42ffc6:/workspace/pytorch# nsys nvprof python gemm_test.py
WARNING: python and any of its children processes will be profiled.

Collecting data...
batch: 16	M: 128000	N: 7168	K: 7168
Time per Iteration:	 192.6420 ms
batch: 128	M: 1	N: 2048	K: 7168
Time per Iteration:	 1.2255 ms
Generating '/tmp/nsys-report-6a53.qdstrm'
[1/7] [========================100%] report1.nsys-rep
[2/7] [========================100%] report1.sqlite
[3/7] Executing 'nvtx_sum' stats report
SKIPPED: /workspace/pytorch/report1.sqlite does not contain NV Tools Extension (NVTX) data.
[4/7] Executing 'cuda_api_sum' stats report

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)      Med (ns)    Min (ns)   Max (ns)    StdDev (ns)                 Name
 --------  ---------------  ---------  ------------  ------------  --------  -----------  ------------  ---------------------------------
     98.9      10586895744          2  5293447872.0  5293447872.0  73786464  10513109280  7381715954.2  cudaDeviceSynchronize
      1.0        104084608          5    20816921.6    33552480.0    100800     34786208    18048125.3  cudaMalloc
      0.1          5694304          4     1423576.0     1416656.0   1258560      1602432      181668.1  cudaGetDeviceProperties_v2_v12000
      0.1          5430496        130       41773.0        4560.0      2496      3854368      345761.8  cudaLaunchKernel
      0.0           587584        110        5341.7        4992.0      4224        16992        1482.0  cudaLaunchKernelExC_v11060
      0.0           119200        660         180.6         128.0        96         4128         206.7  cudaGetDriverEntryPoint_v11030
      0.0            68352        660         103.6          64.0        32         4928         224.6  cuTensorMapEncodeTiled
      0.0            34976         49         713.8         224.0       160         6720        1343.4  cudaStreamIsCapturing_v10000
      0.0            32992          4        8248.0        7456.0      4128        13952        4804.4  cudaEventRecord
      0.0            16928          4        4232.0        3600.0      1728         8000        2764.7  cudaEventQuery
      0.0            16288          4        4072.0        3568.0      1952         7200        2396.1  cudaEventCreateWithFlags
      0.0            13632          4        3408.0        2672.0       544         7744        3408.7  cudaEventDestroy
      0.0             1056          1        1056.0        1056.0      1056         1056           0.0  cuModuleGetLoadingMode

[5/7] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances   Avg (ns)     Med (ns)    Min (ns)   Max (ns)   StdDev (ns)                                                  Name
 --------  ---------------  ---------  -----------  -----------  ---------  ---------  -----------  ----------------------------------------------------------------------------------------------------
     99.0      10549232845         55  191804233.5  192944479.0  165746368  203645313    5353204.3  void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm…
      0.6         67327135         55    1224129.7    1330656.0     924320    1364928     182180.4  void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::Gemm…
      0.3         34854783         20    1742739.1    1597856.0      10080    3899616     818421.2  void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
      0.0           354880        110       3226.2       3296.0       1920       4160        554.4  void at::cuda::detail::prepare_grouped_gemm_data<cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass:…
```

The kernel names are too long to be shown via nvprof, I pasted this from nsight systems
```
small kernel 1SM
100.0%	1.286 ms	1	1.286 ms	1.286 ms	1.286 ms	1.286 ms	0 ns	void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)3, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)128, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM90_TMA_LOAD_MULTICAST, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params)

large kernel 2SM
100.0%	194.178 ms	1	194.178 ms	194.178 ms	194.178 ms	194.178 ms	0 ns	void cutlass::device_kernel<at::cuda::detail::enable_3x_kernel_for_sm10<cutlass::gemm::kernel::GemmUniversal<cutlass::gemm::GroupProblemShape<cute::tuple<int, int, int>>, cutlass::gemm::collective::CollectiveMma<cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized<(int)5, (int)8, (int)2, cute::tuple<cute::C<(int)2>, cute::C<(int)1>, cute::C<(int)1>>>, cute::tuple<cute::C<(int)256>, cute::C<(int)256>, cute::C<(int)64>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<cute::C<(int)1>, long, cute::C<(int)0>> *, cute::TiledMMA<cute::MMA_Atom<cute::SM100_MMA_F16BF16_2x1SM_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, (int)256, (int)256, (cute::UMMA::Major)0, (cute::UMMA::Major)1, (cute::UMMA::ScaleIn)0, (cute::UMMA::ScaleIn)0>>, cute::Layout<cute::tuple<cute::C<(int)1>, cute::C<(int)1>, cute::C<(int)1>>, cute::tuple<cute::C<(int)0>, cute::C<(int)0>, cute::C<(int)0>>>, cute::tuple<cute::Underscore, cute::Underscore, cute::Underscore>>, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, void, cute::identity, cute::SM100_TMA_2SM_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)64>, cute::C<(int)8>>, cute::tuple<cute::C<(int)1>, cute::C<(int)64>>>>, void, cute::identity>, cutlass::epilogue::collective::CollectiveEpilogue<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::bfloat16_t, cute::tuple<long, cute::C<(int)1>, cute::C<(int)0>> *, cutlass::epilogue::fusion::FusionCallbacks<cutlass::epilogue::Sm100PtrArrayTmaWarpSpecialized<(int)4, (int)2, (int)64, (bool)1, (bool)0>, cutlass::epilogue::fusion::LinearCombination<cutlass::bfloat16_t, float, cutlass::bfloat16_t, float, (cutlass::FloatRoundStyle)2>, cute::tuple<cute::C<(int)128>, cute::C<(int)256>, cute::C<(int)64>>, cute::tuple<cute::Layout<cute::C<(int)128>, cute::C<(int)1>>, cute::Layout<cute::C<(int)64>, cute::C<(int)1>>>, >, cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x, cute::SM90_TMA_LOAD, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::SM90_TMA_STORE, cute::ComposedLayout<cute::Swizzle<(int)3, (int)4, (int)3>, cute::smem_ptr_flag_bits<(int)16>, cute::Layout<cute::tuple<cute::C<(int)8>, cute::C<(int)64>>, cute::tuple<cute::C<(int)64>, cute::C<(int)1>>>>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>, cute::AutoVectorizingCopyWithAssumedAlignment<(int)128>>, void, void>>>(T1::Params)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156203
Approved by: https://github.com/syed-ahmed, https://github.com/drisspg
This commit is contained in:
AaronWang04 2025-06-28 23:01:57 +00:00 committed by PyTorch MergeBot
parent 996206e66f
commit 772d590415
5 changed files with 93 additions and 33 deletions

View File

@ -1044,7 +1044,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
return _int_mm_out_cuda(self, mat2, result);
}
static bool _scaled_mm_allowed_device(bool sm90_only=false) {
static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) {
#ifdef USE_ROCM
static const std::vector<std::string> archs = {
"gfx942",
@ -1058,8 +1058,9 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false) {
return at::detail::getCUDAHooks().isGPUArch(archs);
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (sm90_only) {
return dprops->major == 9;
if (sm90_only || sm100_only) {
return (sm90_only && dprops->major == 9) || (sm100_only && dprops->major == 10);
} else {
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
}
@ -1675,8 +1676,8 @@ const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
std::optional<c10::ScalarType> out_dtype) {
#ifndef USE_ROCM
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true);
TORCH_CHECK(allowed_device, "torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0");
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
TORCH_CHECK(allowed_device, "torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0, 10.0");
TORCH_CHECK(mat_a.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_a.scalar_type());
TORCH_CHECK(mat_b.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_b.scalar_type());

View File

@ -8,9 +8,10 @@
#include <c10/util/irange.h>
// Two warninngs in Cutlass included header files
// Three warninngs in Cutlass included header files
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable")
// Determine if the architecture supports rowwise scaled mm
// Currently failing on windows with:
@ -43,11 +44,14 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <ATen/native/cuda/cutlass_common.cuh>
namespace {
using Strides = at::cuda::detail::Strides; // std::array<int64_t, 3>;
template <bool PONG, typename TB_M, typename TB_N, typename TB_K>
template <typename ArchTag, bool PONGOr2SM, typename TB_M, typename TB_N, typename TB_K>
struct Schedule {
// SM90
using CooperativeSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using PongSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
@ -55,10 +59,19 @@ struct Schedule {
cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using PongEpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using KernelSchedule =
cute::conditional_t<PONG, PongSchedule, CooperativeSchedule>;
using EpilogueSchedule = cute::
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;
// SM100
using MMA1SMKernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
using MMA1SMEpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using MMA2SMKernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;
using MMA2SMEpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
using KernelSchedule = cute::conditional_t<std::is_same_v<ArchTag, cutlass::arch::Sm100>,
cute::conditional_t<PONGOr2SM, MMA2SMKernelSchedule, MMA1SMKernelSchedule>,
cute::conditional_t<PONGOr2SM, PongSchedule, CooperativeSchedule>>;
using EpilogueSchedule = cute::conditional_t<std::is_same_v<ArchTag, cutlass::arch::Sm100>,
cute::conditional_t<PONGOr2SM, MMA2SMEpilogueSchedule, MMA1SMEpilogueSchedule>,
cute::conditional_t<PONGOr2SM, PongEpilogueSchedule, CooperativeEpilogueSchedule>>;
};
int ceildiv(int a, int b) {
@ -70,13 +83,14 @@ int round_up_to_nearest_multiple(int a, int b) {
}
template <
typename ArchTag,
bool a_row_major,
bool b_row_major,
bool Pong,
bool PONGOr2SM,
typename TB_M,
typename TB_N,
typename TB_K>
void bf16bf16_grouped_gemm_impl_sm90(
void bf16bf16_grouped_gemm_impl_sm90_sm100(
at::Tensor mat_a, // bf16
at::Tensor mat_b, // bf16
std::optional<at::Tensor> offs,
@ -99,14 +113,13 @@ void bf16bf16_grouped_gemm_impl_sm90(
constexpr int AlignmentB = 16 / sizeof(DtypeB);
using LayoutOutput = cutlass::layout::RowMajor;
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = cute::Shape<TB_M, TB_N, TB_K>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using KernelSchedule =
typename Schedule<Pong, TB_M, TB_N, TB_K>::KernelSchedule;
typename Schedule<ArchTag, PONGOr2SM, TB_M, TB_N, TB_K>::KernelSchedule;
using EpilogueSchedule =
typename Schedule<Pong, TB_M, TB_N, TB_K>::EpilogueSchedule;
typename Schedule<ArchTag, PONGOr2SM, TB_M, TB_N, TB_K>::EpilogueSchedule;
using ProblemShape = cutlass::gemm::GroupProblemShape<
cute::Shape<int32_t, int32_t, int32_t>>; // <M,N,K> per
// group
@ -146,8 +159,16 @@ void bf16bf16_grouped_gemm_impl_sm90(
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::
GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
using GemmKernelBase = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
CollectiveMainloop,
CollectiveEpilogue>;
using GemmKernel = std::conditional_t<
std::is_same_v<ArchTag, cutlass::arch::Sm100>,
at::cuda::detail::enable_3x_kernel_for_sm10<GemmKernelBase>,
at::cuda::detail::enable_3x_kernel_for_sm9x<GemmKernelBase>>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
@ -319,22 +340,49 @@ void dispatch_bf16_grouped_kernel_on_tile_size(
// ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) ||
// (K >= 2048 && N >= 2048));
bool small = (M <= 128 || N <= 128);
if (small) {
bf16bf16_grouped_gemm_impl_sm90<
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
const bool sm10x = properties != nullptr && properties->major == 10;
if (sm10x) {
if (small){
bf16bf16_grouped_gemm_impl_sm90_sm100<
cutlass::arch::Sm100,
a_row_major,
b_row_major,
/*Pong*/ true,
/*PONGOr2SM*/ false,
cute::_128,
cute::_256,
cute::_64>(mat_a, mat_b, offs, bias, out); // Tile shape taken from CUTLASS examples, 64 = 128/sizeof(bfloat16)
} else {
bf16bf16_grouped_gemm_impl_sm90_sm100<
cutlass::arch::Sm100,
a_row_major,
b_row_major,
/*PONGOr2SM*/ true,
cute::_256,
cute::_256,
cute::_64>(mat_a, mat_b, offs, bias, out); // Same as above ^
}
} else {
if(small) {
bf16bf16_grouped_gemm_impl_sm90_sm100<
cutlass::arch::Sm90,
a_row_major,
b_row_major,
/*PONGOr2SM*/ true,
cute::_64,
cute::_128,
cute::_128>(mat_a, mat_b, offs, bias, out);
} else {
bf16bf16_grouped_gemm_impl_sm90<
} else {
bf16bf16_grouped_gemm_impl_sm90_sm100<
cutlass::arch::Sm90,
a_row_major,
b_row_major,
/*Pong*/ false,
/*PONGOr2SM*/ false,
cute::_128,
cute::_256,
cute::_64>(mat_a, mat_b, offs, bias, out);
}
}
}

View File

@ -25,6 +25,16 @@ struct enable_3x_kernel_for_sm9x : Kernel {
}
};
template <typename Kernel>
struct enable_3x_kernel_for_sm10 : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_3x_kernel_for_sm10_or_later : Kernel {
template <typename... Args>

View File

@ -128,7 +128,7 @@ if(INTERN_BUILD_ATEN_OPS)
"90a")
_BUILD_FOR_ADDITIONAL_ARCHS(
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu"
"90a")
"90a;100a")
endif()

View File

@ -24,6 +24,7 @@ from torch.testing._internal.common_cuda import (
SM89OrLater,
SM90OrLater,
xfailIfSM100OrLater,
xfailIfSM120OrLater,
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MX_GEMM,
@ -306,8 +307,8 @@ class TestMatmulCuda(TestCase):
self.assertEqual(bgrad, b.grad)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@xfailIfSM120OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported only on SM90 and SM100")
@parametrize("strided", [False, True])
@parametrize("a_row_major", [False, True])
@parametrize("b_row_major", [False, True])
@ -345,8 +346,8 @@ class TestMatmulCuda(TestCase):
self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@xfailIfSM120OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported only on SM90 and SM100")
@parametrize("strided", [False, True])
@parametrize("a_row_major", [False, True])
@parametrize("b_row_major", [False, True])
@ -402,8 +403,8 @@ class TestMatmulCuda(TestCase):
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@xfailIfSM120OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported only on SM90 and SM100")
@parametrize("strided", [False, True])
@parametrize("a_row_major", [False, True])
@parametrize("b_row_major", [False, True])
@ -437,8 +438,8 @@ class TestMatmulCuda(TestCase):
self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@xfailIfSM100OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@xfailIfSM120OrLater
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported only on SM90 and SM100")
@parametrize("strided", [False, True])
@parametrize("a_row_major", [False, True])
@parametrize("b_row_major", [False, True])