mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make torch.serialization.set_default_mmap_options usable as a context manager (#134371)
As title Pull Request resolved: https://github.com/pytorch/pytorch/pull/134371 Approved by: https://github.com/albanD
This commit is contained in:
parent
0fa0ac80e4
commit
2ac710e667
|
|
@ -4052,7 +4052,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||||
@parametrize('path_type', (str, Path))
|
@parametrize('path_type', (str, Path))
|
||||||
@parametrize('weights_only', (True, False))
|
@parametrize('weights_only', (True, False))
|
||||||
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
||||||
def test_serialization_mmap_loading(self, weights_only, path_type):
|
def test_serialization_mmap_loading_options(self, weights_only, path_type):
|
||||||
class DummyModel(torch.nn.Module):
|
class DummyModel(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -4101,7 +4101,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||||
for v in result.values():
|
for v in result.values():
|
||||||
self.assertTrue(v.is_cuda)
|
self.assertTrue(v.is_cuda)
|
||||||
|
|
||||||
def test_serialization_mmap_loading_options(self):
|
def test_serialization_mmap_loading(self):
|
||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
with self.assertRaisesRegex(RuntimeError, "Changing the default mmap options is currently not supported"):
|
with self.assertRaisesRegex(RuntimeError, "Changing the default mmap options is currently not supported"):
|
||||||
torch.serialization.set_default_mmap_options(2)
|
torch.serialization.set_default_mmap_options(2)
|
||||||
|
|
@ -4111,22 +4111,36 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||||
with tempfile.NamedTemporaryFile() as f:
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
torch.save(sd, f)
|
torch.save(sd, f)
|
||||||
# with MmapVisibility.MAP_PRIVATE, should not be able to modify file
|
# with MmapVisibility.MAP_PRIVATE, should not be able to modify file
|
||||||
sd_loaded = torch.load(f.name, mmap=True)
|
sd_loaded = torch.load(f.name, mmap=True, weights_only=True)
|
||||||
sd_loaded['weight'][0][0] = 0
|
sd_loaded['weight'][0][0] = 0
|
||||||
sd_loaded2 = torch.load(f.name, mmap=True)
|
sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True)
|
||||||
self.assertEqual(sd_loaded2['weight'], sd['weight'])
|
self.assertEqual(sd_loaded2['weight'], sd['weight'])
|
||||||
# with MmapVisibility.MAP_SHARED, should be able to modify file
|
# with MmapVisibility.MAP_SHARED, should be able to modify file
|
||||||
torch.serialization.set_default_mmap_options(MAP_SHARED)
|
torch.serialization.set_default_mmap_options(MAP_SHARED)
|
||||||
try:
|
try:
|
||||||
sd_loaded = torch.load(f.name, mmap=True)
|
sd_loaded = torch.load(f.name, mmap=True, weights_only=True)
|
||||||
sd_loaded['weight'][0][0] = 0
|
sd_loaded['weight'][0][0] = 0
|
||||||
sd_loaded2 = torch.load(f.name, mmap=True)
|
sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True)
|
||||||
self.assertNotEqual(sd_loaded2['weight'], sd['weight'])
|
self.assertNotEqual(sd_loaded2['weight'], sd['weight'])
|
||||||
self.assertEqual(sd_loaded2['weight'][0][0].item(), 0)
|
self.assertEqual(sd_loaded2['weight'][0][0].item(), 0)
|
||||||
self.assertEqual(sd_loaded2['weight'], sd_loaded['weight'])
|
self.assertEqual(sd_loaded2['weight'], sd_loaded['weight'])
|
||||||
finally:
|
finally:
|
||||||
torch.serialization.set_default_mmap_options(MAP_PRIVATE)
|
torch.serialization.set_default_mmap_options(MAP_PRIVATE)
|
||||||
|
|
||||||
|
@unittest.skipIf(IS_WINDOWS, "mmap ctx doesn't work on Windows")
|
||||||
|
def test_serialization_mmap_loading_ctx(self):
|
||||||
|
sd = torch.nn.Linear(3, 5).state_dict()
|
||||||
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
|
torch.save(sd, f)
|
||||||
|
with torch.serialization.set_default_mmap_options(MAP_SHARED):
|
||||||
|
sd_loaded = torch.load(f.name, mmap=True, weights_only=True)
|
||||||
|
sd_loaded['weight'][0][0] = 0
|
||||||
|
sd_loaded2 = torch.load(f.name, mmap=True, weights_only=True)
|
||||||
|
self.assertNotEqual(sd_loaded2['weight'], sd['weight'])
|
||||||
|
self.assertEqual(sd_loaded2['weight'][0][0].item(), 0)
|
||||||
|
self.assertEqual(sd_loaded2['weight'], sd_loaded['weight'])
|
||||||
|
self.assertTrue(torch.serialization.get_default_mmap_options() == MAP_PRIVATE)
|
||||||
|
|
||||||
@parametrize('dtype', (torch.float8_e5m2, torch.float8_e4m3fn, torch.complex32))
|
@parametrize('dtype', (torch.float8_e5m2, torch.float8_e4m3fn, torch.complex32))
|
||||||
@parametrize('weights_only', (True, False))
|
@parametrize('weights_only', (True, False))
|
||||||
def test_serialization_dtype(self, dtype, weights_only):
|
def test_serialization_dtype(self, dtype, weights_only):
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,8 @@ __all__ = [
|
||||||
"LoadEndianness",
|
"LoadEndianness",
|
||||||
"get_default_load_endianness",
|
"get_default_load_endianness",
|
||||||
"set_default_load_endianness",
|
"set_default_load_endianness",
|
||||||
|
"get_default_mmap_options",
|
||||||
|
"set_default_mmap_options",
|
||||||
"clear_safe_globals",
|
"clear_safe_globals",
|
||||||
"get_safe_globals",
|
"get_safe_globals",
|
||||||
"add_safe_globals",
|
"add_safe_globals",
|
||||||
|
|
@ -163,9 +165,9 @@ def get_default_mmap_options() -> int:
|
||||||
return _default_mmap_options
|
return _default_mmap_options
|
||||||
|
|
||||||
|
|
||||||
def set_default_mmap_options(flags: int):
|
class set_default_mmap_options:
|
||||||
"""
|
"""
|
||||||
Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
|
Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
|
||||||
|
|
||||||
For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
|
For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
|
||||||
Please open an issue if you need any other option to be added here.
|
Please open an issue if you need any other option to be added here.
|
||||||
|
|
@ -176,17 +178,27 @@ def set_default_mmap_options(flags: int):
|
||||||
Args:
|
Args:
|
||||||
flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
|
flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
|
||||||
"""
|
"""
|
||||||
global _default_mmap_options
|
|
||||||
if IS_WINDOWS:
|
def __init__(self, flags: int) -> None:
|
||||||
raise RuntimeError(
|
if IS_WINDOWS:
|
||||||
"Changing the default mmap options is currently not supported for Windows"
|
raise RuntimeError(
|
||||||
)
|
"Changing the default mmap options is currently not supported for Windows"
|
||||||
if flags != MAP_PRIVATE and flags != MAP_SHARED:
|
)
|
||||||
raise ValueError(
|
if flags != MAP_PRIVATE and flags != MAP_SHARED:
|
||||||
"Invalid argument in function set_default_mmap_options, "
|
raise ValueError(
|
||||||
f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
|
"Invalid argument in function set_default_mmap_options, "
|
||||||
)
|
f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
|
||||||
_default_mmap_options = flags
|
)
|
||||||
|
global _default_mmap_options
|
||||||
|
self.prev = _default_mmap_options
|
||||||
|
_default_mmap_options = flags
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||||
|
global _default_mmap_options
|
||||||
|
_default_mmap_options = self.prev
|
||||||
|
|
||||||
|
|
||||||
def clear_safe_globals() -> None:
|
def clear_safe_globals() -> None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user