mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make torch.serialization.skip_data work with torch.load (#148018)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148018 Approved by: https://github.com/albanD ghstack dependencies: #147786, #147787, #147788
This commit is contained in:
parent
be0ceee1c3
commit
d5184901c4
|
|
@ -908,6 +908,27 @@ class SerializationMixin:
|
|||
# This test is to make sure that the serialization debug flag is set in CI
|
||||
self.assertTrue(os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1")
|
||||
|
||||
def test_skip_data_load(self):
|
||||
t_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
t_v2 = torch.randn(2, 3, device=t_device)
|
||||
tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device))
|
||||
|
||||
sd = {'t_v2': t_v2, 'tt': tt}
|
||||
sd_zeroed = {
|
||||
't_v2': torch.zeros(2, 3, device=t_device),
|
||||
'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)),
|
||||
}
|
||||
|
||||
with BytesIOContext() as f:
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
with safe_globals([TwoTensor]), skip_data():
|
||||
sd_loaded = torch.load(f)
|
||||
self.assertNotEqual(sd_loaded, sd)
|
||||
for k in sd_loaded.keys():
|
||||
sd_loaded[k] = sd_loaded[k].zero_()
|
||||
self.assertEqual(sd_loaded, sd_zeroed)
|
||||
|
||||
|
||||
class serialization_method:
|
||||
def __init__(self, use_zip):
|
||||
|
|
@ -4463,12 +4484,6 @@ class TestSerialization(TestCase, SerializationMixin):
|
|||
with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"):
|
||||
_save_load(t)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"):
|
||||
with skip_data(), BytesIOContext() as f:
|
||||
torch.save(torch.randn(2, 3), f)
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
@parametrize("force_weights_only", (True, False))
|
||||
def test_weights_only_env_variables(self, force_weights_only):
|
||||
env_var = "TORCH_FORCE_WEIGHTS_ONLY_LOAD" if force_weights_only else "TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"
|
||||
|
|
|
|||
|
|
@ -383,16 +383,18 @@ def get_unsafe_globals_in_checkpoint(f: FileLike) -> list[str]:
|
|||
|
||||
class skip_data:
|
||||
"""
|
||||
Context-manager that skips writing storage bytes for ``torch.save`` calls.
|
||||
Context-manager that skips writing/reading storage bytes for ``torch.save`` / ``torch.load`` calls.
|
||||
|
||||
Storages will still be saved, but the space that their bytes would usually be written to
|
||||
For the save path, storages will still be saved, but the space that their bytes would usually be written to
|
||||
will be empty space. The storage bytes can then be populated in a separate pass.
|
||||
|
||||
For the load path, tensors will be loaded per the checkpoint but their storages will not be populated with data.
|
||||
|
||||
.. warning::
|
||||
The ``skip_data`` context manager is an early prototype and is subject to change.
|
||||
|
||||
Args:
|
||||
materialize_fake_tensors: Whether to materialize FakeTensors.
|
||||
materialize_fake_tensors: Whether to materialize FakeTensors during save. This is a no-op for the load path.
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("NamedTemporaryFile on Windows")
|
||||
|
|
@ -1418,14 +1420,6 @@ def load(
|
|||
updated_message += message
|
||||
return updated_message + DOCS_MESSAGE
|
||||
|
||||
global _serialization_tls
|
||||
skip_data = _serialization_tls.skip_data
|
||||
if skip_data:
|
||||
raise RuntimeError(
|
||||
"`torch.load` called within a torch.serialization.skip_data context manager "
|
||||
"is not supported yet. Please call torch.load outside the skip_data context manager."
|
||||
)
|
||||
|
||||
weights_only_not_set = weights_only is None
|
||||
|
||||
if weights_only_not_set:
|
||||
|
|
@ -1735,6 +1729,9 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
|||
if root_key not in deserialized_objects:
|
||||
if torch._guards.active_fake_mode() is not None:
|
||||
obj = cast(Storage, torch.UntypedStorage(nbytes, device="meta"))
|
||||
elif _serialization_tls.skip_data:
|
||||
obj = cast(Storage, torch.UntypedStorage(nbytes))
|
||||
obj = restore_location(obj, location)
|
||||
else:
|
||||
obj = cast(Storage, torch.UntypedStorage(nbytes))
|
||||
obj._torch_load_uninitialized = True
|
||||
|
|
@ -1807,7 +1804,7 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
|||
|
||||
deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
|
||||
|
||||
if torch._guards.active_fake_mode() is None:
|
||||
if torch._guards.active_fake_mode() is None and not _serialization_tls.skip_data:
|
||||
offset = f.tell() if f_should_read_directly else None
|
||||
for key in deserialized_storage_keys:
|
||||
assert key in deserialized_objects
|
||||
|
|
@ -1999,6 +1996,9 @@ def _load(
|
|||
nbytes = numel * torch._utils._element_size(dtype)
|
||||
storage = torch.UntypedStorage(nbytes, device="meta")
|
||||
storage._checkpoint_offset = zip_file.get_record_offset(name)
|
||||
elif _serialization_tls.skip_data:
|
||||
nbytes = numel * torch._utils._element_size(dtype)
|
||||
storage = torch.UntypedStorage(nbytes)
|
||||
elif overall_storage is not None:
|
||||
if can_calculate_storage_offsets and calculate_storage_offsets:
|
||||
storage_offset = _get_offset(key, name, numel)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user