mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] support experimental CU carveout (#149466)
Fixes #149280. Follow up to #147966, but now available for ROCm. Since hipblaslt does not support HIPBLASLT_MATMUL_DESC_CU_COUNT_TARGET, we instead create a hipStream that has a CU mask applied. We pass this masked stream to hipblaslt instead of pytorch's current stream. We ensure stream ordering between streams using hipEvents and stream synchronization. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149466 Approved by: https://github.com/malfet, https://github.com/atalman
This commit is contained in:
parent
0596323c35
commit
210632fae1
|
|
@ -17,6 +17,7 @@
|
||||||
#include <c10/core/ScalarType.h>
|
#include <c10/core/ScalarType.h>
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
#include <hipblaslt/hipblaslt-ext.hpp>
|
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||||
// until hipblas has an API to accept flags, we must use rocblas here
|
// until hipblas has an API to accept flags, we must use rocblas here
|
||||||
#include <hipblas/hipblas.h>
|
#include <hipblas/hipblas.h>
|
||||||
|
|
@ -188,6 +189,60 @@ uint32_t _getAlignment(uintptr_t address) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) {
|
||||||
|
static int32_t last_value = 0;
|
||||||
|
static hipStream_t stream;
|
||||||
|
if (last_value == 0) {
|
||||||
|
// first request, do nothing for this case
|
||||||
|
}
|
||||||
|
else if (last_value == value) {
|
||||||
|
// stream was created previously and value hasn't changed
|
||||||
|
return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// need a new stream and a previous stream exists, delete it
|
||||||
|
AT_CUDA_CHECK(hipStreamDestroy(stream));
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we got here, we need to create a new stream
|
||||||
|
int32_t CUs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||||
|
// how many uint32_t do we need to cover all CUs, fill bitmask with 1
|
||||||
|
uint32_t mask_size = static_cast<uint32_t>((CUs + 32 - 1) / 32);
|
||||||
|
std::vector<uint32_t> mask(mask_size, uint32_t{0xffffffff});
|
||||||
|
// starting from lowest order bits, in 32-bit chunks
|
||||||
|
// set bits to 0 based on how many CUs to carve out
|
||||||
|
int32_t full_shifts = value / 32;
|
||||||
|
int32_t remainder = value % 32;
|
||||||
|
for (int32_t i = 0; i < full_shifts; i++) {
|
||||||
|
mask[i] = uint32_t{0x00000000};
|
||||||
|
}
|
||||||
|
mask[full_shifts] = uint32_t{0xffffffff} << remainder;
|
||||||
|
|
||||||
|
// finally, create masked stream
|
||||||
|
AT_CUDA_CHECK(hipExtStreamCreateWithCUMask(&stream, mask_size, &mask[0]));
|
||||||
|
|
||||||
|
last_value = value;
|
||||||
|
return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device());
|
||||||
|
}
|
||||||
|
|
||||||
|
static void _syncCurrentWithCarveoutStream(hipStream_t stream, bool presync) {
|
||||||
|
hipEvent_t event;
|
||||||
|
AT_CUDA_CHECK(hipEventCreateWithFlags(&event, hipEventDisableTiming));
|
||||||
|
|
||||||
|
auto current_stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
if (presync) {
|
||||||
|
AT_CUDA_CHECK(hipEventRecord(event, current_stream));
|
||||||
|
AT_CUDA_CHECK(hipStreamWaitEvent(stream, event, 0));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
AT_CUDA_CHECK(hipEventRecord(event, stream));
|
||||||
|
AT_CUDA_CHECK(hipStreamWaitEvent(current_stream, event, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
struct CublasLtWorkspace {
|
struct CublasLtWorkspace {
|
||||||
CublasLtWorkspace() {
|
CublasLtWorkspace() {
|
||||||
size = at::cuda::getCUDABlasLtWorkspaceSize();
|
size = at::cuda::getCUDABlasLtWorkspaceSize();
|
||||||
|
|
@ -390,6 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||||
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
|
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||||
computeDesc.setAttribute<int32_t>(
|
computeDesc.setAttribute<int32_t>(
|
||||||
|
|
@ -397,6 +453,12 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||||
|
stream = _getCarveoutStream(
|
||||||
|
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||||
|
_syncCurrentWithCarveoutStream(stream, true);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
CuBlasLtMatrixLayout Adesc(abType, m, k, lda, opa == CUBLAS_OP_T);
|
CuBlasLtMatrixLayout Adesc(abType, m, k, lda, opa == CUBLAS_OP_T);
|
||||||
CuBlasLtMatrixLayout Bdesc(abType, k, n, ldb, opb == CUBLAS_OP_T);
|
CuBlasLtMatrixLayout Bdesc(abType, k, n, ldb, opb == CUBLAS_OP_T);
|
||||||
|
|
@ -459,7 +521,12 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||||
&heuristicResult.algo,
|
&heuristicResult.algo,
|
||||||
ltworkspace.ptr,
|
ltworkspace.ptr,
|
||||||
ltworkspace.size,
|
ltworkspace.size,
|
||||||
at::cuda::getCurrentCUDAStream());
|
stream);
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||||
|
_syncCurrentWithCarveoutStream(stream, false);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
|
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
|
||||||
TORCH_WARN(
|
TORCH_WARN(
|
||||||
|
|
@ -1557,6 +1624,7 @@ bool gemm_and_bias(
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
|
||||||
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
|
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||||
computeDesc.setAttribute<int32_t>(
|
computeDesc.setAttribute<int32_t>(
|
||||||
|
|
@ -1564,6 +1632,12 @@ bool gemm_and_bias(
|
||||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||||
|
stream = _getCarveoutStream(
|
||||||
|
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||||
|
_syncCurrentWithCarveoutStream(stream, true);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||||
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
|
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
|
||||||
|
|
@ -1632,7 +1706,12 @@ bool gemm_and_bias(
|
||||||
&heuristicResult.algo,
|
&heuristicResult.algo,
|
||||||
ltworkspace.ptr,
|
ltworkspace.ptr,
|
||||||
ltworkspace.size,
|
ltworkspace.size,
|
||||||
at::cuda::getCurrentCUDAStream());
|
stream);
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||||
|
_syncCurrentWithCarveoutStream(stream, false);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
|
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
|
||||||
TORCH_WARN(
|
TORCH_WARN(
|
||||||
|
|
@ -1818,6 +1897,7 @@ void scaled_gemm(
|
||||||
if (result_scale_ptr != nullptr) {
|
if (result_scale_ptr != nullptr) {
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
||||||
}
|
}
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||||
computeDesc.setAttribute<int32_t>(
|
computeDesc.setAttribute<int32_t>(
|
||||||
|
|
@ -1825,6 +1905,12 @@ void scaled_gemm(
|
||||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||||
|
stream = _getCarveoutStream(
|
||||||
|
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||||
|
_syncCurrentWithCarveoutStream(stream, true);
|
||||||
|
}
|
||||||
#endif // ifndef USE_ROCM
|
#endif // ifndef USE_ROCM
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
|
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
|
||||||
|
|
@ -1870,7 +1956,6 @@ void scaled_gemm(
|
||||||
#endif // if CUDA_VERSION >= 12090
|
#endif // if CUDA_VERSION >= 12090
|
||||||
}
|
}
|
||||||
|
|
||||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
|
||||||
CuBlasLtMatmulPreference preference;
|
CuBlasLtMatmulPreference preference;
|
||||||
auto ltworkspace = CublasLtWorkspace();
|
auto ltworkspace = CublasLtWorkspace();
|
||||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
|
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
|
||||||
|
|
@ -1957,6 +2042,11 @@ void scaled_gemm(
|
||||||
ltworkspace.ptr,
|
ltworkspace.ptr,
|
||||||
ltworkspace.size,
|
ltworkspace.size,
|
||||||
stream);
|
stream);
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||||
|
_syncCurrentWithCarveoutStream(stream, false);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
cublasStatus == CUBLAS_STATUS_SUCCESS,
|
cublasStatus == CUBLAS_STATUS_SUCCESS,
|
||||||
"CUDA error: ",
|
"CUDA error: ",
|
||||||
|
|
@ -2010,6 +2100,7 @@ void int8_gemm(
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
|
||||||
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
|
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||||
computeDesc.setAttribute<int32_t>(
|
computeDesc.setAttribute<int32_t>(
|
||||||
|
|
@ -2017,6 +2108,12 @@ void int8_gemm(
|
||||||
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
|
||||||
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||||
|
stream = _getCarveoutStream(
|
||||||
|
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
|
||||||
|
_syncCurrentWithCarveoutStream(stream, true);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
|
CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
|
||||||
|
|
@ -2078,7 +2175,7 @@ void int8_gemm(
|
||||||
#else
|
#else
|
||||||
0,
|
0,
|
||||||
#endif
|
#endif
|
||||||
at::cuda::getCurrentCUDAStream());
|
stream);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
cublasStatus == CUBLAS_STATUS_SUCCESS,
|
cublasStatus == CUBLAS_STATUS_SUCCESS,
|
||||||
"CUDA error: ",
|
"CUDA error: ",
|
||||||
|
|
@ -2107,6 +2204,11 @@ void int8_gemm(
|
||||||
computeType,
|
computeType,
|
||||||
" scaleType ",
|
" scaleType ",
|
||||||
scaleType);
|
scaleType);
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
|
||||||
|
_syncCurrentWithCarveoutStream(stream, false);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
|
|
||||||
|
|
@ -1358,7 +1358,6 @@ class TestFP8Matmul(TestCase):
|
||||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
|
|
||||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling")
|
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling")
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
@unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet")
|
@unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet")
|
||||||
|
|
@ -1387,15 +1386,31 @@ class TestFP8Matmul(TestCase):
|
||||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
prof.export_chrome_trace(f.name)
|
prof.export_chrome_trace(f.name)
|
||||||
no_carveout, carveout_0, carveout_66, no_carveout_again = [
|
if torch.version.hip:
|
||||||
math.prod(evt.get("args", {}).get("grid", []))
|
events = [evt for evt in json.load(open(f.name))["traceEvents"] if evt.get("cat", "") == "kernel"]
|
||||||
for evt in json.load(open(f.name))["traceEvents"]
|
# events were returned out of order; need to be sorted on "ts" timestamp
|
||||||
if evt.get("cat", "") == "kernel"
|
events = sorted(events, key=lambda x: x['ts'])
|
||||||
]
|
# ROCm carveout is invisible except for kernels running slower on fewer CUs
|
||||||
|
no_carveout, carveout_0, carveout_66, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
|
||||||
|
self.assertTrue(no_carveout < carveout_66)
|
||||||
|
self.assertTrue(carveout_0 < carveout_66)
|
||||||
|
self.assertTrue(no_carveout_again < carveout_66)
|
||||||
|
# ROCm carveout will create new streams when enabled, and go back to the original stream when disabled
|
||||||
|
no_carveout, carveout_0, carveout_66, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
|
||||||
|
self.assertTrue(no_carveout == no_carveout_again)
|
||||||
|
self.assertTrue(no_carveout != carveout_0)
|
||||||
|
self.assertTrue(no_carveout != carveout_66)
|
||||||
|
self.assertTrue(carveout_0 != carveout_66)
|
||||||
|
else:
|
||||||
|
no_carveout, carveout_0, carveout_66, no_carveout_again = [
|
||||||
|
math.prod(evt.get("args", {}).get("grid", []))
|
||||||
|
for evt in json.load(open(f.name))["traceEvents"]
|
||||||
|
if evt.get("cat", "") == "kernel"
|
||||||
|
]
|
||||||
|
|
||||||
self.assertEqual(no_carveout, no_carveout_again)
|
self.assertEqual(no_carveout, no_carveout_again)
|
||||||
self.assertNotEqual(no_carveout, carveout_66)
|
self.assertNotEqual(no_carveout, carveout_66)
|
||||||
self.assertNotEqual(carveout_66, carveout_0)
|
self.assertNotEqual(carveout_66, carveout_0)
|
||||||
|
|
||||||
def test_pack_uint4(self):
|
def test_pack_uint4(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user