diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 9bfc74b4ed2..04a212bb8e6 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -315,9 +315,6 @@ void f8f8bf16_rowwise_impl_sm89( using LayoutInputB = cutlass::layout::ColumnMajor; constexpr int AlignmentInputB = 16 / sizeof(DtypeB); - constexpr int AlignmentScale = 16 / sizeof(DtypeScale); - constexpr int AlignmentBias = 16 / sizeof(DtypeBias); - using LayoutOutput = cutlass::layout::RowMajor; constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); @@ -330,8 +327,8 @@ void f8f8bf16_rowwise_impl_sm89( // TODO: instead of fixing these values, implement logic alike to // what is used for SM90+. - using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; - using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; constexpr auto NumStages = 4; @@ -341,20 +338,6 @@ void f8f8bf16_rowwise_impl_sm89( cutlass::arch::OpMultiplyAdd>; constexpr auto NumEVTEpilogueStages = 1; - using ScaleTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - DtypeScale, - AlignmentScale, - NumEVTEpilogueStages>; - using BiasTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - DtypeBias, - AlignmentBias, - NumEVTEpilogueStages>; using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< ThreadblockShape, @@ -365,25 +348,19 @@ void f8f8bf16_rowwise_impl_sm89( using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - using XScale = - cutlass::epilogue::threadblock::VisitorColBroadcast< - ScaleTileThreadMap, - DtypeScale, - cute::Stride>; + using XScale = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, DtypeScale, + cute::Stride>; using XScaleArguments = typename XScale::Arguments; - using WScale = - cutlass::epilogue::threadblock::VisitorRowBroadcast< - ScaleTileThreadMap, - DtypeScale, - cute::Stride>; + using WScale = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, DtypeScale, + cute::Stride>; using WScaleArguments = typename WScale::Arguments; - using Bias = - cutlass::epilogue::threadblock::VisitorRowBroadcast< - BiasTileThreadMap, - DtypeBias, - cute::Stride>; + using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, DtypeBias, + cute::Stride>; using BiasArguments = typename Bias::Arguments; using ApplyXScale = cutlass::epilogue::threadblock::VisitorCompute< @@ -423,8 +400,7 @@ void f8f8bf16_rowwise_impl_sm89( Output, EVTApplyBias>; - using EVTKernel = - typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + using EVTKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< DtypeA, LayoutInputA, cutlass::ComplexTransform::kNone, AlignmentInputA, DtypeB, LayoutInputB, cutlass::ComplexTransform::kNone, AlignmentInputB, DtypeOutput, LayoutOutput, AlignmentOutput, @@ -442,7 +418,7 @@ void f8f8bf16_rowwise_impl_sm89( NumEVTEpilogueStages >::GemmKernel; - using Gemm = cutlass::gemm::device::GemmUniversalBase; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; cutlass::gemm::GemmCoord problem_size(M, N, K); constexpr auto SplitKFactor = 1; @@ -475,14 +451,13 @@ void f8f8bf16_rowwise_impl_sm89( {} // ApplyXScale }, // EVTApplyXScale w_scale_arguments, // WScale - {}, // ApplyWScale + {} // ApplyWScale }, // EVTApplyWScale bias_arguments, // Bias {} // ApplyBias }, // EVTApplyBias output_arguments // Output }; // EVTOutput - constexpr auto AvailSms = -1; typename Gemm::Arguments arguments( cutlass::gemm::GemmUniversalMode::kGemm, @@ -500,8 +475,7 @@ void f8f8bf16_rowwise_impl_sm89( problem_size.k(), // stride A problem_size.k(), // stride B 0, // stride C (unused) - 0, // stride D (unused) - AvailSms); + 0); // stride D (unused) Gemm gemm; diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index 21eb9219b9a..73f36831c0b 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -76,7 +76,7 @@ if(INTERN_BUILD_ATEN_OPS) file(GLOB_RECURSE all_python "${CMAKE_CURRENT_LIST_DIR}/../torchgen/*.py") - # RowwiseScaled.cu requires sm90a flags + # RowwiseScaled.cu requires sm89/sm90a flags if(USE_CUDA) set(ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu") @@ -84,11 +84,17 @@ if(INTERN_BUILD_ATEN_OPS) torch_cuda_get_nvcc_gencode_flag(EXISTING_ARCH_FLAGS) # Check NVCC version and existing arch flags - if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND - EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*") - set_source_files_properties(${ROWWISE_SCALED_MM_FILE} - PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") + set(ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "") + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0) + if(EXISTING_ARCH_FLAGS MATCHES ".*compute_86.*") + list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_89,code=sm_89") + endif() + if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*") + list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a") + endif() endif() + list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS) + set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}") endif() set(GEN_ROCM_FLAG) diff --git a/test/distributed/tensor/test_matrix_ops.py b/test/distributed/tensor/test_matrix_ops.py index 1c0bee62ff1..739cb709d16 100644 --- a/test/distributed/tensor/test_matrix_ops.py +++ b/test/distributed/tensor/test_matrix_ops.py @@ -136,7 +136,10 @@ class DistMatrixOpsTest(DTensorTestBase): @with_comms @skip_unless_torch_gpu - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "torch._scaled_mm requires H100+") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FP8, + "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", + ) def test_scaled_mm(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) shrd0 = Shard(0) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index b2407c7fdd6..b1309ac3ee9 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -708,8 +708,8 @@ class AOTInductorTestsTemplate: self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) @unittest.skipIf( - not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0), - "FP8 is only supported on H100+ and sm_89 and MI300+ devices", + not PLATFORM_SUPPORTS_FP8, + "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", ) @skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform @skipIfXpu @@ -756,8 +756,8 @@ class AOTInductorTestsTemplate: ) @unittest.skipIf( - not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0), - "FP8 is only supported on H100+", + not PLATFORM_SUPPORTS_FP8, + "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", ) @skipIfRocm # _scaled_mm_out_cuda is not compiled for ROCm platform @skipIfXpu @@ -3324,7 +3324,7 @@ class AOTInductorTestsTemplate: @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf( not PLATFORM_SUPPORTS_FP8, - "FP8 is only supported on H100+ and sm_89 and MI300+ devices", + "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", ) def test_runtime_checks_fp8(self): # cuda only diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index 9f4bad06f04..9d71bb6a8f7 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -21,7 +21,7 @@ from torch.utils._triton import has_triton_tma_device torch.set_float32_matmul_precision("high") -f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" +f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" # define the e4m3/e5m2 constants E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 7737ee78d89..1e4f5f0213a 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1259,7 +1259,7 @@ class TestPrologueFusion(TestCase): @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf( not PLATFORM_SUPPORTS_FP8, - "FP8 is only supported on H100+ and sm_89 and MI300+ devices", + "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", ) def test_low_precision(self): M = K = N = 128 diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index 84e6de64cd7..58400d86a81 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -839,7 +839,7 @@ class TestFlopCounter(TestCase): @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FP8, - "Does not support fp8 (pre-SM90 hardware on CUDA)", + "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", ) def test_scaled_mm(self): dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 2d85f597c6c..e0eaa52f093 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -17,6 +17,7 @@ from torch.quantization._quantized_conversions import ( from torch.testing import make_tensor from torch.testing._internal.common_cuda import ( SM53OrLater, + SM89OrLater, _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8 ) @@ -42,10 +43,8 @@ from torch.testing._internal.common_utils import ( ) _IS_SM8X = False -_IS_SM9X = False if TEST_CUDA: _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 - _IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9 # Protects against includes accidentally setting the default dtype assert torch.get_default_dtype() is torch.float32 @@ -213,7 +212,7 @@ class TestMatmulCuda(TestCase): self.assertEqual(out1_gpu, out2_gpu[0]) -f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" +f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" if torch.version.hip: e4m3_type = torch.float8_e4m3fnuz @@ -538,8 +537,7 @@ class TestFP8MatmulCuda(TestCase): lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), ) - @unittest.skipIf(PLATFORM_SUPPORTS_FP8, - "This test is only for devices with compute capability < 8.9") + @unittest.skipIf(PLATFORM_SUPPORTS_FP8, f8_msg) def test_error_message_fp8_pre_sm89(self, device) -> None: (k, l, m) = (16, 48, 32) x = torch.rand((k, l), device=device).to(e4m3_type) @@ -567,7 +565,7 @@ class TestFP8MatmulCuda(TestCase): self.assertEqual(out_fp8, out_fp8_s) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) - @unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific") + @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific") @parametrize("use_fast_accum", [True, False]) def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: M, K, N = (1024, 512, 2048) @@ -673,7 +671,7 @@ class TestFP8MatmulCuda(TestCase): ) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) - @unittest.skipIf(not _IS_SM9X, "rowwise implementation is currently sm90 specific") + @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific") @parametrize("base_dtype", [torch.bfloat16]) def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): torch.manual_seed(42) diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 687e92b5df9..54efc4921c4 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -1047,7 +1047,10 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS: self.skipTest('cuSPARSELt not enabled') - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FP8, + "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", + ) @xfailIfSM89 @parametrize("dense_input_shape", [(256, 128)]) def test_sparse_fp8fp8_mm(self, dense_input_shape, device): @@ -1067,7 +1070,10 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): ): dense_result = torch.mm(A_fp8_sparse, B_fp8) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FP8, + "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", + ) @xfailIfSM89 def test_sparse_semi_structured_scaled_mm_fp8(self, device) -> None: (k, l, m) = (32, 64, 32) @@ -1084,7 +1090,10 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): out_fp32_sparse = out_fp8_sparse.to(torch.float32) torch.testing.assert_close(out_fp32, out_fp32_sparse, rtol=1e-1, atol=1e-1) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+ and sm_89 and MI300+ devices") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FP8, + "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", + ) @xfailIfSM89 @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float32]) @parametrize("dense_input_shape", [(256, 128)]) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 72e2dea64e8..9d6aff15ddf 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -31,7 +31,7 @@ from torch.testing._internal.common_device_type import \ toleranceOverride, tol) from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, - SM53OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, + SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, _get_torch_rocm_version, ) from torch.testing._internal.common_utils import ( @@ -16211,7 +16211,7 @@ op_db: list[OpInfo] = [ supports_out=True, supports_forward_ad=False, supports_autograd=False, - decorators=[skipCUDAIf(not SM90OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 9.0')], + decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], skips=( # Sample inputs isn't really parametrized on dtype DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',