mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
## Semantic
The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).
```python
import torch
import torch.nn as nn
sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```
(2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)
```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode():
m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')
sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])
```
## Follow Ups
- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
|
||
|---|---|---|
| .. | ||
| amp_examples.rst | ||
| autograd.rst | ||
| broadcasting.rst | ||
| cpu_threading_runtimes.svg | ||
| cpu_threading_torchscript_inference.rst | ||
| cpu_threading_torchscript_inference.svg | ||
| cuda.rst | ||
| custom_operators.rst | ||
| ddp.rst | ||
| extending.func.rst | ||
| extending.rst | ||
| faq.rst | ||
| fsdp.rst | ||
| get_start_xpu.rst | ||
| gradcheck.rst | ||
| hip.rst | ||
| large_scale_deployments.rst | ||
| modules.rst | ||
| mps.rst | ||
| multiprocessing.rst | ||
| numerical_accuracy.rst | ||
| randomness.rst | ||
| serialization.rst | ||
| windows.rst | ||