Add small test case for #140230 (#140850)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140850
Approved by: https://github.com/malfet
ghstack dependencies: #140739, #140740
This commit is contained in:
Mikayla Gawarecki 2024-11-18 14:03:09 -08:00 committed by PyTorch MergeBot
parent f3f305ef3e
commit 37959c554d
2 changed files with 29 additions and 0 deletions

View File

@ -4449,6 +4449,34 @@ class TestSerialization(TestCase, SerializationMixin):
) if not should_import else None ) if not should_import else None
self._attempt_load_from_subprocess(filename, import_string, err_msg) self._attempt_load_from_subprocess(filename, import_string, err_msg)
@parametrize("dtype", all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@parametrize("weights_only", [True, False])
def test_save_load_preserves_dtype(self, dtype, weights_only):
class MyModule(torch.nn.Module):
def __init__(self, t):
super().__init__()
requires_grad = torch.is_floating_point(t) or torch.is_complex(t)
self.param = torch.nn.Parameter(t, requires_grad=requires_grad)
if dtype.is_floating_point or dtype.is_complex:
t = torch.randn(10, dtype=dtype)
sd = MyModule(t).state_dict()
elif dtype is torch.bool:
t = torch.randn(10) > 0
sd = MyModule(t).state_dict()
else:
iinfo = torch.iinfo(dtype)
t = torch.randint(iinfo.min, iinfo.max, (10,), dtype=dtype)
sd = MyModule(t).state_dict()
sd_save = {'t': t, 'sd': sd, 'i' : t[0].item()}
with tempfile.NamedTemporaryFile() as f:
torch.save(sd_save, f)
f.seek(0)
loaded_sd = torch.load(f, weights_only=weights_only)
self.assertEqual(sd_save, loaded_sd)
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
with serialization_method(use_zip=True): with serialization_method(use_zip=True):
return super().run(*args, **kwargs) return super().run(*args, **kwargs)

View File

@ -168,6 +168,7 @@ def _get_allowed_globals():
"_codecs.encode": encode, # for bytes "_codecs.encode": encode, # for bytes
"builtins.bytearray": bytearray, # for bytearray "builtins.bytearray": bytearray, # for bytearray
"builtins.set": set, # for set "builtins.set": set, # for set
"builtins.complex": complex, # for complex
} }
# dtype # dtype