mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
OpenReg: fix issue of pin_memory (#145046)
Fix issue of `pin_memory` when rewrapping a storage. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145046 Approved by: https://github.com/albanD
This commit is contained in:
parent
bdf6dfa17d
commit
a08f7f3266
|
|
@ -41,7 +41,10 @@ class Allocator:
|
|||
|
||||
class HostAllocator(Allocator):
|
||||
def is_pinned_ptr(self, ptr):
|
||||
return ptr in self.allocated
|
||||
return ptr in self.allocated or any(
|
||||
ptr_ <= ptr and ptr < ptr_ + size
|
||||
for ptr_, (size, _) in self.allocated.items()
|
||||
)
|
||||
|
||||
|
||||
class DeviceAllocator(Allocator):
|
||||
|
|
|
|||
|
|
@ -78,6 +78,17 @@ class TestOpenReg(TestCase):
|
|||
slice_a = pinned_a[2:5]
|
||||
self.assertTrue(slice_a.is_pinned())
|
||||
|
||||
def test_rewrapped_storage(self):
|
||||
pinned_a = torch.randn(10).pin_memory()
|
||||
rewrapped_a = torch.tensor((), dtype=torch.float32).set_(
|
||||
pinned_a.untyped_storage()[2:],
|
||||
size=(5,),
|
||||
stride=(1,),
|
||||
storage_offset=0,
|
||||
)
|
||||
self.assertTrue(rewrapped_a.is_pinned())
|
||||
self.assertNotEqual(pinned_a.data_ptr(), rewrapped_a.data_ptr())
|
||||
|
||||
def test_stream_synchronize(self):
|
||||
stream = torch.Stream(device="openreg:1")
|
||||
stream.synchronize()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user