mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Use DLPack for creating tensors out of custom classes, when available. (#138697)
Fixes #120614
Takes over #120615
In summary, this PR:
- Adds a `__dlpack__` attribute check in the tensor creation path (i.e. [`internal_new_from_data` @ tensor_new.cpp](cdfe1bffd1/torch/csrc/utils/tensor_new.cpp (L266)))
- Creates the tensor by using the DLPack machinery, instead of an element-by-element copy
- No changes since #120615
- Adds a test, making sure the DLPack machinery is used
- Wraps a tensor in a fresh `TensorDLPackWrapper` class that implements only the DLPack methods
- Creates a new tensor from an instance of `TensorDLPackWrapper`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138697
Approved by: https://github.com/ezyang
Co-authored-by: Wenzel Jakob <wenzel.jakob@epfl.ch>
This commit is contained in:
parent
e299193423
commit
565a53d326
|
|
@ -15,6 +15,23 @@ from torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase
|
|||
from torch.utils.dlpack import from_dlpack, to_dlpack
|
||||
|
||||
|
||||
# Wraps a tensor, exposing only DLPack methods:
|
||||
# - __dlpack__
|
||||
# - __dlpack_device__
|
||||
#
|
||||
# This is used for guaranteeing we are going through the DLPack method, and not
|
||||
# something else, e.g.: CUDA array interface, buffer protocol, etc.
|
||||
class TensorDLPackWrapper:
|
||||
def __init__(self, tensor):
|
||||
self.tensor = tensor
|
||||
|
||||
def __dlpack__(self, *args, **kwargs):
|
||||
return self.tensor.__dlpack__(*args, **kwargs)
|
||||
|
||||
def __dlpack_device__(self, *args, **kwargs):
|
||||
return self.tensor.__dlpack_device__(*args, **kwargs)
|
||||
|
||||
|
||||
class TestTorchDlPack(TestCase):
|
||||
exact_dtype = True
|
||||
|
||||
|
|
@ -251,6 +268,19 @@ class TestTorchDlPack(TestCase):
|
|||
# gh-83069, make sure __dlpack__ normalizes strides
|
||||
self.assertEqual(z.stride(), (1,))
|
||||
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypes
|
||||
def test_automatically_select_in_creation(self, device):
|
||||
# Create a new tensor, and wrap it using TensorDLPackWrapper.
|
||||
tensor = torch.rand(10)
|
||||
wrap = TensorDLPackWrapper(tensor)
|
||||
# Create a new tensor from the wrapper.
|
||||
# This should identify that the wrapper class provides the DLPack methods
|
||||
# and use them for creating the new tensor, instead of iterating element
|
||||
# by element.
|
||||
new_tensor = torch.tensor(wrap)
|
||||
self.assertEqual(tensor, new_tensor)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestTorchDlPack, globals())
|
||||
|
||||
|
|
|
|||
|
|
@ -345,6 +345,23 @@ Tensor internal_new_from_data(
|
|||
}
|
||||
#endif
|
||||
|
||||
if (PyObject_HasAttrString(data, "__dlpack__")) {
|
||||
py::object tensor_o =
|
||||
py::module::import("torch").attr("utils").attr("dlpack").attr(
|
||||
"from_dlpack")(py::handle(data));
|
||||
Tensor tensor = py::cast<Tensor>(tensor_o);
|
||||
const auto& inferred_scalar_type =
|
||||
type_inference ? tensor.scalar_type() : scalar_type;
|
||||
auto device = device_opt.has_value() ? *device_opt : tensor.device();
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
maybe_initialize_device(device);
|
||||
return tensor.to(
|
||||
device,
|
||||
inferred_scalar_type,
|
||||
/*non_blocking=*/false,
|
||||
/*copy=*/copy_variables);
|
||||
}
|
||||
|
||||
auto device = device_opt.has_value() ? *device_opt : options.device();
|
||||
|
||||
auto sizes = compute_sizes(data, scalar_type);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user