mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[ROCm] support experimental CU carveout (#149466)
Fixes #149280. Follow up to #147966, but now available for ROCm. Since hipblaslt does not support HIPBLASLT_MATMUL_DESC_CU_COUNT_TARGET, we instead create a hipStream that has a CU mask applied. We pass this masked stream to hipblaslt instead of pytorch's current stream. We ensure stream ordering between streams using hipEvents and stream synchronization. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149466 Approved by: https://github.com/malfet, https://github.com/atalman
This commit is contained in:
parent
0596323c35
commit
210632fae1
|
|
@ -17,6 +17,7 @@
|
|||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||
// until hipblas has an API to accept flags, we must use rocblas here
|
||||
#include <hipblas/hipblas.h>
|
||||
|
|
@ -188,6 +189,60 @@ uint32_t _getAlignment(uintptr_t address) {
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) {
|
||||
static int32_t last_value = 0;
|
||||
static hipStream_t stream;
|
||||
if (last_value == 0) {
|
||||
// first request, do nothing for this case
|
||||
}
|
||||
else if (last_value == value) {
|
||||
// stream was created previously and value hasn't changed
|
||||
return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device());
|
||||
}
|
||||
else {
|
||||
// need a new stream and a previous stream exists, delete it
|
||||
AT_CUDA_CHECK(hipStreamDestroy(stream));
|
||||
}
|
||||
|
||||
// if we got here, we need to create a new stream
|
||||
int32_t CUs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
// how many uint32_t do we need to cover all CUs, fill bitmask with 1
|
||||
uint32_t mask_size = static_cast<uint32_t>((CUs + 32 - 1) / 32);
|
||||
std::vector<uint32_t> mask(mask_size, uint32_t{0xffffffff});
|
||||
// starting from lowest order bits, in 32-bit chunks
|
||||
// set bits to 0 based on how many CUs to carve out
|
||||
int32_t full_shifts = value / 32;
|
||||
int32_t remainder = value % 32;
|
||||
for (int32_t i = 0; i < full_shifts; i++) {
|
||||
mask[i] = uint32_t{0x00000000};
|
||||
}
|
||||
mask[full_shifts] = uint32_t{0xffffffff} << remainder;
|
||||
|
||||
// finally, create masked stream
|
||||
AT_CUDA_CHECK(hipExtStreamCreateWithCUMask(&stream, mask_size, &mask[0]));
|
||||
|
||||
last_value = value;
|
||||
return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device());
|
||||
}
|
||||
|
||||
static void _syncCurrentWithCarveoutStream(hipStream_t stream, bool presync) {
|
||||
hipEvent_t event;
|
||||
AT_CUDA_CHECK(hipEventCreateWithFlags(&event, hipEventDisableTiming));
|
||||
|
||||
auto current_stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (presync) {
|
||||
AT_CUDA_CHECK(hipEventRecord(event, current_stream));
|
||||
AT_CUDA_CHECK(hipStreamWaitEvent(stream, event, 0));
|
||||
}
|
||||
else {
|
||||
AT_CUDA_CHECK(hipEventRecord(event, stream));
|
||||
AT_CUDA_CHECK(hipStreamWaitEvent(current_stream, event, 0));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
struct CublasLtWorkspace {
|
||||
CublasLtWorkspace() {
|
||||
size = at::cuda::getCUDABlasLtWorkspaceSize();
|
||||
|
|
@ -390,6 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
computeDesc.setAttribute<int32_t>(
|
||||
|
|
@ -397,6 +453,12 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
}
|
||||
#else
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
stream = _getCarveoutStream(
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
_syncCurrentWithCarveoutStream(stream, true);
|
||||
}
|
||||
#endif
|
||||
CuBlasLtMatrixLayout Adesc(abType, m, k, lda, opa == CUBLAS_OP_T);
|
||||
CuBlasLtMatrixLayout Bdesc(abType, k, n, ldb, opb == CUBLAS_OP_T);
|
||||
|
|
@ -459,7 +521,12 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||
&heuristicResult.algo,
|
||||
ltworkspace.ptr,
|
||||
ltworkspace.size,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
stream);
|
||||
#ifdef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
_syncCurrentWithCarveoutStream(stream, false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
|
||||
TORCH_WARN(
|
||||
|
|
@ -1557,6 +1624,7 @@ bool gemm_and_bias(
|
|||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
|
||||
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
computeDesc.setAttribute<int32_t>(
|
||||
|
|
@ -1564,6 +1632,12 @@ bool gemm_and_bias(
|
|||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
}
|
||||
#else
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
stream = _getCarveoutStream(
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
_syncCurrentWithCarveoutStream(stream, true);
|
||||
}
|
||||
#endif
|
||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
|
||||
|
|
@ -1632,7 +1706,12 @@ bool gemm_and_bias(
|
|||
&heuristicResult.algo,
|
||||
ltworkspace.ptr,
|
||||
ltworkspace.size,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
stream);
|
||||
#ifdef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
_syncCurrentWithCarveoutStream(stream, false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
|
||||
TORCH_WARN(
|
||||
|
|
@ -1818,6 +1897,7 @@ void scaled_gemm(
|
|||
if (result_scale_ptr != nullptr) {
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
||||
}
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
computeDesc.setAttribute<int32_t>(
|
||||
|
|
@ -1825,6 +1905,12 @@ void scaled_gemm(
|
|||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
}
|
||||
#else
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
stream = _getCarveoutStream(
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
_syncCurrentWithCarveoutStream(stream, true);
|
||||
}
|
||||
#endif // ifndef USE_ROCM
|
||||
#ifndef USE_ROCM
|
||||
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
|
||||
|
|
@ -1870,7 +1956,6 @@ void scaled_gemm(
|
|||
#endif // if CUDA_VERSION >= 12090
|
||||
}
|
||||
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
CuBlasLtMatmulPreference preference;
|
||||
auto ltworkspace = CublasLtWorkspace();
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
|
||||
|
|
@ -1957,6 +2042,11 @@ void scaled_gemm(
|
|||
ltworkspace.ptr,
|
||||
ltworkspace.size,
|
||||
stream);
|
||||
#ifdef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
_syncCurrentWithCarveoutStream(stream, false);
|
||||
}
|
||||
#endif
|
||||
TORCH_CHECK(
|
||||
cublasStatus == CUBLAS_STATUS_SUCCESS,
|
||||
"CUDA error: ",
|
||||
|
|
@ -2010,6 +2100,7 @@ void int8_gemm(
|
|||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
|
||||
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
#ifndef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
computeDesc.setAttribute<int32_t>(
|
||||
|
|
@ -2017,6 +2108,12 @@ void int8_gemm(
|
|||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
}
|
||||
#else
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
stream = _getCarveoutStream(
|
||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||
_syncCurrentWithCarveoutStream(stream, true);
|
||||
}
|
||||
#endif
|
||||
|
||||
CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
|
||||
|
|
@ -2078,7 +2175,7 @@ void int8_gemm(
|
|||
#else
|
||||
0,
|
||||
#endif
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
stream);
|
||||
TORCH_CHECK(
|
||||
cublasStatus == CUBLAS_STATUS_SUCCESS,
|
||||
"CUDA error: ",
|
||||
|
|
@ -2107,6 +2204,11 @@ void int8_gemm(
|
|||
computeType,
|
||||
" scaleType ",
|
||||
scaleType);
|
||||
#ifdef USE_ROCM
|
||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||
_syncCurrentWithCarveoutStream(stream, false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
|
|
|
|||
|
|
@ -1358,7 +1358,6 @@ class TestFP8Matmul(TestCase):
|
|||
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet")
|
||||
|
|
@ -1387,15 +1386,31 @@ class TestFP8Matmul(TestCase):
|
|||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
|
||||
prof.export_chrome_trace(f.name)
|
||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [
|
||||
math.prod(evt.get("args", {}).get("grid", []))
|
||||
for evt in json.load(open(f.name))["traceEvents"]
|
||||
if evt.get("cat", "") == "kernel"
|
||||
]
|
||||
if torch.version.hip:
|
||||
events = [evt for evt in json.load(open(f.name))["traceEvents"] if evt.get("cat", "") == "kernel"]
|
||||
# events were returned out of order; need to be sorted on "ts" timestamp
|
||||
events = sorted(events, key=lambda x: x['ts'])
|
||||
# ROCm carveout is invisible except for kernels running slower on fewer CUs
|
||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
|
||||
self.assertTrue(no_carveout < carveout_66)
|
||||
self.assertTrue(carveout_0 < carveout_66)
|
||||
self.assertTrue(no_carveout_again < carveout_66)
|
||||
# ROCm carveout will create new streams when enabled, and go back to the original stream when disabled
|
||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
|
||||
self.assertTrue(no_carveout == no_carveout_again)
|
||||
self.assertTrue(no_carveout != carveout_0)
|
||||
self.assertTrue(no_carveout != carveout_66)
|
||||
self.assertTrue(carveout_0 != carveout_66)
|
||||
else:
|
||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [
|
||||
math.prod(evt.get("args", {}).get("grid", []))
|
||||
for evt in json.load(open(f.name))["traceEvents"]
|
||||
if evt.get("cat", "") == "kernel"
|
||||
]
|
||||
|
||||
self.assertEqual(no_carveout, no_carveout_again)
|
||||
self.assertNotEqual(no_carveout, carveout_66)
|
||||
self.assertNotEqual(carveout_66, carveout_0)
|
||||
self.assertEqual(no_carveout, no_carveout_again)
|
||||
self.assertNotEqual(no_carveout, carveout_66)
|
||||
self.assertNotEqual(carveout_66, carveout_0)
|
||||
|
||||
def test_pack_uint4(self):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user