[ROCm] support experimental CU carveout (#149466)

Fixes #149280.  Follow up to #147966, but now available for ROCm.

Since hipblaslt does not support HIPBLASLT_MATMUL_DESC_CU_COUNT_TARGET, we instead create a hipStream that has a CU mask applied.  We pass this masked stream to hipblaslt instead of pytorch's current stream.  We ensure stream ordering between streams using hipEvents and stream synchronization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149466
Approved by: https://github.com/malfet, https://github.com/atalman
This commit is contained in:
Jeff Daily 2025-07-01 08:54:48 +00:00 committed by PyTorch MergeBot
parent 0596323c35
commit 210632fae1
2 changed files with 130 additions and 13 deletions

View File

@ -17,6 +17,7 @@
#include <c10/core/ScalarType.h> #include <c10/core/ScalarType.h>
#ifdef USE_ROCM #ifdef USE_ROCM
#include <c10/cuda/CUDAStream.h>
#include <hipblaslt/hipblaslt-ext.hpp> #include <hipblaslt/hipblaslt-ext.hpp>
// until hipblas has an API to accept flags, we must use rocblas here // until hipblas has an API to accept flags, we must use rocblas here
#include <hipblas/hipblas.h> #include <hipblas/hipblas.h>
@ -188,6 +189,60 @@ uint32_t _getAlignment(uintptr_t address) {
} }
#endif #endif
#ifdef USE_ROCM
static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) {
static int32_t last_value = 0;
static hipStream_t stream;
if (last_value == 0) {
// first request, do nothing for this case
}
else if (last_value == value) {
// stream was created previously and value hasn't changed
return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device());
}
else {
// need a new stream and a previous stream exists, delete it
AT_CUDA_CHECK(hipStreamDestroy(stream));
}
// if we got here, we need to create a new stream
int32_t CUs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
// how many uint32_t do we need to cover all CUs, fill bitmask with 1
uint32_t mask_size = static_cast<uint32_t>((CUs + 32 - 1) / 32);
std::vector<uint32_t> mask(mask_size, uint32_t{0xffffffff});
// starting from lowest order bits, in 32-bit chunks
// set bits to 0 based on how many CUs to carve out
int32_t full_shifts = value / 32;
int32_t remainder = value % 32;
for (int32_t i = 0; i < full_shifts; i++) {
mask[i] = uint32_t{0x00000000};
}
mask[full_shifts] = uint32_t{0xffffffff} << remainder;
// finally, create masked stream
AT_CUDA_CHECK(hipExtStreamCreateWithCUMask(&stream, mask_size, &mask[0]));
last_value = value;
return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device());
}
static void _syncCurrentWithCarveoutStream(hipStream_t stream, bool presync) {
hipEvent_t event;
AT_CUDA_CHECK(hipEventCreateWithFlags(&event, hipEventDisableTiming));
auto current_stream = at::cuda::getCurrentCUDAStream();
if (presync) {
AT_CUDA_CHECK(hipEventRecord(event, current_stream));
AT_CUDA_CHECK(hipStreamWaitEvent(stream, event, 0));
}
else {
AT_CUDA_CHECK(hipEventRecord(event, stream));
AT_CUDA_CHECK(hipStreamWaitEvent(current_stream, event, 0));
}
}
#endif
struct CublasLtWorkspace { struct CublasLtWorkspace {
CublasLtWorkspace() { CublasLtWorkspace() {
size = at::cuda::getCUDABlasLtWorkspaceSize(); size = at::cuda::getCUDABlasLtWorkspaceSize();
@ -390,6 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
auto stream = at::cuda::getCurrentCUDAStream();
#ifndef USE_ROCM #ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>( computeDesc.setAttribute<int32_t>(
@ -397,6 +453,12 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value()); at::globalContext()._SMCarveout_EXPERIMENTAL().value());
} }
#else
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
stream = _getCarveoutStream(
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
_syncCurrentWithCarveoutStream(stream, true);
}
#endif #endif
CuBlasLtMatrixLayout Adesc(abType, m, k, lda, opa == CUBLAS_OP_T); CuBlasLtMatrixLayout Adesc(abType, m, k, lda, opa == CUBLAS_OP_T);
CuBlasLtMatrixLayout Bdesc(abType, k, n, ldb, opb == CUBLAS_OP_T); CuBlasLtMatrixLayout Bdesc(abType, k, n, ldb, opb == CUBLAS_OP_T);
@ -459,7 +521,12 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
&heuristicResult.algo, &heuristicResult.algo,
ltworkspace.ptr, ltworkspace.ptr,
ltworkspace.size, ltworkspace.size,
at::cuda::getCurrentCUDAStream()); stream);
#ifdef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
_syncCurrentWithCarveoutStream(stream, false);
}
#endif
} }
if (cublasStatus != CUBLAS_STATUS_SUCCESS) { if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
TORCH_WARN( TORCH_WARN(
@ -1557,6 +1624,7 @@ bool gemm_and_bias(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
auto stream = at::cuda::getCurrentCUDAStream();
#ifndef USE_ROCM #ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>( computeDesc.setAttribute<int32_t>(
@ -1564,6 +1632,12 @@ bool gemm_and_bias(
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value()); at::globalContext()._SMCarveout_EXPERIMENTAL().value());
} }
#else
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
stream = _getCarveoutStream(
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
_syncCurrentWithCarveoutStream(stream, true);
}
#endif #endif
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
if (activation == GEMMAndBiasActivationEpilogue::RELU) { if (activation == GEMMAndBiasActivationEpilogue::RELU) {
@ -1632,7 +1706,12 @@ bool gemm_and_bias(
&heuristicResult.algo, &heuristicResult.algo,
ltworkspace.ptr, ltworkspace.ptr,
ltworkspace.size, ltworkspace.size,
at::cuda::getCurrentCUDAStream()); stream);
#ifdef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
_syncCurrentWithCarveoutStream(stream, false);
}
#endif
} }
if (cublasStatus != CUBLAS_STATUS_SUCCESS) { if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
TORCH_WARN( TORCH_WARN(
@ -1818,6 +1897,7 @@ void scaled_gemm(
if (result_scale_ptr != nullptr) { if (result_scale_ptr != nullptr) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
} }
auto stream = at::cuda::getCurrentCUDAStream();
#ifndef USE_ROCM #ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>( computeDesc.setAttribute<int32_t>(
@ -1825,6 +1905,12 @@ void scaled_gemm(
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value()); at::globalContext()._SMCarveout_EXPERIMENTAL().value());
} }
#else
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
stream = _getCarveoutStream(
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
_syncCurrentWithCarveoutStream(stream, true);
}
#endif // ifndef USE_ROCM #endif // ifndef USE_ROCM
#ifndef USE_ROCM #ifndef USE_ROCM
const int8_t fastAccuMode = use_fast_accum ? 1 : 0; const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
@ -1870,7 +1956,6 @@ void scaled_gemm(
#endif // if CUDA_VERSION >= 12090 #endif // if CUDA_VERSION >= 12090
} }
auto stream = c10::cuda::getCurrentCUDAStream();
CuBlasLtMatmulPreference preference; CuBlasLtMatmulPreference preference;
auto ltworkspace = CublasLtWorkspace(); auto ltworkspace = CublasLtWorkspace();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size); preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
@ -1957,6 +2042,11 @@ void scaled_gemm(
ltworkspace.ptr, ltworkspace.ptr,
ltworkspace.size, ltworkspace.size,
stream); stream);
#ifdef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
_syncCurrentWithCarveoutStream(stream, false);
}
#endif
TORCH_CHECK( TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS, cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ", "CUDA error: ",
@ -2010,6 +2100,7 @@ void int8_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
auto stream = at::cuda::getCurrentCUDAStream();
#ifndef USE_ROCM #ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>( computeDesc.setAttribute<int32_t>(
@ -2017,6 +2108,12 @@ void int8_gemm(
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value()); at::globalContext()._SMCarveout_EXPERIMENTAL().value());
} }
#else
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
stream = _getCarveoutStream(
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
_syncCurrentWithCarveoutStream(stream, true);
}
#endif #endif
CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1); CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
@ -2078,7 +2175,7 @@ void int8_gemm(
#else #else
0, 0,
#endif #endif
at::cuda::getCurrentCUDAStream()); stream);
TORCH_CHECK( TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS, cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ", "CUDA error: ",
@ -2107,6 +2204,11 @@ void int8_gemm(
computeType, computeType,
" scaleType ", " scaleType ",
scaleType); scaleType);
#ifdef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
_syncCurrentWithCarveoutStream(stream, false);
}
#endif
} }
template <> template <>

View File

@ -1358,7 +1358,6 @@ class TestFP8Matmul(TestCase):
self.assertEqual(out_dtype, out_fp8.dtype) self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float)) self.assertEqual(out_fp32, out_fp8.to(torch.float))
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet") @unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet")
@ -1387,15 +1386,31 @@ class TestFP8Matmul(TestCase):
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16) torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
prof.export_chrome_trace(f.name) prof.export_chrome_trace(f.name)
no_carveout, carveout_0, carveout_66, no_carveout_again = [ if torch.version.hip:
math.prod(evt.get("args", {}).get("grid", [])) events = [evt for evt in json.load(open(f.name))["traceEvents"] if evt.get("cat", "") == "kernel"]
for evt in json.load(open(f.name))["traceEvents"] # events were returned out of order; need to be sorted on "ts" timestamp
if evt.get("cat", "") == "kernel" events = sorted(events, key=lambda x: x['ts'])
] # ROCm carveout is invisible except for kernels running slower on fewer CUs
no_carveout, carveout_0, carveout_66, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
self.assertTrue(no_carveout < carveout_66)
self.assertTrue(carveout_0 < carveout_66)
self.assertTrue(no_carveout_again < carveout_66)
# ROCm carveout will create new streams when enabled, and go back to the original stream when disabled
no_carveout, carveout_0, carveout_66, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
self.assertTrue(no_carveout == no_carveout_again)
self.assertTrue(no_carveout != carveout_0)
self.assertTrue(no_carveout != carveout_66)
self.assertTrue(carveout_0 != carveout_66)
else:
no_carveout, carveout_0, carveout_66, no_carveout_again = [
math.prod(evt.get("args", {}).get("grid", []))
for evt in json.load(open(f.name))["traceEvents"]
if evt.get("cat", "") == "kernel"
]
self.assertEqual(no_carveout, no_carveout_again) self.assertEqual(no_carveout, no_carveout_again)
self.assertNotEqual(no_carveout, carveout_66) self.assertNotEqual(no_carveout, carveout_66)
self.assertNotEqual(carveout_66, carveout_0) self.assertNotEqual(carveout_66, carveout_0)
def test_pack_uint4(self): def test_pack_uint4(self):
""" """