fix-unpin-memory-tensor-param (#160992)

Fixes #160983

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160992
Approved by: https://github.com/ngimel
This commit is contained in:
gaoyufeng 2025-08-26 21:55:22 +00:00 committed by PyTorch MergeBot
parent e06d1d6610
commit cde54fe4e9
2 changed files with 9 additions and 1 deletions

View File

@ -220,6 +220,13 @@ class TestStateDictUtils(DTensorTestBase):
self.assertEqual(cpu_state_dict["step"], 7)
self.assertEqual(cpu_state_dict["nested"], {"list": [1, 2, 3, 4]})
def _verify_weakref_finalize(cpu_state_dict):
import gc
del cpu_state_dict["tensor1"]
del cpu_state_dict
gc.collect()
cpu_state_dict = _create_cpu_state_dict(state_dict)
_verify(cpu_state_dict)
cpu_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True)
@ -230,6 +237,7 @@ class TestStateDictUtils(DTensorTestBase):
state_dict, share_memory=True, pin_memory=True
)
_verify(cpu_state_dict)
_verify_weakref_finalize(cpu_state_dict)
@with_comms
@skip_if_lt_x_gpu(2)

View File

@ -423,7 +423,7 @@ def _create_cpu_state_dict(
t = t.share_memory_()
if pin_memory:
pin_memory_utils.pin_memory(t.data_ptr(), t.numel() * t.element_size())
weakref.finalize(t, pin_memory_utils.unpin_memory, t)
weakref.finalize(t, pin_memory_utils.unpin_memory, t.data_ptr())
return t
elif pin_memory: