diff --git a/test/test_dlpack.py b/test/test_dlpack.py index a9036be160b..fe1107ac850 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -15,6 +15,23 @@ from torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase from torch.utils.dlpack import from_dlpack, to_dlpack +# Wraps a tensor, exposing only DLPack methods: +# - __dlpack__ +# - __dlpack_device__ +# +# This is used for guaranteeing we are going through the DLPack method, and not +# something else, e.g.: CUDA array interface, buffer protocol, etc. +class TensorDLPackWrapper: + def __init__(self, tensor): + self.tensor = tensor + + def __dlpack__(self, *args, **kwargs): + return self.tensor.__dlpack__(*args, **kwargs) + + def __dlpack_device__(self, *args, **kwargs): + return self.tensor.__dlpack_device__(*args, **kwargs) + + class TestTorchDlPack(TestCase): exact_dtype = True @@ -251,6 +268,19 @@ class TestTorchDlPack(TestCase): # gh-83069, make sure __dlpack__ normalizes strides self.assertEqual(z.stride(), (1,)) + @skipMeta + @onlyNativeDeviceTypes + def test_automatically_select_in_creation(self, device): + # Create a new tensor, and wrap it using TensorDLPackWrapper. + tensor = torch.rand(10) + wrap = TensorDLPackWrapper(tensor) + # Create a new tensor from the wrapper. + # This should identify that the wrapper class provides the DLPack methods + # and use them for creating the new tensor, instead of iterating element + # by element. + new_tensor = torch.tensor(wrap) + self.assertEqual(tensor, new_tensor) + instantiate_device_type_tests(TestTorchDlPack, globals()) diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 099991f8414..e6371498314 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -345,6 +345,23 @@ Tensor internal_new_from_data( } #endif + if (PyObject_HasAttrString(data, "__dlpack__")) { + py::object tensor_o = + py::module::import("torch").attr("utils").attr("dlpack").attr( + "from_dlpack")(py::handle(data)); + Tensor tensor = py::cast(tensor_o); + const auto& inferred_scalar_type = + type_inference ? tensor.scalar_type() : scalar_type; + auto device = device_opt.has_value() ? *device_opt : tensor.device(); + pybind11::gil_scoped_release no_gil; + maybe_initialize_device(device); + return tensor.to( + device, + inferred_scalar_type, + /*non_blocking=*/false, + /*copy=*/copy_variables); + } + auto device = device_opt.has_value() ? *device_opt : options.device(); auto sizes = compute_sizes(data, scalar_type);