mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Implements a simple content-addressable store for storages (with tensors implemented as cheap references on top), enabling incremental serialization of tensors to disk, which I intend to use in the accuracy repro extractor. Check the comment at the top of torch/utils/_content_store.py for more details on the intended use case. One major piece of this PR is implementing the content hash for tensors. For our prospective use case, we may need to repeatedly hash up to 80 GB of tensor data every time we snapshot (and we may snapshot multiple times). Using a conventional cryptographic hash and hashing each snapshot would likely take on order of minutes, which seemed too slow to me. So instead, I implemented a crappy hash function that can be run on GPU. It is at least somewhat theoretically grounded: using random parameters generated by Philox, we use the standard shift-multiply and xor sum universal hash family. The hash function is a bit dorky though; instead of properly doing 160-bit math, it just runs 32-bit hash five times and cats them together. By the way, this sets the first precedent for kernel in PyTorch library which MUST be torch.compile'd to be run (in fact, this kernel does not run in eager mode because of the use of xor_sum, which doesn't actually exist in ATen.) I had to add a few more primitives to inductor, namely randint (over the entire int range) and xor_sum. Fortunately, these primitives are natively supported by Triton/C++, and so they were very easy to plumb through. xor_sum is exposed as a prim, while randint special cases on when low/high span the entire 32-bit signed integer range. Thanks to Jeff Johnson for letting me bounce ideas of him on a Saturday morning lol. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/99809 Approved by: https://github.com/voznesenskym
207 lines
7.6 KiB
Python
207 lines
7.6 KiB
Python
r"""This file is allowed to initialize CUDA context when imported."""
|
|
|
|
import functools
|
|
import torch
|
|
import torch.cuda
|
|
from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM
|
|
import inspect
|
|
import contextlib
|
|
|
|
|
|
CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
|
|
|
|
|
|
TEST_CUDA = torch.cuda.is_available()
|
|
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
|
|
CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
|
|
# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
|
|
TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
|
|
TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0)
|
|
|
|
SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3))
|
|
SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0))
|
|
SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0))
|
|
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
|
|
|
|
PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
|
|
|
|
TEST_MAGMA = TEST_CUDA
|
|
if TEST_CUDA:
|
|
def test_has_magma():
|
|
torch.ones(1).cuda() # has_magma shows up after cuda is initialized
|
|
return torch.cuda.has_magma
|
|
TEST_MAGMA = LazyVal(test_has_magma)
|
|
|
|
if TEST_NUMBA:
|
|
import numba.cuda
|
|
TEST_NUMBA_CUDA = numba.cuda.is_available()
|
|
else:
|
|
TEST_NUMBA_CUDA = False
|
|
|
|
# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
|
|
# RNG have been initialized.
|
|
__cuda_ctx_rng_initialized = False
|
|
|
|
|
|
# after this call, CUDA context and RNG must have been initialized on each GPU
|
|
def initialize_cuda_context_rng():
|
|
global __cuda_ctx_rng_initialized
|
|
assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
|
|
if not __cuda_ctx_rng_initialized:
|
|
# initialize cuda context and rng for memory tests
|
|
for i in range(torch.cuda.device_count()):
|
|
torch.randn(1, device="cuda:{}".format(i))
|
|
__cuda_ctx_rng_initialized = True
|
|
|
|
|
|
# Test whether hardware TF32 math mode enabled. It is enabled only on:
|
|
# - CUDA >= 11
|
|
# - arch >= Ampere
|
|
def tf32_is_not_fp32():
|
|
if not torch.cuda.is_available() or torch.version.cuda is None:
|
|
return False
|
|
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
|
return False
|
|
if int(torch.version.cuda.split('.')[0]) < 11:
|
|
return False
|
|
return True
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def tf32_off():
|
|
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
|
|
try:
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
|
|
yield
|
|
finally:
|
|
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def tf32_on(self, tf32_precision=1e-5):
|
|
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
|
|
old_precision = self.precision
|
|
try:
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
self.precision = tf32_precision
|
|
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
|
|
yield
|
|
finally:
|
|
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
|
|
self.precision = old_precision
|
|
|
|
|
|
# This is a wrapper that wraps a test to run this test twice, one with
|
|
# allow_tf32=True, another with allow_tf32=False. When running with
|
|
# allow_tf32=True, it will use reduced precision as specified by the
|
|
# argument. For example:
|
|
# @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
|
|
# @tf32_on_and_off(0.005)
|
|
# def test_matmul(self, device, dtype):
|
|
# a = ...; b = ...;
|
|
# c = torch.matmul(a, b)
|
|
# self.assertEqual(c, expected)
|
|
# In the above example, when testing torch.float32 and torch.complex64 on CUDA
|
|
# on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at
|
|
# TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
|
|
# precision to check values.
|
|
#
|
|
# This decorator can be used for function with or without device/dtype, such as
|
|
# @tf32_on_and_off(0.005)
|
|
# def test_my_op(self)
|
|
# @tf32_on_and_off(0.005)
|
|
# def test_my_op(self, device)
|
|
# @tf32_on_and_off(0.005)
|
|
# def test_my_op(self, device, dtype)
|
|
# @tf32_on_and_off(0.005)
|
|
# def test_my_op(self, dtype)
|
|
# if neither device nor dtype is specified, it will check if the system has ampere device
|
|
# if device is specified, it will check if device is cuda
|
|
# if dtype is specified, it will check if dtype is float32 or complex64
|
|
# tf32 and fp32 are different only when all the three checks pass
|
|
def tf32_on_and_off(tf32_precision=1e-5):
|
|
def with_tf32_disabled(self, function_call):
|
|
with tf32_off():
|
|
function_call()
|
|
|
|
def with_tf32_enabled(self, function_call):
|
|
with tf32_on(self, tf32_precision):
|
|
function_call()
|
|
|
|
def wrapper(f):
|
|
params = inspect.signature(f).parameters
|
|
arg_names = tuple(params.keys())
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
for k, v in zip(arg_names, args):
|
|
kwargs[k] = v
|
|
cond = tf32_is_not_fp32()
|
|
if 'device' in kwargs:
|
|
cond = cond and (torch.device(kwargs['device']).type == 'cuda')
|
|
if 'dtype' in kwargs:
|
|
cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64})
|
|
if cond:
|
|
with_tf32_disabled(kwargs['self'], lambda: f(**kwargs))
|
|
with_tf32_enabled(kwargs['self'], lambda: f(**kwargs))
|
|
else:
|
|
f(**kwargs)
|
|
|
|
return wrapped
|
|
return wrapper
|
|
|
|
|
|
# This is a wrapper that wraps a test to run it with TF32 turned off.
|
|
# This wrapper is designed to be used when a test uses matmul or convolutions
|
|
# but the purpose of that test is not testing matmul or convolutions.
|
|
# Disabling TF32 will enforce torch.float tensors to be always computed
|
|
# at full precision.
|
|
def with_tf32_off(f):
|
|
@functools.wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
with tf32_off():
|
|
return f(*args, **kwargs)
|
|
|
|
return wrapped
|
|
|
|
def _get_magma_version():
|
|
if 'Magma' not in torch.__config__.show():
|
|
return (0, 0)
|
|
position = torch.__config__.show().find('Magma ')
|
|
version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0]
|
|
return tuple(int(x) for x in version_str.split("."))
|
|
|
|
def _get_torch_cuda_version():
|
|
if torch.version.cuda is None:
|
|
return (0, 0)
|
|
cuda_version = str(torch.version.cuda)
|
|
return tuple(int(x) for x in cuda_version.split("."))
|
|
|
|
def _get_torch_rocm_version():
|
|
if not TEST_WITH_ROCM:
|
|
return (0, 0)
|
|
rocm_version = str(torch.version.hip)
|
|
rocm_version = rocm_version.split("-")[0] # ignore git sha
|
|
return tuple(int(x) for x in rocm_version.split("."))
|
|
|
|
def _check_cusparse_generic_available():
|
|
return not TEST_WITH_ROCM
|
|
|
|
def _check_hipsparse_generic_available():
|
|
if not TEST_WITH_ROCM:
|
|
return False
|
|
|
|
rocm_version = str(torch.version.hip)
|
|
rocm_version = rocm_version.split("-")[0] # ignore git sha
|
|
rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
|
|
return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
|
|
|
|
|
|
TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available()
|
|
TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available()
|
|
|
|
# Importing this module should NOT eagerly initialize CUDA
|
|
if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
|
|
assert not torch.cuda.is_initialized()
|