mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Add option to limit number of SMs used by matmul kernels (#144974)"
This reverts commit af2d63637e.
Reverted https://github.com/pytorch/pytorch/pull/144974 on behalf of https://github.com/wdvr due to reverting in order to revert #147548 that causes a merge conflict ([comment](https://github.com/pytorch/pytorch/pull/144974#issuecomment-2683461733))
This commit is contained in:
parent
cc444e75d5
commit
1e894d2635
|
|
@ -433,18 +433,6 @@ void Context::setAllowFP16AccumulationCuBLAS(bool b) {
|
|||
allow_fp16_accumulation_cublas = b;
|
||||
}
|
||||
|
||||
std::optional<int32_t> Context::_SMCarveout_EXPERIMENTAL() const {
|
||||
return sm_carveout;
|
||||
}
|
||||
|
||||
void Context::_setSMCarveout_EXPERIMENTAL(std::optional<int32_t> c) {
|
||||
if (c.has_value()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Setting the SM carveout for matmuls is a temporary experimental mitigation for performance issues, "
|
||||
"while more robust solutions are developed. It may be removed at any moment without notice.");
|
||||
}
|
||||
sm_carveout = c;
|
||||
}
|
||||
|
||||
bool Context::hasMKL() {
|
||||
#if AT_MKL_ENABLED()
|
||||
|
|
|
|||
|
|
@ -345,19 +345,6 @@ class TORCH_API Context {
|
|||
void setAllowBF16ReductionCuBLAS(bool);
|
||||
bool allowFP16AccumulationCuBLAS() const;
|
||||
void setAllowFP16AccumulationCuBLAS(bool);
|
||||
|
||||
// Matmuls can use a so-called "persistent" kernel which launches one CUDA
|
||||
// block for each SM on the GPU, and each block then iterates over multiple
|
||||
// output tiles. This allows to use software pipelining to hide the begin/end
|
||||
// latencies (e.g., epilogue), especially when only one tile fits per SM.
|
||||
// However, if some SMs are busy (e.g., with a background NCCL kernel), the
|
||||
// matmul's blocks will be scheduled in two waves and, in the absence of some
|
||||
// smart load balancing, the kernel will take twice as long. This flag allows
|
||||
// to make matmuls target only a subset of the SMs, so they can fully schedule
|
||||
// even next to a comms kernel, and only be a few percent slower.
|
||||
std::optional<int32_t> _SMCarveout_EXPERIMENTAL() const;
|
||||
void _setSMCarveout_EXPERIMENTAL(std::optional<int32_t>);
|
||||
|
||||
at::QEngine qEngine() const;
|
||||
void setQEngine(at::QEngine e);
|
||||
static const std::vector<at::QEngine>& supportedQEngines();
|
||||
|
|
@ -436,7 +423,6 @@ class TORCH_API Context {
|
|||
bool allow_fp16_reduction_cublas = true;
|
||||
bool allow_bf16_reduction_cublas = true;
|
||||
bool allow_fp16_accumulation_cublas = false;
|
||||
std::optional<int32_t> sm_carveout = std::nullopt;
|
||||
bool enabled_mkldnn = true;
|
||||
bool allow_tf32_onednn = false;
|
||||
bool enabled_nnpack = true;
|
||||
|
|
|
|||
|
|
@ -406,14 +406,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
|||
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
computeDesc.setAttribute<int32_t>(
|
||||
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
|
||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
}
|
||||
#endif
|
||||
CuBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == CUBLAS_OP_T);
|
||||
CuBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == CUBLAS_OP_T);
|
||||
CuBlasLtMatrixLayout Cdesc(abcType, m, n, ldc);
|
||||
|
|
@ -1340,14 +1332,6 @@ void gemm_and_bias(
|
|||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
|
||||
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
computeDesc.setAttribute<int32_t>(
|
||||
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
|
||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
}
|
||||
#endif
|
||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
|
||||
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
|
||||
|
|
@ -1561,14 +1545,6 @@ void scaled_gemm(
|
|||
if (result_scale_ptr != nullptr) {
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
||||
}
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
computeDesc.setAttribute<int32_t>(
|
||||
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
|
||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
}
|
||||
#endif
|
||||
#ifndef USE_ROCM
|
||||
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode);
|
||||
|
|
@ -1738,14 +1714,7 @@ void int8_gemm(
|
|||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
|
||||
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
computeDesc.setAttribute<int32_t>(
|
||||
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
|
||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
|
||||
CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ C10_DIAGNOSTIC_POP()
|
|||
|
||||
namespace {
|
||||
|
||||
constexpr int kNumSMsForH100 = 132;
|
||||
|
||||
using DtypeScale = float;
|
||||
using DtypeAccum = float;
|
||||
using DtypeEpilogue = float;
|
||||
|
|
@ -261,13 +263,6 @@ void f8f8bf16_rowwise_impl(
|
|||
// multiplication computation
|
||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Ensure persistent kernels leave enough free SMs for NCCL background ops.
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
arguments.hw_info.sm_count =
|
||||
at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value();
|
||||
}
|
||||
|
||||
// Set the swizzle size
|
||||
arguments.scheduler.max_swizzle_size = swizzle;
|
||||
|
||||
|
|
@ -526,17 +521,12 @@ void dispatch_fp8_rowwise_kernel_on_tile_size(
|
|||
int M = XQ.size(0);
|
||||
int N = WQ.size(1);
|
||||
|
||||
int smTarget = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount;
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
smTarget -= at::globalContext()._SMCarveout_EXPERIMENTAL().value();
|
||||
}
|
||||
|
||||
// We prefer to use smaller tiles (less wasted compute in case of padding),
|
||||
// but if this causes us to have more CUDA blocks than there are SMs on the
|
||||
// GPU then we'll hit wave quantization, hence we'll switch to larger tiles.
|
||||
if (ceildiv(M, 64 * cute::get<0>(ClusterShape{})) *
|
||||
ceildiv(N, 128 * cute::get<1>(ClusterShape{})) <=
|
||||
smTarget / cute::size(ClusterShape{})) {
|
||||
kNumSMsForH100 / cute::size(ClusterShape{})) {
|
||||
return f8f8bf16_rowwise_impl<
|
||||
/*TileShape=*/cute::Shape<cute::_64, cute::_128, cute::_128>,
|
||||
ClusterShape,
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
# Owner(s): ["module: linear algebra"]
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
from itertools import product
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -21,7 +18,6 @@ from torch.testing import make_tensor
|
|||
from torch.testing._internal.common_cuda import (
|
||||
SM53OrLater,
|
||||
SM89OrLater,
|
||||
SM90OrLater,
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_MX_GEMM
|
||||
|
|
@ -847,45 +843,6 @@ class TestFP8MatmulCuda(TestCase):
|
|||
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support row-wise scaling")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet")
|
||||
def test_honor_sm_carveout(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
x = torch.randn(8192, 2048, device="cuda", dtype=torch.float32)
|
||||
y = torch.randn(8192, 2048, device="cuda", dtype=torch.float32).t()
|
||||
x_scales = tensor_to_scale(x, e4m3_type, dim=1).reciprocal()
|
||||
y_scales = tensor_to_scale(y, e4m3_type, dim=0).reciprocal()
|
||||
x_fp8 = to_fp8_saturated(x / x_scales, e4m3_type)
|
||||
y_fp8 = to_fp8_saturated(y / y_scales, e4m3_type)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
|
||||
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
torch._C._set_sm_carveout_experimental(0)
|
||||
self.assertEqual(torch._C._get_sm_carveout_experimental(), 0)
|
||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
torch._C._set_sm_carveout_experimental(66)
|
||||
self.assertEqual(torch._C._get_sm_carveout_experimental(), 66)
|
||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
torch._C._set_sm_carveout_experimental(None)
|
||||
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
|
||||
prof.export_chrome_trace(f.name)
|
||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [
|
||||
math.prod(evt.get("args", {}).get("grid", []))
|
||||
for evt in json.load(open(f.name))["traceEvents"]
|
||||
if evt.get("cat", "") == "kernel"
|
||||
]
|
||||
|
||||
self.assertEqual(no_carveout, no_carveout_again)
|
||||
self.assertNotEqual(no_carveout, carveout_66)
|
||||
self.assertNotEqual(carveout_66, carveout_0)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
|
||||
@parametrize("test_case_name", [
|
||||
"a_eye_b_eye",
|
||||
|
|
|
|||
|
|
@ -1216,8 +1216,6 @@ def _get_cublas_allow_fp16_accumulation() -> _bool: ... # THPModule_allowFP16Acc
|
|||
def _set_cublas_allow_fp16_accumulation(
|
||||
arg: _bool,
|
||||
) -> None: ... # THPModule_setAllowFP16AccumulationCuBLAS
|
||||
def _get_sm_carveout_experimental() -> Optional[_int]: ...
|
||||
def _set_sm_carveout_experimental(arg: Optional[_int]) -> None: ...
|
||||
def _set_conj(x: Tensor, conj: _bool) -> None: ...
|
||||
def _set_neg(x: Tensor, neg: _bool) -> None: ...
|
||||
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...
|
||||
|
|
|
|||
|
|
@ -639,7 +639,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch._C._get_privateuse1_backend_name",
|
||||
"torch._C._get_qengine",
|
||||
"torch._C._get_schema",
|
||||
"torch._C._get_sm_carveout_experimental",
|
||||
"torch._C._get_nested_int",
|
||||
"torch._C._get_tensor_metadata",
|
||||
"torch._C._get_tracing_state",
|
||||
|
|
@ -1158,7 +1157,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch._C._set_math_sdp_allow_fp16_bf16_reduction",
|
||||
"torch._C._set_sdp_use_mem_efficient",
|
||||
"torch._C._set_should_use_format_with_string_table",
|
||||
"torch._C._set_sm_carveout_experimental",
|
||||
"torch._C._set_storage_access_error_msg",
|
||||
"torch._C._set_tensor_metadata",
|
||||
"torch._C._set_tracing_state",
|
||||
|
|
|
|||
|
|
@ -2262,14 +2262,6 @@ Call this whenever a new thread is created in order to propagate values from
|
|||
return at::globalContext().getROCmFAPreferredBackend();
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_set_sm_carveout_experimental", [](std::optional<int32_t> val) {
|
||||
at::globalContext()._setSMCarveout_EXPERIMENTAL(val);
|
||||
});
|
||||
py_module.def("_get_sm_carveout_experimental", []() {
|
||||
return at::globalContext()._SMCarveout_EXPERIMENTAL();
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
"_construct_storage_from_data_pointer",
|
||||
[](int64_t data_ptr, c10::Device device, size_t size_bytes) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user