mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35384 Removing references to PYTHON3COMPATIMPORTS, mostly suppressions but removed one instance of usage in a bash script. Fixed errors arc lint uncovered. Test Plan: arc lint Sandcastle tests Reviewed By: zertosh Differential Revision: D20635401 fbshipit-source-id: 74c6b5edb85a78a44f96b96f72ee75a9c2d029f1
406 lines
13 KiB
Python
406 lines
13 KiB
Python
|
|
r"""
|
|
The torch package contains data structures for multi-dimensional
|
|
tensors and mathematical operations over these are defined.
|
|
Additionally, it provides many utilities for efficient serializing of
|
|
Tensors and arbitrary types, and other useful utilities.
|
|
|
|
It has a CUDA counterpart, that enables you to run your tensor computations
|
|
on an NVIDIA GPU with compute capability >= 3.0.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import platform
|
|
import ctypes
|
|
from ._utils import _import_dotted_name
|
|
from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \
|
|
USE_RTLD_GLOBAL_WITH_LIBTORCH
|
|
from .version import __version__
|
|
from ._six import string_classes as _string_classes
|
|
|
|
__all__ = [
|
|
'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
|
|
'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed',
|
|
'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
|
|
'no_grad', 'enable_grad', 'rand', 'randn',
|
|
'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
|
|
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage',
|
|
'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
|
|
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor',
|
|
'lobpcg',
|
|
]
|
|
|
|
################################################################################
|
|
# Load the extension module
|
|
################################################################################
|
|
|
|
if platform.system() == 'Windows':
|
|
is_conda = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
|
|
py_dll_path = os.path.join(sys.exec_prefix, 'Library', 'bin')
|
|
th_dll_path = os.path.join(os.path.dirname(__file__), 'lib')
|
|
|
|
if not os.path.exists(os.path.join(th_dll_path, 'nvToolsExt64_1.dll')) and \
|
|
not os.path.exists(os.path.join(py_dll_path, 'nvToolsExt64_1.dll')):
|
|
nvtoolsext_dll_path = os.path.join(
|
|
os.getenv('NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt'), 'bin', 'x64')
|
|
else:
|
|
nvtoolsext_dll_path = ''
|
|
|
|
from .version import cuda as cuda_version
|
|
import glob
|
|
if cuda_version and len(glob.glob(os.path.join(th_dll_path, 'cudart64*.dll'))) == 0 and \
|
|
len(glob.glob(os.path.join(py_dll_path, 'cudart64*.dll'))) == 0:
|
|
cuda_version_1 = cuda_version.replace('.', '_')
|
|
cuda_path_var = 'CUDA_PATH_V' + cuda_version_1
|
|
default_path = 'C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v' + cuda_version
|
|
cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), 'bin')
|
|
else:
|
|
cuda_path = ''
|
|
|
|
if sys.version_info >= (3, 8):
|
|
dll_paths = list(filter(os.path.exists, [th_dll_path, py_dll_path, nvtoolsext_dll_path, cuda_path]))
|
|
|
|
for dll_path in dll_paths:
|
|
os.add_dll_directory(dll_path)
|
|
|
|
if is_conda or sys.version_info < (3, 8):
|
|
dll_paths = [th_dll_path, py_dll_path, nvtoolsext_dll_path, cuda_path]
|
|
dll_paths = list(filter(os.path.exists, dll_paths)) + [os.environ['PATH']]
|
|
|
|
os.environ['PATH'] = ';'.join(dll_paths)
|
|
|
|
import glob
|
|
dlls = glob.glob(os.path.join(th_dll_path, '*.dll'))
|
|
for dll in dlls:
|
|
ctypes.CDLL(dll)
|
|
|
|
|
|
# See Note [Global dependencies]
|
|
def _load_global_deps():
|
|
if platform.system() == 'Windows':
|
|
return
|
|
|
|
lib_name = 'libtorch_global_deps' + ('.dylib' if platform.system() == 'Darwin' else '.so')
|
|
here = os.path.abspath(__file__)
|
|
lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name)
|
|
|
|
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
|
|
|
|
|
|
if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \
|
|
platform.system() != 'Windows':
|
|
# Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a
|
|
# few circumstances:
|
|
#
|
|
# 1. You're in a build environment (e.g., fbcode) where
|
|
# libtorch_global_deps is not available, but you still need
|
|
# to get mkl to link in with RTLD_GLOBAL or it will just
|
|
# not work.
|
|
#
|
|
# 2. You're trying to run PyTorch under UBSAN and you need
|
|
# to ensure that only one copy of libtorch is loaded, so
|
|
# vptr checks work properly
|
|
#
|
|
# If you're using this setting, you must verify that all the libraries
|
|
# you load consistently use the same libstdc++, or you may have
|
|
# mysterious segfaults.
|
|
#
|
|
import os as _dl_flags
|
|
if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'):
|
|
try:
|
|
# next try if DLFCN exists
|
|
import DLFCN as _dl_flags
|
|
except ImportError:
|
|
# as a last attempt, use compile-time constants
|
|
import torch._dl as _dl_flags
|
|
old_flags = sys.getdlopenflags()
|
|
sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY)
|
|
from torch._C import *
|
|
sys.setdlopenflags(old_flags)
|
|
del old_flags
|
|
del _dl_flags
|
|
|
|
else:
|
|
# Easy way. You want this most of the time, because it will prevent
|
|
# C++ symbols from libtorch clobbering C++ symbols from other
|
|
# libraries, leading to mysterious segfaults.
|
|
#
|
|
# See Note [Global dependencies]
|
|
_load_global_deps()
|
|
from torch._C import *
|
|
|
|
__all__ += [name for name in dir(_C)
|
|
if name[0] != '_' and
|
|
not name.endswith('Base')]
|
|
|
|
################################################################################
|
|
# Define basic utilities
|
|
################################################################################
|
|
|
|
|
|
def typename(o):
|
|
if isinstance(o, torch.Tensor):
|
|
return o.type()
|
|
|
|
module = ''
|
|
class_name = ''
|
|
if hasattr(o, '__module__') and o.__module__ != 'builtins' \
|
|
and o.__module__ != '__builtin__' and o.__module__ is not None:
|
|
module = o.__module__ + '.'
|
|
|
|
if hasattr(o, '__qualname__'):
|
|
class_name = o.__qualname__
|
|
elif hasattr(o, '__name__'):
|
|
class_name = o.__name__
|
|
else:
|
|
class_name = o.__class__.__name__
|
|
|
|
return module + class_name
|
|
|
|
|
|
def is_tensor(obj):
|
|
r"""Returns True if `obj` is a PyTorch tensor.
|
|
|
|
Args:
|
|
obj (Object): Object to test
|
|
"""
|
|
return isinstance(obj, torch.Tensor)
|
|
|
|
|
|
def is_storage(obj):
|
|
r"""Returns True if `obj` is a PyTorch storage object.
|
|
|
|
Args:
|
|
obj (Object): Object to test
|
|
"""
|
|
return type(obj) in _storage_classes
|
|
|
|
|
|
def set_default_tensor_type(t):
|
|
r"""Sets the default ``torch.Tensor`` type to floating point tensor type
|
|
``t``. This type will also be used as default floating point type for
|
|
type inference in :func:`torch.tensor`.
|
|
|
|
The default floating point tensor type is initially ``torch.FloatTensor``.
|
|
|
|
Args:
|
|
t (type or string): the floating point tensor type or its name
|
|
|
|
Example::
|
|
|
|
>>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
|
|
torch.float32
|
|
>>> torch.set_default_tensor_type(torch.DoubleTensor)
|
|
>>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
|
|
torch.float64
|
|
|
|
"""
|
|
if isinstance(t, _string_classes):
|
|
t = _import_dotted_name(t)
|
|
_C._set_default_tensor_type(t)
|
|
|
|
|
|
def set_default_dtype(d):
|
|
r"""Sets the default floating point dtype to :attr:`d`. This type will be
|
|
used as default floating point type for type inference in
|
|
:func:`torch.tensor`.
|
|
|
|
The default floating point dtype is initially ``torch.float32``.
|
|
|
|
Args:
|
|
d (:class:`torch.dtype`): the floating point dtype to make the default
|
|
|
|
Example::
|
|
|
|
>>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
|
|
torch.float32
|
|
>>> torch.set_default_dtype(torch.float64)
|
|
>>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
|
|
torch.float64
|
|
|
|
"""
|
|
_C._set_default_dtype(d)
|
|
|
|
# If you edit these imports, please update torch/__init__.py.in as well
|
|
from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
|
|
from .serialization import save, load
|
|
from ._tensor_str import set_printoptions
|
|
|
|
################################################################################
|
|
# Define Storage and Tensor classes
|
|
################################################################################
|
|
|
|
from .tensor import Tensor
|
|
from .storage import _StorageBase
|
|
|
|
|
|
class DoubleStorage(_C.DoubleStorageBase, _StorageBase):
|
|
pass
|
|
|
|
|
|
class FloatStorage(_C.FloatStorageBase, _StorageBase):
|
|
pass
|
|
|
|
|
|
class HalfStorage(_C.HalfStorageBase, _StorageBase):
|
|
pass
|
|
|
|
|
|
class LongStorage(_C.LongStorageBase, _StorageBase):
|
|
pass
|
|
|
|
|
|
class IntStorage(_C.IntStorageBase, _StorageBase):
|
|
pass
|
|
|
|
|
|
class ShortStorage(_C.ShortStorageBase, _StorageBase):
|
|
pass
|
|
|
|
|
|
class CharStorage(_C.CharStorageBase, _StorageBase):
|
|
pass
|
|
|
|
|
|
class ByteStorage(_C.ByteStorageBase, _StorageBase):
|
|
pass
|
|
|
|
|
|
class BoolStorage(_C.BoolStorageBase, _StorageBase):
|
|
pass
|
|
|
|
|
|
class BFloat16Storage(_C.BFloat16StorageBase, _StorageBase):
|
|
pass
|
|
|
|
|
|
class QUInt8Storage(_C.QUInt8StorageBase, _StorageBase):
|
|
pass
|
|
|
|
class QInt8Storage(_C.QInt8StorageBase, _StorageBase):
|
|
pass
|
|
|
|
class QInt32Storage(_C.QInt32StorageBase, _StorageBase):
|
|
pass
|
|
|
|
|
|
_storage_classes = {
|
|
DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage,
|
|
CharStorage, ByteStorage, HalfStorage, BoolStorage, QUInt8Storage, QInt8Storage,
|
|
QInt32Storage, BFloat16Storage
|
|
}
|
|
|
|
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
|
|
_tensor_classes = set()
|
|
|
|
|
|
################################################################################
|
|
# Initialize extension
|
|
################################################################################
|
|
|
|
def manager_path():
|
|
if platform.system() == 'Windows':
|
|
return b""
|
|
path = get_file_path('torch', 'bin', 'torch_shm_manager')
|
|
prepare_multiprocessing_environment(get_file_path('torch'))
|
|
if not os.path.exists(path):
|
|
raise RuntimeError("Unable to find torch_shm_manager at " + path)
|
|
return path.encode('utf-8')
|
|
|
|
|
|
# Shared memory manager needs to know the exact location of manager executable
|
|
_C._initExtension(manager_path())
|
|
del manager_path
|
|
|
|
for name in dir(_C._VariableFunctions):
|
|
if name.startswith('__'):
|
|
continue
|
|
globals()[name] = getattr(_C._VariableFunctions, name)
|
|
|
|
################################################################################
|
|
# Import interface functions defined in Python
|
|
################################################################################
|
|
|
|
# needs to be after the above ATen bindings so we can overwrite from Python side
|
|
from .functional import *
|
|
|
|
|
|
################################################################################
|
|
# Remove unnecessary members
|
|
################################################################################
|
|
|
|
del DoubleStorageBase
|
|
del FloatStorageBase
|
|
del LongStorageBase
|
|
del IntStorageBase
|
|
del ShortStorageBase
|
|
del CharStorageBase
|
|
del ByteStorageBase
|
|
del BoolStorageBase
|
|
del QUInt8StorageBase
|
|
del BFloat16StorageBase
|
|
|
|
################################################################################
|
|
# Import most common subpackages
|
|
################################################################################
|
|
|
|
import torch.cuda
|
|
import torch.autograd
|
|
from torch.autograd import no_grad, enable_grad, set_grad_enabled
|
|
import torch.nn
|
|
import torch.nn.intrinsic
|
|
import torch.nn.quantized
|
|
import torch.optim
|
|
import torch.multiprocessing
|
|
import torch.sparse
|
|
import torch.utils.backcompat
|
|
import torch.onnx
|
|
import torch.jit
|
|
import torch.hub
|
|
import torch.random
|
|
import torch.distributions
|
|
import torch.testing
|
|
import torch.backends.cuda
|
|
import torch.backends.mkl
|
|
import torch.backends.mkldnn
|
|
import torch.backends.openmp
|
|
import torch.backends.quantized
|
|
import torch.quantization
|
|
import torch.utils.data
|
|
import torch.__config__
|
|
import torch.__future__
|
|
|
|
_C._init_names(list(torch._storage_classes))
|
|
|
|
# attach docstrings to torch and tensor functions
|
|
from . import _torch_docs, _tensor_docs, _storage_docs
|
|
del _torch_docs, _tensor_docs, _storage_docs
|
|
|
|
|
|
def compiled_with_cxx11_abi():
|
|
r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1"""
|
|
return _C._GLIBCXX_USE_CXX11_ABI
|
|
|
|
|
|
# Import the ops "namespace"
|
|
from torch._ops import ops
|
|
from torch._classes import classes
|
|
|
|
# Import the quasi random sampler
|
|
import torch.quasirandom
|
|
|
|
# If you are seeing this, it means that this call site was not checked if
|
|
# the memory format could be preserved, and it was switched to old default
|
|
# behaviour of contiguous
|
|
legacy_contiguous_format = contiguous_format
|
|
|
|
# Register fork handler to initialize OpenMP in child processes (see gh-28389)
|
|
from torch.multiprocessing._atfork import register_after_fork
|
|
register_after_fork(torch.get_num_threads)
|
|
del register_after_fork
|
|
|
|
# Import tools that require fully imported torch (for applying
|
|
# torch.jit.script as a decorator, for instance):
|
|
from ._lobpcg import lobpcg
|