From d3c2123ea61140985325a5f654d38b295eff4242 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Sat, 28 Sep 2024 21:11:15 +0000 Subject: [PATCH] [BE][CUDA][Bugfix]: Enable extended MMA shapes in CUTLASS. (#133686) * This fixes a major CMake/Bazel configuration bug where we were leaving CUTLASS performance on the table, especially with FlashAttention. This now enables using MMA instructions on SM90+, which should close the gap between SDPA and the external FA2. Note these operations only affect H100 and newer GPUs. Thankfully, this seems to have been updated recently into being a noop on the CUTLASS side. Still better set the CMake variable properly. * Also enables additional new shape kernels added in the recent CUTLASS 3.5.1+ update. This was the original motivatin of the PR before I realized the basic MMA kernels were accidentally disabled since we didn't go through the submodule's CMake/Bazels. * Adds a bit to compile time and code size, but well worth it considering it speeds up our internal flash attention significantly on H100s at the cost of some minor additional compile time. * These kernels and settings will be needed for Flash Attention 3 whenever we add that too. Fixes #133695 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133686 Approved by: https://github.com/ezyang --- aten/src/ATen/CMakeLists.txt | 3 +++ third_party/cutlass.BUILD | 5 +++++ torch/_inductor/codecache.py | 2 ++ 3 files changed, 10 insertions(+) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 1896530c0af..16e4641ddf2 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -467,6 +467,9 @@ if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE) endif() if(USE_CUDA AND NOT USE_ROCM) + add_definitions(-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1) + add_definitions(-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1) + add_definitions(-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) if($ENV{ATEN_STATIC_CUDA}) diff --git a/third_party/cutlass.BUILD b/third_party/cutlass.BUILD index e3e7b7b288e..10100531d9b 100644 --- a/third_party/cutlass.BUILD +++ b/third_party/cutlass.BUILD @@ -13,6 +13,11 @@ cc_library( "tools/util/include/**/*.hpp", "tools/util/include/**/*.inl", ]), + defines = [ + "CUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1", + "CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", + ], includes = [ "include/", "tools/util/include/", diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 17b0e168803..ab39e6660d7 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -3063,6 +3063,8 @@ def _nvcc_compiler_options() -> List[str]: options = [ "-t=0", "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1", + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", "-w", f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", config.cuda.compile_opt_level,