mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add mmap option to torch.load (#102549)
Using [`nanoGPT/model.py`](https://github.com/karpathy/nanoGPT/blob/master/model.py) run <details><summary><b>Click for script to save gpt2-xlarge (1.5B params)</b></summary> ``` # test_load_save_gpt.py from model import GPT import torch import time torch.manual_seed(5) # gpt2-xlarge 1558M parameters class GPTConfig: block_size: int = 1024 vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: int = 48 n_head: int = 25 n_embd: int = 1600 dropout: float = 0.0 bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster def f(): model = GPT(GPTConfig()) state_dict = model.state_dict() start_saving = time.time() torch.save(state_dict, "gpt2-xlarge.pth") end_saving = time.time() if __name__ == "__main__": f() ``` </details> <details><summary><b>Click for script to load</b></summary> ``` # test_load_gpt.py import torch from model import GPT from test_load_save_gpt import GPTConfig import time import argparse def f(mmap, meta): device = 'meta' if meta else 'cpu' assign = True if meta else False with torch.device(device): model = GPT(GPTConfig()) start_loading = time.time() loaded_state_dict = torch.load("gpt2-xlarge.pth", _mmap=mmap) end_loading = time.time() print(f"loading time using torch.load with mmap={mmap}: ", end_loading - start_loading) model.load_state_dict(loaded_state_dict, assign=assign) end_load_state_dict = time.time() print("load_state_dict time: ", end_load_state_dict - end_loading) model.cuda() end_cuda = time.time() print("cuda time using torch.load with mmap: ", end_cuda - end_load_state_dict) if __name__ == "__main__": parser = argparse.ArgumentParser(prog='load_gpt_xlarge') parser.add_argument('-m', '--mmap', action='store_true') parser.add_argument('-d', '--devicemeta', action='store_true') args = parser.parse_args() mmap = args.mmap meta = args.devicemeta f(mmap, meta) ``` </details> `python test_load_gpt.py` <img width="614" alt="Screenshot 2023-06-06 at 1 35 43 PM" src="https://github.com/pytorch/pytorch/assets/35276741/ee06e5b3-b610-463b-a867-df995d21af29"> `python test_load_gpt.py --mmap` <img width="622" alt="Screenshot 2023-06-06 at 1 35 30 PM" src="https://github.com/pytorch/pytorch/assets/35276741/00d2fdd0-b1f5-4313-83dc-e540b654b2af"> If we further use the `with torch.device('meta')` context manager and pull the changes from https://github.com/pytorch/pytorch/pull/102212 that allow the model to reuse tensors from the state_dict, we have `python test_load_gpt.py --mmap --devicemeta` <img width="727" alt="Screenshot 2023-06-06 at 1 35 51 PM" src="https://github.com/pytorch/pytorch/assets/35276741/b50257d9-092a-49c3-acae-876ee44d009f"> \ \ Running the above in a docker container containing a build of PyTorch with RAM limited to 512mb by 1) running `make -f docker.Makefile` from `pytorch/` directory 2) `docker run -m 512m -it <image> bash` 3) docker cp `gpt2-xlarge.pth` and `test_load_gpt.py` into the image `python test_load_gpt.py` Docker will Kill the process due to OOM whereas `python test_load_gpt.py --mmap --devicemeta` <img width="635" alt="Screenshot 2023-06-06 at 1 55 48 PM" src="https://github.com/pytorch/pytorch/assets/35276741/f3820d9e-f24c-43e7-885b-3bfdf24ef8ad"> Pull Request resolved: https://github.com/pytorch/pytorch/pull/102549 Approved by: https://github.com/albanD
This commit is contained in:
parent
74b7a6c75e
commit
6fa2d41dc7
|
|
@ -3436,11 +3436,60 @@ class TestSerialization(TestCase, SerializationMixin):
|
|||
self.assertTrue(torch.equal(tensor_be_no_bom, tensor_le_bom))
|
||||
self.assertTrue(torch.equal(tensor_be_no_bom, tensor_be_bom))
|
||||
|
||||
@parametrize('weights_only', (True, False))
|
||||
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
||||
def test_serialization_mmap_loading(self, weights_only):
|
||||
class DummyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(3, 1024)
|
||||
self.fc2 = torch.nn.Linear(1024, 5)
|
||||
|
||||
def forward(self, input):
|
||||
return self.fc2(self.fc1(input))
|
||||
|
||||
with TemporaryFileName() as f:
|
||||
state_dict = DummyModel().state_dict()
|
||||
torch.save(state_dict, f)
|
||||
result = torch.load(f, mmap=True, weights_only=weights_only)
|
||||
result_non_mmap = torch.load(f, mmap=False, weights_only=weights_only)
|
||||
|
||||
model_mmap_state_dict = DummyModel()
|
||||
model_mmap_state_dict.load_state_dict(result)
|
||||
model_non_mmap_state_dict = DummyModel()
|
||||
model_non_mmap_state_dict.load_state_dict(result_non_mmap)
|
||||
input = torch.randn(4, 3)
|
||||
self.assertEqual(model_mmap_state_dict(input), model_non_mmap_state_dict(input.clone()))
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available() or IS_WINDOWS,
|
||||
"CUDA is unavailable or NamedTemporaryFile on Windows")
|
||||
def test_serialization_mmap_loading_with_map_location(self):
|
||||
class DummyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(3, 1024)
|
||||
self.fc2 = torch.nn.Linear(1024, 5)
|
||||
|
||||
def forward(self, input):
|
||||
return self.fc2(self.fc1(input))
|
||||
|
||||
# make sure mmap where tensors' location tags are not CPU does not crash
|
||||
# zipfile will first be mmap-ed on CPU and storages are extracted using
|
||||
# overall_storage[start_offset:end_offset] before running
|
||||
# _{device}_deserialize, which moves the storage to device
|
||||
with TemporaryFileName() as f:
|
||||
with torch.device('cuda'):
|
||||
m = DummyModel()
|
||||
state_dict = m.state_dict()
|
||||
torch.save(state_dict, f)
|
||||
result = torch.load(f, mmap=True)
|
||||
for k, v in result.items():
|
||||
self.assertTrue(v.is_cuda)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class TestWrapperSubclass(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
__slots__ = ['elem', 'other']
|
||||
|
|
|
|||
|
|
@ -1475,9 +1475,14 @@ void initJITBindings(PyObject* module) {
|
|||
return at::Tensor(std::move(ptr));
|
||||
})
|
||||
.def("serialization_id", &PyTorchStreamReader::serializationId)
|
||||
.def("get_all_records", [](PyTorchStreamReader& self) {
|
||||
return self.getAllRecords();
|
||||
});
|
||||
.def(
|
||||
"get_all_records",
|
||||
[](PyTorchStreamReader& self) { return self.getAllRecords(); })
|
||||
.def(
|
||||
"get_record_offset",
|
||||
[](PyTorchStreamReader& self, const std::string& key) {
|
||||
return self.getRecordOffset(key);
|
||||
});
|
||||
|
||||
// Used by torch.Package to coordinate deserialization of storages across
|
||||
// ScriptModules and eager modules
|
||||
|
|
|
|||
|
|
@ -745,6 +745,7 @@ def load(
|
|||
pickle_module: Any = None,
|
||||
*,
|
||||
weights_only: bool = False,
|
||||
mmap: bool = None,
|
||||
**pickle_load_args: Any
|
||||
) -> Any:
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/54354
|
||||
|
|
@ -794,6 +795,11 @@ def load(
|
|||
match the :attr:`pickle_module` used to serialize file)
|
||||
weights_only: Indicates whether unpickler should be restricted to
|
||||
loading only tensors, primitive types and dictionaries
|
||||
mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory.
|
||||
Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
|
||||
are moved to the location that they were tagged with when saving, or specified by `map_location`. This
|
||||
second step is a no-op if the final location is CPU. When the `mmap` flag is set, instead of copying the
|
||||
tensor storages from disk to CPU memory in the first step, f is mmaped.
|
||||
pickle_load_args: (Python 3 only) optional keyword arguments passed over to
|
||||
:func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
|
||||
:attr:`errors=...`.
|
||||
|
|
@ -854,6 +860,10 @@ def load(
|
|||
if pickle_module is None:
|
||||
pickle_module = pickle
|
||||
|
||||
# make flipping default BC-compatible
|
||||
if mmap is None:
|
||||
mmap = False
|
||||
|
||||
_check_dill_version(pickle_module)
|
||||
|
||||
if 'encoding' not in pickle_load_args.keys():
|
||||
|
|
@ -865,6 +875,7 @@ def load(
|
|||
# If we want to actually tail call to torch.jit.load, we need to
|
||||
# reset back to the original position.
|
||||
orig_position = opened_file.tell()
|
||||
overall_storage = None
|
||||
with _open_zipfile_reader(opened_file) as opened_zipfile:
|
||||
if _is_torchscript_zip(opened_zipfile):
|
||||
warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
|
||||
|
|
@ -872,12 +883,29 @@ def load(
|
|||
" silence this warning)", UserWarning)
|
||||
opened_file.seek(orig_position)
|
||||
return torch.jit.load(opened_file, map_location=map_location)
|
||||
if mmap:
|
||||
if not isinstance(f, str):
|
||||
raise ValueError("f must be a string filename in order to use mmap argument")
|
||||
size = os.path.getsize(f)
|
||||
overall_storage = torch.UntypedStorage.from_file(f, False, size)
|
||||
if weights_only:
|
||||
try:
|
||||
return _load(opened_zipfile, map_location, _weights_only_unpickler, **pickle_load_args)
|
||||
return _load(opened_zipfile,
|
||||
map_location,
|
||||
_weights_only_unpickler,
|
||||
overall_storage=overall_storage,
|
||||
**pickle_load_args)
|
||||
except RuntimeError as e:
|
||||
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
|
||||
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
|
||||
return _load(opened_zipfile,
|
||||
map_location,
|
||||
pickle_module,
|
||||
overall_storage=overall_storage,
|
||||
**pickle_load_args)
|
||||
if mmap:
|
||||
raise RuntimeError("mmap can only be used with files saved with ",
|
||||
"`torch.save(_use_new_zipfile_serialization=True), "
|
||||
"please torch.save your checkpoint with this option in order to use mmap.")
|
||||
if weights_only:
|
||||
try:
|
||||
return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
|
||||
|
|
@ -1172,7 +1200,7 @@ class StorageType():
|
|||
return f'StorageType(dtype={self.dtype})'
|
||||
|
||||
|
||||
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args):
|
||||
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args):
|
||||
restore_location = _get_restore_location(map_location)
|
||||
|
||||
loaded_storages = {}
|
||||
|
|
@ -1187,8 +1215,11 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickl
|
|||
|
||||
def load_tensor(dtype, numel, key, location):
|
||||
name = f'data/{key}'
|
||||
|
||||
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
|
||||
if overall_storage is not None:
|
||||
storage_offset = zip_file.get_record_offset(name)
|
||||
storage = overall_storage[storage_offset:storage_offset + numel]
|
||||
else:
|
||||
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
|
||||
# swap here if byteswapping is needed
|
||||
if byteorderdata is not None:
|
||||
if byteorderdata.decode() != sys.byteorder:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user