Allow setting TORCH_LINALG_PREFER_CUSOLVER=1 to prefer cusolver as linear algebra library globally (#106226)

setting TORCH_LINALG_PREFER_CUSOLVER=1

This will allow users to prefer cusolver as linear algebra backend in their container use case. The switch is not enabled by default so it won't change any existing default behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106226
Approved by: https://github.com/lezcano
This commit is contained in:
Xiao Wang 2023-07-30 09:38:46 +00:00 committed by PyTorch MergeBot
parent 858ca65c8a
commit 21fd2bc32e
3 changed files with 10 additions and 2 deletions

View File

@ -313,7 +313,10 @@ class TORCH_API Context {
bool allow_fp16_reduction_cublas = true;
bool allow_bf16_reduction_cublas = true;
bool enabled_mkldnn = true;
at::LinalgBackend linalg_preferred_backend = at::LinalgBackend::Default;
at::LinalgBackend linalg_preferred_backend =
c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true
? at::LinalgBackend::Cusolver
: at::LinalgBackend::Default;
#ifdef C10_MOBILE
bool release_original_weights = true;
#else

View File

@ -158,6 +158,10 @@ def preferred_linalg_library(
* If `"default"` (the default) is set then heuristics will be used to pick between
cuSOLVER and MAGMA if both are available.
* When no input is given, this function returns the currently preferred library.
* User may use the environment variable TORCH_LINALG_PREFER_CUSOLVER=1 to set the preferred library to cuSOLVER
globally.
This flag only sets the initial value of the preferred library and the preferred library
may still be overridden by this function call later in your script.
Note: When a library is preferred other libraries may still be used if the preferred library
doesn't implement the operation(s) called.

View File

@ -1323,10 +1323,11 @@ def skipIfNotMiopenSuggestNHWC(fn):
def setLinalgBackendsToDefaultFinally(fn):
@wraps(fn)
def _fn(*args, **kwargs):
_preferred_backend = torch.backends.cuda.preferred_linalg_library()
try:
fn(*args, **kwargs)
finally:
torch.backends.cuda.preferred_linalg_library('default')
torch.backends.cuda.preferred_linalg_library(_preferred_backend)
return _fn