mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Fixes https://github.com/pytorch/pytorch/issues/42265 This PR adds cusolver to the pytorch build, and enables the use of cusolver/cublas library functions on GPU `torch.inverse` on certain tensor shapes. Specifically, when * the tensor is two dimensional (single batch), or * has >2 dimensions (multiple batches) and `batch_size <= 2`, or * magma is not linked, cusolver/cublas will be used. In other conditions, the current implementation of MAGMA will still be used.8c0949ae45/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu (L742-L752)The reason for this is that for tensors with large batch_size, `cublasXgetrfBatched` and `cublasXgetriBatched` doesn't perform very well. For `batch_size > 1`, we launch cusolver functions in multiple streams. This lets cusolver functions run in parallel, and can greatly increase the performance. When `batch_size > 2`, the parallel launched cusolver functions are slightly slower than the current magma implementation, so we still use the current magma impl. On CUDA 9.2, there were some numerical issues detected, so cusolver impl will not be used. The cusolver impl will also not be used on platforms other than Nvidia CUDA.060769feaf/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h (L10-L13)Note that there is a new heuristic used before cusolver/cublas calls here:8c0949ae45/aten/src/ATen/native/cuda/MiscUtils.h (L113-L121)where `use_loop_launch = true` means launch single batch cusolver functions in parallel, and `use_loop_launch = false` means use cublas_X_batched functions. When magma is enabled (only `batch_size <= 2` will be dispatched to cusolver/cublas), the heuristic will always return `true` and the cusolver calls are faster than small batch_size magma calls. When magma is disabled, this adds the functionality of `torch.inverse`, which was disabled before for all shapes (though large batch_size cublas performance may not be as well as magma). Checklist: - [X] Add benchmark, cpu, gpu-before (magma), gpu-after (cusolver) - [X] Rewrite single inverse (ndim == 2) with cusolver - [X] Rewrite batched inverse (ndim > 2) with cublas - [X] Add cusolver to build - [x] Clean up functions related to `USE_MAGMA` define guard - [x] Workaround for non-cuda platform - [x] Workaround for cuda 9.2 - [x] Add zero size check - [x] Add tests Next step: If cusolver doesn't cause any problem in pytorch build, and there are no major performance regressions reported after this PR being merged, I will start porting other cusolver/cublas functions for linear algebra to improve the performance. <details> <summary> benchmark 73499c6 </summary> benchmark code: https://github.com/xwang233/code-snippet/blob/master/torch.inverse/inverse-cusolver.ipynb shape meaning: * `[] 2 torch.float32 -> torch.randn(2, 2, dtype=torch.float32)` * `[2] 4 torch.float32 -> torch.randn(2, 4, 4, dtype=torch.float32)` | shape | cpu_time (ms) | gpu_time_before (magma) (ms) | gpu_time_after (ms) | | --- | --- | --- | --- | | [] 2 torch.float32 | 0.095 | 7.534 | 0.129 | | [] 4 torch.float32 | 0.009 | 7.522 | 0.129 | | [] 8 torch.float32 | 0.011 | 7.647 | 0.138 | | [] 16 torch.float32 | 0.075 | 7.582 | 0.135 | | [] 32 torch.float32 | 0.073 | 7.573 | 0.191 | | [] 64 torch.float32 | 0.134 | 7.694 | 0.288 | | [] 128 torch.float32 | 0.398 | 8.073 | 0.491 | | [] 256 torch.float32 | 1.054 | 11.860 | 1.074 | | [] 512 torch.float32 | 5.218 | 14.130 | 2.582 | | [] 1024 torch.float32 | 19.010 | 18.780 | 6.936 | | [1] 2 torch.float32 | 0.009 | 0.113 | 0.128 ***regressed | | [1] 4 torch.float32 | 0.009 | 0.113 | 0.131 ***regressed | | [1] 8 torch.float32 | 0.011 | 0.116 | 0.129 ***regressed | | [1] 16 torch.float32 | 0.015 | 0.122 | 0.135 ***regressed | | [1] 32 torch.float32 | 0.032 | 0.177 | 0.178 ***regressed | | [1] 64 torch.float32 | 0.070 | 0.420 | 0.281 | | [1] 128 torch.float32 | 0.328 | 0.816 | 0.490 | | [1] 256 torch.float32 | 1.125 | 1.690 | 1.084 | | [1] 512 torch.float32 | 4.344 | 4.305 | 2.576 | | [1] 1024 torch.float32 | 16.510 | 16.340 | 6.928 | | [2] 2 torch.float32 | 0.009 | 0.113 | 0.186 ***regressed | | [2] 4 torch.float32 | 0.011 | 0.115 | 0.184 ***regressed | | [2] 8 torch.float32 | 0.012 | 0.114 | 0.184 ***regressed | | [2] 16 torch.float32 | 0.019 | 0.119 | 0.173 ***regressed | | [2] 32 torch.float32 | 0.050 | 0.170 | 0.240 ***regressed | | [2] 64 torch.float32 | 0.120 | 0.429 | 0.375 | | [2] 128 torch.float32 | 0.576 | 0.830 | 0.675 | | [2] 256 torch.float32 | 2.021 | 1.748 | 1.451 | | [2] 512 torch.float32 | 9.070 | 4.749 | 3.539 | | [2] 1024 torch.float32 | 33.655 | 18.240 | 12.220 | | [4] 2 torch.float32 | 0.009 | 0.112 | 0.318 ***regressed | | [4] 4 torch.float32 | 0.010 | 0.115 | 0.319 ***regressed | | [4] 8 torch.float32 | 0.013 | 0.115 | 0.320 ***regressed | | [4] 16 torch.float32 | 0.027 | 0.120 | 0.331 ***regressed | | [4] 32 torch.float32 | 0.085 | 0.173 | 0.385 ***regressed | | [4] 64 torch.float32 | 0.221 | 0.431 | 0.646 ***regressed | | [4] 128 torch.float32 | 1.102 | 0.834 | 1.055 ***regressed | | [4] 256 torch.float32 | 4.042 | 1.811 | 2.054 ***regressed | | [4] 512 torch.float32 | 18.390 | 4.884 | 5.087 ***regressed | | [4] 1024 torch.float32 | 69.025 | 19.840 | 20.000 ***regressed | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/42403 Reviewed By: ailzhang, mruberry Differential Revision: D23717984 Pulled By: ngimel fbshipit-source-id: 54cbd9ea72a97989cff4127089938e8a8e29a72b
135 lines
4.9 KiB
Python
135 lines
4.9 KiB
Python
r"""This file is allowed to initialize CUDA context when imported."""
|
|
|
|
import functools
|
|
import torch
|
|
import torch.cuda
|
|
from torch.testing._internal.common_utils import TEST_NUMBA
|
|
import inspect
|
|
import contextlib
|
|
|
|
|
|
TEST_CUDA = torch.cuda.is_available()
|
|
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
|
|
CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
|
|
# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
|
|
TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE))
|
|
TEST_CUDNN_VERSION = torch.backends.cudnn.version() if TEST_CUDNN else 0
|
|
|
|
TEST_MAGMA = TEST_CUDA
|
|
if TEST_CUDA:
|
|
torch.ones(1).cuda() # has_magma shows up after cuda is initialized
|
|
TEST_MAGMA = torch.cuda.has_magma
|
|
|
|
if TEST_NUMBA:
|
|
import numba.cuda
|
|
TEST_NUMBA_CUDA = numba.cuda.is_available()
|
|
else:
|
|
TEST_NUMBA_CUDA = False
|
|
|
|
# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
|
|
# RNG have been initialized.
|
|
__cuda_ctx_rng_initialized = False
|
|
|
|
|
|
# after this call, CUDA context and RNG must have been initialized on each GPU
|
|
def initialize_cuda_context_rng():
|
|
global __cuda_ctx_rng_initialized
|
|
assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
|
|
if not __cuda_ctx_rng_initialized:
|
|
# initialize cuda context and rng for memory tests
|
|
for i in range(torch.cuda.device_count()):
|
|
torch.randn(1, device="cuda:{}".format(i))
|
|
__cuda_ctx_rng_initialized = True
|
|
|
|
|
|
# Test whether hardware TF32 math mode enabled. It is enabled only on:
|
|
# - CUDA >= 11
|
|
# - arch >= Ampere
|
|
def tf32_is_not_fp32():
|
|
if not torch.cuda.is_available() or torch.version.cuda is None:
|
|
return False
|
|
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
|
return False
|
|
if int(torch.version.cuda.split('.')[0]) < 11:
|
|
return False
|
|
return True
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def tf32_off():
|
|
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
|
|
try:
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
|
|
yield
|
|
finally:
|
|
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def tf32_on(self, tf32_precision=1e-5):
|
|
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
|
|
old_precison = self.precision
|
|
try:
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
self.precision = tf32_precision
|
|
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
|
|
yield
|
|
finally:
|
|
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
|
|
self.precision = old_precison
|
|
|
|
|
|
# This is a wrapper that wraps a test to run this test twice, one with
|
|
# allow_tf32=True, another with allow_tf32=False. When running with
|
|
# allow_tf32=True, it will use reduced precision as pecified by the
|
|
# argument. For example:
|
|
# @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
|
# @tf32_on_and_off(0.005)
|
|
# def test_matmul(self, device, dtype):
|
|
# a = ...; b = ...;
|
|
# c = torch.matmul(a, b)
|
|
# self.assertEqual(c, expected)
|
|
# In the above example, when testing torch.float32 and torch.complex64 on CUDA
|
|
# on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at
|
|
# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
|
|
# precision to check values.
|
|
def tf32_on_and_off(tf32_precision=1e-5):
|
|
def with_tf32_disabled(self, function_call):
|
|
with tf32_off():
|
|
function_call()
|
|
|
|
def with_tf32_enabled(self, function_call):
|
|
with tf32_on(self, tf32_precision):
|
|
function_call()
|
|
|
|
def wrapper(f):
|
|
nargs = len(inspect.signature(f).parameters)
|
|
if nargs == 2:
|
|
@functools.wraps(f)
|
|
def wrapped(self, device):
|
|
if self.device_type == 'cuda' and tf32_is_not_fp32():
|
|
with_tf32_disabled(self, lambda: f(self, device))
|
|
with_tf32_enabled(self, lambda: f(self, device))
|
|
else:
|
|
f(self, device)
|
|
else:
|
|
assert nargs == 3, "this decorator only support function with signature (self, device) or (self, device, dtype)"
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(self, device, dtype):
|
|
if self.device_type == 'cuda' and dtype in {torch.float32, torch.complex64} and tf32_is_not_fp32():
|
|
with_tf32_disabled(self, lambda: f(self, device, dtype))
|
|
with_tf32_enabled(self, lambda: f(self, device, dtype))
|
|
else:
|
|
f(self, device, dtype)
|
|
|
|
return wrapped
|
|
return wrapper
|
|
|
|
def _get_torch_cuda_version():
|
|
if torch.version.cuda is None:
|
|
return [0, 0]
|
|
cuda_version = str(torch.version.cuda)
|
|
return [int(x) for x in cuda_version.split(".")]
|