pytorch/torch/csrc/jit/python
Mikayla Gawarecki 6fa2d41dc7 Add mmap option to torch.load (#102549)
Using [`nanoGPT/model.py`](https://github.com/karpathy/nanoGPT/blob/master/model.py) run

<details><summary><b>Click for script to save gpt2-xlarge (1.5B params)</b></summary>

```
# test_load_save_gpt.py
from model import GPT
import torch
import time

torch.manual_seed(5)
# gpt2-xlarge 1558M parameters
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 48
    n_head: int = 25
    n_embd: int = 1600
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

def f():
    model = GPT(GPTConfig())
    state_dict = model.state_dict()

    start_saving = time.time()
    torch.save(state_dict, "gpt2-xlarge.pth")
    end_saving = time.time()

if __name__ == "__main__":
    f()
```
</details>

<details><summary><b>Click for script to load</b></summary>

```
# test_load_gpt.py

import torch
from model import GPT
from test_load_save_gpt import GPTConfig
import time
import argparse

def f(mmap, meta):
    device = 'meta' if meta else 'cpu'
    assign = True if meta else False
    with torch.device(device):
        model = GPT(GPTConfig())
    start_loading = time.time()
    loaded_state_dict = torch.load("gpt2-xlarge.pth", _mmap=mmap)
    end_loading = time.time()
    print(f"loading time using torch.load with mmap={mmap}: ", end_loading - start_loading)
    model.load_state_dict(loaded_state_dict, assign=assign)
    end_load_state_dict = time.time()
    print("load_state_dict time: ", end_load_state_dict - end_loading)
    model.cuda()
    end_cuda = time.time()
    print("cuda time using torch.load with mmap: ", end_cuda - end_load_state_dict)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog='load_gpt_xlarge')
    parser.add_argument('-m', '--mmap', action='store_true')
    parser.add_argument('-d', '--devicemeta', action='store_true')
    args = parser.parse_args()
    mmap = args.mmap
    meta = args.devicemeta
    f(mmap, meta)

```

</details>

`python test_load_gpt.py`

<img width="614" alt="Screenshot 2023-06-06 at 1 35 43 PM" src="https://github.com/pytorch/pytorch/assets/35276741/ee06e5b3-b610-463b-a867-df995d21af29">

`python test_load_gpt.py --mmap`
<img width="622" alt="Screenshot 2023-06-06 at 1 35 30 PM" src="https://github.com/pytorch/pytorch/assets/35276741/00d2fdd0-b1f5-4313-83dc-e540b654b2af">

If we further use the `with torch.device('meta')` context manager and pull the changes from https://github.com/pytorch/pytorch/pull/102212 that allow the model to reuse tensors from the state_dict, we have

`python test_load_gpt.py --mmap --devicemeta`
<img width="727" alt="Screenshot 2023-06-06 at 1 35 51 PM" src="https://github.com/pytorch/pytorch/assets/35276741/b50257d9-092a-49c3-acae-876ee44d009f">

\
\
Running the above in a docker container containing a build of PyTorch with RAM limited to 512mb by

1) running `make -f docker.Makefile` from `pytorch/` directory
2) `docker run -m 512m -it <image> bash`
3) docker cp `gpt2-xlarge.pth` and `test_load_gpt.py` into the image

`python test_load_gpt.py`

Docker will Kill the process due to OOM whereas

`python test_load_gpt.py --mmap --devicemeta`
<img width="635" alt="Screenshot 2023-06-06 at 1 55 48 PM" src="https://github.com/pytorch/pytorch/assets/35276741/f3820d9e-f24c-43e7-885b-3bfdf24ef8ad">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102549
Approved by: https://github.com/albanD
2023-06-09 15:49:58 +00:00
..
init.cpp Add mmap option to torch.load (#102549) 2023-06-09 15:49:58 +00:00
init.h
module_python.h
pybind_utils.cpp ASAN: fix use-after-free (#101400) 2023-05-15 15:32:10 +00:00
pybind_utils.h Revert "Reduce includes of CUDACachingAllocator.h (#97072)" 2023-04-07 06:15:11 +00:00
pybind.h
python_arg_flatten.cpp [BE] Use nested namespaces in .cpp/.cu files (#92100) 2023-01-13 16:32:34 +00:00
python_arg_flatten.h
python_custom_class.cpp [BE] Use nested namespaces in .cpp/.cu files (#92100) 2023-01-13 16:32:34 +00:00
python_custom_class.h
python_dict.cpp Apply Clang-Tidy readability-container-size-empty (#93236) 2023-01-29 23:28:19 +00:00
python_dict.h
python_interpreter.cpp [BE] Use nested namespaces in .cpp/.cu files (#92100) 2023-01-13 16:32:34 +00:00
python_ir.cpp Add Symbool support in python to C++ translation (#98453) 2023-04-12 03:21:57 +00:00
python_ir.h
python_ivalue.h
python_list.cpp Use size in python list (#102538) 2023-06-01 00:46:29 +00:00
python_list.h Use size in python list (#102538) 2023-06-01 00:46:29 +00:00
python_sugared_value.cpp [JIT] clarify errors due to non-literal indexing into ModuleList, ModuleDict (#98606) 2023-04-18 02:53:53 +00:00
python_sugared_value.h [JIT] Partially support ForwardRef type annotations for NamedTuple attributes (#96933) 2023-03-22 15:20:38 +00:00
python_tracer.cpp do not need to check if element in dict input is Tensor. (#97866) 2023-03-31 19:39:00 +00:00
python_tracer.h
python_tree_views.cpp [BE] Use nested namespaces in .cpp/.cu files (#92100) 2023-01-13 16:32:34 +00:00
python_tree_views.h
script_init.cpp Allow C++ custom class to define __repr__ and use it from Python (#100724) 2023-05-10 15:46:45 +00:00
script_init.h
update_graph_executor_opt.cpp [BE] Use nested namespaces in .cpp/.cu files (#92100) 2023-01-13 16:32:34 +00:00
update_graph_executor_opt.h
utf8_decoding_ignore.cpp Optionally ignore utf-8 decoding error when converting std::string to python str. (#97282) 2023-03-23 01:19:08 +00:00
utf8_decoding_ignore.h Optionally ignore utf-8 decoding error when converting std::string to python str. (#97282) 2023-03-23 01:19:08 +00:00