mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support "device" keyword argument (#79)
Adds the optional "device" keyword argument to Tensor and Storage constructors and .new methods.
This commit is contained in:
parent
e034f258e3
commit
2bc9da4f5e
|
|
@ -21,7 +21,7 @@ def _cuda(self, idx=None, async=False):
|
||||||
else:
|
else:
|
||||||
return self
|
return self
|
||||||
else:
|
else:
|
||||||
ctx = torch.cuda.device(idx) if idx else torch.cuda._dummy_ctx()
|
ctx = torch.cuda.device(idx if idx else -1)
|
||||||
with ctx:
|
with ctx:
|
||||||
return self.type(getattr(torch.cuda, self.__class__.__name__), async)
|
return self.type(getattr(torch.cuda, self.__class__.__name__), async)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,14 +36,17 @@ of the CUDA driver.""".format(str(torch._C._cuda_getDriverVersion())))
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def device(idx):
|
def device(idx):
|
||||||
_lazy_init()
|
if idx is -1:
|
||||||
prev_idx = torch._C._cuda_getDevice()
|
|
||||||
if prev_idx != idx:
|
|
||||||
torch._C._cuda_setDevice(idx)
|
|
||||||
yield
|
yield
|
||||||
torch._C._cuda_setDevice(prev_idx)
|
|
||||||
else:
|
else:
|
||||||
yield
|
_lazy_init()
|
||||||
|
prev_idx = torch._C._cuda_getDevice()
|
||||||
|
if prev_idx != idx:
|
||||||
|
torch._C._cuda_setDevice(idx)
|
||||||
|
yield
|
||||||
|
torch._C._cuda_setDevice(prev_idx)
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
|
@ -55,15 +58,11 @@ def device_of(tensor):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def _dummy_ctx():
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def device_count():
|
def device_count():
|
||||||
_lazy_init()
|
_lazy_init()
|
||||||
return torch._C._cuda_getDeviceCount()
|
return torch._C._cuda_getDeviceCount()
|
||||||
|
|
||||||
|
|
||||||
def current_device():
|
def current_device():
|
||||||
_lazy_init()
|
_lazy_init()
|
||||||
return torch._C._cuda_getDevice()
|
return torch._C._cuda_getDevice()
|
||||||
|
|
@ -76,8 +75,8 @@ from .random import *
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
|
|
||||||
from .tensor import _CudaTensorBase
|
from ..tensor import _TensorBase
|
||||||
from .storage import _CudaStorageBase
|
from ..storage import _StorageBase
|
||||||
|
|
||||||
if not hasattr(torch._C, 'CudaDoubleStorageBase'):
|
if not hasattr(torch._C, 'CudaDoubleStorageBase'):
|
||||||
# Define dummy base classes
|
# Define dummy base classes
|
||||||
|
|
@ -88,51 +87,64 @@ if not hasattr(torch._C, 'CudaDoubleStorageBase'):
|
||||||
torch._C.__dict__[storage_name] = type(storage_name, (object,), {})
|
torch._C.__dict__[storage_name] = type(storage_name, (object,), {})
|
||||||
torch._C.__dict__[tensor_name] = type(tensor_name, (object,), {})
|
torch._C.__dict__[tensor_name] = type(tensor_name, (object,), {})
|
||||||
|
|
||||||
class InitCuda(object):
|
|
||||||
|
class _CudaBase(object):
|
||||||
|
is_cuda = True
|
||||||
|
|
||||||
|
def type(self, *args, **kwargs):
|
||||||
|
with device(self.get_device()):
|
||||||
|
return super(_CudaBase, self).type(*args, **kwargs)
|
||||||
|
|
||||||
|
def new(self, *args, **kwargs):
|
||||||
|
with device(kwargs.pop('device', self.get_device())):
|
||||||
|
return super(_CudaBase, self).new(*args, **kwargs)
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
_lazy_init()
|
_lazy_init()
|
||||||
return super(InitCuda, cls).__new__(cls, *args, **kwargs)
|
with device(kwargs.pop('device', -1)):
|
||||||
|
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
|
||||||
|
|
||||||
class DoubleStorage(InitCuda, torch._C.CudaDoubleStorageBase, _CudaStorageBase):
|
|
||||||
|
class DoubleStorage(_CudaBase, torch._C.CudaDoubleStorageBase, _StorageBase):
|
||||||
pass
|
pass
|
||||||
class FloatStorage(InitCuda, torch._C.CudaFloatStorageBase, _CudaStorageBase):
|
class FloatStorage(_CudaBase, torch._C.CudaFloatStorageBase, _StorageBase):
|
||||||
pass
|
pass
|
||||||
class LongStorage(InitCuda, torch._C.CudaLongStorageBase, _CudaStorageBase):
|
class LongStorage(_CudaBase, torch._C.CudaLongStorageBase, _StorageBase):
|
||||||
pass
|
pass
|
||||||
class IntStorage(InitCuda, torch._C.CudaIntStorageBase, _CudaStorageBase):
|
class IntStorage(_CudaBase, torch._C.CudaIntStorageBase, _StorageBase):
|
||||||
pass
|
pass
|
||||||
class ShortStorage(InitCuda, torch._C.CudaShortStorageBase, _CudaStorageBase):
|
class ShortStorage(_CudaBase, torch._C.CudaShortStorageBase, _StorageBase):
|
||||||
pass
|
pass
|
||||||
class CharStorage(InitCuda, torch._C.CudaCharStorageBase, _CudaStorageBase):
|
class CharStorage(_CudaBase, torch._C.CudaCharStorageBase, _StorageBase):
|
||||||
pass
|
pass
|
||||||
class ByteStorage(InitCuda, torch._C.CudaByteStorageBase, _CudaStorageBase):
|
class ByteStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
|
||||||
pass
|
pass
|
||||||
class HalfStorage(InitCuda, torch._C.CudaHalfStorageBase, _CudaStorageBase):
|
class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class DoubleTensor(InitCuda, torch._C.CudaDoubleTensorBase, _CudaTensorBase):
|
class DoubleTensor(_CudaBase, torch._C.CudaDoubleTensorBase, _TensorBase):
|
||||||
def is_signed(self):
|
def is_signed(self):
|
||||||
return True
|
return True
|
||||||
class FloatTensor(InitCuda, torch._C.CudaFloatTensorBase, _CudaTensorBase):
|
class FloatTensor(_CudaBase, torch._C.CudaFloatTensorBase, _TensorBase):
|
||||||
def is_signed(self):
|
def is_signed(self):
|
||||||
return True
|
return True
|
||||||
class LongTensor(InitCuda, torch._C.CudaLongTensorBase, _CudaTensorBase):
|
class LongTensor(_CudaBase, torch._C.CudaLongTensorBase, _TensorBase):
|
||||||
def is_signed(self):
|
def is_signed(self):
|
||||||
return True
|
return True
|
||||||
class IntTensor(InitCuda, torch._C.CudaIntTensorBase, _CudaTensorBase):
|
class IntTensor(_CudaBase, torch._C.CudaIntTensorBase, _TensorBase):
|
||||||
def is_signed(self):
|
def is_signed(self):
|
||||||
return True
|
return True
|
||||||
class ShortTensor(InitCuda, torch._C.CudaShortTensorBase, _CudaTensorBase):
|
class ShortTensor(_CudaBase, torch._C.CudaShortTensorBase, _TensorBase):
|
||||||
def is_signed(self):
|
def is_signed(self):
|
||||||
return True
|
return True
|
||||||
class CharTensor(InitCuda, torch._C.CudaCharTensorBase, _CudaTensorBase):
|
class CharTensor(_CudaBase, torch._C.CudaCharTensorBase, _TensorBase):
|
||||||
def is_signed(self):
|
def is_signed(self):
|
||||||
# TODO
|
# TODO
|
||||||
return False
|
return False
|
||||||
class ByteTensor(InitCuda, torch._C.CudaByteTensorBase, _CudaTensorBase):
|
class ByteTensor(_CudaBase, torch._C.CudaByteTensorBase, _TensorBase):
|
||||||
def is_signed(self):
|
def is_signed(self):
|
||||||
return False
|
return False
|
||||||
class HalfTensor(InitCuda, torch._C.CudaHalfTensorBase, _CudaTensorBase):
|
class HalfTensor(_CudaBase, torch._C.CudaHalfTensorBase, _TensorBase):
|
||||||
def is_signed(self):
|
def is_signed(self):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
from . import device, _dummy_ctx
|
|
||||||
from ..storage import _StorageBase
|
|
||||||
|
|
||||||
|
|
||||||
class _CudaStorageBase(_StorageBase):
|
|
||||||
is_cuda = True
|
|
||||||
|
|
||||||
def type(self, *args, **kwargs):
|
|
||||||
source_device = self.get_device()
|
|
||||||
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
|
|
||||||
with ctx:
|
|
||||||
return super(_CudaStorageBase, self).type(*args, **kwargs)
|
|
||||||
|
|
||||||
def new(self, *args, **kwargs):
|
|
||||||
source_device = self.get_device()
|
|
||||||
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
|
|
||||||
with ctx:
|
|
||||||
return super(_CudaStorageBase, self).new(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
from . import device, _dummy_ctx
|
|
||||||
from ..tensor import _TensorBase
|
|
||||||
|
|
||||||
|
|
||||||
class _CudaTensorBase(_TensorBase):
|
|
||||||
is_cuda = True
|
|
||||||
|
|
||||||
def type(self, *args, **kwargs):
|
|
||||||
source_device = self.get_device()
|
|
||||||
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
|
|
||||||
with ctx:
|
|
||||||
return super(_CudaTensorBase, self).type(*args, **kwargs)
|
|
||||||
|
|
||||||
def new(self, *args, **kwargs):
|
|
||||||
source_device = self.get_device()
|
|
||||||
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
|
|
||||||
with ctx:
|
|
||||||
return super(_CudaTensorBase, self).new(*args, **kwargs)
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user