mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140850 Approved by: https://github.com/malfet ghstack dependencies: #140739, #140740
This commit is contained in:
parent
f3f305ef3e
commit
37959c554d
|
|
@ -4449,6 +4449,34 @@ class TestSerialization(TestCase, SerializationMixin):
|
|||
) if not should_import else None
|
||||
self._attempt_load_from_subprocess(filename, import_string, err_msg)
|
||||
|
||||
@parametrize("dtype", all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||
@parametrize("weights_only", [True, False])
|
||||
def test_save_load_preserves_dtype(self, dtype, weights_only):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self, t):
|
||||
super().__init__()
|
||||
requires_grad = torch.is_floating_point(t) or torch.is_complex(t)
|
||||
self.param = torch.nn.Parameter(t, requires_grad=requires_grad)
|
||||
|
||||
if dtype.is_floating_point or dtype.is_complex:
|
||||
t = torch.randn(10, dtype=dtype)
|
||||
sd = MyModule(t).state_dict()
|
||||
elif dtype is torch.bool:
|
||||
t = torch.randn(10) > 0
|
||||
sd = MyModule(t).state_dict()
|
||||
else:
|
||||
iinfo = torch.iinfo(dtype)
|
||||
t = torch.randint(iinfo.min, iinfo.max, (10,), dtype=dtype)
|
||||
sd = MyModule(t).state_dict()
|
||||
sd_save = {'t': t, 'sd': sd, 'i' : t[0].item()}
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(sd_save, f)
|
||||
f.seek(0)
|
||||
loaded_sd = torch.load(f, weights_only=weights_only)
|
||||
self.assertEqual(sd_save, loaded_sd)
|
||||
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
return super().run(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -168,6 +168,7 @@ def _get_allowed_globals():
|
|||
"_codecs.encode": encode, # for bytes
|
||||
"builtins.bytearray": bytearray, # for bytearray
|
||||
"builtins.set": set, # for set
|
||||
"builtins.complex": complex, # for complex
|
||||
}
|
||||
|
||||
# dtype
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user