mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support "device" keyword argument (#79)
Adds the optional "device" keyword argument to Tensor and Storage constructors and .new methods.
This commit is contained in:
parent
e034f258e3
commit
2bc9da4f5e
|
|
@ -21,7 +21,7 @@ def _cuda(self, idx=None, async=False):
|
|||
else:
|
||||
return self
|
||||
else:
|
||||
ctx = torch.cuda.device(idx) if idx else torch.cuda._dummy_ctx()
|
||||
ctx = torch.cuda.device(idx if idx else -1)
|
||||
with ctx:
|
||||
return self.type(getattr(torch.cuda, self.__class__.__name__), async)
|
||||
|
||||
|
|
|
|||
|
|
@ -36,14 +36,17 @@ of the CUDA driver.""".format(str(torch._C._cuda_getDriverVersion())))
|
|||
|
||||
@contextlib.contextmanager
|
||||
def device(idx):
|
||||
_lazy_init()
|
||||
prev_idx = torch._C._cuda_getDevice()
|
||||
if prev_idx != idx:
|
||||
torch._C._cuda_setDevice(idx)
|
||||
if idx is -1:
|
||||
yield
|
||||
torch._C._cuda_setDevice(prev_idx)
|
||||
else:
|
||||
yield
|
||||
_lazy_init()
|
||||
prev_idx = torch._C._cuda_getDevice()
|
||||
if prev_idx != idx:
|
||||
torch._C._cuda_setDevice(idx)
|
||||
yield
|
||||
torch._C._cuda_setDevice(prev_idx)
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
|
@ -55,15 +58,11 @@ def device_of(tensor):
|
|||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _dummy_ctx():
|
||||
yield
|
||||
|
||||
|
||||
def device_count():
|
||||
_lazy_init()
|
||||
return torch._C._cuda_getDeviceCount()
|
||||
|
||||
|
||||
def current_device():
|
||||
_lazy_init()
|
||||
return torch._C._cuda_getDevice()
|
||||
|
|
@ -76,8 +75,8 @@ from .random import *
|
|||
################################################################################
|
||||
|
||||
|
||||
from .tensor import _CudaTensorBase
|
||||
from .storage import _CudaStorageBase
|
||||
from ..tensor import _TensorBase
|
||||
from ..storage import _StorageBase
|
||||
|
||||
if not hasattr(torch._C, 'CudaDoubleStorageBase'):
|
||||
# Define dummy base classes
|
||||
|
|
@ -88,51 +87,64 @@ if not hasattr(torch._C, 'CudaDoubleStorageBase'):
|
|||
torch._C.__dict__[storage_name] = type(storage_name, (object,), {})
|
||||
torch._C.__dict__[tensor_name] = type(tensor_name, (object,), {})
|
||||
|
||||
class InitCuda(object):
|
||||
|
||||
class _CudaBase(object):
|
||||
is_cuda = True
|
||||
|
||||
def type(self, *args, **kwargs):
|
||||
with device(self.get_device()):
|
||||
return super(_CudaBase, self).type(*args, **kwargs)
|
||||
|
||||
def new(self, *args, **kwargs):
|
||||
with device(kwargs.pop('device', self.get_device())):
|
||||
return super(_CudaBase, self).new(*args, **kwargs)
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
_lazy_init()
|
||||
return super(InitCuda, cls).__new__(cls, *args, **kwargs)
|
||||
with device(kwargs.pop('device', -1)):
|
||||
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
|
||||
|
||||
class DoubleStorage(InitCuda, torch._C.CudaDoubleStorageBase, _CudaStorageBase):
|
||||
|
||||
class DoubleStorage(_CudaBase, torch._C.CudaDoubleStorageBase, _StorageBase):
|
||||
pass
|
||||
class FloatStorage(InitCuda, torch._C.CudaFloatStorageBase, _CudaStorageBase):
|
||||
class FloatStorage(_CudaBase, torch._C.CudaFloatStorageBase, _StorageBase):
|
||||
pass
|
||||
class LongStorage(InitCuda, torch._C.CudaLongStorageBase, _CudaStorageBase):
|
||||
class LongStorage(_CudaBase, torch._C.CudaLongStorageBase, _StorageBase):
|
||||
pass
|
||||
class IntStorage(InitCuda, torch._C.CudaIntStorageBase, _CudaStorageBase):
|
||||
class IntStorage(_CudaBase, torch._C.CudaIntStorageBase, _StorageBase):
|
||||
pass
|
||||
class ShortStorage(InitCuda, torch._C.CudaShortStorageBase, _CudaStorageBase):
|
||||
class ShortStorage(_CudaBase, torch._C.CudaShortStorageBase, _StorageBase):
|
||||
pass
|
||||
class CharStorage(InitCuda, torch._C.CudaCharStorageBase, _CudaStorageBase):
|
||||
class CharStorage(_CudaBase, torch._C.CudaCharStorageBase, _StorageBase):
|
||||
pass
|
||||
class ByteStorage(InitCuda, torch._C.CudaByteStorageBase, _CudaStorageBase):
|
||||
class ByteStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
|
||||
pass
|
||||
class HalfStorage(InitCuda, torch._C.CudaHalfStorageBase, _CudaStorageBase):
|
||||
class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
class DoubleTensor(InitCuda, torch._C.CudaDoubleTensorBase, _CudaTensorBase):
|
||||
class DoubleTensor(_CudaBase, torch._C.CudaDoubleTensorBase, _TensorBase):
|
||||
def is_signed(self):
|
||||
return True
|
||||
class FloatTensor(InitCuda, torch._C.CudaFloatTensorBase, _CudaTensorBase):
|
||||
class FloatTensor(_CudaBase, torch._C.CudaFloatTensorBase, _TensorBase):
|
||||
def is_signed(self):
|
||||
return True
|
||||
class LongTensor(InitCuda, torch._C.CudaLongTensorBase, _CudaTensorBase):
|
||||
class LongTensor(_CudaBase, torch._C.CudaLongTensorBase, _TensorBase):
|
||||
def is_signed(self):
|
||||
return True
|
||||
class IntTensor(InitCuda, torch._C.CudaIntTensorBase, _CudaTensorBase):
|
||||
class IntTensor(_CudaBase, torch._C.CudaIntTensorBase, _TensorBase):
|
||||
def is_signed(self):
|
||||
return True
|
||||
class ShortTensor(InitCuda, torch._C.CudaShortTensorBase, _CudaTensorBase):
|
||||
class ShortTensor(_CudaBase, torch._C.CudaShortTensorBase, _TensorBase):
|
||||
def is_signed(self):
|
||||
return True
|
||||
class CharTensor(InitCuda, torch._C.CudaCharTensorBase, _CudaTensorBase):
|
||||
class CharTensor(_CudaBase, torch._C.CudaCharTensorBase, _TensorBase):
|
||||
def is_signed(self):
|
||||
# TODO
|
||||
return False
|
||||
class ByteTensor(InitCuda, torch._C.CudaByteTensorBase, _CudaTensorBase):
|
||||
class ByteTensor(_CudaBase, torch._C.CudaByteTensorBase, _TensorBase):
|
||||
def is_signed(self):
|
||||
return False
|
||||
class HalfTensor(InitCuda, torch._C.CudaHalfTensorBase, _CudaTensorBase):
|
||||
class HalfTensor(_CudaBase, torch._C.CudaHalfTensorBase, _TensorBase):
|
||||
def is_signed(self):
|
||||
return True
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +0,0 @@
|
|||
from . import device, _dummy_ctx
|
||||
from ..storage import _StorageBase
|
||||
|
||||
|
||||
class _CudaStorageBase(_StorageBase):
|
||||
is_cuda = True
|
||||
|
||||
def type(self, *args, **kwargs):
|
||||
source_device = self.get_device()
|
||||
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
|
||||
with ctx:
|
||||
return super(_CudaStorageBase, self).type(*args, **kwargs)
|
||||
|
||||
def new(self, *args, **kwargs):
|
||||
source_device = self.get_device()
|
||||
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
|
||||
with ctx:
|
||||
return super(_CudaStorageBase, self).new(*args, **kwargs)
|
||||
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
from . import device, _dummy_ctx
|
||||
from ..tensor import _TensorBase
|
||||
|
||||
|
||||
class _CudaTensorBase(_TensorBase):
|
||||
is_cuda = True
|
||||
|
||||
def type(self, *args, **kwargs):
|
||||
source_device = self.get_device()
|
||||
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
|
||||
with ctx:
|
||||
return super(_CudaTensorBase, self).type(*args, **kwargs)
|
||||
|
||||
def new(self, *args, **kwargs):
|
||||
source_device = self.get_device()
|
||||
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
|
||||
with ctx:
|
||||
return super(_CudaTensorBase, self).new(*args, **kwargs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user