From 37959c554df15c2c437fa0a481bcbe23bc6a73cb Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 18 Nov 2024 14:03:09 -0800 Subject: [PATCH] Add small test case for #140230 (#140850) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140850 Approved by: https://github.com/malfet ghstack dependencies: #140739, #140740 --- test/test_serialization.py | 28 ++++++++++++++++++++++++++++ torch/_weights_only_unpickler.py | 1 + 2 files changed, 29 insertions(+) diff --git a/test/test_serialization.py b/test/test_serialization.py index f69ed526acc..e5ee63c4218 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -4449,6 +4449,34 @@ class TestSerialization(TestCase, SerializationMixin): ) if not should_import else None self._attempt_load_from_subprocess(filename, import_string, err_msg) + @parametrize("dtype", all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) + @parametrize("weights_only", [True, False]) + def test_save_load_preserves_dtype(self, dtype, weights_only): + class MyModule(torch.nn.Module): + def __init__(self, t): + super().__init__() + requires_grad = torch.is_floating_point(t) or torch.is_complex(t) + self.param = torch.nn.Parameter(t, requires_grad=requires_grad) + + if dtype.is_floating_point or dtype.is_complex: + t = torch.randn(10, dtype=dtype) + sd = MyModule(t).state_dict() + elif dtype is torch.bool: + t = torch.randn(10) > 0 + sd = MyModule(t).state_dict() + else: + iinfo = torch.iinfo(dtype) + t = torch.randint(iinfo.min, iinfo.max, (10,), dtype=dtype) + sd = MyModule(t).state_dict() + sd_save = {'t': t, 'sd': sd, 'i' : t[0].item()} + + with tempfile.NamedTemporaryFile() as f: + torch.save(sd_save, f) + f.seek(0) + loaded_sd = torch.load(f, weights_only=weights_only) + self.assertEqual(sd_save, loaded_sd) + + def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super().run(*args, **kwargs) diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 9a146632ded..8a474587105 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -168,6 +168,7 @@ def _get_allowed_globals(): "_codecs.encode": encode, # for bytes "builtins.bytearray": bytearray, # for bytearray "builtins.set": set, # for set + "builtins.complex": complex, # for complex } # dtype