mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Optimize __dlpack_device__ performance (#86665)
This can be critical when processing a large number of tensors ```bash python -m timeit --setup 'import torch; t = torch.empty(1000, device="cuda")' 't.__dlpack_device__()' ``` based on 1.12.1: before: 100000 loops, best of 5: 2.32 usec per loop after: 500000 loops, best of 5: 844 nsec per loop Pull Request resolved: https://github.com/pytorch/pytorch/pull/86665 Approved by: https://github.com/SunDoge, https://github.com/soulitzer
This commit is contained in:
parent
c12f829cce
commit
92562046e9
|
|
@ -25,6 +25,7 @@ from torch.overrides import (
|
|||
has_torch_function_unary,
|
||||
has_torch_function_variadic,
|
||||
)
|
||||
from torch.utils.dlpack import DLDeviceType
|
||||
|
||||
|
||||
def _handle_torch_function_and_wrap_type_error_to_not_implemented(f):
|
||||
|
|
@ -1327,23 +1328,22 @@ class Tensor(torch._C._TensorBase):
|
|||
return torch.to_dlpack(self)
|
||||
|
||||
def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
|
||||
# Avoid circular import
|
||||
from torch.utils.dlpack import DLDeviceType
|
||||
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(Tensor.__dlpack_device__, (self,), self)
|
||||
idx = self.device.index if self.device.index is not None else 0
|
||||
if self.device.type == "cuda" and torch.version.hip is not None:
|
||||
device = self.device
|
||||
idx = device.index if device.index is not None else 0
|
||||
torch_device_type = device.type
|
||||
if torch_device_type == "cuda" and torch.version.hip is not None:
|
||||
device_type = DLDeviceType.kDLROCM
|
||||
elif self.device.type == "cpu" and self.is_pinned():
|
||||
elif torch_device_type == "cpu" and self.is_pinned():
|
||||
device_type = DLDeviceType.kDLCPUPinned
|
||||
elif self.device.type == "cuda":
|
||||
elif torch_device_type == "cuda":
|
||||
device_type = DLDeviceType.kDLGPU
|
||||
elif self.device.type == "cpu":
|
||||
elif torch_device_type == "cpu":
|
||||
device_type = DLDeviceType.kDLCPU
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown device type {} for Dlpack".format(self.device.type)
|
||||
"Unknown device type {} for Dlpack".format(torch_device_type)
|
||||
)
|
||||
return (device_type, idx)
|
||||
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ The DLPack capsule shares the tensor's memory.
|
|||
|
||||
# TODO: add a typing.Protocol to be able to tell Mypy that only objects with
|
||||
# __dlpack__ and __dlpack_device__ methods are accepted.
|
||||
def from_dlpack(ext_tensor: Any) -> torch.Tensor:
|
||||
def from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
|
||||
"""from_dlpack(ext_tensor) -> Tensor
|
||||
|
||||
Converts a tensor from an external library into a ``torch.Tensor``.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user