Add mmap option to torch.load (#102549)

Using [`nanoGPT/model.py`](https://github.com/karpathy/nanoGPT/blob/master/model.py) run

<details><summary><b>Click for script to save gpt2-xlarge (1.5B params)</b></summary>

```
# test_load_save_gpt.py
from model import GPT
import torch
import time

torch.manual_seed(5)
# gpt2-xlarge 1558M parameters
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 48
    n_head: int = 25
    n_embd: int = 1600
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

def f():
    model = GPT(GPTConfig())
    state_dict = model.state_dict()

    start_saving = time.time()
    torch.save(state_dict, "gpt2-xlarge.pth")
    end_saving = time.time()

if __name__ == "__main__":
    f()
```
</details>

<details><summary><b>Click for script to load</b></summary>

```
# test_load_gpt.py

import torch
from model import GPT
from test_load_save_gpt import GPTConfig
import time
import argparse

def f(mmap, meta):
    device = 'meta' if meta else 'cpu'
    assign = True if meta else False
    with torch.device(device):
        model = GPT(GPTConfig())
    start_loading = time.time()
    loaded_state_dict = torch.load("gpt2-xlarge.pth", _mmap=mmap)
    end_loading = time.time()
    print(f"loading time using torch.load with mmap={mmap}: ", end_loading - start_loading)
    model.load_state_dict(loaded_state_dict, assign=assign)
    end_load_state_dict = time.time()
    print("load_state_dict time: ", end_load_state_dict - end_loading)
    model.cuda()
    end_cuda = time.time()
    print("cuda time using torch.load with mmap: ", end_cuda - end_load_state_dict)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog='load_gpt_xlarge')
    parser.add_argument('-m', '--mmap', action='store_true')
    parser.add_argument('-d', '--devicemeta', action='store_true')
    args = parser.parse_args()
    mmap = args.mmap
    meta = args.devicemeta
    f(mmap, meta)

```

</details>

`python test_load_gpt.py`

<img width="614" alt="Screenshot 2023-06-06 at 1 35 43 PM" src="https://github.com/pytorch/pytorch/assets/35276741/ee06e5b3-b610-463b-a867-df995d21af29">

`python test_load_gpt.py --mmap`
<img width="622" alt="Screenshot 2023-06-06 at 1 35 30 PM" src="https://github.com/pytorch/pytorch/assets/35276741/00d2fdd0-b1f5-4313-83dc-e540b654b2af">

If we further use the `with torch.device('meta')` context manager and pull the changes from https://github.com/pytorch/pytorch/pull/102212 that allow the model to reuse tensors from the state_dict, we have

`python test_load_gpt.py --mmap --devicemeta`
<img width="727" alt="Screenshot 2023-06-06 at 1 35 51 PM" src="https://github.com/pytorch/pytorch/assets/35276741/b50257d9-092a-49c3-acae-876ee44d009f">

\
\
Running the above in a docker container containing a build of PyTorch with RAM limited to 512mb by

1) running `make -f docker.Makefile` from `pytorch/` directory
2) `docker run -m 512m -it <image> bash`
3) docker cp `gpt2-xlarge.pth` and `test_load_gpt.py` into the image

`python test_load_gpt.py`

Docker will Kill the process due to OOM whereas

`python test_load_gpt.py --mmap --devicemeta`
<img width="635" alt="Screenshot 2023-06-06 at 1 55 48 PM" src="https://github.com/pytorch/pytorch/assets/35276741/f3820d9e-f24c-43e7-885b-3bfdf24ef8ad">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102549
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki 2023-06-09 13:00:24 +00:00 committed by PyTorch MergeBot
parent 74b7a6c75e
commit 6fa2d41dc7
3 changed files with 94 additions and 9 deletions

View File

@ -3436,11 +3436,60 @@ class TestSerialization(TestCase, SerializationMixin):
self.assertTrue(torch.equal(tensor_be_no_bom, tensor_le_bom))
self.assertTrue(torch.equal(tensor_be_no_bom, tensor_be_bom))
@parametrize('weights_only', (True, False))
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
def test_serialization_mmap_loading(self, weights_only):
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(3, 1024)
self.fc2 = torch.nn.Linear(1024, 5)
def forward(self, input):
return self.fc2(self.fc1(input))
with TemporaryFileName() as f:
state_dict = DummyModel().state_dict()
torch.save(state_dict, f)
result = torch.load(f, mmap=True, weights_only=weights_only)
result_non_mmap = torch.load(f, mmap=False, weights_only=weights_only)
model_mmap_state_dict = DummyModel()
model_mmap_state_dict.load_state_dict(result)
model_non_mmap_state_dict = DummyModel()
model_non_mmap_state_dict.load_state_dict(result_non_mmap)
input = torch.randn(4, 3)
self.assertEqual(model_mmap_state_dict(input), model_non_mmap_state_dict(input.clone()))
@unittest.skipIf(not torch.cuda.is_available() or IS_WINDOWS,
"CUDA is unavailable or NamedTemporaryFile on Windows")
def test_serialization_mmap_loading_with_map_location(self):
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(3, 1024)
self.fc2 = torch.nn.Linear(1024, 5)
def forward(self, input):
return self.fc2(self.fc1(input))
# make sure mmap where tensors' location tags are not CPU does not crash
# zipfile will first be mmap-ed on CPU and storages are extracted using
# overall_storage[start_offset:end_offset] before running
# _{device}_deserialize, which moves the storage to device
with TemporaryFileName() as f:
with torch.device('cuda'):
m = DummyModel()
state_dict = m.state_dict()
torch.save(state_dict, f)
result = torch.load(f, mmap=True)
for k, v in result.items():
self.assertTrue(v.is_cuda)
def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super().run(*args, **kwargs)
class TestWrapperSubclass(torch.Tensor):
elem: torch.Tensor
__slots__ = ['elem', 'other']

View File

@ -1475,9 +1475,14 @@ void initJITBindings(PyObject* module) {
return at::Tensor(std::move(ptr));
})
.def("serialization_id", &PyTorchStreamReader::serializationId)
.def("get_all_records", [](PyTorchStreamReader& self) {
return self.getAllRecords();
});
.def(
"get_all_records",
[](PyTorchStreamReader& self) { return self.getAllRecords(); })
.def(
"get_record_offset",
[](PyTorchStreamReader& self, const std::string& key) {
return self.getRecordOffset(key);
});
// Used by torch.Package to coordinate deserialization of storages across
// ScriptModules and eager modules

View File

@ -745,6 +745,7 @@ def load(
pickle_module: Any = None,
*,
weights_only: bool = False,
mmap: bool = None,
**pickle_load_args: Any
) -> Any:
# Reference: https://github.com/pytorch/pytorch/issues/54354
@ -794,6 +795,11 @@ def load(
match the :attr:`pickle_module` used to serialize file)
weights_only: Indicates whether unpickler should be restricted to
loading only tensors, primitive types and dictionaries
mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory.
Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
are moved to the location that they were tagged with when saving, or specified by `map_location`. This
second step is a no-op if the final location is CPU. When the `mmap` flag is set, instead of copying the
tensor storages from disk to CPU memory in the first step, f is mmaped.
pickle_load_args: (Python 3 only) optional keyword arguments passed over to
:func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
:attr:`errors=...`.
@ -854,6 +860,10 @@ def load(
if pickle_module is None:
pickle_module = pickle
# make flipping default BC-compatible
if mmap is None:
mmap = False
_check_dill_version(pickle_module)
if 'encoding' not in pickle_load_args.keys():
@ -865,6 +875,7 @@ def load(
# If we want to actually tail call to torch.jit.load, we need to
# reset back to the original position.
orig_position = opened_file.tell()
overall_storage = None
with _open_zipfile_reader(opened_file) as opened_zipfile:
if _is_torchscript_zip(opened_zipfile):
warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
@ -872,12 +883,29 @@ def load(
" silence this warning)", UserWarning)
opened_file.seek(orig_position)
return torch.jit.load(opened_file, map_location=map_location)
if mmap:
if not isinstance(f, str):
raise ValueError("f must be a string filename in order to use mmap argument")
size = os.path.getsize(f)
overall_storage = torch.UntypedStorage.from_file(f, False, size)
if weights_only:
try:
return _load(opened_zipfile, map_location, _weights_only_unpickler, **pickle_load_args)
return _load(opened_zipfile,
map_location,
_weights_only_unpickler,
overall_storage=overall_storage,
**pickle_load_args)
except RuntimeError as e:
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
return _load(opened_zipfile,
map_location,
pickle_module,
overall_storage=overall_storage,
**pickle_load_args)
if mmap:
raise RuntimeError("mmap can only be used with files saved with ",
"`torch.save(_use_new_zipfile_serialization=True), "
"please torch.save your checkpoint with this option in order to use mmap.")
if weights_only:
try:
return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
@ -1172,7 +1200,7 @@ class StorageType():
return f'StorageType(dtype={self.dtype})'
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args):
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args):
restore_location = _get_restore_location(map_location)
loaded_storages = {}
@ -1187,8 +1215,11 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickl
def load_tensor(dtype, numel, key, location):
name = f'data/{key}'
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
if overall_storage is not None:
storage_offset = zip_file.get_record_offset(name)
storage = overall_storage[storage_offset:storage_offset + numel]
else:
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
# swap here if byteswapping is needed
if byteorderdata is not None:
if byteorderdata.decode() != sys.byteorder: