mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] use correct workspace for hipblaslt, silence warning (#150227)
Follow up to #145130. That PR caused a warning on ROCm the first time hipblaslt was called for any workload, always. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/150227 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
parent
51f0403f46
commit
c158eac0de
|
|
@ -222,19 +222,35 @@ static size_t _getWorkspaceSize() {
|
|||
return workspace_size;
|
||||
}
|
||||
|
||||
static at::DataPtr _getNewWorkspace() {
|
||||
return c10::cuda::CUDACachingAllocator::get()->allocate(_getWorkspaceSize());
|
||||
}
|
||||
|
||||
// See Note [hipblaslt handles].
|
||||
// ROCm's hipblas and hipblaslt do not share handles, unlike with CUDA.
|
||||
// Using getCurrentCUDABlasLtHandle is on purpose. For CUDA it's the same as
|
||||
// getCurrentCUDABlasHandle, but for ROCm it's a unique handle.
|
||||
void* _getWorkspaceWithoutHandle() {
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
cublasLtHandle_t handle = at::cuda::getCurrentCUDABlasLtHandle();
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
|
||||
#ifdef USE_ROCM
|
||||
// The first call to _getWorkspaceWithoutHandle could be empty, so allocate and store.
|
||||
if (workspace_it == at::cuda::cublas_handle_stream_to_workspace().end()) {
|
||||
workspace_it = at::cuda::cublas_handle_stream_to_workspace().insert(workspace_it, {key, _getNewWorkspace()});
|
||||
}
|
||||
#else
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
|
||||
#endif
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
|
||||
void* _getWorkspace(size_t& workspaceSize) {
|
||||
// #ifdef (defined(USE_ROCM) || defined(FBCODE_CAFFE2))
|
||||
workspaceSize = _getWorkspaceSize();
|
||||
#ifndef USE_ROCM
|
||||
// See Note [hipblaslt handles].
|
||||
auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize();
|
||||
if (cublasWorkspaceSize < workspaceSize) {
|
||||
TORCH_WARN_ONCE("Requested CUBLASLT workspace size of ", workspaceSize,
|
||||
|
|
@ -245,9 +261,7 @@ void* _getWorkspace(size_t& workspaceSize) {
|
|||
" size will be limited to the CUBLAS workspace size.");
|
||||
workspaceSize = cublasWorkspaceSize;
|
||||
}
|
||||
// #else
|
||||
// workspaceSize = at::cuda::getChosenWorkspaceSize();
|
||||
// #endif
|
||||
#endif
|
||||
auto workspace_ptr = _getWorkspaceWithoutHandle();
|
||||
return workspace_ptr;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user