Revert "Add option to limit number of SMs used by matmul kernels (#144974)"

This reverts commit af2d63637e.

Reverted https://github.com/pytorch/pytorch/pull/144974 on behalf of https://github.com/wdvr due to reverting in order to revert #147548 that causes a merge conflict ([comment](https://github.com/pytorch/pytorch/pull/144974#issuecomment-2683461733))
This commit is contained in:
PyTorch MergeBot 2025-02-25 22:46:38 +00:00
parent cc444e75d5
commit 1e894d2635
8 changed files with 5 additions and 127 deletions

View File

@ -433,18 +433,6 @@ void Context::setAllowFP16AccumulationCuBLAS(bool b) {
allow_fp16_accumulation_cublas = b;
}
std::optional<int32_t> Context::_SMCarveout_EXPERIMENTAL() const {
return sm_carveout;
}
void Context::_setSMCarveout_EXPERIMENTAL(std::optional<int32_t> c) {
if (c.has_value()) {
TORCH_WARN_ONCE(
"Setting the SM carveout for matmuls is a temporary experimental mitigation for performance issues, "
"while more robust solutions are developed. It may be removed at any moment without notice.");
}
sm_carveout = c;
}
bool Context::hasMKL() {
#if AT_MKL_ENABLED()

View File

@ -345,19 +345,6 @@ class TORCH_API Context {
void setAllowBF16ReductionCuBLAS(bool);
bool allowFP16AccumulationCuBLAS() const;
void setAllowFP16AccumulationCuBLAS(bool);
// Matmuls can use a so-called "persistent" kernel which launches one CUDA
// block for each SM on the GPU, and each block then iterates over multiple
// output tiles. This allows to use software pipelining to hide the begin/end
// latencies (e.g., epilogue), especially when only one tile fits per SM.
// However, if some SMs are busy (e.g., with a background NCCL kernel), the
// matmul's blocks will be scheduled in two waves and, in the absence of some
// smart load balancing, the kernel will take twice as long. This flag allows
// to make matmuls target only a subset of the SMs, so they can fully schedule
// even next to a comms kernel, and only be a few percent slower.
std::optional<int32_t> _SMCarveout_EXPERIMENTAL() const;
void _setSMCarveout_EXPERIMENTAL(std::optional<int32_t>);
at::QEngine qEngine() const;
void setQEngine(at::QEngine e);
static const std::vector<at::QEngine>& supportedQEngines();
@ -436,7 +423,6 @@ class TORCH_API Context {
bool allow_fp16_reduction_cublas = true;
bool allow_bf16_reduction_cublas = true;
bool allow_fp16_accumulation_cublas = false;
std::optional<int32_t> sm_carveout = std::nullopt;
bool enabled_mkldnn = true;
bool allow_tf32_onednn = false;
bool enabled_nnpack = true;

View File

@ -406,14 +406,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#endif
CuBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == CUBLAS_OP_T);
CuBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == CUBLAS_OP_T);
CuBlasLtMatrixLayout Cdesc(abcType, m, n, ldc);
@ -1340,14 +1332,6 @@ void gemm_and_bias(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#endif
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
@ -1561,14 +1545,6 @@ void scaled_gemm(
if (result_scale_ptr != nullptr) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
}
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#endif
#ifndef USE_ROCM
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode);
@ -1738,14 +1714,7 @@ void int8_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#endif
CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);

View File

@ -46,6 +46,8 @@ C10_DIAGNOSTIC_POP()
namespace {
constexpr int kNumSMsForH100 = 132;
using DtypeScale = float;
using DtypeAccum = float;
using DtypeEpilogue = float;
@ -261,13 +263,6 @@ void f8f8bf16_rowwise_impl(
// multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
arguments.hw_info.sm_count =
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
}
// Set the swizzle size
arguments.scheduler.max_swizzle_size = swizzle;
@ -526,17 +521,12 @@ void dispatch_fp8_rowwise_kernel_on_tile_size(
int M = XQ.size(0);
int N = WQ.size(1);
int smTarget = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount;
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
smTarget -= at::globalContext()._SMCarveout_EXPERIMENTAL().value();
}
// We prefer to use smaller tiles (less wasted compute in case of padding),
// but if this causes us to have more CUDA blocks than there are SMs on the
// GPU then we'll hit wave quantization, hence we'll switch to larger tiles.
if (ceildiv(M, 64 * cute::get<0>(ClusterShape{})) *
ceildiv(N, 128 * cute::get<1>(ClusterShape{})) <=
smTarget / cute::size(ClusterShape{})) {
kNumSMsForH100 / cute::size(ClusterShape{})) {
return f8f8bf16_rowwise_impl<
/*TileShape=*/cute::Shape<cute::_64, cute::_128, cute::_128>,
ClusterShape,

View File

@ -1,14 +1,11 @@
# Owner(s): ["module: linear algebra"]
import contextlib
import json
import math
import re
import tempfile
import unittest
from itertools import product
from functools import partial
from typing import Optional
import re
import torch
@ -21,7 +18,6 @@ from torch.testing import make_tensor
from torch.testing._internal.common_cuda import (
SM53OrLater,
SM89OrLater,
SM90OrLater,
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MX_GEMM
@ -847,45 +843,6 @@ class TestFP8MatmulCuda(TestCase):
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support row-wise scaling")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet")
def test_honor_sm_carveout(self) -> None:
torch.manual_seed(42)
x = torch.randn(8192, 2048, device="cuda", dtype=torch.float32)
y = torch.randn(8192, 2048, device="cuda", dtype=torch.float32).t()
x_scales = tensor_to_scale(x, e4m3_type, dim=1).reciprocal()
y_scales = tensor_to_scale(y, e4m3_type, dim=0).reciprocal()
x_fp8 = to_fp8_saturated(x / x_scales, e4m3_type)
y_fp8 = to_fp8_saturated(y / y_scales, e4m3_type)
with tempfile.NamedTemporaryFile() as f:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
self.assertIsNone(torch._C._get_sm_carveout_experimental())
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
torch._C._set_sm_carveout_experimental(0)
self.assertEqual(torch._C._get_sm_carveout_experimental(), 0)
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
torch._C._set_sm_carveout_experimental(66)
self.assertEqual(torch._C._get_sm_carveout_experimental(), 66)
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
torch._C._set_sm_carveout_experimental(None)
self.assertIsNone(torch._C._get_sm_carveout_experimental())
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
prof.export_chrome_trace(f.name)
no_carveout, carveout_0, carveout_66, no_carveout_again = [
math.prod(evt.get("args", {}).get("grid", []))
for evt in json.load(open(f.name))["traceEvents"]
if evt.get("cat", "") == "kernel"
]
self.assertEqual(no_carveout, no_carveout_again)
self.assertNotEqual(no_carveout, carveout_66)
self.assertNotEqual(carveout_66, carveout_0)
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
@parametrize("test_case_name", [
"a_eye_b_eye",

View File

@ -1216,8 +1216,6 @@ def _get_cublas_allow_fp16_accumulation() -> _bool: ... # THPModule_allowFP16Acc
def _set_cublas_allow_fp16_accumulation(
arg: _bool,
) -> None: ... # THPModule_setAllowFP16AccumulationCuBLAS
def _get_sm_carveout_experimental() -> Optional[_int]: ...
def _set_sm_carveout_experimental(arg: Optional[_int]) -> None: ...
def _set_conj(x: Tensor, conj: _bool) -> None: ...
def _set_neg(x: Tensor, neg: _bool) -> None: ...
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...

View File

@ -639,7 +639,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._get_privateuse1_backend_name",
"torch._C._get_qengine",
"torch._C._get_schema",
"torch._C._get_sm_carveout_experimental",
"torch._C._get_nested_int",
"torch._C._get_tensor_metadata",
"torch._C._get_tracing_state",
@ -1158,7 +1157,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._set_math_sdp_allow_fp16_bf16_reduction",
"torch._C._set_sdp_use_mem_efficient",
"torch._C._set_should_use_format_with_string_table",
"torch._C._set_sm_carveout_experimental",
"torch._C._set_storage_access_error_msg",
"torch._C._set_tensor_metadata",
"torch._C._set_tracing_state",

View File

@ -2262,14 +2262,6 @@ Call this whenever a new thread is created in order to propagate values from
return at::globalContext().getROCmFAPreferredBackend();
});
py_module.def(
"_set_sm_carveout_experimental", [](std::optional<int32_t> val) {
at::globalContext()._setSMCarveout_EXPERIMENTAL(val);
});
py_module.def("_get_sm_carveout_experimental", []() {
return at::globalContext()._SMCarveout_EXPERIMENTAL();
});
py_module.def(
"_construct_storage_from_data_pointer",
[](int64_t data_ptr, c10::Device device, size_t size_bytes) {