From 0eda02a94c754e2256ff1701bcc03c40ece2bbef Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 17 Jan 2025 08:51:46 -0800 Subject: [PATCH] Prevent legacy_load when weights_only=True (correctly) (#145020) Only prevent `legacy_load` (.tar format removed in https://github.com/pytorch/pytorch/pull/713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False) Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145020 Approved by: https://github.com/kit1980, https://github.com/albanD --- test/test_serialization.py | 6 +++++- torch/serialization.py | 18 ++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index aea2cf1a6f0..6e947a570e8 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -466,7 +466,11 @@ class SerializationMixin: b += [a[0].storage()] b += [a[0].reshape(-1)[1:4].clone().storage()] path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt') - c = torch.load(path, weights_only=weights_only) + if weights_only: + with self.assertRaisesRegex(RuntimeError, + "Cannot use ``weights_only=True`` with files saved in the legacy .tar format."): + c = torch.load(path, weights_only=weights_only) + c = torch.load(path, weights_only=False) self.assertEqual(b, c, atol=0, rtol=0) self.assertTrue(isinstance(c[0], torch.FloatTensor)) self.assertTrue(isinstance(c[1], torch.FloatTensor)) diff --git a/torch/serialization.py b/torch/serialization.py index 0a4d067b6ab..85333525ed7 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -85,6 +85,13 @@ STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedSto IS_WINDOWS = sys.platform == "win32" +UNSAFE_MESSAGE = ( + "In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` " + "from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, " + "but it can result in arbitrary code execution. Do it only if you got the file from a " + "trusted source." +) + if not IS_WINDOWS: from mmap import MAP_PRIVATE, MAP_SHARED else: @@ -1341,12 +1348,6 @@ def load( >>> torch.load("module.pt", encoding="ascii", weights_only=False) """ torch._C._log_api_usage_once("torch.load") - UNSAFE_MESSAGE = ( - "In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` " - "from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, " - "but it can result in arbitrary code execution. Do it only if you got the file from a " - "trusted source." - ) DOCS_MESSAGE = ( "\n\nCheck the documentation of torch.load to learn more about types accepted by default with " "weights_only https://pytorch.org/docs/stable/generated/torch.load.html." @@ -1611,6 +1612,11 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args): with closing( tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT) ) as tar, mkdtemp() as tmpdir: + if pickle_module is _weights_only_unpickler: + raise RuntimeError( + "Cannot use ``weights_only=True`` with files saved in the " + "legacy .tar format. " + UNSAFE_MESSAGE + ) tar.extract("storages", path=tmpdir) with open(os.path.join(tmpdir, "storages"), "rb", 0) as f: num_storages = pickle_module.load(f, **pickle_load_args)