[Reland] Add BUILD_LAZY_CUDA_LINALG option (#73447)

Summary:
When enabled, it will generate `torch_cuda_linalg` library, which would depend on cusolve and magma and registers dynamic bindings to it from LinearAlgebraStubs

Avoid symbol clashes that can result in infinite recursion by moving all symbols in the library to its own namespace.

Add checks that should prevent calling self in recursion to `LinearAlgebraStubs.cpp`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73447

Reviewed By: albanD

Differential Revision: D34538827

Pulled By: malfet

fbshipit-source-id: f2535b471d3524768a84b2e169b6aa24c26c03bf
(cherry picked from commit 4ec24b079c861c1122f0fa86e280b977c3c2f7ac)
This commit is contained in:
Nikita Shulga 2022-03-01 12:58:49 -08:00 committed by PyTorch MergeBot
parent ebf0ca3307
commit 6302cdb9bc
11 changed files with 312 additions and 32 deletions

View File

@ -189,6 +189,8 @@ option(USE_CUDA "Use CUDA" ON)
cmake_dependent_option( cmake_dependent_option(
BUILD_SPLIT_CUDA "Split torch_cuda library into torch_cuda_cu and torch_cuda_cpp" OFF BUILD_SPLIT_CUDA "Split torch_cuda library into torch_cuda_cu and torch_cuda_cpp" OFF
"USE_CUDA AND NOT CUDA_SEPARABLE_COMPILATION" OFF) "USE_CUDA AND NOT CUDA_SEPARABLE_COMPILATION" OFF)
cmake_dependent_option(
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
option(USE_FAST_NVCC "Use parallel NVCC build" OFF) option(USE_FAST_NVCC "Use parallel NVCC build" OFF)
option(USE_ROCM "Use ROCm" ON) option(USE_ROCM "Use ROCm" ON)
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF) option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)

View File

@ -23,6 +23,7 @@ set(ATen_CPU_INCLUDE)
set(ATen_THIRD_PARTY_INCLUDE) set(ATen_THIRD_PARTY_INCLUDE)
set(ATen_CUDA_CPP_SRCS) set(ATen_CUDA_CPP_SRCS)
set(ATen_CUDA_CU_SRCS) set(ATen_CUDA_CU_SRCS)
set(ATen_CUDA_LINALG_SRCS)
set(ATen_CUDA_SRCS_W_SORT_BY_KEY) set(ATen_CUDA_SRCS_W_SORT_BY_KEY)
set(ATen_CUDA_TEST_SRCS) set(ATen_CUDA_TEST_SRCS)
set(ATen_CUDA_INCLUDE) set(ATen_CUDA_INCLUDE)
@ -99,6 +100,7 @@ set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE)
set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE) set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE)
set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE) set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE)
set(ATen_CUDA_CPP_SRCS ${ATen_CUDA_CPP_SRCS} PARENT_SCOPE) set(ATen_CUDA_CPP_SRCS ${ATen_CUDA_CPP_SRCS} PARENT_SCOPE)
set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE)
set(ATen_CUDA_SRCS_W_SORT_BY_KEY ${ATen_CUDA_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_CUDA_SRCS_W_SORT_BY_KEY ${ATen_CUDA_SRCS_W_SORT_BY_KEY} PARENT_SCOPE)
set(ATen_CUDA_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_CUDA_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY} PARENT_SCOPE)
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE) set(ATen_HIP_SRCS ${ATen_HIP_SRCS} PARENT_SCOPE)

View File

@ -194,7 +194,6 @@ if(USE_CUDA)
list(APPEND ATen_CUDA_CU_SRCS list(APPEND ATen_CUDA_CU_SRCS
${cuda_cu} ${cuda_cu}
${native_cuda_cu} ${native_cuda_cu}
${native_cuda_linalg_cpp}
${native_sparse_cuda_cu} ${native_sparse_cuda_cu}
${native_quantized_cuda_cu} ${native_quantized_cuda_cu}
${cuda_generated_sources} ${cuda_generated_sources}
@ -208,6 +207,10 @@ if(USE_CUDA)
${native_quantized_cudnn_cpp} ${native_quantized_cudnn_cpp}
${native_sparse_cuda_cpp} ${native_sparse_cuda_cpp}
) )
set(ATen_CUDA_LINALG_SRCS ${native_cuda_linalg_cpp})
if(NOT BUILD_LAZY_CUDA_LINALG)
list(APPEND ATen_CUDA_CU_SRCS ${native_cuda_linalg_cpp})
endif()
if(CAFFE2_USE_CUDNN) if(CAFFE2_USE_CUDNN)
list(APPEND ATen_CUDA_CPP_SRCS ${cudnn_cpp}) list(APPEND ATen_CUDA_CPP_SRCS ${cudnn_cpp})
endif() endif()
@ -392,17 +395,25 @@ if(USE_CUDA AND NOT USE_ROCM)
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcurand_static.a ${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcurand_static.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublas_static.a ${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublas_static.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcufft_static_nocallback.a ${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcufft_static_nocallback.a
)
if(NOT BUILD_LAZY_CUDA_LINALG)
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcusolver_static.a ${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcusolver_static.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/liblapack_static.a # needed for libcusolver_static ${CUDA_TOOLKIT_ROOT_DIR}/lib64/liblapack_static.a # needed for libcusolver_static
) )
endif()
else() else()
list(APPEND ATen_CUDA_DEPENDENCY_LIBS list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_LIBRARIES} ${CUDA_LIBRARIES}
${CUDA_cusparse_LIBRARY} ${CUDA_cusparse_LIBRARY}
${CUDA_curand_LIBRARY} ${CUDA_curand_LIBRARY}
)
if(NOT BUILD_LAZY_CUDA_LINALG)
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_cusolver_LIBRARY} ${CUDA_cusolver_LIBRARY}
) )
endif() endif()
endif()
if(CAFFE2_USE_CUDNN) if(CAFFE2_USE_CUDNN)
list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${CUDNN_LIBRARIES}) list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${CUDNN_LIBRARIES})
@ -415,9 +426,9 @@ endif()
if(USE_MAGMA) if(USE_MAGMA)
if(USE_CUDA) if(USE_CUDA AND NOT BUILD_LAZY_CUDA_LINALG)
list(APPEND ATen_CUDA_DEPENDENCY_LIBS torch::magma) list(APPEND ATen_CUDA_DEPENDENCY_LIBS torch::magma)
endif(USE_CUDA) endif(USE_CUDA AND NOT BUILD_LAZY_CUDA_LINALG)
if(USE_ROCM) if(USE_ROCM)
list(APPEND ATen_HIP_DEPENDENCY_LIBS torch::magma) list(APPEND ATen_HIP_DEPENDENCY_LIBS torch::magma)
endif(USE_ROCM) endif(USE_ROCM)
@ -536,6 +547,7 @@ set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE)
set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE) set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE)
set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE) set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE)
set(ATen_CUDA_CPP_SRCS ${ATen_CUDA_CPP_SRCS} PARENT_SCOPE) set(ATen_CUDA_CPP_SRCS ${ATen_CUDA_CPP_SRCS} PARENT_SCOPE)
set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE)
set(ATen_CUDA_SRCS_W_SORT_BY_KEY ${ATen_CUDA_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_CUDA_SRCS_W_SORT_BY_KEY ${ATen_CUDA_SRCS_W_SORT_BY_KEY} PARENT_SCOPE)
set(ATen_CUDA_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY} PARENT_SCOPE) set(ATen_CUDA_CU_SRCS_W_SORT_BY_KEY ${ATen_CUDA_CU_SRCS_W_SORT_BY_KEY} PARENT_SCOPE)
set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE) set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE)

View File

@ -25,7 +25,7 @@ static void* checkDL(void* x) {
return x; return x;
} }
DynamicLibrary::DynamicLibrary(const char* name, const char* alt_name) { DynamicLibrary::DynamicLibrary(const char* name, const char* alt_name, bool leak_handle_): leak_handle(leak_handle_) {
// NOLINTNEXTLINE(hicpp-signed-bitwise) // NOLINTNEXTLINE(hicpp-signed-bitwise)
handle = dlopen(name, RTLD_LOCAL | RTLD_NOW); handle = dlopen(name, RTLD_LOCAL | RTLD_NOW);
if (!handle) { if (!handle) {
@ -46,8 +46,9 @@ void* DynamicLibrary::sym(const char* name) {
} }
DynamicLibrary::~DynamicLibrary() { DynamicLibrary::~DynamicLibrary() {
if (!handle) if (!handle || leak_handle) {
return; return;
}
dlclose(handle); dlclose(handle);
} }
@ -55,7 +56,7 @@ DynamicLibrary::~DynamicLibrary() {
// Windows // Windows
DynamicLibrary::DynamicLibrary(const char* name, const char* alt_name) { DynamicLibrary::DynamicLibrary(const char* name, const char* alt_name, bool leak_handle_): leak_handle(leak_handle_) {
// NOLINTNEXTLINE(hicpp-signed-bitwise) // NOLINTNEXTLINE(hicpp-signed-bitwise)
HMODULE theModule; HMODULE theModule;
bool reload = true; bool reload = true;
@ -97,7 +98,7 @@ void* DynamicLibrary::sym(const char* name) {
} }
DynamicLibrary::~DynamicLibrary() { DynamicLibrary::~DynamicLibrary() {
if (!handle) { if (!handle || leak_handle) {
return; return;
} }
FreeLibrary((HMODULE)handle); FreeLibrary((HMODULE)handle);

View File

@ -8,13 +8,14 @@ namespace at {
struct DynamicLibrary { struct DynamicLibrary {
AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary); AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary);
TORCH_API DynamicLibrary(const char* name, const char* alt_name = nullptr); TORCH_API DynamicLibrary(const char* name, const char* alt_name = nullptr, bool leak_handle = false);
TORCH_API void* sym(const char* name); TORCH_API void* sym(const char* name);
TORCH_API ~DynamicLibrary(); TORCH_API ~DynamicLibrary();
private: private:
bool leak_handle;
void* handle = nullptr; void* handle = nullptr;
}; };

View File

@ -0,0 +1,220 @@
// LinearAlgebraStubs.cpp
// Mostly a no-op unless BUILD_LAZY_CUDA_LINALG is defined
// In that case load library is dynamically loaded when first linalg call is made
// This helps reduce size of GPU memory context if linear algebra functions are not used
#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/DynamicLibrary.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/cuda/MiscUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/LinearAlgebra.h>
#include <ATen/native/BatchLinearAlgebra.h>
#if defined(BUILD_LAZY_CUDA_LINALG)
#include <ATen/native/cuda/linalg/BatchLinearAlgebraLib.h>
#if AT_MAGMA_ENABLED()
#include <ATen/cuda/detail/CUDAHooks.h>
namespace {
struct MagmaInitializer {
MagmaInitializer() {
::at::cuda::detail::set_magma_init_fn([]{ });
};
} initializer;
} // namespace (anonymous)
#endif
#endif
namespace at {
namespace native {
#if defined(BUILD_LAZY_CUDA_LINALG)
namespace {
cuda::detail::LinalgDispatch disp = {_solve_helper_cuda,
_symeig_helper_cuda,
_linalg_qr_helper_cuda,
_cholesky_solve_helper_cuda,
legacy_lstsq_cuda,
_linalg_inv_out_helper_cuda};
at::DynamicLibrary& getTorchLinalgLibrary() {
static at::DynamicLibrary lib("libtorch_cuda_linalg.so", nullptr, true);
return lib;
}
// Lazy dispatches do nothing but load linalg library and call the stub
// Loading the library should override the registration of those with the proper implementation
// getTorchLinalgLibrary() throws an exception if library is not found,
// which makes it unnecessary to have an explicit error checking
// But make sure that this function is called only once, to avoid infinite recursion
void loadLazyTorchLinalgLibrary() {
static int invoke_count = 0;
getTorchLinalgLibrary();
TORCH_CHECK(invoke_count++ == 0, "lazy wrapper should be called at most once");
}
void lazy_cholesky_kernel(const Tensor& input, const Tensor& info, bool upper) {
loadLazyTorchLinalgLibrary();
cholesky_stub(DeviceType::CUDA, input, info, upper);
}
Tensor& lazy_cholesky_inverse_kernel(Tensor &result, Tensor& infos, bool upper) {
loadLazyTorchLinalgLibrary();
return cholesky_inverse_stub(DeviceType::CUDA, result, infos, upper);
}
void lazy_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
loadLazyTorchLinalgLibrary();
lu_factor_stub(DeviceType::CUDA, input, pivots, infos, compute_pivots);
}
void lazy_triangular_solve_kernel(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
loadLazyTorchLinalgLibrary();
triangular_solve_stub(DeviceType::CUDA, A, B, left, upper, transpose, unitriangular);
}
Tensor& lazy_orgqr_kernel(Tensor& result, const Tensor& tau) {
loadLazyTorchLinalgLibrary();
return orgqr_stub(DeviceType::CUDA, result, tau);
}
void lazy_ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
loadLazyTorchLinalgLibrary();
ormqr_stub(DeviceType::CUDA, input, tau, other, left, transpose);
}
void lazy_geqrf_kernel(const Tensor& input, const Tensor& tau) {
loadLazyTorchLinalgLibrary();
geqrf_stub(DeviceType::CUDA, input, tau);
}
void lazy_linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
loadLazyTorchLinalgLibrary();
linalg_eigh_stub(DeviceType::CUDA, eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
}
std::tuple<Tensor, Tensor> lazy_eig_kernel(const Tensor& self, bool& eigenvectors) {
loadLazyTorchLinalgLibrary();
return eig_stub(DeviceType::CUDA, self, eigenvectors);
}
void lazy_linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) {
getTorchLinalgLibrary();
linalg_eig_stub(DeviceType::CUDA, eigenvalues, eigenvectors, infos, input, compute_eigenvectors);
}
void lazy_svd_kernel(const Tensor& A,
const bool full_matrices,
const bool compute_uv,
const Tensor& U,
const Tensor& S,
const Tensor& Vh,
const Tensor& info) {
getTorchLinalgLibrary();
svd_stub(DeviceType::CUDA, A, full_matrices, compute_uv, U, S, Vh, info);
}
void lazy_lu_solve_trans(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
getTorchLinalgLibrary();
lu_solve_trans_stub(DeviceType::CUDA, b, lu, pivots, trans);
}
void lazy_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots) {
getTorchLinalgLibrary();
lu_solve_stub(DeviceType::CUDA, b, lu, pivots);
}
void lazy_lstsq_kernel(const Tensor& a, Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, double rcond, std::string driver_name) {
getTorchLinalgLibrary();
lstsq_stub(DeviceType::CUDA, a, b, rank, singular_values, infos, rcond, driver_name);
}
REGISTER_CUDA_DISPATCH(cholesky_stub, &lazy_cholesky_kernel)
REGISTER_CUDA_DISPATCH(cholesky_inverse_stub, &lazy_cholesky_inverse_kernel);
REGISTER_CUDA_DISPATCH(lu_factor_stub, &lazy_lu_factor);
REGISTER_CUDA_DISPATCH(triangular_solve_stub, &lazy_triangular_solve_kernel);
REGISTER_CUDA_DISPATCH(orgqr_stub, &lazy_orgqr_kernel);
REGISTER_CUDA_DISPATCH(ormqr_stub, &lazy_ormqr_kernel);
REGISTER_CUDA_DISPATCH(geqrf_stub, &lazy_geqrf_kernel);
REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &lazy_linalg_eigh_kernel);
REGISTER_CUDA_DISPATCH(eig_stub, &lazy_eig_kernel);
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &lazy_linalg_eig_kernel);
REGISTER_CUDA_DISPATCH(svd_stub, &lazy_svd_kernel)
REGISTER_CUDA_DISPATCH(lu_solve_trans_stub, &lazy_lu_solve_trans);
REGISTER_CUDA_DISPATCH(lu_solve_stub, &lazy_lu_solve);
REGISTER_CUDA_DISPATCH(lstsq_stub, &lazy_lstsq_kernel);
} // anonymous namespace
// Old style dispatches
// torch_cuda_linalg dynamic library should have a global constructor
// that calls regiserLinaglDispatch so in order ot lazy bind
// old style dispatch all one have to do is to load library and call disp.func_name
// Protect from infinite recursion by initializing dispatch to self and checking
// that values are different after linalg library were loaded
namespace cuda {
namespace detail {
void registerLinalgDispatch(const LinalgDispatch& disp_) {
disp = disp_;
}
}} //namespace cuda::detail
Tensor& _linalg_inv_out_helper_cuda(Tensor &result, Tensor& infos_lu, Tensor& infos_getri) {
getTorchLinalgLibrary();
TORCH_CHECK(disp.inv_out_helper != _linalg_inv_out_helper_cuda, "Can't find _linalg_inv_out_helper_cuda");
return disp.inv_out_helper(result, infos_lu, infos_getri);
}
std::tuple<Tensor, Tensor> legacy_lstsq_cuda(const Tensor &B, const Tensor &A) {
getTorchLinalgLibrary();
TORCH_CHECK(disp.legacy_lstsq != legacy_lstsq_cuda, "Can't find legacy_lstsq_cuda");
return disp.legacy_lstsq(B, A);
}
Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upper) {
getTorchLinalgLibrary();
TORCH_CHECK(disp.cholesky_solve_helper != _cholesky_solve_helper_cuda, "Can't find _cholesky_solve_helper_cuda");
return disp.cholesky_solve_helper(self, A, upper);
}
std::tuple<Tensor, Tensor> _linalg_qr_helper_cuda(const Tensor& input, c10::string_view mode) {
getTorchLinalgLibrary();
TORCH_CHECK(disp.qr_helper != _linalg_qr_helper_cuda, "Can't find _linalg_qr_helper_cuda");
return disp.qr_helper(input, mode);
}
std::tuple<Tensor, Tensor> _symeig_helper_cuda(const Tensor& self, bool eigenvectors, bool upper) {
getTorchLinalgLibrary();
TORCH_CHECK(disp.symeig_helper != _symeig_helper_cuda, "Can't find _symeig_helper_cuda");
return disp.symeig_helper(self, eigenvectors, upper);
}
std::tuple<Tensor, Tensor> _solve_helper_cuda(const Tensor& self, const Tensor& A) {
getTorchLinalgLibrary();
TORCH_CHECK(disp.solve_helper != _solve_helper_cuda, "Can't find _solve_helper_cuda");
return disp.solve_helper(self, A);
}
#endif /*defined(BUILD_LAZY_CUDA_LINALG)*/
std::tuple<Tensor&, Tensor&> legacy_lstsq_out_cuda(
const Tensor& B, const Tensor& A, Tensor& B_out, Tensor& A_out) {
const auto dtype = A.scalar_type();
TORCH_CHECK(B.scalar_type() == dtype, "exepected A and B dtypes to match but found ",
A.scalar_type(), " and ", B.scalar_type());
TORCH_CHECK(A_out.scalar_type() == dtype, "A_out to have scalar type ", dtype,
" but found", A_out.scalar_type());
TORCH_CHECK(B_out.scalar_type() == dtype, "A_out to have scalar type ", dtype,
" but found", B_out.scalar_type());
Tensor A_tmp, B_tmp;
std::tie(B_tmp, A_tmp) = native::legacy_lstsq_cuda(B, A);
resize_output(A_out, A_tmp.sizes());
A_out.copy_(A_tmp);
resize_output(B_out, B_tmp.sizes());
B_out.copy_(B_tmp);
return std::tuple<Tensor&, Tensor&>(B_out, A_out);
}
}} // namespace at::native

View File

@ -9,7 +9,6 @@
#include <ATen/native/LinearAlgebraUtils.h> #include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/cuda/MiscUtils.h> #include <ATen/native/cuda/MiscUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/LinearAlgebra.h> #include <ATen/native/LinearAlgebra.h>
#include <ATen/native/BatchLinearAlgebra.h> #include <ATen/native/BatchLinearAlgebra.h>
#include <ATen/native/cuda/linalg/BatchLinearAlgebraLib.h> #include <ATen/native/cuda/linalg/BatchLinearAlgebraLib.h>
@ -26,8 +25,12 @@ const bool use_magma_ = true;
namespace { namespace {
struct MagmaInitializer { struct MagmaInitializer {
MagmaInitializer() { MagmaInitializer() {
#if defined(BUILD_LAZY_CUDA_LINALG)
magma_init();
#else
::at::cuda::detail::set_magma_init_fn([]{ magma_init(); }); ::at::cuda::detail::set_magma_init_fn([]{ magma_init(); });
}; #endif
}
} initializer; } initializer;
} // namespace (anonymous) } // namespace (anonymous)
@ -38,6 +41,12 @@ const bool use_magma_ = false;
namespace at { namespace at {
namespace native { namespace native {
#if defined(BUILD_LAZY_CUDA_LINALG)
// All registrations with PyTorch runtime should be done dynamically
// so if library is lazy loaded it must not export anything, otherwise
// it can result in symbol clashes
namespace lazy_linalg {
#endif
#if AT_MAGMA_ENABLED() #if AT_MAGMA_ENABLED()
template<class scalar_t> template<class scalar_t>
@ -3245,25 +3254,22 @@ std::tuple<Tensor, Tensor> legacy_lstsq_cuda(const Tensor &B, const Tensor &A) {
#endif // AT_MAGMA_ENABLED() #endif // AT_MAGMA_ENABLED()
} }
std::tuple<Tensor&, Tensor&> legacy_lstsq_out_cuda(
const Tensor& B, const Tensor& A, Tensor& B_out, Tensor& A_out) {
const auto dtype = A.scalar_type();
TORCH_CHECK(B.scalar_type() == dtype, "exepected A and B dtypes to match but found ",
A.scalar_type(), " and ", B.scalar_type());
TORCH_CHECK(A_out.scalar_type() == dtype, "A_out to have scalar type ", dtype,
" but found", A_out.scalar_type());
TORCH_CHECK(B_out.scalar_type() == dtype, "A_out to have scalar type ", dtype,
" but found", B_out.scalar_type());
Tensor A_tmp, B_tmp;
std::tie(B_tmp, A_tmp) = native::legacy_lstsq_cuda(B, A);
resize_output(A_out, A_tmp.sizes());
A_out.copy_(A_tmp);
resize_output(B_out, B_tmp.sizes());
B_out.copy_(B_tmp);
return std::tuple<Tensor&, Tensor&>(B_out, A_out);
}
#if defined(BUILD_LAZY_CUDA_LINALG)
struct DispatchInitializer {
DispatchInitializer() {
cuda::detail::LinalgDispatch disp{ _solve_helper_cuda,
_symeig_helper_cuda,
_linalg_qr_helper_cuda,
_cholesky_solve_helper_cuda,
legacy_lstsq_cuda,
_linalg_inv_out_helper_cuda};
cuda::detail::registerLinalgDispatch(disp);
};
} initializer;
} // namespace lazy_linalg
#endif
}} // namespace at::native }} // namespace at::native
#undef ALLOCATE_ARRAY #undef ALLOCATE_ARRAY

View File

@ -65,4 +65,20 @@ void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const T
#endif // USE_CUSOLVER #endif // USE_CUSOLVER
#if defined(BUILD_LAZY_CUDA_LINALG)
namespace cuda { namespace detail {
// This is only used for an old-style dispatches
// Please do not add any new entires to it
struct LinalgDispatch {
std::tuple<Tensor, Tensor> (*solve_helper)(const Tensor& self, const Tensor& A);
std::tuple<Tensor, Tensor> (*symeig_helper)(const Tensor& self, bool eigenvectors, bool upper);
std::tuple<Tensor, Tensor> (*qr_helper)(const Tensor& input, c10::string_view mode);
Tensor (*cholesky_solve_helper)(const Tensor& self, const Tensor& A, bool upper);
std::tuple<Tensor, Tensor> (*legacy_lstsq)(const Tensor &B, const Tensor &A);
Tensor& (*inv_out_helper)(Tensor &result, Tensor& infos_lu, Tensor& infos_getri);
};
C10_EXPORT void registerLinalgDispatch(const LinalgDispatch&);
}} // namespace cuda::detail
#endif
}} // namespace at::native }} // namespace at::native

View File

@ -901,6 +901,22 @@ elseif(USE_CUDA)
target_link_libraries(torch_cuda PRIVATE __caffe2_nccl) target_link_libraries(torch_cuda PRIVATE __caffe2_nccl)
target_compile_definitions(torch_cuda PRIVATE USE_NCCL) target_compile_definitions(torch_cuda PRIVATE USE_NCCL)
endif() endif()
if(BUILD_LAZY_CUDA_LINALG)
add_library(torch_cuda_linalg ${ATen_CUDA_LINALG_SRCS})
target_compile_definitions(torch_cuda_linalg PRIVATE USE_CUDA BUILD_LAZY_CUDA_LINALG)
target_link_libraries(torch_cuda_linalg PRIVATE
torch_cpu
torch_cuda
${CUDA_cusolver_LIBRARY}
)
if(USE_MAGMA)
target_link_libraries(torch_cuda_linalg PRIVATE torch::magma)
# CUDAHooks reports version of MAGMA PyTorch was compiled against, i.e. needs to be able to include magma headers
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/CUDAHooks.cpp PROPERTIES INCLUDE_DIRECTORIES "${MAGMA_INCLUDE_DIR}")
endif()
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp PROPERTIES COMPILE_FLAGS "-DBUILD_LAZY_CUDA_LINALG")
install(TARGETS torch_cuda_linalg DESTINATION "${TORCH_INSTALL_LIB_DIR}")
endif()
if(USE_PRECOMPILED_HEADERS) if(USE_PRECOMPILED_HEADERS)
if(BUILD_SPLIT_CUDA) if(BUILD_SPLIT_CUDA)

View File

@ -51,15 +51,18 @@ if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None
# todo(mkozuki): Figure out the root cause # todo(mkozuki): Figure out the root cause
if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None: if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
# malfet: One shoudl not assume that PyTorch re-exports CUDA dependencies
cublas_extension = CUDAExtension( cublas_extension = CUDAExtension(
name='torch_test_cpp_extension.cublas_extension', name='torch_test_cpp_extension.cublas_extension',
sources=['cublas_extension.cpp'] sources=['cublas_extension.cpp'],
libraries=['cublas'] if torch.version.hip is None else [],
) )
ext_modules.append(cublas_extension) ext_modules.append(cublas_extension)
cusolver_extension = CUDAExtension( cusolver_extension = CUDAExtension(
name='torch_test_cpp_extension.cusolver_extension', name='torch_test_cpp_extension.cusolver_extension',
sources=['cusolver_extension.cpp'] sources=['cusolver_extension.cpp'],
libraries=['cusolver'] if torch.version.hip is None else [],
) )
ext_modules.append(cusolver_extension) ext_modules.append(cusolver_extension)

View File

@ -1355,6 +1355,7 @@ aten_cuda_cu_source_list = [
"aten/src/ATen/cuda/CUDASparseBlas.cpp", "aten/src/ATen/cuda/CUDASparseBlas.cpp",
"aten/src/ATen/cuda/CublasHandlePool.cpp", "aten/src/ATen/cuda/CublasHandlePool.cpp",
"aten/src/ATen/native/cuda/Activation.cpp", "aten/src/ATen/native/cuda/Activation.cpp",
"aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp",
"aten/src/ATen/native/cuda/Blas.cpp", "aten/src/ATen/native/cuda/Blas.cpp",
"aten/src/ATen/native/cuda/Distributions.cpp", "aten/src/ATen/native/cuda/Distributions.cpp",
"aten/src/ATen/native/cuda/Equal.cpp", "aten/src/ATen/native/cuda/Equal.cpp",