mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140850 Approved by: https://github.com/malfet ghstack dependencies: #140739, #140740
This commit is contained in:
parent
f3f305ef3e
commit
37959c554d
|
|
@ -4449,6 +4449,34 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||||
) if not should_import else None
|
) if not should_import else None
|
||||||
self._attempt_load_from_subprocess(filename, import_string, err_msg)
|
self._attempt_load_from_subprocess(filename, import_string, err_msg)
|
||||||
|
|
||||||
|
@parametrize("dtype", all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
|
||||||
|
@parametrize("weights_only", [True, False])
|
||||||
|
def test_save_load_preserves_dtype(self, dtype, weights_only):
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self, t):
|
||||||
|
super().__init__()
|
||||||
|
requires_grad = torch.is_floating_point(t) or torch.is_complex(t)
|
||||||
|
self.param = torch.nn.Parameter(t, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
if dtype.is_floating_point or dtype.is_complex:
|
||||||
|
t = torch.randn(10, dtype=dtype)
|
||||||
|
sd = MyModule(t).state_dict()
|
||||||
|
elif dtype is torch.bool:
|
||||||
|
t = torch.randn(10) > 0
|
||||||
|
sd = MyModule(t).state_dict()
|
||||||
|
else:
|
||||||
|
iinfo = torch.iinfo(dtype)
|
||||||
|
t = torch.randint(iinfo.min, iinfo.max, (10,), dtype=dtype)
|
||||||
|
sd = MyModule(t).state_dict()
|
||||||
|
sd_save = {'t': t, 'sd': sd, 'i' : t[0].item()}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
|
torch.save(sd_save, f)
|
||||||
|
f.seek(0)
|
||||||
|
loaded_sd = torch.load(f, weights_only=weights_only)
|
||||||
|
self.assertEqual(sd_save, loaded_sd)
|
||||||
|
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
with serialization_method(use_zip=True):
|
with serialization_method(use_zip=True):
|
||||||
return super().run(*args, **kwargs)
|
return super().run(*args, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -168,6 +168,7 @@ def _get_allowed_globals():
|
||||||
"_codecs.encode": encode, # for bytes
|
"_codecs.encode": encode, # for bytes
|
||||||
"builtins.bytearray": bytearray, # for bytearray
|
"builtins.bytearray": bytearray, # for bytearray
|
||||||
"builtins.set": set, # for set
|
"builtins.set": set, # for set
|
||||||
|
"builtins.complex": complex, # for complex
|
||||||
}
|
}
|
||||||
|
|
||||||
# dtype
|
# dtype
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user