Support "device" keyword argument (#79)

Adds the optional "device" keyword argument to Tensor and Storage
constructors and .new methods.
This commit is contained in:
Sam Gross 2016-10-01 19:32:55 -04:00 committed by Soumith Chintala
parent e034f258e3
commit 2bc9da4f5e
4 changed files with 44 additions and 70 deletions

View File

@ -21,7 +21,7 @@ def _cuda(self, idx=None, async=False):
else: else:
return self return self
else: else:
ctx = torch.cuda.device(idx) if idx else torch.cuda._dummy_ctx() ctx = torch.cuda.device(idx if idx else -1)
with ctx: with ctx:
return self.type(getattr(torch.cuda, self.__class__.__name__), async) return self.type(getattr(torch.cuda, self.__class__.__name__), async)

View File

@ -36,14 +36,17 @@ of the CUDA driver.""".format(str(torch._C._cuda_getDriverVersion())))
@contextlib.contextmanager @contextlib.contextmanager
def device(idx): def device(idx):
_lazy_init() if idx is -1:
prev_idx = torch._C._cuda_getDevice()
if prev_idx != idx:
torch._C._cuda_setDevice(idx)
yield yield
torch._C._cuda_setDevice(prev_idx)
else: else:
yield _lazy_init()
prev_idx = torch._C._cuda_getDevice()
if prev_idx != idx:
torch._C._cuda_setDevice(idx)
yield
torch._C._cuda_setDevice(prev_idx)
else:
yield
@contextlib.contextmanager @contextlib.contextmanager
@ -55,15 +58,11 @@ def device_of(tensor):
yield yield
@contextlib.contextmanager
def _dummy_ctx():
yield
def device_count(): def device_count():
_lazy_init() _lazy_init()
return torch._C._cuda_getDeviceCount() return torch._C._cuda_getDeviceCount()
def current_device(): def current_device():
_lazy_init() _lazy_init()
return torch._C._cuda_getDevice() return torch._C._cuda_getDevice()
@ -76,8 +75,8 @@ from .random import *
################################################################################ ################################################################################
from .tensor import _CudaTensorBase from ..tensor import _TensorBase
from .storage import _CudaStorageBase from ..storage import _StorageBase
if not hasattr(torch._C, 'CudaDoubleStorageBase'): if not hasattr(torch._C, 'CudaDoubleStorageBase'):
# Define dummy base classes # Define dummy base classes
@ -88,51 +87,64 @@ if not hasattr(torch._C, 'CudaDoubleStorageBase'):
torch._C.__dict__[storage_name] = type(storage_name, (object,), {}) torch._C.__dict__[storage_name] = type(storage_name, (object,), {})
torch._C.__dict__[tensor_name] = type(tensor_name, (object,), {}) torch._C.__dict__[tensor_name] = type(tensor_name, (object,), {})
class InitCuda(object):
class _CudaBase(object):
is_cuda = True
def type(self, *args, **kwargs):
with device(self.get_device()):
return super(_CudaBase, self).type(*args, **kwargs)
def new(self, *args, **kwargs):
with device(kwargs.pop('device', self.get_device())):
return super(_CudaBase, self).new(*args, **kwargs)
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
_lazy_init() _lazy_init()
return super(InitCuda, cls).__new__(cls, *args, **kwargs) with device(kwargs.pop('device', -1)):
return super(_CudaBase, cls).__new__(cls, *args, **kwargs)
class DoubleStorage(InitCuda, torch._C.CudaDoubleStorageBase, _CudaStorageBase):
class DoubleStorage(_CudaBase, torch._C.CudaDoubleStorageBase, _StorageBase):
pass pass
class FloatStorage(InitCuda, torch._C.CudaFloatStorageBase, _CudaStorageBase): class FloatStorage(_CudaBase, torch._C.CudaFloatStorageBase, _StorageBase):
pass pass
class LongStorage(InitCuda, torch._C.CudaLongStorageBase, _CudaStorageBase): class LongStorage(_CudaBase, torch._C.CudaLongStorageBase, _StorageBase):
pass pass
class IntStorage(InitCuda, torch._C.CudaIntStorageBase, _CudaStorageBase): class IntStorage(_CudaBase, torch._C.CudaIntStorageBase, _StorageBase):
pass pass
class ShortStorage(InitCuda, torch._C.CudaShortStorageBase, _CudaStorageBase): class ShortStorage(_CudaBase, torch._C.CudaShortStorageBase, _StorageBase):
pass pass
class CharStorage(InitCuda, torch._C.CudaCharStorageBase, _CudaStorageBase): class CharStorage(_CudaBase, torch._C.CudaCharStorageBase, _StorageBase):
pass pass
class ByteStorage(InitCuda, torch._C.CudaByteStorageBase, _CudaStorageBase): class ByteStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
pass pass
class HalfStorage(InitCuda, torch._C.CudaHalfStorageBase, _CudaStorageBase): class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase):
pass pass
class DoubleTensor(InitCuda, torch._C.CudaDoubleTensorBase, _CudaTensorBase): class DoubleTensor(_CudaBase, torch._C.CudaDoubleTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return True return True
class FloatTensor(InitCuda, torch._C.CudaFloatTensorBase, _CudaTensorBase): class FloatTensor(_CudaBase, torch._C.CudaFloatTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return True return True
class LongTensor(InitCuda, torch._C.CudaLongTensorBase, _CudaTensorBase): class LongTensor(_CudaBase, torch._C.CudaLongTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return True return True
class IntTensor(InitCuda, torch._C.CudaIntTensorBase, _CudaTensorBase): class IntTensor(_CudaBase, torch._C.CudaIntTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return True return True
class ShortTensor(InitCuda, torch._C.CudaShortTensorBase, _CudaTensorBase): class ShortTensor(_CudaBase, torch._C.CudaShortTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return True return True
class CharTensor(InitCuda, torch._C.CudaCharTensorBase, _CudaTensorBase): class CharTensor(_CudaBase, torch._C.CudaCharTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
# TODO # TODO
return False return False
class ByteTensor(InitCuda, torch._C.CudaByteTensorBase, _CudaTensorBase): class ByteTensor(_CudaBase, torch._C.CudaByteTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return False return False
class HalfTensor(InitCuda, torch._C.CudaHalfTensorBase, _CudaTensorBase): class HalfTensor(_CudaBase, torch._C.CudaHalfTensorBase, _TensorBase):
def is_signed(self): def is_signed(self):
return True return True

View File

@ -1,19 +0,0 @@
from . import device, _dummy_ctx
from ..storage import _StorageBase
class _CudaStorageBase(_StorageBase):
is_cuda = True
def type(self, *args, **kwargs):
source_device = self.get_device()
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
with ctx:
return super(_CudaStorageBase, self).type(*args, **kwargs)
def new(self, *args, **kwargs):
source_device = self.get_device()
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
with ctx:
return super(_CudaStorageBase, self).new(*args, **kwargs)

View File

@ -1,19 +0,0 @@
from . import device, _dummy_ctx
from ..tensor import _TensorBase
class _CudaTensorBase(_TensorBase):
is_cuda = True
def type(self, *args, **kwargs):
source_device = self.get_device()
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
with ctx:
return super(_CudaTensorBase, self).type(*args, **kwargs)
def new(self, *args, **kwargs):
source_device = self.get_device()
ctx = device(source_device) if source_device != -1 else _dummy_ctx()
with ctx:
return super(_CudaTensorBase, self).new(*args, **kwargs)