Use DLPack for creating tensors out of custom classes, when available. (#138697)

Fixes #120614
Takes over #120615

In summary, this PR:
- Adds a `__dlpack__` attribute check in the tensor creation path (i.e. [`internal_new_from_data` @ tensor_new.cpp](cdfe1bffd1/torch/csrc/utils/tensor_new.cpp (L266)))
    - Creates the tensor by using the DLPack machinery, instead of an element-by-element copy
    - No changes since #120615
- Adds a test, making sure the DLPack machinery is used
    - Wraps a tensor in a fresh `TensorDLPackWrapper` class that implements only the DLPack methods
    - Creates a new tensor from an instance of `TensorDLPackWrapper`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138697
Approved by: https://github.com/ezyang

Co-authored-by: Wenzel Jakob <wenzel.jakob@epfl.ch>
This commit is contained in:
Yukio Siraichi 2024-10-26 01:27:02 +00:00 committed by PyTorch MergeBot
parent e299193423
commit 565a53d326
2 changed files with 47 additions and 0 deletions

View File

@ -15,6 +15,23 @@ from torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase
from torch.utils.dlpack import from_dlpack, to_dlpack
# Wraps a tensor, exposing only DLPack methods:
# - __dlpack__
# - __dlpack_device__
#
# This is used for guaranteeing we are going through the DLPack method, and not
# something else, e.g.: CUDA array interface, buffer protocol, etc.
class TensorDLPackWrapper:
def __init__(self, tensor):
self.tensor = tensor
def __dlpack__(self, *args, **kwargs):
return self.tensor.__dlpack__(*args, **kwargs)
def __dlpack_device__(self, *args, **kwargs):
return self.tensor.__dlpack_device__(*args, **kwargs)
class TestTorchDlPack(TestCase):
exact_dtype = True
@ -251,6 +268,19 @@ class TestTorchDlPack(TestCase):
# gh-83069, make sure __dlpack__ normalizes strides
self.assertEqual(z.stride(), (1,))
@skipMeta
@onlyNativeDeviceTypes
def test_automatically_select_in_creation(self, device):
# Create a new tensor, and wrap it using TensorDLPackWrapper.
tensor = torch.rand(10)
wrap = TensorDLPackWrapper(tensor)
# Create a new tensor from the wrapper.
# This should identify that the wrapper class provides the DLPack methods
# and use them for creating the new tensor, instead of iterating element
# by element.
new_tensor = torch.tensor(wrap)
self.assertEqual(tensor, new_tensor)
instantiate_device_type_tests(TestTorchDlPack, globals())

View File

@ -345,6 +345,23 @@ Tensor internal_new_from_data(
}
#endif
if (PyObject_HasAttrString(data, "__dlpack__")) {
py::object tensor_o =
py::module::import("torch").attr("utils").attr("dlpack").attr(
"from_dlpack")(py::handle(data));
Tensor tensor = py::cast<Tensor>(tensor_o);
const auto& inferred_scalar_type =
type_inference ? tensor.scalar_type() : scalar_type;
auto device = device_opt.has_value() ? *device_opt : tensor.device();
pybind11::gil_scoped_release no_gil;
maybe_initialize_device(device);
return tensor.to(
device,
inferred_scalar_type,
/*non_blocking=*/false,
/*copy=*/copy_variables);
}
auto device = device_opt.has_value() ? *device_opt : options.device();
auto sizes = compute_sizes(data, scalar_type);