Commit Graph

182 Commits

Author SHA1 Message Date
Mikayla Gawarecki
a135776307 Remove tensor subclass detection logic from weights_only unpickler (#127808)
Remove logic to auto-detect and allow subclasses that did not override certain methods from the weights_only unpickler from https://github.com/pytorch/pytorch/pull/124331 for 2.4 release

Subclasses should be loadable using `torch.serialization.add_safe_globals`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127808
Approved by: https://github.com/malfet
2024-06-05 02:14:30 +00:00
Xuehai Pan
8b08b0f340 [BE] enable ruff rule Q from flake8-quotes (#127713)
Enable [ruff rule `Q`](https://docs.astral.sh/ruff/rules/#flake8-quotes-q) from flake8-quotes. Fixes:

- [avoidable-escaped-quote (Q003)](https://docs.astral.sh/ruff/rules/avoidable-escaped-quote/#avoidable-escaped-quote-q003)
- [unnecessary-escaped-quote (Q004)](https://docs.astral.sh/ruff/rules/unnecessary-escaped-quote/#unnecessary-escaped-quote-q004)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127713
Approved by: https://github.com/ezyang
2024-06-02 23:25:26 +00:00
Mikayla Gawarecki
4644def434 Update docstring for weights_only (#127575)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127575
Approved by: https://github.com/malfet
2024-05-31 14:27:31 +00:00
Mikayla Gawarecki
87f79af24d Fix map_location for wrapper subclass and device tensors that go through numpy (#126728)
Fixes https://github.com/pytorch/pytorch/issues/124418

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126728
Approved by: https://github.com/albanD
2024-05-24 16:39:30 +00:00
Mikayla Gawarecki
66dc8fb7ff Allow tensor subclasses and add torch.serialization.add_safe_globals that allows users to allowlist classes for weights_only load (#124331)
#### Conditions for allowlisting tensor subclasses
We allow tensor subclasses types that
(1) Do not override `__setstate__`, `__getattr__`, `__setattr__`, `__get__`, `__set__` or `__getattribute__` of `torch.Tensor` (`torch.Tensor` does not have a definition of `__getattr__`, `__get__` or `__set__` so we check that these are `None`)
(2) Use the generic `tp_alloc`
(3) Are in a module that *has been imported by the user*
to be pushed onto the stack as strings by `GLOBAL` instructions, while storing the type in a dict

The strings will be converted to the classes as appropriate when executing `REBUILD` with `_rebuild_from_type_v2`

*Note that we use `inspect.getattr_static(sys.modules[module], name)` to get the class/function as this method claims to have no code execution.

The rationale for the 3 conditions above is as follows:

The rebuild func provided by `Tensor.__reduce_ex__` is `torch._tensor._rebuild_from_type_v2`, which is defined as such (note the call to `getattr`, `Tensor.__setstate__` and the call to `as_subclass` as well as the call to `_set_obj_state` which calls `setattr`)

4e66aaa010/torch/_tensor.py (L57-L71)

`as_subclass` is implemented with a call to `THPVariable_NewWithVar`

that will eventually call `tp_alloc` here
4e66aaa010/torch/csrc/autograd/python_variable.cpp (L2053)

The `func` arg to `_rebuild_from_type_v2` for wrapper subclasses is `Tensor.rebuild_wrapper_subclass`, which will similarly call into `THPVariable_NewWithVar` and hit the above `tp_alloc`

**Note that we do not call `tp_init` or `tp_new` (i.e. `cls.__init__` or `cls.__new__`) when unpickling**

### How do we check something is a tensor subclass/constraints around imports

In order to check whether `bla` is a tensor subclass in the bytecode `GLOBAL module.name`, we need to do an `issubclass` check, which entails converting the global string to the appropriate type. We *do not* arbitrarily import modules but will perform this check as long as the given subclass (given by `module.name`) has already been imported by the user (i.e. `module in sys.modules` and `issubclass(getattr(sys[modules], name), torch.Tensor)`

This PR also allowlisted  `torch._utils._rebuild_wrapper_subclass` and `torch.device` (used by `_rebuild_wrapper_subclass`)

### API for allow listing
This PR also added `torch.serialization.{add/get/clear}_safe_globals` that enables user to allowlist globals they have deemed safe and manipulate this list (for example they could allowlist a tensor subclass with a custom `__setstate__` if they have checked that this is safe).

Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g. `DTensor`, `FakeTensor` etc.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124331
Approved by: https://github.com/albanD
2024-05-17 17:56:57 +00:00
Mikayla Gawarecki
776b878917 [easy] Fix typing for map_location docs in torch.load (#125473)
Currently it incorrectly has `Callable[[Tensor, str], Tensor]` as a possible type signature, this should be `Callable[[Storage, str], Storage]`

<img width="716" alt="Screenshot 2024-05-03 at 12 09 54 PM" src="https://github.com/pytorch/pytorch/assets/35276741/b8946f95-8297-445f-a9d9-570b8a3caab1">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125473
Approved by: https://github.com/albanD
2024-05-17 01:15:25 +00:00
Dmitry Rogozhkin
8f0c207e18 xpu: implement xpu serialization (#125530)
Fixes: #125529

BC-breaking note:
The deprecated "async" argument to the Storage.cuda and Storage.hpu has been removed. Use non_blocking instead.

CC: @jbschlosser, @frank-wei @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @albanD

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125530
Approved by: https://github.com/guangyey, https://github.com/albanD
2024-05-16 20:22:17 +00:00
Mikayla Gawarecki
2480e8b8a1 Add MAP_SHARED option for torch.load(mmap=True) (#124889)
Fixes #124528

Going over the options for our MapAllocator and what they do, I don't think any other of them need to be piped up to `torch.load`

4f29103749/aten/src/ATen/MapAllocator.h (L8-L16)

~However, I wonder if this `MmapVisibility(Enum)` is a good way to represent "or-ing" together of `mmap` flags if we want to extend it in the future. I looked over the flags for [`mmap(2)`](https://man7.org/linux/man-pages/man2/mmap.2.html), and could not immediately see how most of them would be useful for `torch.load` (would maybe `MAP_LOCKED` (like `mlock`) or `MAP_HUGE` ever be worthwhile?)~

Using the flags provided by the python `mmap` library so that we can extend the allowed flags and pipe them down to the cpp `mmap` call if there is a need for other flags in the future

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124889
Approved by: https://github.com/albanD
2024-04-30 15:02:19 +00:00
Thiago Crepaldi
6c11d3ce0c Add support to save safetensors checkpoint directly into onnx (#121001)
Currently, when `torch.onnx.dynamo_export` is called within `torch.onnx.enable_fake_mode`, all the external pytorch checkpoint files used to initialize the model are automatically and used by `torch.onnx.ONNXProgram.save` to recreate the initializers for
the newly exported ONNX model.

This API extends the mechanism for HuggingFace models that use safetensors weights. This PR detects safetensors state files and converts them to PyTorch format using mmap on a temporary file, which is deleted after conversion is finished.

Without this PR, the user would have to convert the safetensors files to pytorch format manually and feed it to `torch.onnx.ONNXProgram.save` manually.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121001
Approved by: https://github.com/BowenBao, https://github.com/malfet
2024-03-11 15:21:59 +00:00
albanD
8cb4855d1e Release the GIL in serialization when it is safe to do so (#120818)
In particular this ensures we release the GIL when serializing:
- PyBytes objects (this is how we get the pickle object)
- Storage objects

Other string-like objects keep the gil which is fine because we only use this for very small strings today (for endianess) and so releasing the GIL is not important there
Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120818
Approved by: https://github.com/colesbury
2024-03-01 22:37:26 +00:00
Sam Larsen
06f8af30fa Change FakeTensor serialization to consider only an _active_ FakeTensor mode (#120848)
Summary: https://github.com/pytorch/pytorch/pull/108186 make some changes related to FakeTensor serialization such that saving and loading a tensor will give us a meta tensor, even if FakeTensor mode is not enabled. This means we can't properly save and load Tensors as part of Fx graph caching. This PR changes the logic to check if there's an _active_ FakeTensor mode.

Test Plan:
* New unit tests
* Validated unit tests introduced in https://github.com/pytorch/pytorch/pull/108186 still pass
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120848
Approved by: https://github.com/eellison, https://github.com/thiagocrepaldi
2024-03-01 02:37:21 +00:00
Thiago Crepaldi
761fa5d6ec Add FakeTensor support to torch._utils._rebuild_tensor (#108186)
There are two scenarios:

* Scenario 1: The checkpoint was saved with pytorch < 1.6
* Scenario 2: The checkpoint was saved with pytorch >= 1.6

Repro Scenario 1:

```python
from torch._subclasses import fake_tensor
import transformers

fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
    fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2")
```

Error:

```bash
Some weights of the model checkpoint at sshleifer/tiny-gpt2 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:463 in           │
│ load_state_dict                                                                                  │
│                                                                                                  │
│    460 │   │   │   )                                                                             │
│    461 │   │   return safe_load_file(checkpoint_file)                                            │
│    462 │   try:                                                                                  │
│ ❱  463 │   │   return torch.load(checkpoint_file, map_location="cpu")                            │
│    464 │   except Exception as e:                                                                │
│    465 │   │   try:                                                                              │
│    466 │   │   │   with open(checkpoint_file) as f:                                              │
│                                                                                                  │
│ /opt/pytorch/torch/serialization.py:1030 in load                                                 │
│                                                                                                  │
│   1027 │   │   │   │   return _legacy_load(opened_file, map_location, _weights_only_unpickler,   │
│   1028 │   │   │   except RuntimeError as e:                                                     │
│   1029 │   │   │   │   raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None           │
│ ❱ 1030 │   │   return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args  │
│   1031                                                                                           │
│   1032                                                                                           │
│   1033 # Register pickling support for layout instances such as                                  │
│                                                                                                  │
│ /opt/pytorch/torch/serialization.py:1258 in _legacy_load                                         │
│                                                                                                  │
│   1255 │   _sys_info = pickle_module.load(f, **pickle_load_args)                                 │
│   1256 │   unpickler = UnpicklerWrapper(f, **pickle_load_args)                                   │
│   1257 │   unpickler.persistent_load = persistent_load                                           │
│ ❱ 1258 │   result = unpickler.load()                                                             │
│   1259 │                                                                                         │
│   1260 │   deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)                 │
│   1261                                                                                           │
│                                                                                                  │
│ /opt/pytorch/torch/_utils.py:201 in _rebuild_tensor_v2                                           │
│                                                                                                  │
│   198 def _rebuild_tensor_v2(                                                                    │
│   199 │   storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None    │
│   200 ):                                                                                         │
│ ❱ 201 │   tensor = _rebuild_tensor(storage, storage_offset, size, stride)                        │
│   202 │   tensor.requires_grad = requires_grad                                                   │
│   203 │   if metadata:                                                                           │
│   204 │   │   set_tensor_metadata(tensor, metadata)                                              │
│                                                                                                  │
│ /opt/pytorch/torch/_utils.py:180 in _rebuild_tensor                                              │
│                                                                                                  │
│   177 def _rebuild_tensor(storage, storage_offset, size, stride):                                │
│   178 │   # first construct a tensor with the correct dtype/device                               │
│   179 │   t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_storage.device)      │
│ ❱ 180 │   return t.set_(storage._untyped_storage, storage_offset, size, stride)                  │
│   181                                                                                            │
│   182                                                                                            │
│   183 def get_tensor_metadata(tensor):                                                           │
│                                                                                                  │
│ /opt/pytorch/torch/utils/_stats.py:20 in wrapper                                                 │
│                                                                                                  │
│   17 │   │   if fn.__qualname__ not in simple_call_counter:                                      │
│   18 │   │   │   simple_call_counter[fn.__qualname__] = 0                                        │
│   19 │   │   simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1     │
│ ❱ 20 │   │   return fn(*args, **kwargs)                                                          │
│   21 │   return wrapper                                                                          │
│   22                                                                                             │
│                                                                                                  │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1160 in __torch_dispatch__                         │
│                                                                                                  │
│   1157 │   def __torch_dispatch__(self, func, types, args=(), kwargs=None):                      │
│   1158 │   │   assert self not in _get_current_dispatch_mode_stack(), func                       │
│   1159 │   │   try:                                                                              │
│ ❱ 1160 │   │   │   return self.dispatch(func, types, args, kwargs)                               │
│   1161 │   │   except TypeError:                                                                 │
│   1162 │   │   │   log.exception("fake tensor raised TypeError")                                 │
│   1163 │   │   │   raise                                                                         │
│                                                                                                  │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1318 in dispatch                                   │
│                                                                                                  │
│   1315 │   │                                                                                     │
│   1316 │   │   # we are falling through to running non constant tensors, any input constant tha  │
│   1317 │   │   # is written to must be invalidated                                               │
│ ❱ 1318 │   │   self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)   │
│   1319 │   │                                                                                     │
│   1320 │   │   # Try for fastpath                                                                │
│   1321 │   │   if has_symbolic_sizes:                                                            │
│                                                                                                  │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1557 in invalidate_written_to_constants            │
│                                                                                                  │
│   1554 │   │   any_constant = any(e.constant is not None for e in flat_arg_fake_tensors)         │
│   1555 │   │   if any_constant and get_schema_info(func).is_mutable():                           │
│   1556 │   │   │   schema_info = get_schema_info(func)                                           │
│ ❱ 1557 │   │   │   _, new_kwargs = normalize_function(                                           │
│   1558 │   │   │   │   func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True         │
│   1559 │   │   │   )                                                                             │
│   1560 │   │   │   for k, v in new_kwargs.items():                                               │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:297 in normalize_function                              │
│                                                                                                  │
│   294 │   │   new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs,    │
│   295 │   else:                                                                                  │
│   296 │   │   assert callable(target)                                                            │
│ ❱ 297 │   │   torch_op_schemas = get_signature_for_torch_op(target)                              │
│   298 │   │   matched_schemas = []                                                               │
│   299 │   │   if torch_op_schemas:                                                               │
│   300 │   │   │   # Iterate through all of the schema until we find one that matches             │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:167 in get_signature_for_torch_op                      │
│                                                                                                  │
│   164 │   │   │   return (None, None) if return_schemas else None                                │
│   165 │   │   schemas = torch._C._jit_get_schemas_for_operator(aten_fn)                          │
│   166 │                                                                                          │
│ ❱ 167 │   signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]          │
│   168 │   return (signatures, schemas) if return_schemas else signatures                         │
│   169                                                                                            │
│   170 @compatibility(is_backward_compatible=False)                                               │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:167 in <listcomp>                                      │
│                                                                                                  │
│   164 │   │   │   return (None, None) if return_schemas else None                                │
│   165 │   │   schemas = torch._C._jit_get_schemas_for_operator(aten_fn)                          │
│   166 │                                                                                          │
│ ❱ 167 │   signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]          │
│   168 │   return (signatures, schemas) if return_schemas else signatures                         │
│   169                                                                                            │
│   170 @compatibility(is_backward_compatible=False)                                               │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:70 in _torchscript_schema_to_signature                 │
│                                                                                                  │
│    67 │   from inspect import Parameter                                                          │
│    68 │   parameters : List[Parameter] = []                                                      │
│    69 │   for arg in ts_schema.arguments:                                                        │
│ ❱  70 │   │   arg_type = _torchscript_type_to_python_type(arg.type)                              │
│    71 │   │   default = arg.default_value if arg.has_default_value() else Parameter.empty        │
│    72 │   │   # TODO: Figure out if this is safe. It seems like when generating the type signa   │
│    73 │   │   # PythonArgParser, we emit signatures with `input` instead of `self` as the firs   │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:64 in _torchscript_type_to_python_type                 │
│                                                                                                  │
│    61 │   eval'ing the annotation_str. _type_eval_globals sets up expressions                    │
│    62 │   like "List" and "Future" to map to actual types (typing.List and jit.Future)           │
│    63 │   """                                                                                    │
│ ❱  64 │   return eval(ts_type.annotation_str, _type_eval_globals)                                │
│    65                                                                                            │
│    66 def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Sig   │
│    67 │   from inspect import Parameter                                                          │
│ <string>:1 in <module>                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NameError: name 'Storage' is not defined

During handling of the above exception, another exception occurred:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:467 in           │
│ load_state_dict                                                                                  │
│                                                                                                  │
│    464 │   except Exception as e:                                                                │
│    465 │   │   try:                                                                              │
│    466 │   │   │   with open(checkpoint_file) as f:                                              │
│ ❱  467 │   │   │   │   if f.read(7) == "version":                                                │
│    468 │   │   │   │   │   raise OSError(                                                        │
│    469 │   │   │   │   │   │   "You seem to have cloned a repository without having git-lfs ins  │
│    470 │   │   │   │   │   │   "git-lfs and run `git lfs install` followed by `git lfs pull` in  │
│                                                                                                  │
│ /opt/conda/envs/ptca/lib/python3.8/codecs.py:322 in decode                                       │
│                                                                                                  │
│    319 │   def decode(self, input, final=False):                                                 │
│    320 │   │   # decode input (taking the buffer into account)                                   │
│    321 │   │   data = self.buffer + input                                                        │
│ ❱  322 │   │   (result, consumed) = self._buffer_decode(data, self.errors, final)                │
│    323 │   │   # keep undecoded input until the next call                                        │
│    324 │   │   self.buffer = data[consumed:]                                                     │
│    325 │   │   return result                                                                     │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte

During handling of the above exception, another exception occurred:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/pytorch/bug_repro.py:16 in <module>                                                         │
│                                                                                                  │
│   13 fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2")                  │
│   14 assert fake_model is not None                                                               │
│   15 with fake_mode:                                                                             │
│ ❱ 16 │   fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2")  # raises    │
│                                                                                                  │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py:484 in │
│ from_pretrained                                                                                  │
│                                                                                                  │
│   481 │   │   │   )                                                                              │
│   482 │   │   elif type(config) in cls._model_mapping.keys():                                    │
│   483 │   │   │   model_class = _get_model_class(config, cls._model_mapping)                     │
│ ❱ 484 │   │   │   return model_class.from_pretrained(                                            │
│   485 │   │   │   │   pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs,   │
│   486 │   │   │   )                                                                              │
│   487 │   │   raise ValueError(                                                                  │
│                                                                                                  │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:2604 in          │
│ from_pretrained                                                                                  │
│                                                                                                  │
│   2601 │   │   if from_pt:                                                                       │
│   2602 │   │   │   if not is_sharded and state_dict is None:                                     │
│   2603 │   │   │   │   # Time to load the checkpoint                                             │
│ ❱ 2604 │   │   │   │   state_dict = load_state_dict(resolved_archive_file)                       │
│   2605 │   │   │                                                                                 │
│   2606 │   │   │   # set dtype to instantiate the model under:                                   │
│   2607 │   │   │   # 1. If torch_dtype is not None, we use that dtype                            │
│                                                                                                  │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:479 in           │
│ load_state_dict                                                                                  │
│                                                                                                  │
│    476 │   │   │   │   │   │   "model. Make sure you have saved the model properly."             │
│    477 │   │   │   │   │   ) from e                                                              │
│    478 │   │   except (UnicodeDecodeError, ValueError):                                          │
│ ❱  479 │   │   │   raise OSError(                                                                │
│    480 │   │   │   │   f"Unable to load weights from pytorch checkpoint file for '{checkpoint_f  │
│    481 │   │   │   │   f"at '{checkpoint_file}'. "                                               │
│    482 │   │   │   │   "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please s  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OSError: Unable to load weights from pytorch checkpoint file for '/root/.cache/huggingface/hub/models--sshleifer--tiny-gpt2/snapshots/5f91d94bd9cd7190a9f3216ff93cd1dd95f2c7be/pytorch_model.bin' at
'/root/.cache/huggingface/hub/models--sshleifer--tiny-gpt2/snapshots/5f91d94bd9cd7190a9f3216ff93cd1dd95f2c7be/pytorch_model.bin'. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set
from_tf=True.
```

Repro scenario 2:

```python
import tempfile
import torch
from torch._subclasses import fake_tensor

class TheModelClass(torch.nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.fc1 = torch.nn.Linear(5, 10)

    def forward(self, x):
        return self.fc1(x)

with tempfile.NamedTemporaryFile() as state_dict_file:
    # Create state_dict to be loaded later
    model = TheModelClass()
    torch.save(model.state_dict(), state_dict_file.name)

    fake_mode = fake_tensor.FakeTensorMode()
    with fake_mode:
        # This is where the bug is triggered
        state_dict = torch.load(state_dict_file.name)
```

Error:

```bash
Traceback (most recent call last):
  File "issue_gh_torch_105077.py", line 22, in <module>
    state_dict = torch.load(state_dict_file.name)
  File "/opt/pytorch/torch/serialization.py", line 1014, in load
    return _load(opened_zipfile,
  File "/opt/pytorch/torch/serialization.py", line 1422, in _load
    result = unpickler.load()
  File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
    return t.set_(storage._untyped_storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
    self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
    _, new_kwargs = normalize_function(
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
    torch_op_schemas = get_signature_for_torch_op(target)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
    arg_type = _torchscript_type_to_python_type(arg.type)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
    return eval(ts_type.annotation_str, _type_eval_globals)
  File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```

This PR adds the ability to create fake tensors during torch.load (when fake mode is active) by changing the storage's device to 'meta'.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang, https://github.com/atalman
2024-02-16 23:42:50 +00:00
PyTorch MergeBot
458e83b5b3 Revert "Add FakeTensor support to torch._utils._rebuild_tensor (#108186)"
This reverts commit 113506d2d4.

Reverted https://github.com/pytorch/pytorch/pull/108186 on behalf of https://github.com/atalman due to Reverted Internally ([comment](https://github.com/pytorch/pytorch/pull/108186#issuecomment-1935310344))
2024-02-09 04:19:20 +00:00
Thiago Crepaldi
113506d2d4 Add FakeTensor support to torch._utils._rebuild_tensor (#108186)
Partially fixes https://github.com/pytorch/pytorch/issues/105077

Repro:

```python
import tempfile
import torch
from torch._subclasses import fake_tensor

class TheModelClass(torch.nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.fc1 = torch.nn.Linear(5, 10)

    def forward(self, x):
        return self.fc1(x)

with tempfile.NamedTemporaryFile() as state_dict_file:
    # Create state_dict to be loaded later
    model = TheModelClass()
    torch.save(model.state_dict(), state_dict_file.name)

    fake_mode = fake_tensor.FakeTensorMode()
    with fake_mode:
        # This is where the bug is triggered
        state_dict = torch.load(state_dict_file.name)
```

Error:

```bash
Traceback (most recent call last):
  File "issue_gh_torch_105077.py", line 22, in <module>
    state_dict = torch.load(state_dict_file.name)
  File "/opt/pytorch/torch/serialization.py", line 1014, in load
    return _load(opened_zipfile,
  File "/opt/pytorch/torch/serialization.py", line 1422, in _load
    result = unpickler.load()
  File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
    return t.set_(storage._untyped_storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
    self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
    _, new_kwargs = normalize_function(
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
    torch_op_schemas = get_signature_for_torch_op(target)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
    arg_type = _torchscript_type_to_python_type(arg.type)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
    return eval(ts_type.annotation_str, _type_eval_globals)
  File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```

This PR adds the ability to create fake tensors during `torch.load` by wrapping the `torch.tensor.set_` call around a `torch.utils._mode_utils.no_dispatch()` to skip fake mode dispatcher for it and thus create a real tensor. It later calls `fake_mode.from_tensor(t)` to finally create the fake tensor.

Co-authored-by: Edward Z. Yang <ezyang@mit.edu>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang
2024-02-08 03:01:34 +00:00
PyTorch MergeBot
499040ac32 Revert "Add FakeTensor support to torch._utils._rebuild_tensor (#108186)"
This reverts commit 426339e4de.

Reverted https://github.com/pytorch/pytorch/pull/108186 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/108186#issuecomment-1929978008))
2024-02-06 15:04:48 +00:00
Thiago Crepaldi
426339e4de Add FakeTensor support to torch._utils._rebuild_tensor (#108186)
Partially fixes https://github.com/pytorch/pytorch/issues/105077

Repro:

```python
import tempfile
import torch
from torch._subclasses import fake_tensor

class TheModelClass(torch.nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.fc1 = torch.nn.Linear(5, 10)

    def forward(self, x):
        return self.fc1(x)

with tempfile.NamedTemporaryFile() as state_dict_file:
    # Create state_dict to be loaded later
    model = TheModelClass()
    torch.save(model.state_dict(), state_dict_file.name)

    fake_mode = fake_tensor.FakeTensorMode()
    with fake_mode:
        # This is where the bug is triggered
        state_dict = torch.load(state_dict_file.name)
```

Error:

```bash
Traceback (most recent call last):
  File "issue_gh_torch_105077.py", line 22, in <module>
    state_dict = torch.load(state_dict_file.name)
  File "/opt/pytorch/torch/serialization.py", line 1014, in load
    return _load(opened_zipfile,
  File "/opt/pytorch/torch/serialization.py", line 1422, in _load
    result = unpickler.load()
  File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
    return t.set_(storage._untyped_storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
    self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
    _, new_kwargs = normalize_function(
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
    torch_op_schemas = get_signature_for_torch_op(target)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
    arg_type = _torchscript_type_to_python_type(arg.type)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
    return eval(ts_type.annotation_str, _type_eval_globals)
  File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```

This PR adds the ability to create fake tensors during `torch.load` by wrapping the `torch.tensor.set_` call around a `torch.utils._mode_utils.no_dispatch()` to skip fake mode dispatcher for it and thus create a real tensor. It later calls `fake_mode.from_tensor(t)` to finally create the fake tensor.

Co-authored-by: Edward Z. Yang <ezyang@mit.edu>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang
2024-02-02 20:35:38 +00:00
Edward Z. Yang
46712b019d Enable local_partial_types (#118467)
When using dmypy, this setting is enabled and cannot be turned off. Force it for regular mypy too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118467
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432
2024-01-28 13:38:22 +00:00
Aaron Gokaslan
4f9858a902 [BE]: Use os.fspath and os.PathLike in torch serialization (#116562)
Use proper `os.fspath` to better convert `os.PathLike` object to a path.
Replace `pathlib.Path` with `os.PathLike` which is more generic and typing correct. `pathlib.Path` is an instance of `os.PathLike`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116562
Approved by: https://github.com/malfet
2023-12-30 20:53:10 +00:00
Adrian Wälchli
8220d5c66d Support pathlib.Path as input to torch.load when mmap=True (#116104)
Fixes #116103

This now works:

```py
import torch
from pathlib import Path

file = Path("example.pt")
torch.save(torch.rand(5, 3), file)
torch.load(file, mmap=True)   # works!
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116104
Approved by: https://github.com/mikaylagawarecki
2023-12-28 22:54:11 +00:00
Nikita Shulga
1d640566d4 [BE] Do not warn when safely loading legacy dicts (#113614)
Use the same strategy as for unsafe pickler, i.e. use dummy `torch.serialization.StorageType` to represent legacy typed storage classes during deserialization. Add `_dtype` property to be able to use it for both new and legacy format deserialization.

Parametrize `test_serialization_new_format_old_format_compat`

Add regression test to validate that loading legacy modes can be done
without any warnings

Before the change:
```
% python test_serialization.py -v -k test_serialization_new_format_old_format_compat_
test_serialization_new_format_old_format_compat_cpu (__main__.TestBothSerializationCPU) ... ok
test_serialization_new_format_old_format_compat_safe_cpu (__main__.TestBothSerializationCPU) ... /Users/nshulga/git/pytorch/pytorch/torch/_utils.py:836: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
ok

----------------------------------------------------------------------
Ran 2 tests in 0.116s

OK
```
Without the change but update test to catch warnings:
```
 % python test_serialization.py -v -k test_serialization_new_format_old_format_compat_
test_serialization_new_format_old_format_compat_weights_only_False_cpu (__main__.TestBothSerializationCPU) ... ok
test_serialization_new_format_old_format_compat_weights_only_True_cpu (__main__.TestBothSerializationCPU) ... FAIL

======================================================================
FAIL: test_serialization_new_format_old_format_compat_weights_only_True_cpu (__main__.TestBothSerializationCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 2536, in wrapper
    method(*args, **kwargs)
  File "/Users/nshulga/git/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
    result = test(self, **param_kwargs)
  File "/Users/nshulga/git/pytorch/pytorch/test/test_serialization.py", line 807, in test_serialization_new_format_old_format_compat
    self.assertTrue(len(w) == 0, msg=f"Expected no warnings but got {[str(x) for x in w]}")
AssertionError: False is not true : Expected no warnings but got ["{message : UserWarning('TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()'), category : 'UserWarning', filename : '/Users/nshulga/git/pytorch/pytorch/torch/_utils.py', lineno : 836, line : None}"]

To execute this test, run the following from the base repo dir:
     python test/test_serialization.py -k test_serialization_new_format_old_format_compat_weights_only_True_cpu

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

----------------------------------------------------------------------
Ran 2 tests in 0.109s

FAILED (failures=1)

```

Fixes problem reported in https://github.com/pytorch/pytorch/issues/52181#issuecomment-1715738910
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113614
Approved by: https://github.com/kit1980, https://github.com/albanD
2023-11-14 22:09:10 +00:00
Vidit Agarwal
7b99b3efb1 added 'weights_only' param in torch.load examples (#112860)
Fixes #111876

`torch.load` without setting `weights_only=True` is unsafe. So updating examples of `torch.load` to use `weights_only=True` where possible and `weights_only=False` elsewhere with a warning of being unsafety.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112860
Approved by: https://github.com/kit1980
2023-11-06 21:17:36 +00:00
XDaoHong
1b34238d67 fix get device index if has _utils._get_device_index in privateuse1 (#108123)
**Get device index by torch.privateuse1._utils._get_device_index, if the metched exists.**

Reason:
Can only get device_index 0 if ```location``` such as 'privateuse1' before modify.
Can get accurate deivce index use _get_device_index in this scenario.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108123
Approved by: https://github.com/albanD
2023-10-07 06:18:59 +00:00
Yu, Guangye
871b5caae7 Fix hpu deserialization bug (#109499)
# Motivation
fix hpu deserialization bug. It should check hpu model if and only if location start with hpu. Otherwise, it always raise an AssertError if hpu is not imported. This break the serialization/desirialization functionality abourt other third-party like IPEX.

# Solution
only assert hpu model when start with hpu

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109499
Approved by: https://github.com/ezyang
2023-09-19 00:10:51 +00:00
Tobias Ringwald
f7574ea43f torch.load: Replaced multiple one byte read() calls during the _is_zipfile check with a single call (#109119)
Fixes #108955.

Right now, the `_is_zipfile` check in `torch.load` performs multiple `read()` calls, reading 1 byte at a time in a loop. This is rather wasteful and leads to performance problems when accessing files on a network share (see #108955) .
This PR replaces those 1 byte calls with a single big call. Functionally, this is equivalent as `read(n)` only reads up to `n` bytes, so even if the file is shorter there should not be any problems.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109119
Approved by: https://github.com/mikaylagawarecki
2023-09-14 19:39:10 +00:00
moto
d64e1c5f9d Fix error message concatanation (#108581)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108581
Approved by: https://github.com/mikaylagawarecki
2023-09-05 19:46:52 +00:00
Aleksei Nikiforov
51c2e22e94 When byteorder record is missing load as little endian by default (#108343)
Fixes #101688

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108343
Approved by: https://github.com/mikaylagawarecki
2023-09-04 15:20:22 +00:00
Aaron Gokaslan
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
Mikayla Gawarecki
9c458942ae [easy] Minor torch.load docs fix (#105876)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105876
Approved by: https://github.com/albanD
2023-07-25 03:58:30 +00:00
Justin Chu
4cc1745b13 [BE] f-stringify torch/ and scripts (#105538)
This PR is a follow up on the pyupgrade series to convert more strings to use f-strings using `flynt`.

- https://docs.python.org/3/reference/lexical_analysis.html#f-strings
- https://pypi.org/project/flynt/

Command used:

```
flynt torch/ -ll 120
flynt scripts/ -ll 120
flynt tools/ -ll 120
```

and excluded `collect_env.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105538
Approved by: https://github.com/ezyang, https://github.com/malfet
2023-07-21 19:35:24 +00:00
Justin Chu
79c5e33349 [BE] Enable ruff's UP rules and autoformat nn/ mps/ and torch/ (#105436)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105436
Approved by: https://github.com/malfet, https://github.com/albanD
2023-07-21 07:38:46 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
b4d91b1c5b Revert "[Typing] Fix PEP 484 Violation (#105022)"
This reverts commit 4148b7bada.

Reverted https://github.com/pytorch/pytorch/pull/105022 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/105022#issuecomment-1635967734))
2023-07-14 14:45:09 +00:00
Nikita Shulga
4148b7bada [Typing] Fix PEP 484 Violation (#105022)
Not sure, how it worked before, but if arguments must be annotated is optional if they are defaulted to None

Towards enabling mypy-1.4.1 in lintrunner

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 5e1b9f4</samp>

> _We annotate the arguments of doom_
> _To show the `None` values of gloom_
> _We improve the type checking and readability_
> _With `Optional` annotations of metal-ity_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105022
Approved by: https://github.com/izaitsevfb, https://github.com/huydhn, https://github.com/Skylion007
2023-07-12 10:20:48 +00:00
Aleksei Nikiforov
c42fd73cf9 Add functions to get and set default endianness in load() functions (#101973)
By default interpret tensor data as native endian, but add an option to interpret data as little endian or big endian.

Related to #101688

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101973
Approved by: https://github.com/mikaylagawarecki
2023-07-06 20:12:56 +00:00
Mikayla Gawarecki
981f24e806 Add docstring to torch.serialization.register_package (#104046)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104046
Approved by: https://github.com/albanD
2023-06-26 23:28:32 +00:00
Paweł Piskorski
7fb2a928cf fix hpu storage serialization (#101680)
Change-Id: Ia534400a0e8972590374eceba5b62a2525b796e5

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101680
Approved by: https://github.com/mikaylagawarecki
2023-06-21 21:19:49 +00:00
magic-akari
e56cdfd74b [MPS] Handle deserialization more permissively (#98834)
MPS deserialization should handle `mps:0`.
It is generated from some codes like the following

```python
torch.rand(size=(3, 4)).to("mps")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98834
Approved by: https://github.com/kulinseth, https://github.com/kit1980, https://github.com/malfet
2023-06-15 15:51:03 +00:00
Mikayla Gawarecki
6fa2d41dc7 Add mmap option to torch.load (#102549)
Using [`nanoGPT/model.py`](https://github.com/karpathy/nanoGPT/blob/master/model.py) run

<details><summary><b>Click for script to save gpt2-xlarge (1.5B params)</b></summary>

```
# test_load_save_gpt.py
from model import GPT
import torch
import time

torch.manual_seed(5)
# gpt2-xlarge 1558M parameters
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 48
    n_head: int = 25
    n_embd: int = 1600
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster

def f():
    model = GPT(GPTConfig())
    state_dict = model.state_dict()

    start_saving = time.time()
    torch.save(state_dict, "gpt2-xlarge.pth")
    end_saving = time.time()

if __name__ == "__main__":
    f()
```
</details>

<details><summary><b>Click for script to load</b></summary>

```
# test_load_gpt.py

import torch
from model import GPT
from test_load_save_gpt import GPTConfig
import time
import argparse

def f(mmap, meta):
    device = 'meta' if meta else 'cpu'
    assign = True if meta else False
    with torch.device(device):
        model = GPT(GPTConfig())
    start_loading = time.time()
    loaded_state_dict = torch.load("gpt2-xlarge.pth", _mmap=mmap)
    end_loading = time.time()
    print(f"loading time using torch.load with mmap={mmap}: ", end_loading - start_loading)
    model.load_state_dict(loaded_state_dict, assign=assign)
    end_load_state_dict = time.time()
    print("load_state_dict time: ", end_load_state_dict - end_loading)
    model.cuda()
    end_cuda = time.time()
    print("cuda time using torch.load with mmap: ", end_cuda - end_load_state_dict)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog='load_gpt_xlarge')
    parser.add_argument('-m', '--mmap', action='store_true')
    parser.add_argument('-d', '--devicemeta', action='store_true')
    args = parser.parse_args()
    mmap = args.mmap
    meta = args.devicemeta
    f(mmap, meta)

```

</details>

`python test_load_gpt.py`

<img width="614" alt="Screenshot 2023-06-06 at 1 35 43 PM" src="https://github.com/pytorch/pytorch/assets/35276741/ee06e5b3-b610-463b-a867-df995d21af29">

`python test_load_gpt.py --mmap`
<img width="622" alt="Screenshot 2023-06-06 at 1 35 30 PM" src="https://github.com/pytorch/pytorch/assets/35276741/00d2fdd0-b1f5-4313-83dc-e540b654b2af">

If we further use the `with torch.device('meta')` context manager and pull the changes from https://github.com/pytorch/pytorch/pull/102212 that allow the model to reuse tensors from the state_dict, we have

`python test_load_gpt.py --mmap --devicemeta`
<img width="727" alt="Screenshot 2023-06-06 at 1 35 51 PM" src="https://github.com/pytorch/pytorch/assets/35276741/b50257d9-092a-49c3-acae-876ee44d009f">

\
\
Running the above in a docker container containing a build of PyTorch with RAM limited to 512mb by

1) running `make -f docker.Makefile` from `pytorch/` directory
2) `docker run -m 512m -it <image> bash`
3) docker cp `gpt2-xlarge.pth` and `test_load_gpt.py` into the image

`python test_load_gpt.py`

Docker will Kill the process due to OOM whereas

`python test_load_gpt.py --mmap --devicemeta`
<img width="635" alt="Screenshot 2023-06-06 at 1 55 48 PM" src="https://github.com/pytorch/pytorch/assets/35276741/f3820d9e-f24c-43e7-885b-3bfdf24ef8ad">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102549
Approved by: https://github.com/albanD
2023-06-09 15:49:58 +00:00
atannous
b469ed72d0 Integrating new API usage metadata logger (#101762)
Summary: The new logger allows passing metadata into the api usage logger. The immediate use case is to pass the serialization_id to the save and load events to be enable tracking serialized models in API events. It could be extended to add more metadata in the future.

Test Plan:
```
buck2 test @//mode/dev //caffe2/caffe2/serialize:inline_container_test
```

Reviewed By: davidberard98

Differential Revision: D45683697

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101762
Approved by: https://github.com/davidberard98
2023-05-26 00:24:26 +00:00
XDaoHong
a723f1f2b9 fix _privateuse1_tag problem (#100632)
Fix _privateuse1_tag bug in torch/serialization.py
Add device_index after device_type.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100632
Approved by: https://github.com/ezyang
2023-05-10 09:53:19 +00:00
Rob Guo
111358de19 Support non-ASCII characters in model file paths (#99453)
Fixes #98918

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99453
Approved by: https://github.com/albanD, https://github.com/malfet
2023-04-26 01:15:49 +00:00
Aleksei Nikiforov
87a2af6d4a Fix loading data on different encoding (#94503)
Add endianness marker when saving,
and if it doesn't match host endianness when loading data, do a byteswap.

Older data will load correctly only on systems
with same endianness it was saved on.
New data should load correctly on systems
with any endianness.

Fixes #65300
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94503
Approved by: https://github.com/kurtamohler, https://github.com/ezyang
2023-04-25 21:05:20 +00:00
XDaoHong
27f8eb8c2b add storage serialization methods for privateuse1 (#98920)
add entry for privateuse1 storage serialization register_package in _register_device_module.
1. User only need to implement `privateuse1_tag` and `privateuse1_deserialize` in the device module of open device. When registering device module, the methods are registered with _package_registry in storage serialization.
2. Provides a fixed sequence number 30 for privateuse1 in storage serialization _package_registry list.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98920
Approved by: https://github.com/ezyang
2023-04-21 01:51:08 +00:00
Xuehai Pan
e6888697c4 Revisit torch._six.string_classes removal (#94709) (#97863)
Revisit `torch._six.string_classes` (which is `(str, bytes)`) removal: `isinstance(obj, string_classes) -> isinstance(obj, str)`.

Both `str` and `bytes` are `Sequence` classes.

```python
In [1]: from typing import Sequence

In [2]: issubclass(bytes, Sequence)
Out[2]: True

In [3]: issubclass(str, Sequence)
Out[3]: True
```

Re-add `bytes` to type guards like:

```python
def is_seq(obj):
    return isinstance(obj, Sequence) and not isinstance(obj, (str, bytes))
```

Ref:

- https://github.com/pytorch/pytorch/pull/94709#issuecomment-1487282912
- #97737
- #97789
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97863
Approved by: https://github.com/Skylion007, https://github.com/albanD
2023-03-30 17:02:45 +00:00
Horace He
5bbec680d7 Fix usages of contextmanager without finally (#96170)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96170
Approved by: https://github.com/ngimel, https://github.com/malfet
2023-03-08 20:59:27 +00:00
Xuehai Pan
b005ec62b9 [BE] Remove dependency on six and future (#94709)
Remove the Python 2 and 3 compatibility library [six](https://pypi.org/project/six) and [future](https://pypi.org/project/future) and `torch._six`. We only support Python 3.8+ now. It's time to retire them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94709
Approved by: https://github.com/malfet, https://github.com/Skylion007
2023-02-14 09:14:14 +00:00