Fix failures when default is flipped for weights_only (#127627)

Tests on XLA shard not fixed yet but there is an issue here https://github.com/pytorch/xla/issues/7799

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127627
Approved by: https://github.com/albanD
ghstack dependencies: #132349
This commit is contained in:
Mikayla Gawarecki 2024-08-15 19:48:35 +00:00 committed by PyTorch MergeBot
parent c8ad5e37e8
commit d9576c9440
22 changed files with 135 additions and 78 deletions

View File

@ -38,6 +38,7 @@ from cpp_api_parity.utils import (
TorchNNModuleTestParams,
try_remove_folder,
)
from torch.jit._pickle import restore_type_tag
# Expected substitutions:
@ -193,6 +194,8 @@ def test_forward_backward(unit_test_class, test_params):
backward_grad_dict_file_path,
)
cpp_output = torch.load(forward_output_file_path)
# weights_only: need GLOBAL torch.jit._pickle.restore_type_tag
with torch.serialization.safe_globals([restore_type_tag]):
cpp_grad_dict = torch.load(backward_grad_dict_file_path)
# Check that forward outputs are equal

View File

@ -27,6 +27,7 @@ from torch.distributed.tensor.parallel import (
parallelize_module,
RowwiseParallel,
)
from torch.serialization import safe_globals
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
@ -535,18 +536,13 @@ class DTensorTest(DTensorTestBase):
buffer = io.BytesIO()
torch.save(sharded_tensor, buffer)
buffer.seek(0)
reloaded_st = torch.load(buffer)
reloaded_st = torch.load(buffer, weights_only=False)
self.assertEqual(sharded_tensor, reloaded_st)
# Test weights_only load
try:
torch.serialization.add_safe_globals(
[DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]
)
with safe_globals([DTensor, DeviceMesh, Shard, DTensorSpec, TensorMeta]):
buffer.seek(0)
reloaded_st = torch.load(buffer, weights_only=True)
self.assertEqual(sharded_tensor, reloaded_st)
finally:
torch.serialization.clear_safe_globals()
class DTensorMeshTest(DTensorTestBase):

View File

@ -376,7 +376,8 @@ class TestTorchbind(JitTestCase):
b = io.BytesIO()
torch.save(nt, b)
b.seek(0)
nt_loaded = torch.load(b)
# weights_only=False as trying to load ScriptObject
nt_loaded = torch.load(b, weights_only=False)
for exp in [7, 3, 3, 1]:
self.assertEqual(nt_loaded.pop(), exp)

View File

@ -87,7 +87,8 @@ class TestConvolutionNN(NNTestCase):
path = download_file("https://download.pytorch.org/test_data/legacy_conv2d.pt")
with warnings.catch_warnings():
warnings.simplefilter("ignore", SourceChangeWarning)
m = torch.load(path, encoding="utf-8")
# weights_only=False as this is legacy code that saves the model
m = torch.load(path, encoding="utf-8", weights_only=False)
input = torch.randn((1, 1, 1, 1), dtype=torch.float)
self.assertEqual(m(input).size(), (1, 1, 1, 1))

View File

@ -772,7 +772,8 @@ class TestStateDictHooks(TestCase):
# Note that torch.save / torch.load is not recommended to save/load
# modules.
torch.save(m, f.name)
m = torch.load(f.name)
# weights_only=False as this is legacy code that saves the model
m = torch.load(f.name, weights_only=False)
m.load_state_dict(sd)
self.assertFalse(called)
@ -855,7 +856,8 @@ class TestStateDictHooks(TestCase):
# Note that torch.save / torch.load is not recommended
# to save / load modules.
torch.save(m, f.name)
m = torch.load(f.name)
# weights_only=False as this is legacy code that saves the model
m = torch.load(f.name, weights_only=False)
# Ensure we can run state_dict without issues
_ = m.state_dict()

View File

@ -775,7 +775,8 @@ class TestPruningNN(NNTestCase):
with TemporaryFileName() as fname:
torch.save(model, fname)
new_model = torch.load(fname)
# weights_only=False as this is legacy code that saves the model
new_model = torch.load(fname, weights_only=False)
# check that the original weight and the new mask are present
self.assertIn("0.weight_orig", new_model.state_dict())

View File

@ -113,7 +113,8 @@ class TestSerialization(TestCase):
torch.save(qmodule(input_tensor), expected_file)
input_tensor = torch.load(input_file)
qmodule.load_state_dict(torch.load(state_dict_file))
# weights_only = False as sometimes get ScriptObject here
qmodule.load_state_dict(torch.load(state_dict_file, weights_only=False))
qmodule_scripted = torch.jit.load(scripted_module_file)
qmodule_traced = torch.jit.load(traced_module_file)
expected = torch.load(expected_file)
@ -266,7 +267,7 @@ class TestSerialization(TestCase):
# load input tensor
input_tensor = torch.load(input_file)
expected_output_tensor = torch.load(expected_file)
expected_get_attrs = torch.load(get_attr_targets_file)
expected_get_attrs = torch.load(get_attr_targets_file, weights_only=False)
# load model from package and verify output and get_attr targets match
imp = torch.package.PackageImporter(package_file)

View File

@ -185,8 +185,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
b = io.BytesIO()
torch.save(qlinear, b)
b.seek(0)
# Don't test weights_only here as this is legacy code that saves the model
loaded = torch.load(b)
# weights_only=False as this is legacy code that saves the model
loaded = torch.load(b, weights_only=False)
self.assertEqual(qlinear.weight(), loaded.weight())
self.assertEqual(qlinear.scale, loaded.scale)
self.assertEqual(qlinear.zero_point, loaded.zero_point)
@ -368,8 +368,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
b = io.BytesIO()
torch.save(qconv_module, b)
b.seek(0)
# Don't test weights_only here as this is legacy code that saves the model
loaded_conv = torch.load(b)
# weights_only=False as this is legacy code that saves the model
loaded_conv = torch.load(b, weights_only=False)
self.assertEqual(loaded_conv.bias(), qconv_module.bias())
self.assertEqual(loaded_conv.scale, qconv_module.scale)
@ -1482,8 +1482,8 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
b = io.BytesIO()
torch.save(dynamic_module, b)
b.seek(0)
# Don't test weights_only here as this is legacy code that saves the model
loaded_conv = torch.load(b)
# weights_only=False as this is legacy code that saves the model
loaded_conv = torch.load(b, weights_only=False)
self.assertEqual(loaded_conv.bias(), dynamic_module.bias())
self.assertEqual(loaded_conv.scale, dynamic_module.scale)
@ -1662,8 +1662,8 @@ class TestDynamicQuantizedModule(QuantizationTestCase):
b = io.BytesIO()
torch.save(qlinear, b)
b.seek(0)
# Don't test weights_only here as this is legacy code that saves the model
loaded = torch.load(b)
# weights_only=False as this is legacy code that saves the model
loaded = torch.load(b, weights_only=False)
self.assertEqual(qlinear.weight(), loaded.weight())
self.assertEqual(qlinear.zero_point, loaded.zero_point)

View File

@ -1348,8 +1348,8 @@ class TestQuantizedTensor(TestCase):
torch.save(f, buf)
buf.seek(0)
# Don't test weights_only here as this is loading a Module (legacy)
f2 = torch.load(buf)
# weights_only=False as this is legacy code that saves the model
f2 = torch.load(buf, weights_only=False)
self.assertEqual(f2.qscheme, torch.per_tensor_symmetric)

View File

@ -4285,8 +4285,8 @@ class TestQuantizeFx(QuantizationTestCase):
m.load_state_dict(state_dict)
with TemporaryFileName() as fname:
torch.save(m.state_dict(), fname)
# Don't test weights_only here as this is loading a ScriptModule
m.load_state_dict(torch.load(fname))
# weights_only=False as this is loading a ScriptModule
m.load_state_dict(torch.load(fname, weights_only=False))
checkModel(m, data, ref_weight, ref_bias, ref_res)

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: cpp-extensions"]
import _codecs
import os
import shutil
import sys
@ -9,9 +10,12 @@ import unittest
from typing import Union
from unittest.mock import patch
import numpy as np
import torch
import torch.testing._internal.common_utils as common
import torch.utils.cpp_extension
from torch.serialization import safe_globals
from torch.testing._internal.common_utils import (
IS_ARM64,
skipIfTorchDynamo,
@ -551,6 +555,17 @@ class TestCppExtensionOpenRgistration(common.TestCase):
)
with TemporaryFileName() as f:
torch.save(sd, f)
with safe_globals(
[
np.core.multiarray._reconstruct,
np.ndarray,
np.dtype,
_codecs.encode,
type(np.dtype(np.float32))
if np.__version__ < "1.25.0"
else np.dtypes.Float32DType,
]
):
sd_loaded = torch.load(f, map_location="cpu")
self.assertTrue(sd_loaded["x"].is_cpu)

View File

@ -1196,7 +1196,9 @@ class TestFX(JitTestCase):
bio.seek(0)
loaded = torch.load(bio)
# weights_only=False as this loads a GraphModule
# GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
loaded = torch.load(bio, weights_only=False)
torch.testing.assert_close(loaded(x), x[0])
@ -4198,7 +4200,9 @@ class TestFXAPIBackwardCompatibility(JitTestCase):
b = io.BytesIO()
torch.save(gm, b)
b.seek(0)
reload_gm = torch.load(b)
# weights_only=False as this loads a GraphModule
# GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
reload_gm = torch.load(b, weights_only=False)
self.assertTrue(hasattr(reload_gm, "foo"))
self.assertTrue(hasattr(reload_gm, "dummy_buffer"))
self.assertTrue(hasattr(reload_gm, "dummy_parameter"))

View File

@ -239,7 +239,8 @@ class TestModule(TestCase):
with tempfile.TemporaryFile() as f:
torch.save(m, f)
f.seek(0)
m_copy = torch.load(f)
# weights_only=False as this is legacy code that saves the model
m_copy = torch.load(f, weights_only=False)
output_from_copy = m_copy(*args, **kwargs)
self.assertEqual(output, output_from_copy)

View File

@ -156,7 +156,8 @@ class TestNN(NNTestCase):
path = download_file('https://download.pytorch.org/test_data/linear.pt')
with warnings.catch_warnings():
warnings.simplefilter('ignore', SourceChangeWarning)
m = torch.load(path)
# weights_only=False as this is legacy code that saves the model
m = torch.load(path, weights_only=False)
input = torch.randn(2, 3, dtype=torch.float)
self.assertEqual(m(input).size(), (2, 5))
@ -4328,7 +4329,8 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
buf = io.BytesIO()
rnn_pickle = torch.save(rnn, buf)
buf.seek(0)
rnn2 = torch.load(buf)
# weights_only=False as this is legacy code that saves the model
rnn2 = torch.load(buf, weights_only=False)
rnn2.flatten_parameters()
output3, hy3 = rnn2(input, hx)

View File

@ -1204,6 +1204,7 @@ def forward(self, x_a_1, x_b_1, y_1):
x = LoggingTensor(torch.randperm(3))
torch.save(x, f)
f.seek(0)
with torch.serialization.safe_globals([LoggingTensor]):
x_loaded = torch.load(f)
self.assertTrue(type(x_loaded) is type(x))
self.assertEqual(x, x_loaded)

View File

@ -25,6 +25,7 @@ from torch.serialization import (
check_module_version_greater_or_equal,
get_default_load_endianness,
LoadEndianness,
safe_globals,
set_default_load_endianness,
SourceChangeWarning,
)
@ -276,7 +277,8 @@ class SerializationMixin:
torch.save(x, f, pickle_module=dill)
f.seek(0)
with self.assertRaisesRegex(ValueError, 'supports dill >='):
x2 = torch.load(f, pickle_module=dill, encoding='utf-8')
# weights_only=False as this is legacy code that saves the model
x2 = torch.load(f, pickle_module=dill, encoding='utf-8', weights_only=False)
def test_pickle_module(self):
class ThrowingUnpickler(pickle.Unpickler):
@ -292,7 +294,8 @@ class SerializationMixin:
torch.save(x, f)
f.seek(0)
with self.assertRaisesRegex(RuntimeError, "rumpelstiltskin"):
torch.load(f, pickle_module=ThrowingModule)
# weights_only=False as True does not support custom pickle module
torch.load(f, pickle_module=ThrowingModule, weights_only=False)
f.seek(0)
z = torch.load(f)
self.assertEqual(x, z)
@ -307,11 +310,13 @@ class SerializationMixin:
with tempfile.NamedTemporaryFile() as f:
torch.save(x, f, pickle_module=dill)
f.seek(0)
x2 = torch.load(f, pickle_module=dill, encoding='utf-8')
# weights_only=False as True does not support custom pickle_module
x2 = torch.load(f, pickle_module=dill, encoding='utf-8', weights_only=False)
self.assertIsInstance(x2, type(x))
self.assertEqual(x, x2)
f.seek(0)
x3 = torch.load(f, pickle_module=dill)
# weights_only=False as True does not support custom pickle_module
x3 = torch.load(f, pickle_module=dill, weights_only=False)
self.assertIsInstance(x3, type(x))
self.assertEqual(x, x3)
@ -718,13 +723,15 @@ class SerializationMixin:
# This Pickle contains a Python 2 module with Unicode data and the
# loading should fail if the user explicitly specifies ascii encoding!
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii'))
# weights_only=False as this is legacy code that saves the model
self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii', weights_only=False))
def test_load_python2_unicode_module(self):
# This Pickle contains some Unicode data!
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
with warnings.catch_warnings(record=True) as w:
self.assertIsNotNone(torch.load(path))
# weights_only=False as this is legacy code that saves the model
self.assertIsNotNone(torch.load(path, weights_only=False))
def test_load_error_msg(self):
expected_err_msg = (".*You can only torch.load from a file that is seekable. " +
@ -735,7 +742,8 @@ class SerializationMixin:
delattr(resource, "tell")
delattr(resource, "seek")
with self.assertRaisesRegex(AttributeError, expected_err_msg):
torch.load(resource)
# weights_only=False as this is legacy code that saves the model
torch.load(resource, weights_only=False)
def test_save_different_dtype_unallocated(self):
devices = ['cpu']
@ -881,12 +889,11 @@ class TestOldSerialization(TestCase, SerializationMixin):
# First check that the checkpoint can be loaded without warning about unsafe loads
checkpoint.seek(0)
with warnings.catch_warnings(record=True) as w:
loaded = torch.load(checkpoint)
# weights_only=False as this is legacy code that saves the model
loaded = torch.load(checkpoint, weights_only=False)
self.assertTrue(isinstance(loaded, module.Net))
if can_retrieve_source:
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, FutureWarning)
self.assertTrue("You are using `torch.load` with `weights_only=False`" in str(w[0].message))
self.assertEqual(len(w), 0)
# Replace the module with different source
fname = get_file_path_2(os.path.dirname(os.path.dirname(torch.__file__)), 'torch', 'testing',
@ -894,12 +901,12 @@ class TestOldSerialization(TestCase, SerializationMixin):
module = import_module(tmpmodule_name, fname)
checkpoint.seek(0)
with warnings.catch_warnings(record=True) as w:
loaded = torch.load(checkpoint)
# weights_only=False as this is legacy code that saves the model
loaded = torch.load(checkpoint, weights_only=False)
self.assertTrue(isinstance(loaded, module.Net))
if can_retrieve_source:
self.assertEqual(len(w), 2)
self.assertEqual(w[0].category, FutureWarning)
self.assertEqual(w[1].category, SourceChangeWarning)
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, SourceChangeWarning)
def test_serialization_container(self):
self._test_serialization_container('file', tempfile.NamedTemporaryFile)
@ -924,7 +931,8 @@ class TestOldSerialization(TestCase, SerializationMixin):
a_loaded = torch.load(f)
j_loaded = pickle.load(f)
b_loaded = torch.load(f)
m_loaded = torch.load(f)
# weights_only=False as this is legacy code that saves the model
m_loaded = torch.load(f, weights_only=False)
self.assertTrue(torch.equal(a, a_loaded))
self.assertTrue(torch.equal(b, b_loaded))
self.assertTrue(m.kernel_size == m_loaded.kernel_size)
@ -1141,7 +1149,7 @@ class TestSerialization(TestCase, SerializationMixin):
torch.load(f, weights_only=True)
f.seek(0)
# safe_globals doesn't work even with allowlist
with torch.serialization.safe_globals([os.execv]):
with safe_globals([os.execv]):
with self.assertRaisesRegex(pickle.UnpicklingError, error_msg):
torch.load(f, weights_only=True)
@ -4252,6 +4260,7 @@ class TestSubclassSerialization(TestCase):
with BytesIOContext() as f:
torch.save(my_tensor, f)
f.seek(0)
with safe_globals([TestWrapperSubclass]):
new_tensor = torch.load(f)
self.assertIsInstance(new_tensor, TestWrapperSubclass)
@ -4269,6 +4278,7 @@ class TestSubclassSerialization(TestCase):
with BytesIOContext() as f:
torch.save(my_tensor, f)
f.seek(0)
with safe_globals([TestGetStateSubclass]):
new_tensor = torch.load(f)
self.assertIsInstance(new_tensor, TestGetStateSubclass)
@ -4306,6 +4316,7 @@ class TestSubclassSerialization(TestCase):
with BytesIOContext() as f:
torch.save(tensor, f)
f.seek(0)
with safe_globals([TestEmptySubclass]):
tensor2 = torch.load(f)
tensor = TestEmptySubclass()
@ -4316,6 +4327,7 @@ class TestSubclassSerialization(TestCase):
with BytesIOContext() as f:
torch.save(tensor, f)
f.seek(0)
with safe_globals([TestEmptySubclass]):
tensor2 = torch.load(f)
@skipIfTorchDynamo("name 'SYNTHETIC_LOCAL' is not defined")
@ -4357,7 +4369,7 @@ class TestSubclassSerialization(TestCase):
def test_safe_globals_context_manager_weights_only(self):
'''
Tests torch.serialization.safe_globals context manager
Tests safe_globals context manager
'''
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
p = torch.nn.Parameter(t)
@ -4367,7 +4379,7 @@ class TestSubclassSerialization(TestCase):
torch.serialization.add_safe_globals([TestEmptySubclass])
with tempfile.NamedTemporaryFile() as f:
torch.save(sd, f)
with torch.serialization.safe_globals([TwoTensor]):
with safe_globals([TwoTensor]):
f.seek(0)
torch.load(f, weights_only=True)
self.assertTrue(torch.serialization.get_safe_globals() == [TestEmptySubclass])
@ -4385,6 +4397,7 @@ class TestSubclassSerialization(TestCase):
with TemporaryFileName() as f:
torch.save(sd, f)
with safe_globals([TwoTensor]):
sd_loaded = torch.load(f, map_location=torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].device == torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].a.device == torch.device('cuda:0'))

View File

@ -86,6 +86,7 @@ class TestSubclass(TestCase):
x = nn.Parameter(x)
torch.save(x, f)
f.seek(0)
with torch.serialization.safe_globals([tensor_cls]):
x_loaded = torch.load(f)
self.assertEqual(x, x_loaded)

View File

@ -146,6 +146,10 @@ def _tensor_rebuild_functions():
torch._utils._rebuild_meta_tensor_no_storage,
torch._utils._rebuild_nested_tensor,
torch._utils._rebuild_wrapper_subclass,
# Allowlisting this, but not allowlisting the numpy functions by default
# Reasoning is that we don't have control over the numpy functions, but
# this utility is provided by pytorch
torch._utils._rebuild_device_tensor_from_numpy,
}

View File

@ -568,7 +568,8 @@ class {module_name}(torch.nn.Module):
torch.save(module, module_file)
blobified_modules.append(module_name)
module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}"
# weights_only=False as this is legacy code that saves the model
module_str = f"torch.load(r'{module_file}', weights_only=False) # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
for buffer_name, buffer in self._buffers.items():

View File

@ -3437,7 +3437,8 @@ class ModuleTest(TestBase):
test_case._forward(module, input)
torch.save(module, f)
f.seek(0)
module_copy = torch.load(f)
# weights_only=False as this is legacy code that saves the model
module_copy = torch.load(f, weights_only=False)
test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
self._do_test(test_case, module, input)

View File

@ -657,7 +657,8 @@ class QuantizationTestCase(TestCase):
b = io.BytesIO()
torch.save(model_dict, b)
b.seek(0)
loaded_dict = torch.load(b)
# weights_only=False as we sometimes get a ScriptObect here (weird)
loaded_dict = torch.load(b, weights_only=False)
loaded_model.load_state_dict(loaded_dict)
ref_out = ref_model(*x)
load_out = loaded_model(*x)
@ -674,7 +675,8 @@ class QuantizationTestCase(TestCase):
b = io.BytesIO()
torch.save(ref_model, b)
b.seek(0)
loaded = torch.load(b)
# weights_only=False as this is legacy code that saves the model
loaded = torch.load(b, weights_only=False)
load_out = loaded(*x)
check_outputs(ref_out, load_out)

View File

@ -4271,15 +4271,18 @@ class DistributedTest:
if sys.platform == "win32":
torch.save(model_DDP, tmp)
tmp.seek(0)
model_DDP = torch.load(tmp)
# weights_only=False as this is legacy code that saves the model
model_DDP = torch.load(tmp, weights_only=False)
else:
torch.save(model_DDP, tmp.name)
model_DDP = torch.load(tmp.name)
# weights_only=False as this is legacy code that saves the model
model_DDP = torch.load(tmp.name, weights_only=False)
with tempfile.TemporaryFile() as tmp_file:
torch.save(model_DDP, tmp_file)
tmp_file.seek(0)
saved_model = torch.load(tmp_file)
# weights_only=False as this is legacy code that saves the model
saved_model = torch.load(tmp_file, weights_only=False)
for k in model_DDP.state_dict():
self.assertEqual(model_DDP.state_dict()[k], saved_model.state_dict()[k])
@ -4320,10 +4323,12 @@ class DistributedTest:
if sys.platform == "win32":
torch.save(model_DDP, tmp)
tmp.seek(0)
model_DDP = torch.load(tmp)
# weights_only=False as this is legacy code that saves the model
model_DDP = torch.load(tmp, weights_only=False)
else:
torch.save(model_DDP, tmp.name)
model_DDP = torch.load(tmp.name)
# weights_only=False as this is legacy code that saves the model
model_DDP = torch.load(tmp.name, weights_only=False)
# dummy data initialization
local_bs = len(gpu_subset)
@ -5598,10 +5603,12 @@ class DistributedTest:
if sys.platform == "win32":
torch.save(model_DDP, tmp)
tmp.seek(0)
model_DDP = torch.load(tmp)
# weights_only=False as this is legacy code that saves the model
model_DDP = torch.load(tmp, weights_only=False)
else:
torch.save(model_DDP, tmp.name)
model_DDP = torch.load(tmp.name)
# weights_only=False as this is legacy code that saves the model
model_DDP = torch.load(tmp.name, weights_only=False)
# data initialization
input_cpu = torch.randn(global_bs, 2)