OpenReg: fix issue of pin_memory (#145046)

Fix issue of `pin_memory` when rewrapping a storage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145046
Approved by: https://github.com/albanD
This commit is contained in:
Zhenbin Lin 2025-01-28 09:41:01 +00:00 committed by PyTorch MergeBot
parent bdf6dfa17d
commit a08f7f3266
2 changed files with 15 additions and 1 deletions

View File

@ -41,7 +41,10 @@ class Allocator:
class HostAllocator(Allocator):
def is_pinned_ptr(self, ptr):
return ptr in self.allocated
return ptr in self.allocated or any(
ptr_ <= ptr and ptr < ptr_ + size
for ptr_, (size, _) in self.allocated.items()
)
class DeviceAllocator(Allocator):

View File

@ -78,6 +78,17 @@ class TestOpenReg(TestCase):
slice_a = pinned_a[2:5]
self.assertTrue(slice_a.is_pinned())
def test_rewrapped_storage(self):
pinned_a = torch.randn(10).pin_memory()
rewrapped_a = torch.tensor((), dtype=torch.float32).set_(
pinned_a.untyped_storage()[2:],
size=(5,),
stride=(1,),
storage_offset=0,
)
self.assertTrue(rewrapped_a.is_pinned())
self.assertNotEqual(pinned_a.data_ptr(), rewrapped_a.data_ptr())
def test_stream_synchronize(self):
stream = torch.Stream(device="openreg:1")
stream.synchronize()