mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix failures when default is flipped for weights_only (#127627)
Tests on XLA shard not fixed yet but there is an issue here https://github.com/pytorch/xla/issues/7799 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127627 Approved by: https://github.com/albanD ghstack dependencies: #132349
This commit is contained in:
parent
c8ad5e37e8
commit
d9576c9440
|
|
@ -38,6 +38,7 @@ from cpp_api_parity.utils import (
|
|||
TorchNNModuleTestParams,
|
||||
try_remove_folder,
|
||||
)
|
||||
from torch.jit._pickle import restore_type_tag
|
||||
|
||||
|
||||
# Expected substitutions:
|
||||
|
|
@ -193,6 +194,8 @@ def test_forward_backward(unit_test_class, test_params):
|
|||
backward_grad_dict_file_path,
|
||||
)
|
||||
cpp_output = torch.load(forward_output_file_path)
|
||||
# weights_only: need GLOBAL torch.jit._pickle.restore_type_tag
|
||||
with torch.serialization.safe_globals([restore_type_tag]):
|
||||
cpp_grad_dict = torch.load(backward_grad_dict_file_path)
|
||||
|
||||
# Check that forward outputs are equal
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from torch.distributed.tensor.parallel import (
|
|||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.serialization import safe_globals
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
|
|
@ -535,18 +536,13 @@ class DTensorTest(DTensorTestBase):
|
|||
buffer = io.BytesIO()
|
||||
torch.save(sharded_tensor, buffer)
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer)
|
||||
reloaded_st = torch.load(buffer, weights_only=False)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
# Test weights_only load
|
||||
try:
|
||||
torch.serialization.add_safe_globals(
|
||||
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
|
||||
)
|
||||
with safe_globals([DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]):
|
||||
buffer.seek(0)
|
||||
reloaded_st = torch.load(buffer, weights_only=True)
|
||||
self.assertEqual(sharded_tensor, reloaded_st)
|
||||
finally:
|
||||
torch.serialization.clear_safe_globals()
|
||||
|
||||
|
||||
class DTensorMeshTest(DTensorTestBase):
|
||||
|
|
|
|||
|
|
@ -376,7 +376,8 @@ class TestTorchbind(JitTestCase):
|
|||
b = io.BytesIO()
|
||||
torch.save(nt, b)
|
||||
b.seek(0)
|
||||
nt_loaded = torch.load(b)
|
||||
# weights_only=False as trying to load ScriptObject
|
||||
nt_loaded = torch.load(b, weights_only=False)
|
||||
for exp in [7, 3, 3, 1]:
|
||||
self.assertEqual(nt_loaded.pop(), exp)
|
||||
|
||||
|
|
|
|||
|
|
@ -87,7 +87,8 @@ class TestConvolutionNN(NNTestCase):
|
|||
path = download_file("https://download.pytorch.org/test_data/legacy_conv2d.pt")
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", SourceChangeWarning)
|
||||
m = torch.load(path, encoding="utf-8")
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
m = torch.load(path, encoding="utf-8", weights_only=False)
|
||||
input = torch.randn((1, 1, 1, 1), dtype=torch.float)
|
||||
self.assertEqual(m(input).size(), (1, 1, 1, 1))
|
||||
|
||||
|
|
|
|||
|
|
@ -772,7 +772,8 @@ class TestStateDictHooks(TestCase):
|
|||
# Note that torch.save / torch.load is not recommended to save/load
|
||||
# modules.
|
||||
torch.save(m, f.name)
|
||||
m = torch.load(f.name)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
m = torch.load(f.name, weights_only=False)
|
||||
m.load_state_dict(sd)
|
||||
self.assertFalse(called)
|
||||
|
||||
|
|
@ -855,7 +856,8 @@ class TestStateDictHooks(TestCase):
|
|||
# Note that torch.save / torch.load is not recommended
|
||||
# to save / load modules.
|
||||
torch.save(m, f.name)
|
||||
m = torch.load(f.name)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
m = torch.load(f.name, weights_only=False)
|
||||
|
||||
# Ensure we can run state_dict without issues
|
||||
_ = m.state_dict()
|
||||
|
|
|
|||
|
|
@ -775,7 +775,8 @@ class TestPruningNN(NNTestCase):
|
|||
|
||||
with TemporaryFileName() as fname:
|
||||
torch.save(model, fname)
|
||||
new_model = torch.load(fname)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
new_model = torch.load(fname, weights_only=False)
|
||||
|
||||
# check that the original weight and the new mask are present
|
||||
self.assertIn("0.weight_orig", new_model.state_dict())
|
||||
|
|
|
|||
|
|
@ -113,7 +113,8 @@ class TestSerialization(TestCase):
|
|||
torch.save(qmodule(input_tensor), expected_file)
|
||||
|
||||
input_tensor = torch.load(input_file)
|
||||
qmodule.load_state_dict(torch.load(state_dict_file))
|
||||
# weights_only = False as sometimes get ScriptObject here
|
||||
qmodule.load_state_dict(torch.load(state_dict_file, weights_only=False))
|
||||
qmodule_scripted = torch.jit.load(scripted_module_file)
|
||||
qmodule_traced = torch.jit.load(traced_module_file)
|
||||
expected = torch.load(expected_file)
|
||||
|
|
@ -266,7 +267,7 @@ class TestSerialization(TestCase):
|
|||
# load input tensor
|
||||
input_tensor = torch.load(input_file)
|
||||
expected_output_tensor = torch.load(expected_file)
|
||||
expected_get_attrs = torch.load(get_attr_targets_file)
|
||||
expected_get_attrs = torch.load(get_attr_targets_file, weights_only=False)
|
||||
|
||||
# load model from package and verify output and get_attr targets match
|
||||
imp = torch.package.PackageImporter(package_file)
|
||||
|
|
|
|||
|
|
@ -185,8 +185,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
|||
b = io.BytesIO()
|
||||
torch.save(qlinear, b)
|
||||
b.seek(0)
|
||||
# Don't test weights_only here as this is legacy code that saves the model
|
||||
loaded = torch.load(b)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
loaded = torch.load(b, weights_only=False)
|
||||
self.assertEqual(qlinear.weight(), loaded.weight())
|
||||
self.assertEqual(qlinear.scale, loaded.scale)
|
||||
self.assertEqual(qlinear.zero_point, loaded.zero_point)
|
||||
|
|
@ -368,8 +368,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
|||
b = io.BytesIO()
|
||||
torch.save(qconv_module, b)
|
||||
b.seek(0)
|
||||
# Don't test weights_only here as this is legacy code that saves the model
|
||||
loaded_conv = torch.load(b)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
loaded_conv = torch.load(b, weights_only=False)
|
||||
|
||||
self.assertEqual(loaded_conv.bias(), qconv_module.bias())
|
||||
self.assertEqual(loaded_conv.scale, qconv_module.scale)
|
||||
|
|
@ -1482,8 +1482,8 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
|
|||
b = io.BytesIO()
|
||||
torch.save(dynamic_module, b)
|
||||
b.seek(0)
|
||||
# Don't test weights_only here as this is legacy code that saves the model
|
||||
loaded_conv = torch.load(b)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
loaded_conv = torch.load(b, weights_only=False)
|
||||
|
||||
self.assertEqual(loaded_conv.bias(), dynamic_module.bias())
|
||||
self.assertEqual(loaded_conv.scale, dynamic_module.scale)
|
||||
|
|
@ -1662,8 +1662,8 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
|
|||
b = io.BytesIO()
|
||||
torch.save(qlinear, b)
|
||||
b.seek(0)
|
||||
# Don't test weights_only here as this is legacy code that saves the model
|
||||
loaded = torch.load(b)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
loaded = torch.load(b, weights_only=False)
|
||||
self.assertEqual(qlinear.weight(), loaded.weight())
|
||||
self.assertEqual(qlinear.zero_point, loaded.zero_point)
|
||||
|
||||
|
|
|
|||
|
|
@ -1348,8 +1348,8 @@ class TestQuantizedTensor(TestCase):
|
|||
torch.save(f, buf)
|
||||
|
||||
buf.seek(0)
|
||||
# Don't test weights_only here as this is loading a Module (legacy)
|
||||
f2 = torch.load(buf)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
f2 = torch.load(buf, weights_only=False)
|
||||
|
||||
self.assertEqual(f2.qscheme, torch.per_tensor_symmetric)
|
||||
|
||||
|
|
|
|||
|
|
@ -4285,8 +4285,8 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
m.load_state_dict(state_dict)
|
||||
with TemporaryFileName() as fname:
|
||||
torch.save(m.state_dict(), fname)
|
||||
# Don't test weights_only here as this is loading a ScriptModule
|
||||
m.load_state_dict(torch.load(fname))
|
||||
# weights_only=False as this is loading a ScriptModule
|
||||
m.load_state_dict(torch.load(fname, weights_only=False))
|
||||
|
||||
checkModel(m, data, ref_weight, ref_bias, ref_res)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# Owner(s): ["module: cpp-extensions"]
|
||||
|
||||
import _codecs
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
|
@ -9,9 +10,12 @@ import unittest
|
|||
from typing import Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.testing._internal.common_utils as common
|
||||
import torch.utils.cpp_extension
|
||||
from torch.serialization import safe_globals
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_ARM64,
|
||||
skipIfTorchDynamo,
|
||||
|
|
@ -551,6 +555,17 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
|||
)
|
||||
with TemporaryFileName() as f:
|
||||
torch.save(sd, f)
|
||||
with safe_globals(
|
||||
[
|
||||
np.core.multiarray._reconstruct,
|
||||
np.ndarray,
|
||||
np.dtype,
|
||||
_codecs.encode,
|
||||
type(np.dtype(np.float32))
|
||||
if np.__version__ < "1.25.0"
|
||||
else np.dtypes.Float32DType,
|
||||
]
|
||||
):
|
||||
sd_loaded = torch.load(f, map_location="cpu")
|
||||
self.assertTrue(sd_loaded["x"].is_cpu)
|
||||
|
||||
|
|
|
|||
|
|
@ -1196,7 +1196,9 @@ class TestFX(JitTestCase):
|
|||
|
||||
bio.seek(0)
|
||||
|
||||
loaded = torch.load(bio)
|
||||
# weights_only=False as this loads a GraphModule
|
||||
# GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
|
||||
loaded = torch.load(bio, weights_only=False)
|
||||
|
||||
torch.testing.assert_close(loaded(x), x[0])
|
||||
|
||||
|
|
@ -4198,7 +4200,9 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
|
|||
b = io.BytesIO()
|
||||
torch.save(gm, b)
|
||||
b.seek(0)
|
||||
reload_gm = torch.load(b)
|
||||
# weights_only=False as this loads a GraphModule
|
||||
# GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
|
||||
reload_gm = torch.load(b, weights_only=False)
|
||||
self.assertTrue(hasattr(reload_gm, "foo"))
|
||||
self.assertTrue(hasattr(reload_gm, "dummy_buffer"))
|
||||
self.assertTrue(hasattr(reload_gm, "dummy_parameter"))
|
||||
|
|
|
|||
|
|
@ -239,7 +239,8 @@ class TestModule(TestCase):
|
|||
with tempfile.TemporaryFile() as f:
|
||||
torch.save(m, f)
|
||||
f.seek(0)
|
||||
m_copy = torch.load(f)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
m_copy = torch.load(f, weights_only=False)
|
||||
output_from_copy = m_copy(*args, **kwargs)
|
||||
self.assertEqual(output, output_from_copy)
|
||||
|
||||
|
|
|
|||
|
|
@ -156,7 +156,8 @@ class TestNN(NNTestCase):
|
|||
path = download_file('https://download.pytorch.org/test_data/linear.pt')
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', SourceChangeWarning)
|
||||
m = torch.load(path)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
m = torch.load(path, weights_only=False)
|
||||
input = torch.randn(2, 3, dtype=torch.float)
|
||||
self.assertEqual(m(input).size(), (2, 5))
|
||||
|
||||
|
|
@ -4328,7 +4329,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
buf = io.BytesIO()
|
||||
rnn_pickle = torch.save(rnn, buf)
|
||||
buf.seek(0)
|
||||
rnn2 = torch.load(buf)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
rnn2 = torch.load(buf, weights_only=False)
|
||||
rnn2.flatten_parameters()
|
||||
output3, hy3 = rnn2(input, hx)
|
||||
|
||||
|
|
|
|||
|
|
@ -1204,6 +1204,7 @@ def forward(self, x_a_1, x_b_1, y_1):
|
|||
x = LoggingTensor(torch.randperm(3))
|
||||
torch.save(x, f)
|
||||
f.seek(0)
|
||||
with torch.serialization.safe_globals([LoggingTensor]):
|
||||
x_loaded = torch.load(f)
|
||||
self.assertTrue(type(x_loaded) is type(x))
|
||||
self.assertEqual(x, x_loaded)
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from torch.serialization import (
|
|||
check_module_version_greater_or_equal,
|
||||
get_default_load_endianness,
|
||||
LoadEndianness,
|
||||
safe_globals,
|
||||
set_default_load_endianness,
|
||||
SourceChangeWarning,
|
||||
)
|
||||
|
|
@ -276,7 +277,8 @@ class SerializationMixin:
|
|||
torch.save(x, f, pickle_module=dill)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(ValueError, 'supports dill >='):
|
||||
x2 = torch.load(f, pickle_module=dill, encoding='utf-8')
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
x2 = torch.load(f, pickle_module=dill, encoding='utf-8', weights_only=False)
|
||||
|
||||
def test_pickle_module(self):
|
||||
class ThrowingUnpickler(pickle.Unpickler):
|
||||
|
|
@ -292,7 +294,8 @@ class SerializationMixin:
|
|||
torch.save(x, f)
|
||||
f.seek(0)
|
||||
with self.assertRaisesRegex(RuntimeError, "rumpelstiltskin"):
|
||||
torch.load(f, pickle_module=ThrowingModule)
|
||||
# weights_only=False as True does not support custom pickle module
|
||||
torch.load(f, pickle_module=ThrowingModule, weights_only=False)
|
||||
f.seek(0)
|
||||
z = torch.load(f)
|
||||
self.assertEqual(x, z)
|
||||
|
|
@ -307,11 +310,13 @@ class SerializationMixin:
|
|||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(x, f, pickle_module=dill)
|
||||
f.seek(0)
|
||||
x2 = torch.load(f, pickle_module=dill, encoding='utf-8')
|
||||
# weights_only=False as True does not support custom pickle_module
|
||||
x2 = torch.load(f, pickle_module=dill, encoding='utf-8', weights_only=False)
|
||||
self.assertIsInstance(x2, type(x))
|
||||
self.assertEqual(x, x2)
|
||||
f.seek(0)
|
||||
x3 = torch.load(f, pickle_module=dill)
|
||||
# weights_only=False as True does not support custom pickle_module
|
||||
x3 = torch.load(f, pickle_module=dill, weights_only=False)
|
||||
self.assertIsInstance(x3, type(x))
|
||||
self.assertEqual(x, x3)
|
||||
|
||||
|
|
@ -718,13 +723,15 @@ class SerializationMixin:
|
|||
# This Pickle contains a Python 2 module with Unicode data and the
|
||||
# loading should fail if the user explicitly specifies ascii encoding!
|
||||
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
|
||||
self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii'))
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii', weights_only=False))
|
||||
|
||||
def test_load_python2_unicode_module(self):
|
||||
# This Pickle contains some Unicode data!
|
||||
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
self.assertIsNotNone(torch.load(path))
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
self.assertIsNotNone(torch.load(path, weights_only=False))
|
||||
|
||||
def test_load_error_msg(self):
|
||||
expected_err_msg = (".*You can only torch.load from a file that is seekable. " +
|
||||
|
|
@ -735,7 +742,8 @@ class SerializationMixin:
|
|||
delattr(resource, "tell")
|
||||
delattr(resource, "seek")
|
||||
with self.assertRaisesRegex(AttributeError, expected_err_msg):
|
||||
torch.load(resource)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
torch.load(resource, weights_only=False)
|
||||
|
||||
def test_save_different_dtype_unallocated(self):
|
||||
devices = ['cpu']
|
||||
|
|
@ -881,12 +889,11 @@ class TestOldSerialization(TestCase, SerializationMixin):
|
|||
# First check that the checkpoint can be loaded without warning about unsafe loads
|
||||
checkpoint.seek(0)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
loaded = torch.load(checkpoint)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
loaded = torch.load(checkpoint, weights_only=False)
|
||||
self.assertTrue(isinstance(loaded, module.Net))
|
||||
if can_retrieve_source:
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertEqual(w[0].category, FutureWarning)
|
||||
self.assertTrue("You are using `torch.load` with `weights_only=False`" in str(w[0].message))
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
# Replace the module with different source
|
||||
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
|
||||
|
|
@ -894,12 +901,12 @@ class TestOldSerialization(TestCase, SerializationMixin):
|
|||
module = import_module(tmpmodule_name, fname)
|
||||
checkpoint.seek(0)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
loaded = torch.load(checkpoint)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
loaded = torch.load(checkpoint, weights_only=False)
|
||||
self.assertTrue(isinstance(loaded, module.Net))
|
||||
if can_retrieve_source:
|
||||
self.assertEqual(len(w), 2)
|
||||
self.assertEqual(w[0].category, FutureWarning)
|
||||
self.assertEqual(w[1].category, SourceChangeWarning)
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertEqual(w[0].category, SourceChangeWarning)
|
||||
|
||||
def test_serialization_container(self):
|
||||
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
|
||||
|
|
@ -924,7 +931,8 @@ class TestOldSerialization(TestCase, SerializationMixin):
|
|||
a_loaded = torch.load(f)
|
||||
j_loaded = pickle.load(f)
|
||||
b_loaded = torch.load(f)
|
||||
m_loaded = torch.load(f)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
m_loaded = torch.load(f, weights_only=False)
|
||||
self.assertTrue(torch.equal(a, a_loaded))
|
||||
self.assertTrue(torch.equal(b, b_loaded))
|
||||
self.assertTrue(m.kernel_size == m_loaded.kernel_size)
|
||||
|
|
@ -1141,7 +1149,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
|||
torch.load(f, weights_only=True)
|
||||
f.seek(0)
|
||||
# safe_globals doesn't work even with allowlist
|
||||
with torch.serialization.safe_globals([os.execv]):
|
||||
with safe_globals([os.execv]):
|
||||
with self.assertRaisesRegex(pickle.UnpicklingError, error_msg):
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
|
|
@ -4252,6 +4260,7 @@ class TestSubclassSerialization(TestCase):
|
|||
with BytesIOContext() as f:
|
||||
torch.save(my_tensor, f)
|
||||
f.seek(0)
|
||||
with safe_globals([TestWrapperSubclass]):
|
||||
new_tensor = torch.load(f)
|
||||
|
||||
self.assertIsInstance(new_tensor, TestWrapperSubclass)
|
||||
|
|
@ -4269,6 +4278,7 @@ class TestSubclassSerialization(TestCase):
|
|||
with BytesIOContext() as f:
|
||||
torch.save(my_tensor, f)
|
||||
f.seek(0)
|
||||
with safe_globals([TestGetStateSubclass]):
|
||||
new_tensor = torch.load(f)
|
||||
|
||||
self.assertIsInstance(new_tensor, TestGetStateSubclass)
|
||||
|
|
@ -4306,6 +4316,7 @@ class TestSubclassSerialization(TestCase):
|
|||
with BytesIOContext() as f:
|
||||
torch.save(tensor, f)
|
||||
f.seek(0)
|
||||
with safe_globals([TestEmptySubclass]):
|
||||
tensor2 = torch.load(f)
|
||||
|
||||
tensor = TestEmptySubclass()
|
||||
|
|
@ -4316,6 +4327,7 @@ class TestSubclassSerialization(TestCase):
|
|||
with BytesIOContext() as f:
|
||||
torch.save(tensor, f)
|
||||
f.seek(0)
|
||||
with safe_globals([TestEmptySubclass]):
|
||||
tensor2 = torch.load(f)
|
||||
|
||||
@skipIfTorchDynamo("name 'SYNTHETIC_LOCAL' is not defined")
|
||||
|
|
@ -4357,7 +4369,7 @@ class TestSubclassSerialization(TestCase):
|
|||
|
||||
def test_safe_globals_context_manager_weights_only(self):
|
||||
'''
|
||||
Tests torch.serialization.safe_globals context manager
|
||||
Tests safe_globals context manager
|
||||
'''
|
||||
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
|
||||
p = torch.nn.Parameter(t)
|
||||
|
|
@ -4367,7 +4379,7 @@ class TestSubclassSerialization(TestCase):
|
|||
torch.serialization.add_safe_globals([TestEmptySubclass])
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(sd, f)
|
||||
with torch.serialization.safe_globals([TwoTensor]):
|
||||
with safe_globals([TwoTensor]):
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
self.assertTrue(torch.serialization.get_safe_globals() == [TestEmptySubclass])
|
||||
|
|
@ -4385,6 +4397,7 @@ class TestSubclassSerialization(TestCase):
|
|||
|
||||
with TemporaryFileName() as f:
|
||||
torch.save(sd, f)
|
||||
with safe_globals([TwoTensor]):
|
||||
sd_loaded = torch.load(f, map_location=torch.device('cuda:0'))
|
||||
self.assertTrue(sd_loaded['t'].device == torch.device('cuda:0'))
|
||||
self.assertTrue(sd_loaded['t'].a.device == torch.device('cuda:0'))
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ class TestSubclass(TestCase):
|
|||
x = nn.Parameter(x)
|
||||
torch.save(x, f)
|
||||
f.seek(0)
|
||||
with torch.serialization.safe_globals([tensor_cls]):
|
||||
x_loaded = torch.load(f)
|
||||
|
||||
self.assertEqual(x, x_loaded)
|
||||
|
|
|
|||
|
|
@ -146,6 +146,10 @@ def _tensor_rebuild_functions():
|
|||
torch._utils._rebuild_meta_tensor_no_storage,
|
||||
torch._utils._rebuild_nested_tensor,
|
||||
torch._utils._rebuild_wrapper_subclass,
|
||||
# Allowlisting this, but not allowlisting the numpy functions by default
|
||||
# Reasoning is that we don't have control over the numpy functions, but
|
||||
# this utility is provided by pytorch
|
||||
torch._utils._rebuild_device_tensor_from_numpy,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -568,7 +568,8 @@ class {module_name}(torch.nn.Module):
|
|||
torch.save(module, module_file)
|
||||
blobified_modules.append(module_name)
|
||||
module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
|
||||
module_str = f"torch.load(r'{module_file}') # {module_repr}"
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
module_str = f"torch.load(r'{module_file}', weights_only=False) # {module_repr}"
|
||||
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
|
||||
|
||||
for buffer_name, buffer in self._buffers.items():
|
||||
|
|
|
|||
|
|
@ -3437,7 +3437,8 @@ class ModuleTest(TestBase):
|
|||
test_case._forward(module, input)
|
||||
torch.save(module, f)
|
||||
f.seek(0)
|
||||
module_copy = torch.load(f)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
module_copy = torch.load(f, weights_only=False)
|
||||
test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
|
||||
|
||||
self._do_test(test_case, module, input)
|
||||
|
|
|
|||
|
|
@ -657,7 +657,8 @@ class QuantizationTestCase(TestCase):
|
|||
b = io.BytesIO()
|
||||
torch.save(model_dict, b)
|
||||
b.seek(0)
|
||||
loaded_dict = torch.load(b)
|
||||
# weights_only=False as we sometimes get a ScriptObect here (weird)
|
||||
loaded_dict = torch.load(b, weights_only=False)
|
||||
loaded_model.load_state_dict(loaded_dict)
|
||||
ref_out = ref_model(*x)
|
||||
load_out = loaded_model(*x)
|
||||
|
|
@ -674,7 +675,8 @@ class QuantizationTestCase(TestCase):
|
|||
b = io.BytesIO()
|
||||
torch.save(ref_model, b)
|
||||
b.seek(0)
|
||||
loaded = torch.load(b)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
loaded = torch.load(b, weights_only=False)
|
||||
load_out = loaded(*x)
|
||||
check_outputs(ref_out, load_out)
|
||||
|
||||
|
|
|
|||
|
|
@ -4271,15 +4271,18 @@ class DistributedTest:
|
|||
if sys.platform == "win32":
|
||||
torch.save(model_DDP, tmp)
|
||||
tmp.seek(0)
|
||||
model_DDP = torch.load(tmp)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
model_DDP = torch.load(tmp, weights_only=False)
|
||||
else:
|
||||
torch.save(model_DDP, tmp.name)
|
||||
model_DDP = torch.load(tmp.name)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
model_DDP = torch.load(tmp.name, weights_only=False)
|
||||
|
||||
with tempfile.TemporaryFile() as tmp_file:
|
||||
torch.save(model_DDP, tmp_file)
|
||||
tmp_file.seek(0)
|
||||
saved_model = torch.load(tmp_file)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
saved_model = torch.load(tmp_file, weights_only=False)
|
||||
for k in model_DDP.state_dict():
|
||||
self.assertEqual(model_DDP.state_dict()[k], saved_model.state_dict()[k])
|
||||
|
||||
|
|
@ -4320,10 +4323,12 @@ class DistributedTest:
|
|||
if sys.platform == "win32":
|
||||
torch.save(model_DDP, tmp)
|
||||
tmp.seek(0)
|
||||
model_DDP = torch.load(tmp)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
model_DDP = torch.load(tmp, weights_only=False)
|
||||
else:
|
||||
torch.save(model_DDP, tmp.name)
|
||||
model_DDP = torch.load(tmp.name)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
model_DDP = torch.load(tmp.name, weights_only=False)
|
||||
|
||||
# dummy data initialization
|
||||
local_bs = len(gpu_subset)
|
||||
|
|
@ -5598,10 +5603,12 @@ class DistributedTest:
|
|||
if sys.platform == "win32":
|
||||
torch.save(model_DDP, tmp)
|
||||
tmp.seek(0)
|
||||
model_DDP = torch.load(tmp)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
model_DDP = torch.load(tmp, weights_only=False)
|
||||
else:
|
||||
torch.save(model_DDP, tmp.name)
|
||||
model_DDP = torch.load(tmp.name)
|
||||
# weights_only=False as this is legacy code that saves the model
|
||||
model_DDP = torch.load(tmp.name, weights_only=False)
|
||||
|
||||
# data initialization
|
||||
input_cpu = torch.randn(global_bs, 2)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user