Support dlpack for privateuse1 (#135331)

Fixes #129652
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135331
Approved by: https://github.com/shink, https://github.com/FFFrog, https://github.com/ezyang

Co-authored-by: Jiawei Li <ljw1101.vip@gmail.com>
This commit is contained in:
hipudding 2024-11-13 13:13:13 +00:00 committed by PyTorch MergeBot
parent 97d995a0d3
commit 34743d8a16
3 changed files with 17 additions and 1 deletions

View File

@ -121,6 +121,9 @@ static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) {
case DeviceType::MAIA:
ctx.device_type = DLDeviceType::kDLMAIA;
break;
case DeviceType::PrivateUse1:
ctx.device_type = DLDeviceType::kDLExtDev;
break;
default:
TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str());
}
@ -149,6 +152,8 @@ static Device getATenDevice(const DLDevice& ctx, void* data) {
return at::detail::getXPUHooks().getDeviceFromPtr(data);
case DLDeviceType::kDLMAIA:
return at::Device(DeviceType::MAIA, static_cast<c10::DeviceIndex>(ctx.device_id));
case DLDeviceType::kDLExtDev:
return at::Device(DeviceType::PrivateUse1, static_cast<c10::DeviceIndex>(ctx.device_id));
default:
TORCH_CHECK(
false, "Unsupported device_type: ", std::to_string(ctx.device_type));
@ -287,7 +292,7 @@ DLManagedTensor* toDLPack(const Tensor& src) {
atDLMTensor->tensor.deleter = &deleter;
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
c10::DeviceIndex device_id = 0;
if (src.is_cuda()) {
if (src.is_cuda() || src.is_privateuseone()) {
device_id = src.get_device();
}
atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);

View File

@ -644,6 +644,15 @@ class TestCppExtensionOpenRgistration(common.TestCase):
with torch.serialization.skip_data():
torch.save(sd, f)
def test_open_device_dlpack(self):
t = torch.randn(2, 3).to("foo")
capsule = torch.utils.dlpack.to_dlpack(t)
t1 = torch.from_dlpack(capsule)
self.assertTrue(t1.device == t.device)
t = t.to("cpu")
t1 = t1.to("cpu")
self.assertEqual(t, t1)
if __name__ == "__main__":
common.run_tests()

View File

@ -1637,6 +1637,8 @@ class Tensor(torch._C.TensorBase):
device_type = DLDeviceType.kDLCPU
elif self.device.type == "xpu":
device_type = DLDeviceType.kDLOneAPI
elif self.device.type == "privateuse1":
device_type = DLDeviceType.kDLExtDev
else:
raise ValueError(f"Unknown device type {torch_device_type} for Dlpack")
return (device_type, idx)