mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Prevent legacy_load when weights_only=True (correctly) (#145020)
Only prevent `legacy_load` (.tar format removed in https://github.com/pytorch/pytorch/pull/713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False) Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145020 Approved by: https://github.com/kit1980, https://github.com/albanD
This commit is contained in:
parent
2ef7b68666
commit
0eda02a94c
|
|
@ -466,7 +466,11 @@ class SerializationMixin:
|
||||||
b += [a[0].storage()]
|
b += [a[0].storage()]
|
||||||
b += [a[0].reshape(-1)[1:4].clone().storage()]
|
b += [a[0].reshape(-1)[1:4].clone().storage()]
|
||||||
path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt')
|
path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt')
|
||||||
c = torch.load(path, weights_only=weights_only)
|
if weights_only:
|
||||||
|
with self.assertRaisesRegex(RuntimeError,
|
||||||
|
"Cannot use ``weights_only=True`` with files saved in the legacy .tar format."):
|
||||||
|
c = torch.load(path, weights_only=weights_only)
|
||||||
|
c = torch.load(path, weights_only=False)
|
||||||
self.assertEqual(b, c, atol=0, rtol=0)
|
self.assertEqual(b, c, atol=0, rtol=0)
|
||||||
self.assertTrue(isinstance(c[0], torch.FloatTensor))
|
self.assertTrue(isinstance(c[0], torch.FloatTensor))
|
||||||
self.assertTrue(isinstance(c[1], torch.FloatTensor))
|
self.assertTrue(isinstance(c[1], torch.FloatTensor))
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,13 @@ STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedSto
|
||||||
|
|
||||||
IS_WINDOWS = sys.platform == "win32"
|
IS_WINDOWS = sys.platform == "win32"
|
||||||
|
|
||||||
|
UNSAFE_MESSAGE = (
|
||||||
|
"In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` "
|
||||||
|
"from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
|
||||||
|
"but it can result in arbitrary code execution. Do it only if you got the file from a "
|
||||||
|
"trusted source."
|
||||||
|
)
|
||||||
|
|
||||||
if not IS_WINDOWS:
|
if not IS_WINDOWS:
|
||||||
from mmap import MAP_PRIVATE, MAP_SHARED
|
from mmap import MAP_PRIVATE, MAP_SHARED
|
||||||
else:
|
else:
|
||||||
|
|
@ -1341,12 +1348,6 @@ def load(
|
||||||
>>> torch.load("module.pt", encoding="ascii", weights_only=False)
|
>>> torch.load("module.pt", encoding="ascii", weights_only=False)
|
||||||
"""
|
"""
|
||||||
torch._C._log_api_usage_once("torch.load")
|
torch._C._log_api_usage_once("torch.load")
|
||||||
UNSAFE_MESSAGE = (
|
|
||||||
"In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` "
|
|
||||||
"from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
|
|
||||||
"but it can result in arbitrary code execution. Do it only if you got the file from a "
|
|
||||||
"trusted source."
|
|
||||||
)
|
|
||||||
DOCS_MESSAGE = (
|
DOCS_MESSAGE = (
|
||||||
"\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
|
"\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
|
||||||
"weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
|
"weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
|
||||||
|
|
@ -1611,6 +1612,11 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
||||||
with closing(
|
with closing(
|
||||||
tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
|
tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
|
||||||
) as tar, mkdtemp() as tmpdir:
|
) as tar, mkdtemp() as tmpdir:
|
||||||
|
if pickle_module is _weights_only_unpickler:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot use ``weights_only=True`` with files saved in the "
|
||||||
|
"legacy .tar format. " + UNSAFE_MESSAGE
|
||||||
|
)
|
||||||
tar.extract("storages", path=tmpdir)
|
tar.extract("storages", path=tmpdir)
|
||||||
with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
|
with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
|
||||||
num_storages = pickle_module.load(f, **pickle_load_args)
|
num_storages = pickle_module.load(f, **pickle_load_args)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user