mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support dlpack for privateuse1 (#135331)
Fixes #129652 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135331 Approved by: https://github.com/shink, https://github.com/FFFrog, https://github.com/ezyang Co-authored-by: Jiawei Li <ljw1101.vip@gmail.com>
This commit is contained in:
parent
97d995a0d3
commit
34743d8a16
|
|
@ -121,6 +121,9 @@ static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) {
|
|||
case DeviceType::MAIA:
|
||||
ctx.device_type = DLDeviceType::kDLMAIA;
|
||||
break;
|
||||
case DeviceType::PrivateUse1:
|
||||
ctx.device_type = DLDeviceType::kDLExtDev;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str());
|
||||
}
|
||||
|
|
@ -149,6 +152,8 @@ static Device getATenDevice(const DLDevice& ctx, void* data) {
|
|||
return at::detail::getXPUHooks().getDeviceFromPtr(data);
|
||||
case DLDeviceType::kDLMAIA:
|
||||
return at::Device(DeviceType::MAIA, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
case DLDeviceType::kDLExtDev:
|
||||
return at::Device(DeviceType::PrivateUse1, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Unsupported device_type: ", std::to_string(ctx.device_type));
|
||||
|
|
@ -287,7 +292,7 @@ DLManagedTensor* toDLPack(const Tensor& src) {
|
|||
atDLMTensor->tensor.deleter = &deleter;
|
||||
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
|
||||
c10::DeviceIndex device_id = 0;
|
||||
if (src.is_cuda()) {
|
||||
if (src.is_cuda() || src.is_privateuseone()) {
|
||||
device_id = src.get_device();
|
||||
}
|
||||
atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
|
||||
|
|
|
|||
|
|
@ -644,6 +644,15 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
with torch.serialization.skip_data():
|
||||
torch.save(sd, f)
|
||||
|
||||
def test_open_device_dlpack(self):
|
||||
t = torch.randn(2, 3).to("foo")
|
||||
capsule = torch.utils.dlpack.to_dlpack(t)
|
||||
t1 = torch.from_dlpack(capsule)
|
||||
self.assertTrue(t1.device == t.device)
|
||||
t = t.to("cpu")
|
||||
t1 = t1.to("cpu")
|
||||
self.assertEqual(t, t1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
|
|
|||
|
|
@ -1637,6 +1637,8 @@ class Tensor(torch._C.TensorBase):
|
|||
device_type = DLDeviceType.kDLCPU
|
||||
elif self.device.type == "xpu":
|
||||
device_type = DLDeviceType.kDLOneAPI
|
||||
elif self.device.type == "privateuse1":
|
||||
device_type = DLDeviceType.kDLExtDev
|
||||
else:
|
||||
raise ValueError(f"Unknown device type {torch_device_type} for Dlpack")
|
||||
return (device_type, idx)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user