[ROCm] use correct workspace for hipblaslt, silence warning (#150227)

Follow up to #145130. That PR caused a warning on ROCm the first time hipblaslt was called for any workload, always.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150227
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Ethan Wee 2025-03-31 09:49:40 +00:00 committed by PyTorch MergeBot
parent 51f0403f46
commit c158eac0de

View File

@ -222,19 +222,35 @@ static size_t _getWorkspaceSize() {
return workspace_size; return workspace_size;
} }
static at::DataPtr _getNewWorkspace() {
return c10::cuda::CUDACachingAllocator::get()->allocate(_getWorkspaceSize());
}
// See Note [hipblaslt handles].
// ROCm's hipblas and hipblaslt do not share handles, unlike with CUDA.
// Using getCurrentCUDABlasLtHandle is on purpose. For CUDA it's the same as
// getCurrentCUDABlasHandle, but for ROCm it's a unique handle.
void* _getWorkspaceWithoutHandle() { void* _getWorkspaceWithoutHandle() {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasLtHandle_t handle = at::cuda::getCurrentCUDABlasLtHandle();
auto stream = c10::cuda::getCurrentCUDAStream(); auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream; cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream)); auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key); auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
#ifdef USE_ROCM
// The first call to _getWorkspaceWithoutHandle could be empty, so allocate and store.
if (workspace_it == at::cuda::cublas_handle_stream_to_workspace().end()) {
workspace_it = at::cuda::cublas_handle_stream_to_workspace().insert(workspace_it, {key, _getNewWorkspace()});
}
#else
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end()); TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
#endif
return workspace_it->second.mutable_get(); return workspace_it->second.mutable_get();
} }
void* _getWorkspace(size_t& workspaceSize) { void* _getWorkspace(size_t& workspaceSize) {
// #ifdef (defined(USE_ROCM) || defined(FBCODE_CAFFE2))
workspaceSize = _getWorkspaceSize(); workspaceSize = _getWorkspaceSize();
#ifndef USE_ROCM
// See Note [hipblaslt handles].
auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize(); auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize();
if (cublasWorkspaceSize < workspaceSize) { if (cublasWorkspaceSize < workspaceSize) {
TORCH_WARN_ONCE("Requested CUBLASLT workspace size of ", workspaceSize, TORCH_WARN_ONCE("Requested CUBLASLT workspace size of ", workspaceSize,
@ -245,9 +261,7 @@ void* _getWorkspace(size_t& workspaceSize) {
" size will be limited to the CUBLAS workspace size."); " size will be limited to the CUBLAS workspace size.");
workspaceSize = cublasWorkspaceSize; workspaceSize = cublasWorkspaceSize;
} }
// #else #endif
// workspaceSize = at::cuda::getChosenWorkspaceSize();
// #endif
auto workspace_ptr = _getWorkspaceWithoutHandle(); auto workspace_ptr = _getWorkspaceWithoutHandle();
return workspace_ptr; return workspace_ptr;
} }