diff --git a/torch/_utils.py b/torch/_utils.py index 18e6f97e6dc..b9a325081c6 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -21,7 +21,7 @@ def _cuda(self, idx=None, async=False): else: return self else: - ctx = torch.cuda.device(idx) if idx else torch.cuda._dummy_ctx() + ctx = torch.cuda.device(idx if idx else -1) with ctx: return self.type(getattr(torch.cuda, self.__class__.__name__), async) diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index ced513538c6..ccaa77a84de 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -36,14 +36,17 @@ of the CUDA driver.""".format(str(torch._C._cuda_getDriverVersion()))) @contextlib.contextmanager def device(idx): - _lazy_init() - prev_idx = torch._C._cuda_getDevice() - if prev_idx != idx: - torch._C._cuda_setDevice(idx) + if idx is -1: yield - torch._C._cuda_setDevice(prev_idx) else: - yield + _lazy_init() + prev_idx = torch._C._cuda_getDevice() + if prev_idx != idx: + torch._C._cuda_setDevice(idx) + yield + torch._C._cuda_setDevice(prev_idx) + else: + yield @contextlib.contextmanager @@ -55,15 +58,11 @@ def device_of(tensor): yield -@contextlib.contextmanager -def _dummy_ctx(): - yield - - def device_count(): _lazy_init() return torch._C._cuda_getDeviceCount() + def current_device(): _lazy_init() return torch._C._cuda_getDevice() @@ -76,8 +75,8 @@ from .random import * ################################################################################ -from .tensor import _CudaTensorBase -from .storage import _CudaStorageBase +from ..tensor import _TensorBase +from ..storage import _StorageBase if not hasattr(torch._C, 'CudaDoubleStorageBase'): # Define dummy base classes @@ -88,51 +87,64 @@ if not hasattr(torch._C, 'CudaDoubleStorageBase'): torch._C.__dict__[storage_name] = type(storage_name, (object,), {}) torch._C.__dict__[tensor_name] = type(tensor_name, (object,), {}) -class InitCuda(object): + +class _CudaBase(object): + is_cuda = True + + def type(self, *args, **kwargs): + with device(self.get_device()): + return super(_CudaBase, self).type(*args, **kwargs) + + def new(self, *args, **kwargs): + with device(kwargs.pop('device', self.get_device())): + return super(_CudaBase, self).new(*args, **kwargs) + def __new__(cls, *args, **kwargs): _lazy_init() - return super(InitCuda, cls).__new__(cls, *args, **kwargs) + with device(kwargs.pop('device', -1)): + return super(_CudaBase, cls).__new__(cls, *args, **kwargs) -class DoubleStorage(InitCuda, torch._C.CudaDoubleStorageBase, _CudaStorageBase): + +class DoubleStorage(_CudaBase, torch._C.CudaDoubleStorageBase, _StorageBase): pass -class FloatStorage(InitCuda, torch._C.CudaFloatStorageBase, _CudaStorageBase): +class FloatStorage(_CudaBase, torch._C.CudaFloatStorageBase, _StorageBase): pass -class LongStorage(InitCuda, torch._C.CudaLongStorageBase, _CudaStorageBase): +class LongStorage(_CudaBase, torch._C.CudaLongStorageBase, _StorageBase): pass -class IntStorage(InitCuda, torch._C.CudaIntStorageBase, _CudaStorageBase): +class IntStorage(_CudaBase, torch._C.CudaIntStorageBase, _StorageBase): pass -class ShortStorage(InitCuda, torch._C.CudaShortStorageBase, _CudaStorageBase): +class ShortStorage(_CudaBase, torch._C.CudaShortStorageBase, _StorageBase): pass -class CharStorage(InitCuda, torch._C.CudaCharStorageBase, _CudaStorageBase): +class CharStorage(_CudaBase, torch._C.CudaCharStorageBase, _StorageBase): pass -class ByteStorage(InitCuda, torch._C.CudaByteStorageBase, _CudaStorageBase): +class ByteStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase): pass -class HalfStorage(InitCuda, torch._C.CudaHalfStorageBase, _CudaStorageBase): +class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase): pass -class DoubleTensor(InitCuda, torch._C.CudaDoubleTensorBase, _CudaTensorBase): +class DoubleTensor(_CudaBase, torch._C.CudaDoubleTensorBase, _TensorBase): def is_signed(self): return True -class FloatTensor(InitCuda, torch._C.CudaFloatTensorBase, _CudaTensorBase): +class FloatTensor(_CudaBase, torch._C.CudaFloatTensorBase, _TensorBase): def is_signed(self): return True -class LongTensor(InitCuda, torch._C.CudaLongTensorBase, _CudaTensorBase): +class LongTensor(_CudaBase, torch._C.CudaLongTensorBase, _TensorBase): def is_signed(self): return True -class IntTensor(InitCuda, torch._C.CudaIntTensorBase, _CudaTensorBase): +class IntTensor(_CudaBase, torch._C.CudaIntTensorBase, _TensorBase): def is_signed(self): return True -class ShortTensor(InitCuda, torch._C.CudaShortTensorBase, _CudaTensorBase): +class ShortTensor(_CudaBase, torch._C.CudaShortTensorBase, _TensorBase): def is_signed(self): return True -class CharTensor(InitCuda, torch._C.CudaCharTensorBase, _CudaTensorBase): +class CharTensor(_CudaBase, torch._C.CudaCharTensorBase, _TensorBase): def is_signed(self): # TODO return False -class ByteTensor(InitCuda, torch._C.CudaByteTensorBase, _CudaTensorBase): +class ByteTensor(_CudaBase, torch._C.CudaByteTensorBase, _TensorBase): def is_signed(self): return False -class HalfTensor(InitCuda, torch._C.CudaHalfTensorBase, _CudaTensorBase): +class HalfTensor(_CudaBase, torch._C.CudaHalfTensorBase, _TensorBase): def is_signed(self): return True diff --git a/torch/cuda/storage.py b/torch/cuda/storage.py deleted file mode 100644 index 99358c219b4..00000000000 --- a/torch/cuda/storage.py +++ /dev/null @@ -1,19 +0,0 @@ -from . import device, _dummy_ctx -from ..storage import _StorageBase - - -class _CudaStorageBase(_StorageBase): - is_cuda = True - - def type(self, *args, **kwargs): - source_device = self.get_device() - ctx = device(source_device) if source_device != -1 else _dummy_ctx() - with ctx: - return super(_CudaStorageBase, self).type(*args, **kwargs) - - def new(self, *args, **kwargs): - source_device = self.get_device() - ctx = device(source_device) if source_device != -1 else _dummy_ctx() - with ctx: - return super(_CudaStorageBase, self).new(*args, **kwargs) - diff --git a/torch/cuda/tensor.py b/torch/cuda/tensor.py deleted file mode 100644 index fabd98bf389..00000000000 --- a/torch/cuda/tensor.py +++ /dev/null @@ -1,19 +0,0 @@ -from . import device, _dummy_ctx -from ..tensor import _TensorBase - - -class _CudaTensorBase(_TensorBase): - is_cuda = True - - def type(self, *args, **kwargs): - source_device = self.get_device() - ctx = device(source_device) if source_device != -1 else _dummy_ctx() - with ctx: - return super(_CudaTensorBase, self).type(*args, **kwargs) - - def new(self, *args, **kwargs): - source_device = self.get_device() - ctx = device(source_device) if source_device != -1 else _dummy_ctx() - with ctx: - return super(_CudaTensorBase, self).new(*args, **kwargs) -