mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12794 common.py is used in base_module for almost all tests in test/. The name of this file is so common that can easily conflict with other dependencies if they happen to have another common.py in the base module. Rename the file to avoid conflict. Reviewed By: orionr Differential Revision: D10438204 fbshipit-source-id: 6a996c14980722330be0a9fd3a54c20af4b3d380
35 lines
1.3 KiB
Python
35 lines
1.3 KiB
Python
r"""This file is allowed to initialize CUDA context when imported."""
|
|
|
|
import torch
|
|
import torch.cuda
|
|
from common_utils import TEST_WITH_ROCM, TEST_NUMBA
|
|
|
|
|
|
TEST_CUDA = torch.cuda.is_available()
|
|
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
|
|
CUDA_DEVICE = TEST_CUDA and torch.device("cuda:0")
|
|
# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
|
|
TEST_CUDNN = TEST_CUDA and (TEST_WITH_ROCM or torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
|
|
TEST_CUDNN_VERSION = TEST_CUDNN and torch.backends.cudnn.version()
|
|
|
|
if TEST_NUMBA:
|
|
import numba.cuda
|
|
TEST_NUMBA_CUDA = numba.cuda.is_available()
|
|
else:
|
|
TEST_NUMBA_CUDA = False
|
|
|
|
# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
|
|
# RNG have been initialized.
|
|
__cuda_ctx_rng_initialized = False
|
|
|
|
|
|
# after this call, CUDA context and RNG must have been initialized on each GPU
|
|
def initialize_cuda_context_rng():
|
|
global __cuda_ctx_rng_initialized
|
|
assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
|
|
if not __cuda_ctx_rng_initialized:
|
|
# initialize cuda context and rng for memory tests
|
|
for i in range(torch.cuda.device_count()):
|
|
torch.randn(1, device="cuda:{}".format(i))
|
|
__cuda_ctx_rng_initialized = True
|