mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move magma utils to its own header (#73058)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73058
And keep it in cuda/linalg folder to make sure all MAGMA and CUSolver usage in codebase is restricted to linalg
Test Plan: Imported from OSS
Reviewed By: suo
Differential Revision: D34327978
Pulled By: malfet
fbshipit-source-id: dd4539a2a76bce68cced94fba943bf8a1155db1e
(cherry picked from commit 15d8c9b5dd)
This commit is contained in:
parent
5843fea94d
commit
87f882b056
|
|
@ -4,89 +4,9 @@
|
|||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/cuda/PinnedMemoryAllocator.h>
|
||||
|
||||
#if AT_MAGMA_ENABLED()
|
||||
#include <magma_types.h>
|
||||
#include <magma_v2.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
#if AT_MAGMA_ENABLED()
|
||||
|
||||
// RAII for a MAGMA Queue
|
||||
struct MAGMAQueue {
|
||||
|
||||
// Default constructor without a device will cause
|
||||
// destroying a queue which has not been initialized.
|
||||
MAGMAQueue() = delete;
|
||||
|
||||
// Constructor
|
||||
explicit MAGMAQueue(int64_t device_id) {
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
// Magma operations is numerically sensitive, so TF32 should be off
|
||||
// regardless of the global flag.
|
||||
TORCH_CUDABLAS_CHECK(cublasGetMathMode(handle, &original_math_mode));
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||
#endif
|
||||
magma_queue_create_from_cuda(
|
||||
device_id,
|
||||
at::cuda::getCurrentCUDAStream(),
|
||||
handle,
|
||||
at::cuda::getCurrentCUDASparseHandle(),
|
||||
&magma_queue_);
|
||||
}
|
||||
|
||||
// Getter
|
||||
magma_queue_t get_queue() const { return magma_queue_; }
|
||||
|
||||
// Destructor
|
||||
~MAGMAQueue() {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
// We've manually set the math mode to CUBLAS_DEFAULT_MATH, now we
|
||||
// should restore the original math mode back
|
||||
cublasHandle_t handle = magma_queue_get_cublas_handle(magma_queue_);
|
||||
cublasSetMathMode(handle, original_math_mode);
|
||||
#endif
|
||||
magma_queue_destroy(magma_queue_);
|
||||
}
|
||||
|
||||
private:
|
||||
magma_queue_t magma_queue_;
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
cublasMath_t original_math_mode;
|
||||
#endif
|
||||
};
|
||||
|
||||
static inline magma_int_t magma_int_cast(int64_t value, const char* varname) {
|
||||
auto result = static_cast<magma_int_t>(value);
|
||||
if (static_cast<int64_t>(result) != value) {
|
||||
AT_ERROR("magma: The value of ", varname, "(", (long long)value,
|
||||
") is too large to fit into a magma_int_t (", sizeof(magma_int_t), " bytes)");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// MAGMA functions that don't take a magma_queue_t aren't stream safe
|
||||
// Work around this by synchronizing with the default stream
|
||||
struct MagmaStreamSyncGuard {
|
||||
MagmaStreamSyncGuard() {
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (stream != at::cuda::getDefaultCUDAStream()) {
|
||||
at::cuda::stream_synchronize(stream);
|
||||
}
|
||||
}
|
||||
|
||||
~MagmaStreamSyncGuard() noexcept(false) {
|
||||
auto default_stream = at::cuda::getDefaultCUDAStream();
|
||||
if (at::cuda::getCurrentCUDAStream() != default_stream) {
|
||||
at::cuda::stream_synchronize(default_stream);
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
static inline int cuda_int_cast(int64_t value, const char* varname) {
|
||||
auto result = static_cast<int>(value);
|
||||
TORCH_CHECK(static_cast<int64_t>(result) == value,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
#include <ATen/native/LinearAlgebra.h>
|
||||
#include <ATen/native/BatchLinearAlgebra.h>
|
||||
#include <ATen/native/cuda/linalg/BatchLinearAlgebraLib.h>
|
||||
#include <ATen/native/cuda/linalg/MagmaUtils.h>
|
||||
#include <ATen/native/cpu/zmath.h>
|
||||
|
||||
#if AT_MAGMA_ENABLED()
|
||||
|
|
|
|||
88
aten/src/ATen/native/cuda/linalg/MagmaUtils.h
Normal file
88
aten/src/ATen/native/cuda/linalg/MagmaUtils.h
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
#pragma once
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
|
||||
#if AT_MAGMA_ENABLED()
|
||||
#include <magma_types.h>
|
||||
#include <magma_v2.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
#if AT_MAGMA_ENABLED()
|
||||
|
||||
// RAII for a MAGMA Queue
|
||||
struct MAGMAQueue {
|
||||
|
||||
// Default constructor without a device will cause
|
||||
// destroying a queue which has not been initialized.
|
||||
MAGMAQueue() = delete;
|
||||
|
||||
// Constructor
|
||||
explicit MAGMAQueue(int64_t device_id) {
|
||||
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
// Magma operations is numerically sensitive, so TF32 should be off
|
||||
// regardless of the global flag.
|
||||
TORCH_CUDABLAS_CHECK(cublasGetMathMode(handle, &original_math_mode));
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||
#endif
|
||||
magma_queue_create_from_cuda(
|
||||
device_id,
|
||||
at::cuda::getCurrentCUDAStream(),
|
||||
handle,
|
||||
at::cuda::getCurrentCUDASparseHandle(),
|
||||
&magma_queue_);
|
||||
}
|
||||
|
||||
// Getter
|
||||
magma_queue_t get_queue() const { return magma_queue_; }
|
||||
|
||||
// Destructor
|
||||
~MAGMAQueue() {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
// We've manually set the math mode to CUBLAS_DEFAULT_MATH, now we
|
||||
// should restore the original math mode back
|
||||
cublasHandle_t handle = magma_queue_get_cublas_handle(magma_queue_);
|
||||
cublasSetMathMode(handle, original_math_mode);
|
||||
#endif
|
||||
magma_queue_destroy(magma_queue_);
|
||||
}
|
||||
|
||||
private:
|
||||
magma_queue_t magma_queue_;
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
cublasMath_t original_math_mode;
|
||||
#endif
|
||||
};
|
||||
|
||||
static inline magma_int_t magma_int_cast(int64_t value, const char* varname) {
|
||||
auto result = static_cast<magma_int_t>(value);
|
||||
if (static_cast<int64_t>(result) != value) {
|
||||
AT_ERROR("magma: The value of ", varname, "(", (long long)value,
|
||||
") is too large to fit into a magma_int_t (", sizeof(magma_int_t), " bytes)");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// MAGMA functions that don't take a magma_queue_t aren't stream safe
|
||||
// Work around this by synchronizing with the default stream
|
||||
struct MagmaStreamSyncGuard {
|
||||
MagmaStreamSyncGuard() {
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (stream != at::cuda::getDefaultCUDAStream()) {
|
||||
at::cuda::stream_synchronize(stream);
|
||||
}
|
||||
}
|
||||
|
||||
~MagmaStreamSyncGuard() noexcept(false) {
|
||||
auto default_stream = at::cuda::getDefaultCUDAStream();
|
||||
if (at::cuda::getCurrentCUDAStream() != default_stream) {
|
||||
at::cuda::stream_synchronize(default_stream);
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
Loading…
Reference in New Issue
Block a user