Make torch.serialization.skip_data work with torch.load (#148018)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148018
Approved by: https://github.com/albanD
ghstack dependencies: #147786, #147787, #147788
This commit is contained in:
Mikayla Gawarecki 2025-03-06 08:50:56 +00:00 committed by PyTorch MergeBot
parent be0ceee1c3
commit d5184901c4
2 changed files with 33 additions and 18 deletions

View File

@ -908,6 +908,27 @@ class SerializationMixin:
# This test is to make sure that the serialization debug flag is set in CI
self.assertTrue(os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1")
def test_skip_data_load(self):
t_device = "cuda" if torch.cuda.is_available() else "cpu"
t_v2 = torch.randn(2, 3, device=t_device)
tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device))
sd = {'t_v2': t_v2, 'tt': tt}
sd_zeroed = {
't_v2': torch.zeros(2, 3, device=t_device),
'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)),
}
with BytesIOContext() as f:
torch.save(sd, f)
f.seek(0)
with safe_globals([TwoTensor]), skip_data():
sd_loaded = torch.load(f)
self.assertNotEqual(sd_loaded, sd)
for k in sd_loaded.keys():
sd_loaded[k] = sd_loaded[k].zero_()
self.assertEqual(sd_loaded, sd_zeroed)
class serialization_method:
def __init__(self, use_zip):
@ -4463,12 +4484,6 @@ class TestSerialization(TestCase, SerializationMixin):
with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"):
_save_load(t)
with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"):
with skip_data(), BytesIOContext() as f:
torch.save(torch.randn(2, 3), f)
f.seek(0)
torch.load(f, weights_only=True)
@parametrize("force_weights_only", (True, False))
def test_weights_only_env_variables(self, force_weights_only):
env_var = "TORCH_FORCE_WEIGHTS_ONLY_LOAD" if force_weights_only else "TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"

View File

@ -383,16 +383,18 @@ def get_unsafe_globals_in_checkpoint(f: FileLike) -> list[str]:
class skip_data:
"""
Context-manager that skips writing storage bytes for ``torch.save`` calls.
Context-manager that skips writing/reading storage bytes for ``torch.save`` / ``torch.load`` calls.
Storages will still be saved, but the space that their bytes would usually be written to
For the save path, storages will still be saved, but the space that their bytes would usually be written to
will be empty space. The storage bytes can then be populated in a separate pass.
For the load path, tensors will be loaded per the checkpoint but their storages will not be populated with data.
.. warning::
The ``skip_data`` context manager is an early prototype and is subject to change.
Args:
materialize_fake_tensors: Whether to materialize FakeTensors.
materialize_fake_tensors: Whether to materialize FakeTensors during save. This is a no-op for the load path.
Example:
>>> # xdoctest: +SKIP("NamedTemporaryFile on Windows")
@ -1418,14 +1420,6 @@ def load(
updated_message += message
return updated_message + DOCS_MESSAGE
global _serialization_tls
skip_data = _serialization_tls.skip_data
if skip_data:
raise RuntimeError(
"`torch.load` called within a torch.serialization.skip_data context manager "
"is not supported yet. Please call torch.load outside the skip_data context manager."
)
weights_only_not_set = weights_only is None
if weights_only_not_set:
@ -1735,6 +1729,9 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
if root_key not in deserialized_objects:
if torch._guards.active_fake_mode() is not None:
obj = cast(Storage, torch.UntypedStorage(nbytes, device="meta"))
elif _serialization_tls.skip_data:
obj = cast(Storage, torch.UntypedStorage(nbytes))
obj = restore_location(obj, location)
else:
obj = cast(Storage, torch.UntypedStorage(nbytes))
obj._torch_load_uninitialized = True
@ -1807,7 +1804,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
if torch._guards.active_fake_mode() is None:
if torch._guards.active_fake_mode() is None and not _serialization_tls.skip_data:
offset = f.tell() if f_should_read_directly else None
for key in deserialized_storage_keys:
assert key in deserialized_objects
@ -1999,6 +1996,9 @@ def _load(
nbytes = numel * torch._utils._element_size(dtype)
storage = torch.UntypedStorage(nbytes, device="meta")
storage._checkpoint_offset = zip_file.get_record_offset(name)
elif _serialization_tls.skip_data:
nbytes = numel * torch._utils._element_size(dtype)
storage = torch.UntypedStorage(nbytes)
elif overall_storage is not None:
if can_calculate_storage_offsets and calculate_storage_offsets:
storage_offset = _get_offset(key, name, numel)