Support "device" keyword argument (#79)

Adds the optional "device" keyword argument to Tensor and Storage
constructors and .new methods.
This commit is contained in:
Sam Gross 2016-10-01 19:32:55 -04:00 committed by Soumith Chintala
parent e034f258e3
commit 2bc9da4f5e
4 changed files with 44 additions and 70 deletions

View File

@ -21,7 +21,7 @@ def _cuda(self, idx=None, async=False):
else:
return self
else:
ctx = torch.cuda.device(idx) if idx else torch.cuda._dummy_ctx()
ctx = torch.cuda.device(idx if idx else -1)
with ctx:
return self.type(getattr(torch.cuda, self.__class__.__name__), async)

View File

@ -36,14 +36,17 @@ of the CUDA driver.""".format(str(torch._C._cuda_getDriverVersion())))
@contextlib.contextmanager
def device(idx):
_lazy_init()
prev_idx = torch._C._cuda_getDevice()
if prev_idx != idx:
torch._C._cuda_setDevice(idx)
if idx is -1:
yield
torch._C._cuda_setDevice(prev_idx)
else:
yield
_lazy_init()
prev_idx = torch._C._cuda_getDevice()
if prev_idx != idx:
torch._C._cuda_setDevice(idx)
yield
torch._C._cuda_setDevice(prev_idx)
else:
yield
@contextlib.contextmanager
@ -55,15 +58,11 @@ def device_of(tensor):
yield
@contextlib.contextmanager
def _dummy_ctx():
yield
def device_count():
_lazy_init()
return torch._C._cuda_getDeviceCount()
def current_device():
_lazy_init()
return torch._C._cuda_getDevice()
@ -76,8 +75,8 @@ from .random import *
################################################################################
from .tensor import _CudaTensorBase
from .storage import _CudaStorageBase
from ..tensor import _TensorBase
from ..storage import _StorageBase
if not hasattr(torch._C, 'CudaDoubleStorageBase'):
# Define dummy base classes
@ -88,51 +87,64 @@ if not hasattr(torch._C, 'CudaDoubleStorageBase'):
torch._C.__dict__[storage_name] = type(storage_name, (object,), {})
torch._C.__dict__[tensor_name] = type(tensor_name, (object,), {})
class InitCuda(object):
class _CudaBase(object):
is_cuda = True
def type(self, *args, **kwargs):
with device(self.get_device()):
return super(_CudaBase, self).type(*args, **kwargs)
def new(self, *args, **kwargs):
with device(kwargs.pop('device', self.get_device())):
return super(_CudaBase, self).new(*args, **kwargs)
def __new__(cls, *args, **kwargs):
_lazy_init()
return super(InitCuda, cls).__new__(cls, *args, **kwargs)
with device(kwargs.pop('device', -1)):
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
class DoubleStorage(InitCuda, torch._C.CudaDoubleStorageBase, _CudaStorageBase):
class DoubleStorage(_CudaBase, torch._C.CudaDoubleStorageBase, _StorageBase):
pass
class FloatStorage(InitCuda, torch._C.CudaFloatStorageBase, _CudaStorageBase):
class FloatStorage(_CudaBase, torch._C.CudaFloatStorageBase, _StorageBase):
pass
class LongStorage(InitCuda, torch._C.CudaLongStorageBase, _CudaStorageBase):
class LongStorage(_CudaBase, torch._C.CudaLongStorageBase, _StorageBase):
pass
class IntStorage(InitCuda, torch._C.CudaIntStorageBase, _CudaStorageBase):
class IntStorage(_CudaBase, torch._C.CudaIntStorageBase, _StorageBase):
pass
class ShortStorage(InitCuda, torch._C.CudaShortStorageBase, _CudaStorageBase):
class ShortStorage(_CudaBase, torch._C.CudaShortStorageBase, _StorageBase):
pass
class CharStorage(InitCuda, torch._C.CudaCharStorageBase, _CudaStorageBase):
class CharStorage(_CudaBase, torch._C.CudaCharStorageBase, _StorageBase):
pass
class ByteStorage(InitCuda, torch._C.CudaByteStorageBase, _CudaStorageBase):
class ByteStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
pass
class HalfStorage(InitCuda, torch._C.CudaHalfStorageBase, _CudaStorageBase):
class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase):
pass
class DoubleTensor(InitCuda, torch._C.CudaDoubleTensorBase, _CudaTensorBase):
class DoubleTensor(_CudaBase, torch._C.CudaDoubleTensorBase, _TensorBase):
def is_signed(self):
return True
class FloatTensor(InitCuda, torch._C.CudaFloatTensorBase, _CudaTensorBase):
class FloatTensor(_CudaBase, torch._C.CudaFloatTensorBase, _TensorBase):
def is_signed(self):
return True
class LongTensor(InitCuda, torch._C.CudaLongTensorBase, _CudaTensorBase):
class LongTensor(_CudaBase, torch._C.CudaLongTensorBase, _TensorBase):
def is_signed(self):
return True
class IntTensor(InitCuda, torch._C.CudaIntTensorBase, _CudaTensorBase):
class IntTensor(_CudaBase, torch._C.CudaIntTensorBase, _TensorBase):
def is_signed(self):
return True
class ShortTensor(InitCuda, torch._C.CudaShortTensorBase, _CudaTensorBase):
class ShortTensor(_CudaBase, torch._C.CudaShortTensorBase, _TensorBase):
def is_signed(self):
return True
class CharTensor(InitCuda, torch._C.CudaCharTensorBase, _CudaTensorBase):
class CharTensor(_CudaBase, torch._C.CudaCharTensorBase, _TensorBase):
def is_signed(self):
# TODO
return False
class ByteTensor(InitCuda, torch._C.CudaByteTensorBase, _CudaTensorBase):
class ByteTensor(_CudaBase, torch._C.CudaByteTensorBase, _TensorBase):
def is_signed(self):
return False
class HalfTensor(InitCuda, torch._C.CudaHalfTensorBase, _CudaTensorBase):
class HalfTensor(_CudaBase, torch._C.CudaHalfTensorBase, _TensorBase):
def is_signed(self):
return True

View File

@ -1,19 +0,0 @@
from . import device, _dummy_ctx
from ..storage import _StorageBase
class _CudaStorageBase(_StorageBase):
is_cuda = True
def type(self, *args, **kwargs):
source_device = self.get_device()
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
with ctx:
return super(_CudaStorageBase, self).type(*args, **kwargs)
def new(self, *args, **kwargs):
source_device = self.get_device()
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
with ctx:
return super(_CudaStorageBase, self).new(*args, **kwargs)

View File

@ -1,19 +0,0 @@
from . import device, _dummy_ctx
from ..tensor import _TensorBase
class _CudaTensorBase(_TensorBase):
is_cuda = True
def type(self, *args, **kwargs):
source_device = self.get_device()
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
with ctx:
return super(_CudaTensorBase, self).type(*args, **kwargs)
def new(self, *args, **kwargs):
source_device = self.get_device()
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
with ctx:
return super(_CudaTensorBase, self).new(*args, **kwargs)