mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
There are two scenarios:
* Scenario 1: The checkpoint was saved with pytorch < 1.6
* Scenario 2: The checkpoint was saved with pytorch >= 1.6
Repro Scenario 1:
```python
from torch._subclasses import fake_tensor
import transformers
fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2")
```
Error:
```bash
Some weights of the model checkpoint at sshleifer/tiny-gpt2 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:463 in │
│ load_state_dict │
│ │
│ 460 │ │ │ ) │
│ 461 │ │ return safe_load_file(checkpoint_file) │
│ 462 │ try: │
│ ❱ 463 │ │ return torch.load(checkpoint_file, map_location="cpu") │
│ 464 │ except Exception as e: │
│ 465 │ │ try: │
│ 466 │ │ │ with open(checkpoint_file) as f: │
│ │
│ /opt/pytorch/torch/serialization.py:1030 in load │
│ │
│ 1027 │ │ │ │ return _legacy_load(opened_file, map_location, _weights_only_unpickler, │
│ 1028 │ │ │ except RuntimeError as e: │
│ 1029 │ │ │ │ raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None │
│ ❱ 1030 │ │ return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args │
│ 1031 │
│ 1032 │
│ 1033 # Register pickling support for layout instances such as │
│ │
│ /opt/pytorch/torch/serialization.py:1258 in _legacy_load │
│ │
│ 1255 │ _sys_info = pickle_module.load(f, **pickle_load_args) │
│ 1256 │ unpickler = UnpicklerWrapper(f, **pickle_load_args) │
│ 1257 │ unpickler.persistent_load = persistent_load │
│ ❱ 1258 │ result = unpickler.load() │
│ 1259 │ │
│ 1260 │ deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) │
│ 1261 │
│ │
│ /opt/pytorch/torch/_utils.py:201 in _rebuild_tensor_v2 │
│ │
│ 198 def _rebuild_tensor_v2( │
│ 199 │ storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None │
│ 200 ): │
│ ❱ 201 │ tensor = _rebuild_tensor(storage, storage_offset, size, stride) │
│ 202 │ tensor.requires_grad = requires_grad │
│ 203 │ if metadata: │
│ 204 │ │ set_tensor_metadata(tensor, metadata) │
│ │
│ /opt/pytorch/torch/_utils.py:180 in _rebuild_tensor │
│ │
│ 177 def _rebuild_tensor(storage, storage_offset, size, stride): │
│ 178 │ # first construct a tensor with the correct dtype/device │
│ 179 │ t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_storage.device) │
│ ❱ 180 │ return t.set_(storage._untyped_storage, storage_offset, size, stride) │
│ 181 │
│ 182 │
│ 183 def get_tensor_metadata(tensor): │
│ │
│ /opt/pytorch/torch/utils/_stats.py:20 in wrapper │
│ │
│ 17 │ │ if fn.__qualname__ not in simple_call_counter: │
│ 18 │ │ │ simple_call_counter[fn.__qualname__] = 0 │
│ 19 │ │ simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1 │
│ ❱ 20 │ │ return fn(*args, **kwargs) │
│ 21 │ return wrapper │
│ 22 │
│ │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1160 in __torch_dispatch__ │
│ │
│ 1157 │ def __torch_dispatch__(self, func, types, args=(), kwargs=None): │
│ 1158 │ │ assert self not in _get_current_dispatch_mode_stack(), func │
│ 1159 │ │ try: │
│ ❱ 1160 │ │ │ return self.dispatch(func, types, args, kwargs) │
│ 1161 │ │ except TypeError: │
│ 1162 │ │ │ log.exception("fake tensor raised TypeError") │
│ 1163 │ │ │ raise │
│ │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1318 in dispatch │
│ │
│ 1315 │ │ │
│ 1316 │ │ # we are falling through to running non constant tensors, any input constant tha │
│ 1317 │ │ # is written to must be invalidated │
│ ❱ 1318 │ │ self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) │
│ 1319 │ │ │
│ 1320 │ │ # Try for fastpath │
│ 1321 │ │ if has_symbolic_sizes: │
│ │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1557 in invalidate_written_to_constants │
│ │
│ 1554 │ │ any_constant = any(e.constant is not None for e in flat_arg_fake_tensors) │
│ 1555 │ │ if any_constant and get_schema_info(func).is_mutable(): │
│ 1556 │ │ │ schema_info = get_schema_info(func) │
│ ❱ 1557 │ │ │ _, new_kwargs = normalize_function( │
│ 1558 │ │ │ │ func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True │
│ 1559 │ │ │ ) │
│ 1560 │ │ │ for k, v in new_kwargs.items(): │
│ │
│ /opt/pytorch/torch/fx/operator_schemas.py:297 in normalize_function │
│ │
│ 294 │ │ new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, │
│ 295 │ else: │
│ 296 │ │ assert callable(target) │
│ ❱ 297 │ │ torch_op_schemas = get_signature_for_torch_op(target) │
│ 298 │ │ matched_schemas = [] │
│ 299 │ │ if torch_op_schemas: │
│ 300 │ │ │ # Iterate through all of the schema until we find one that matches │
│ │
│ /opt/pytorch/torch/fx/operator_schemas.py:167 in get_signature_for_torch_op │
│ │
│ 164 │ │ │ return (None, None) if return_schemas else None │
│ 165 │ │ schemas = torch._C._jit_get_schemas_for_operator(aten_fn) │
│ 166 │ │
│ ❱ 167 │ signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] │
│ 168 │ return (signatures, schemas) if return_schemas else signatures │
│ 169 │
│ 170 @compatibility(is_backward_compatible=False) │
│ │
│ /opt/pytorch/torch/fx/operator_schemas.py:167 in <listcomp> │
│ │
│ 164 │ │ │ return (None, None) if return_schemas else None │
│ 165 │ │ schemas = torch._C._jit_get_schemas_for_operator(aten_fn) │
│ 166 │ │
│ ❱ 167 │ signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] │
│ 168 │ return (signatures, schemas) if return_schemas else signatures │
│ 169 │
│ 170 @compatibility(is_backward_compatible=False) │
│ │
│ /opt/pytorch/torch/fx/operator_schemas.py:70 in _torchscript_schema_to_signature │
│ │
│ 67 │ from inspect import Parameter │
│ 68 │ parameters : List[Parameter] = [] │
│ 69 │ for arg in ts_schema.arguments: │
│ ❱ 70 │ │ arg_type = _torchscript_type_to_python_type(arg.type) │
│ 71 │ │ default = arg.default_value if arg.has_default_value() else Parameter.empty │
│ 72 │ │ # TODO: Figure out if this is safe. It seems like when generating the type signa │
│ 73 │ │ # PythonArgParser, we emit signatures with `input` instead of `self` as the firs │
│ │
│ /opt/pytorch/torch/fx/operator_schemas.py:64 in _torchscript_type_to_python_type │
│ │
│ 61 │ eval'ing the annotation_str. _type_eval_globals sets up expressions │
│ 62 │ like "List" and "Future" to map to actual types (typing.List and jit.Future) │
│ 63 │ """ │
│ ❱ 64 │ return eval(ts_type.annotation_str, _type_eval_globals) │
│ 65 │
│ 66 def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Sig │
│ 67 │ from inspect import Parameter │
│ <string>:1 in <module> │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NameError: name 'Storage' is not defined
During handling of the above exception, another exception occurred:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:467 in │
│ load_state_dict │
│ │
│ 464 │ except Exception as e: │
│ 465 │ │ try: │
│ 466 │ │ │ with open(checkpoint_file) as f: │
│ ❱ 467 │ │ │ │ if f.read(7) == "version": │
│ 468 │ │ │ │ │ raise OSError( │
│ 469 │ │ │ │ │ │ "You seem to have cloned a repository without having git-lfs ins │
│ 470 │ │ │ │ │ │ "git-lfs and run `git lfs install` followed by `git lfs pull` in │
│ │
│ /opt/conda/envs/ptca/lib/python3.8/codecs.py:322 in decode │
│ │
│ 319 │ def decode(self, input, final=False): │
│ 320 │ │ # decode input (taking the buffer into account) │
│ 321 │ │ data = self.buffer + input │
│ ❱ 322 │ │ (result, consumed) = self._buffer_decode(data, self.errors, final) │
│ 323 │ │ # keep undecoded input until the next call │
│ 324 │ │ self.buffer = data[consumed:] │
│ 325 │ │ return result │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte
During handling of the above exception, another exception occurred:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/pytorch/bug_repro.py:16 in <module> │
│ │
│ 13 fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2") │
│ 14 assert fake_model is not None │
│ 15 with fake_mode: │
│ ❱ 16 │ fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2") # raises │
│ │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py:484 in │
│ from_pretrained │
│ │
│ 481 │ │ │ ) │
│ 482 │ │ elif type(config) in cls._model_mapping.keys(): │
│ 483 │ │ │ model_class = _get_model_class(config, cls._model_mapping) │
│ ❱ 484 │ │ │ return model_class.from_pretrained( │
│ 485 │ │ │ │ pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, │
│ 486 │ │ │ ) │
│ 487 │ │ raise ValueError( │
│ │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:2604 in │
│ from_pretrained │
│ │
│ 2601 │ │ if from_pt: │
│ 2602 │ │ │ if not is_sharded and state_dict is None: │
│ 2603 │ │ │ │ # Time to load the checkpoint │
│ ❱ 2604 │ │ │ │ state_dict = load_state_dict(resolved_archive_file) │
│ 2605 │ │ │ │
│ 2606 │ │ │ # set dtype to instantiate the model under: │
│ 2607 │ │ │ # 1. If torch_dtype is not None, we use that dtype │
│ │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:479 in │
│ load_state_dict │
│ │
│ 476 │ │ │ │ │ │ "model. Make sure you have saved the model properly." │
│ 477 │ │ │ │ │ ) from e │
│ 478 │ │ except (UnicodeDecodeError, ValueError): │
│ ❱ 479 │ │ │ raise OSError( │
│ 480 │ │ │ │ f"Unable to load weights from pytorch checkpoint file for '{checkpoint_f │
│ 481 │ │ │ │ f"at '{checkpoint_file}'. " │
│ 482 │ │ │ │ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please s │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OSError: Unable to load weights from pytorch checkpoint file for '/root/.cache/huggingface/hub/models--sshleifer--tiny-gpt2/snapshots/5f91d94bd9cd7190a9f3216ff93cd1dd95f2c7be/pytorch_model.bin' at
'/root/.cache/huggingface/hub/models--sshleifer--tiny-gpt2/snapshots/5f91d94bd9cd7190a9f3216ff93cd1dd95f2c7be/pytorch_model.bin'. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set
from_tf=True.
```
Repro scenario 2:
```python
import tempfile
import torch
from torch._subclasses import fake_tensor
class TheModelClass(torch.nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.fc1 = torch.nn.Linear(5, 10)
def forward(self, x):
return self.fc1(x)
with tempfile.NamedTemporaryFile() as state_dict_file:
# Create state_dict to be loaded later
model = TheModelClass()
torch.save(model.state_dict(), state_dict_file.name)
fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
# This is where the bug is triggered
state_dict = torch.load(state_dict_file.name)
```
Error:
```bash
Traceback (most recent call last):
File "issue_gh_torch_105077.py", line 22, in <module>
state_dict = torch.load(state_dict_file.name)
File "/opt/pytorch/torch/serialization.py", line 1014, in load
return _load(opened_zipfile,
File "/opt/pytorch/torch/serialization.py", line 1422, in _load
result = unpickler.load()
File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
return t.set_(storage._untyped_storage, storage_offset, size, stride)
File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
_, new_kwargs = normalize_function(
File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
torch_op_schemas = get_signature_for_torch_op(target)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
arg_type = _torchscript_type_to_python_type(arg.type)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
return eval(ts_type.annotation_str, _type_eval_globals)
File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```
This PR adds the ability to create fake tensors during torch.load (when fake mode is active) by changing the storage's device to 'meta'.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang, https://github.com/atalman
894 lines
32 KiB
Python
894 lines
32 KiB
Python
import copyreg
|
|
import functools
|
|
import sys
|
|
import traceback
|
|
import warnings
|
|
from collections import defaultdict
|
|
from typing import Any, DefaultDict, List, Optional
|
|
|
|
import torch
|
|
|
|
|
|
def _type(self, dtype=None, non_blocking=False, **kwargs):
|
|
"""Returns the type if `dtype` is not provided, else casts this object to
|
|
the specified type.
|
|
|
|
If this is already of the correct type, no copy is performed and the
|
|
original object is returned.
|
|
|
|
Args:
|
|
dtype (type or string): The desired type
|
|
non_blocking (bool): If ``True``, and the source is in pinned memory
|
|
and destination is on the GPU or vice versa, the copy is performed
|
|
asynchronously with respect to the host. Otherwise, the argument
|
|
has no effect.
|
|
**kwargs: For compatibility, may contain the key ``async`` in place of
|
|
the ``non_blocking`` argument. The ``async`` arg is deprecated.
|
|
"""
|
|
non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs)
|
|
if dtype is None:
|
|
return self.__module__ + "." + self.__class__.__name__
|
|
|
|
if isinstance(dtype, str):
|
|
dtype = _import_dotted_name(dtype)
|
|
if dtype == type(self):
|
|
return self
|
|
if self.is_sparse:
|
|
if not dtype.is_sparse:
|
|
raise RuntimeError("Cannot cast sparse tensor to dense tensor")
|
|
new_module_name = dtype.__module__.replace(".sparse", "")
|
|
new_values_type_name = new_module_name + "." + dtype.__name__
|
|
new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
|
|
new_indices_type_name = new_module_name + ".LongTensor"
|
|
new_indices = torch.Tensor._indices(self).type(
|
|
new_indices_type_name, non_blocking
|
|
)
|
|
return dtype(new_indices, new_values, self.size())
|
|
if dtype.is_sparse:
|
|
raise RuntimeError("Cannot cast dense tensor to sparse tensor")
|
|
return dtype(self.size()).copy_(self, non_blocking)
|
|
|
|
|
|
def _hpu(self, device=None, non_blocking=False, **kwargs):
|
|
"""Returns a copy of this object in HPU memory.
|
|
|
|
If this object is already in HPU memory and on the correct device, then
|
|
no copy is performed and the original object is returned.
|
|
|
|
Args:
|
|
device (int): The destination HPU id. Defaults to the current device.
|
|
non_blocking (bool): If ``True`` and the source is in pinned memory,
|
|
the copy will be asynchronous with respect to the host. Otherwise,
|
|
the argument has no effect.
|
|
**kwargs: For compatibility, may contain the key ``async`` in place of
|
|
the ``non_blocking`` argument.
|
|
"""
|
|
non_blocking = _get_async_or_non_blocking("hpu", non_blocking, kwargs)
|
|
hpu = getattr(torch, "hpu", None)
|
|
assert hpu is not None, "HPU device module is not loaded"
|
|
if self.is_hpu:
|
|
if device is None:
|
|
device = hpu.current_device()
|
|
if self.get_device() == device:
|
|
return self
|
|
else:
|
|
if device is None:
|
|
device = -1
|
|
with hpu.device(device):
|
|
assert not self.is_sparse, "sparse storage is not supported for HPU tensors"
|
|
untyped_storage = torch.UntypedStorage(self.size(), device=torch.device("hpu"))
|
|
untyped_storage.copy_(self, non_blocking)
|
|
return untyped_storage
|
|
|
|
|
|
def _cuda(self, device=None, non_blocking=False, **kwargs):
|
|
"""Returns a copy of this object in CUDA memory.
|
|
|
|
If this object is already in CUDA memory and on the correct device, then
|
|
no copy is performed and the original object is returned.
|
|
|
|
Args:
|
|
device (int): The destination GPU id. Defaults to the current device.
|
|
non_blocking (bool): If ``True`` and the source is in pinned memory,
|
|
the copy will be asynchronous with respect to the host. Otherwise,
|
|
the argument has no effect.
|
|
**kwargs: For compatibility, may contain the key ``async`` in place of
|
|
the ``non_blocking`` argument.
|
|
"""
|
|
non_blocking = _get_async_or_non_blocking("cuda", non_blocking, kwargs)
|
|
if self.is_cuda:
|
|
if device is None:
|
|
device = torch.cuda.current_device()
|
|
if self.get_device() == device:
|
|
return self
|
|
else:
|
|
if device is None:
|
|
device = -1
|
|
with torch.cuda.device(device):
|
|
if self.is_sparse:
|
|
new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
|
|
indices = torch.Tensor._indices(self).cuda(device, non_blocking)
|
|
values = torch.Tensor._values(self).cuda(device, non_blocking)
|
|
return new_type(indices, values, self.size())
|
|
else:
|
|
untyped_storage = torch.UntypedStorage(
|
|
self.size(), device=torch.device("cuda")
|
|
)
|
|
untyped_storage.copy_(self, non_blocking)
|
|
return untyped_storage
|
|
|
|
|
|
def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
|
|
"""Return the non-blocking flag given the function name and kwargs.
|
|
|
|
Args:
|
|
function_name (str): the name of the function being used.
|
|
non_blocking (bool): the default value.
|
|
**kwargs (dict): the kwargs passed to the function.
|
|
"""
|
|
if not kwargs:
|
|
return non_blocking
|
|
if len(kwargs) != 1 or "async" not in kwargs:
|
|
message = "{}() got an unexpected keyword argument '{}'"
|
|
argument = list(kwargs.keys()).pop()
|
|
raise TypeError(message.format(function_name, argument))
|
|
warnings.warn("'async' is deprecated; use 'non_blocking'")
|
|
return kwargs["async"]
|
|
|
|
|
|
# Note [Don't serialize hooks]
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
# Since time immemorial, we have serialized the backward hooks associated with
|
|
# variables. This kind of half-worked--Python can pickle global functions
|
|
# (but not closures!)--but there were problems.
|
|
#
|
|
# - It's fragile. If you serialize a backward hook into a saved
|
|
# model, and then you rename the function associated with the hook,
|
|
# now your saved model is broken and you can't load it anymore.
|
|
#
|
|
# - It's not actually used. The standard recommendation is to
|
|
# serialize the *state_dict* of a model, not the model itself
|
|
# (since this is more stable to code changes affecting the model
|
|
# serialization), and the state dict saves "data" only, thus
|
|
# stripping the backward hooks. In some cases, hooks are
|
|
# essential to the well-functioning of a model (e.g., DDP),
|
|
# but DDP already manages readding the hooks!
|
|
#
|
|
# - We didn't serialize them in many cases. Prior to #10220, we
|
|
# were dropping backward hooks in ForkingPickler. We "fixed" this
|
|
# to be convenient with other serialization sites, but lack of
|
|
# serializing backward hooks wasn't actually the root cause of
|
|
# the bug.
|
|
#
|
|
# With these cases in mind, we have decided that a better strategy
|
|
# is to just NOT serialize hooks at all.
|
|
#
|
|
# Since this is a BC-breaking change, we should warn when we previously
|
|
# serialized a hook, but no longer do so. This will be done by adding a special
|
|
# sentinel property to hooks will be used to suppress this warning. If a hook
|
|
# has the property _torch_serialize_ignore, we will not emit a warning if we
|
|
# attempt to serialize a Tensor with this hook attached to it.
|
|
#
|
|
# By the way, when _backward_hooks is skipped, we must give an EMPTY
|
|
# OrderedDict(), if you pass a None you'll run afoul #12219.
|
|
|
|
|
|
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
|
|
# be a TypedStorage
|
|
def _rebuild_tensor(storage, storage_offset, size, stride):
|
|
# first construct a tensor with the correct dtype/device
|
|
t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device)
|
|
return t.set_(storage._untyped_storage, storage_offset, size, stride)
|
|
|
|
|
|
def get_tensor_metadata(tensor):
|
|
# Tensor's Metadata for serializing.
|
|
# Currently, this only returns a dict[string, bool] specifing whether
|
|
# `conj` or `neg` bit is set.
|
|
assert isinstance(tensor, torch.Tensor)
|
|
return torch._C._get_tensor_metadata(tensor) # type: ignore[attr-defined]
|
|
|
|
|
|
def set_tensor_metadata(tensor, metadata):
|
|
# See `get_tensor_metadata` above
|
|
assert isinstance(metadata, dict)
|
|
assert isinstance(tensor, torch.Tensor)
|
|
torch._C._set_tensor_metadata(tensor, metadata) # type: ignore[attr-defined]
|
|
|
|
|
|
def _rebuild_tensor_v2(
|
|
storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None
|
|
):
|
|
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
|
|
tensor.requires_grad = requires_grad
|
|
if metadata:
|
|
set_tensor_metadata(tensor, metadata)
|
|
|
|
# NB: This line exists only for backwards compatibility; the
|
|
# general expectation is that backward_hooks is an empty
|
|
# OrderedDict. See Note [Don't serialize hooks]
|
|
tensor._backward_hooks = backward_hooks
|
|
return tensor
|
|
|
|
|
|
def _rebuild_tensor_v3(
|
|
storage,
|
|
storage_offset,
|
|
size,
|
|
stride,
|
|
requires_grad,
|
|
backward_hooks,
|
|
dtype,
|
|
metadata=None,
|
|
):
|
|
t = torch.empty(
|
|
(0,),
|
|
dtype=dtype,
|
|
device=storage._untyped_storage.device,
|
|
requires_grad=requires_grad,
|
|
)
|
|
t.set_(storage._untyped_storage, storage_offset, size, stride)
|
|
if metadata:
|
|
set_tensor_metadata(t, metadata)
|
|
t._backward_hooks = backward_hooks
|
|
return t
|
|
|
|
|
|
_sparse_tensors_to_validate: List["torch.Tensor"] = []
|
|
|
|
|
|
# In _legacy_load() in serialization.py we unpickle storages after the sparse
|
|
# tensors have been already unpickled. Those storages contain data necessary for
|
|
# validating sparse tensors: indices and values. That's why sparse tensors are
|
|
# first unpickled without any validation, and then this function is called just
|
|
# before _legacy_load() returns, so that all the sparse tensors can be validated
|
|
# in bulk.
|
|
#
|
|
# The same procedure must be followed by _load() in serialization.py because due
|
|
# to Pickler semantics, we have to use the same (non-validating) function for
|
|
# unpickling sparse tensors, regardless of the caller.
|
|
def _validate_loaded_sparse_tensors():
|
|
try:
|
|
for t in _sparse_tensors_to_validate:
|
|
if t.layout is torch.sparse_coo:
|
|
torch._validate_sparse_coo_tensor_args(
|
|
t._indices(), t._values(), t.size(), t.is_coalesced()
|
|
)
|
|
elif t.layout in {
|
|
torch.sparse_csr,
|
|
torch.sparse_csc,
|
|
torch.sparse_bsr,
|
|
torch.sparse_bsc,
|
|
}:
|
|
# TODO: Validation currently involves an expensive traversal
|
|
# on CPU, which may include a device transfer.
|
|
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
|
compressed_indices, plain_indices = (
|
|
t.crow_indices(),
|
|
t.col_indices(),
|
|
)
|
|
else:
|
|
compressed_indices, plain_indices = (
|
|
t.ccol_indices(),
|
|
t.row_indices(),
|
|
)
|
|
torch._validate_sparse_compressed_tensor_args(
|
|
compressed_indices, plain_indices, t.values(), t.size(), t.layout
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"_validate_loaded_sparse_tensors for layout `{t.layout}`"
|
|
)
|
|
|
|
finally:
|
|
_sparse_tensors_to_validate.clear()
|
|
|
|
|
|
def _rebuild_sparse_tensor(layout, data):
|
|
"""
|
|
Rebuilds a sparse tensor from its sparse storage representation.
|
|
|
|
Args:
|
|
layout (str): The sparse storage layout of the tensor.
|
|
data (tuple): The tensor's sparse storage representation.
|
|
"""
|
|
if layout == torch.sparse_coo:
|
|
if len(data) == 3:
|
|
# For BC:
|
|
indices, values, size = data
|
|
is_coalesced = None
|
|
else:
|
|
indices, values, size, is_coalesced = data
|
|
result = torch.sparse_coo_tensor(
|
|
indices, values, size, check_invariants=False, is_coalesced=is_coalesced
|
|
)
|
|
_sparse_tensors_to_validate.append(result)
|
|
return result
|
|
|
|
elif layout in {
|
|
torch.sparse_csr,
|
|
torch.sparse_csc,
|
|
torch.sparse_bsr,
|
|
torch.sparse_bsc,
|
|
}:
|
|
compressed_indices, plain_indices, values, size = data
|
|
result = torch.sparse_compressed_tensor(
|
|
compressed_indices,
|
|
plain_indices,
|
|
values,
|
|
size,
|
|
layout=layout,
|
|
check_invariants=False,
|
|
)
|
|
_sparse_tensors_to_validate.append(result)
|
|
return result
|
|
|
|
raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}")
|
|
|
|
|
|
def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
|
|
return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets)
|
|
|
|
|
|
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
|
|
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
|
|
tensor.requires_grad = requires_grad
|
|
return tensor
|
|
|
|
|
|
# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
|
|
_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
|
|
|
|
|
|
def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
|
|
return torch.empty_strided(
|
|
size, stride, dtype=dtype, device="meta", requires_grad=requires_grad
|
|
)
|
|
|
|
|
|
def _rebuild_wrapper_subclass(
|
|
cls, dtype, size, stride, storage_offset, layout, device, requires_grad
|
|
):
|
|
return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
|
cls,
|
|
size,
|
|
strides=stride,
|
|
storage_offset=storage_offset,
|
|
layout=layout,
|
|
device=device,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
|
|
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
|
|
# be a TypedStorage
|
|
def _rebuild_qtensor(
|
|
storage,
|
|
storage_offset,
|
|
size,
|
|
stride,
|
|
quantizer_params,
|
|
requires_grad,
|
|
backward_hooks,
|
|
):
|
|
qscheme = quantizer_params[0]
|
|
if qscheme == torch.per_tensor_affine:
|
|
_, scale, zero_point = quantizer_params
|
|
tensor = torch._empty_affine_quantized(
|
|
size,
|
|
scale=scale,
|
|
zero_point=zero_point,
|
|
dtype=storage.dtype,
|
|
device=storage.device,
|
|
)
|
|
elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
|
|
_, scales, zero_points, axis = quantizer_params
|
|
if type(scales) is list and type(zero_points) is list:
|
|
if qscheme == torch.per_channel_affine:
|
|
scales = torch.tensor(scales, dtype=torch.double, device=storage.device)
|
|
zero_points = torch.tensor(
|
|
zero_points, dtype=torch.long, device=storage.device
|
|
)
|
|
else:
|
|
scales = torch.tensor(scales, dtype=torch.float, device=storage.device)
|
|
zero_points = torch.tensor(
|
|
zero_points, dtype=torch.float, device=storage.device
|
|
)
|
|
tensor = torch._empty_per_channel_affine_quantized(
|
|
size,
|
|
scales=scales,
|
|
zero_points=zero_points,
|
|
axis=axis,
|
|
dtype=storage.dtype,
|
|
device=storage.device,
|
|
)
|
|
else:
|
|
raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}")
|
|
tensor.set_(storage, storage_offset, size, stride)
|
|
tensor.requires_grad = requires_grad
|
|
# NB: This line exists only for backwards compatibility; the
|
|
# general expectation is that backward_hooks is an empty
|
|
# OrderedDict. See Note [Don't serialize hooks]
|
|
tensor._backward_hooks = backward_hooks
|
|
return tensor
|
|
|
|
|
|
def _rebuild_parameter(data, requires_grad, backward_hooks):
|
|
param = torch.nn.Parameter(data, requires_grad)
|
|
# NB: This line exists only for backwards compatibility; the
|
|
# general expectation is that backward_hooks is an empty
|
|
# OrderedDict. See Note [Don't serialize hooks]
|
|
param._backward_hooks = backward_hooks
|
|
|
|
return param
|
|
|
|
|
|
def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
|
|
param = torch.nn.Parameter(data, requires_grad)
|
|
# NB: This line exists only for backwards compatibility; the
|
|
# general expectation is that backward_hooks is an empty
|
|
# OrderedDict. See Note [Don't serialize hooks]
|
|
param._backward_hooks = backward_hooks
|
|
|
|
# Restore state on Parameter like python attr.
|
|
param = _set_obj_state(param, state)
|
|
return param
|
|
|
|
|
|
def _get_obj_state(obj):
|
|
# Get the state of the python subclass
|
|
# This loosely mimicks the function on the object class but since Tensor do not inherit
|
|
# from it, we cannot call that function directly
|
|
# https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891
|
|
# Note that starting with Python 3.11, this `__getstate__` is always defined and thus
|
|
# the else branch will never be taken.
|
|
getstate_fn = getattr(obj, "__getstate__", None)
|
|
if getstate_fn:
|
|
state = getstate_fn()
|
|
else:
|
|
slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined]
|
|
if slots_to_save:
|
|
state = (
|
|
obj.__dict__,
|
|
{
|
|
name: getattr(obj, name)
|
|
for name in slots_to_save
|
|
if hasattr(obj, name)
|
|
},
|
|
)
|
|
else:
|
|
state = obj.__dict__
|
|
|
|
return state
|
|
|
|
|
|
def _set_obj_state(obj, state):
|
|
if isinstance(state, tuple):
|
|
if not len(state) == 2:
|
|
raise RuntimeError(f"Invalid serialized state: {state}")
|
|
dict_state = state[0]
|
|
slots_state = state[1]
|
|
else:
|
|
dict_state = state
|
|
slots_state = None
|
|
|
|
# Starting with Python 3.11, the __dict__ attribute is lazily created
|
|
# and is serialized as None when not needed.
|
|
if dict_state:
|
|
for k, v in dict_state.items():
|
|
setattr(obj, k, v)
|
|
|
|
if slots_state:
|
|
for k, v in slots_state.items():
|
|
setattr(obj, k, v)
|
|
return obj
|
|
|
|
|
|
def _import_dotted_name(name):
|
|
components = name.split(".")
|
|
obj = __import__(components[0])
|
|
for component in components[1:]:
|
|
obj = getattr(obj, component)
|
|
return obj
|
|
|
|
|
|
def _flatten_dense_tensors(tensors):
|
|
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
|
|
same dense type.
|
|
|
|
Since inputs are dense, the resulting tensor will be a concatenated 1D
|
|
buffer. Element-wise operation on this buffer will be equivalent to
|
|
operating individually.
|
|
|
|
Args:
|
|
tensors (Iterable[Tensor]): dense tensors to flatten.
|
|
|
|
Returns:
|
|
A contiguous 1D buffer containing input tensors.
|
|
"""
|
|
return torch._C._nn.flatten_dense_tensors(tensors)
|
|
|
|
|
|
def _flatten_sparse_tensors(tensors):
|
|
"""Flatten sparse tensors into two contiguous 1D buffers, one of indices and
|
|
one of values. Assume tensors are of same sparse type.
|
|
|
|
Args:
|
|
tensors (Iterable[Tensor]): sparse tensors to flatten.
|
|
|
|
Returns:
|
|
A tuple of two contiguous 1D buffers, one containing input tensors'
|
|
indices and the other containing the values.
|
|
"""
|
|
flat_indices = torch._C._nn.flatten_dense_tensors(
|
|
[torch.Tensor._indices(t) for t in tensors]
|
|
)
|
|
flat_values = torch._C._nn.flatten_dense_tensors(
|
|
[torch.Tensor._values(t) for t in tensors]
|
|
)
|
|
return flat_indices, flat_values
|
|
|
|
|
|
def _unflatten_dense_tensors(flat, tensors):
|
|
"""View a flat buffer using the sizes of tensors. Assume that tensors are of
|
|
same dense type, and that flat is given by _flatten_dense_tensors.
|
|
|
|
Args:
|
|
flat (Tensor): flattened dense tensors to unflatten.
|
|
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
|
|
unflatten flat.
|
|
|
|
Returns:
|
|
Unflattened dense tensors with sizes same as tensors and values from
|
|
flat.
|
|
"""
|
|
return torch._C._nn.unflatten_dense_tensors(flat, tensors)
|
|
|
|
|
|
def _unflatten_sparse_tensors(flat, tensors):
|
|
"""View flat buffer (containing indices and values) using the sizes of
|
|
tensors. Assume that tensors are of same sparse type, and that flat is given
|
|
by _flatten_sparse_tensors.
|
|
|
|
Args:
|
|
flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
|
|
tensors to unflatten.
|
|
tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
|
|
unflatten flat.
|
|
|
|
Returns:
|
|
Unflattened sparse tensors with sizes same as tensors and values from
|
|
flat.
|
|
"""
|
|
flat_indices, flat_values = flat
|
|
indices = torch._C._nn.unflatten_dense_tensors(
|
|
flat_indices, [torch.Tensor._indices(t) for t in tensors]
|
|
)
|
|
values = torch._C._nn.unflatten_dense_tensors(
|
|
flat_values, [torch.Tensor._values(t) for t in tensors]
|
|
)
|
|
outputs = []
|
|
for t, i, v in zip(tensors, indices, values):
|
|
outputs.append(t.new(i, v, t.size()))
|
|
return tuple(outputs)
|
|
|
|
|
|
def _reorder_tensors_as(tensors, ordered_tensors):
|
|
"""Assume that tensors are of same order as ordered_tensors within their
|
|
types, e.g., from _take_tensors. Reorder them to be of same order as
|
|
ordered_tensors.
|
|
|
|
Args:
|
|
tensors (Iterable[Tensor]): tensors to be reordered. They should be of
|
|
the same order as ordered_tensors within their own types.
|
|
ordered_tensors (Iterable[Tensor]): tensors whose order will be the
|
|
reference.
|
|
|
|
Returns:
|
|
Ordered tuple of tensors with contents from tensors and order of
|
|
ordered_tensors.
|
|
"""
|
|
type_dict = defaultdict(list)
|
|
for tensor in tensors:
|
|
type_dict[tensor.type()].append(tensor)
|
|
type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
|
|
return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
|
|
|
|
|
|
def _take_tensors(tensors, size_limit):
|
|
"""Group tensors into chunks. This generator yields a chunk at each time,
|
|
each containing tensors of same type up to certain byte limit in total size.
|
|
|
|
Args:
|
|
tensors (Sequence): A sequence of tensors to be separated into chunks.
|
|
size_limit (int): The limit of each chunk in bytes.
|
|
|
|
Yields:
|
|
Blocks of tensors of same type and within size_limit. The yielded
|
|
tensors are only ordered as the original sequence within its types.
|
|
"""
|
|
buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
|
|
for tensor in tensors:
|
|
t = tensor.type()
|
|
if tensor.is_sparse:
|
|
indices = torch.Tensor._indices(tensor)
|
|
values = torch.Tensor._values(tensor)
|
|
size = (
|
|
indices.numel() * indices.element_size()
|
|
+ values.numel() * values.element_size()
|
|
)
|
|
else:
|
|
size = tensor.numel() * tensor.element_size()
|
|
buf_and_size = buf_dict[t]
|
|
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
|
|
yield buf_and_size[0]
|
|
buf_and_size = buf_dict[t] = [[], 0]
|
|
buf_and_size[0].append(tensor)
|
|
buf_and_size[1] += size
|
|
for buf, _ in buf_dict.values():
|
|
if len(buf) > 0:
|
|
yield buf
|
|
|
|
|
|
# annotation decorator to get annotations in a way that is compatible
|
|
# with both Python 2 and 3
|
|
def annotate(ret, **kwargs):
|
|
def dec(fun):
|
|
fun.__annotations__ = dict(kwargs)
|
|
fun.__annotations__["return"] = ret
|
|
return fun
|
|
|
|
return dec
|
|
|
|
|
|
def render_call(fn, args, kwargs):
|
|
str_fn = torch.overrides.resolve_name(fn)
|
|
if str_fn is None:
|
|
str_fn = str(fn)
|
|
|
|
str_args: List[str] = []
|
|
with torch._tensor_str.printoptions(threshold=0, edgeitems=0):
|
|
str_args.extend(repr(a) for a in args)
|
|
str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items())
|
|
r = f"{str_fn}({', '.join(str_args)})"
|
|
return r
|
|
|
|
|
|
# NOTE [ Python Traceback Reference Cycle Problem ]
|
|
#
|
|
# When using sys.exc_info(), it is important to **not** store the exc_info[2],
|
|
# which is the traceback, because otherwise you will run into the traceback
|
|
# reference cycle problem, i.e., the traceback holding reference to the frame,
|
|
# and the frame (which holds reference to all the object in its temporary scope)
|
|
# holding reference the traceback.
|
|
|
|
|
|
class KeyErrorMessage(str):
|
|
r"""str subclass that returns itself in repr"""
|
|
|
|
def __repr__(self):
|
|
return self
|
|
|
|
|
|
class ExceptionWrapper:
|
|
r"""Wraps an exception plus traceback to communicate across threads"""
|
|
|
|
def __init__(self, exc_info=None, where="in background"):
|
|
# It is important that we don't store exc_info, see
|
|
# NOTE [ Python Traceback Reference Cycle Problem ]
|
|
if exc_info is None:
|
|
exc_info = sys.exc_info()
|
|
self.exc_type = exc_info[0]
|
|
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
|
self.where = where
|
|
|
|
def reraise(self):
|
|
r"""Reraises the wrapped exception in the current thread"""
|
|
# Format a message such as: "Caught ValueError in DataLoader worker
|
|
# process 2. Original Traceback:", followed by the traceback.
|
|
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}"
|
|
if self.exc_type == KeyError:
|
|
# KeyError calls repr() on its argument (usually a dict key). This
|
|
# makes stack traces unreadable. It will not be changed in Python
|
|
# (https://bugs.python.org/issue2651), so we work around it.
|
|
msg = KeyErrorMessage(msg)
|
|
elif getattr(self.exc_type, "message", None):
|
|
# Some exceptions have first argument as non-str but explicitly
|
|
# have message field
|
|
raise self.exc_type(message=msg)
|
|
try:
|
|
exception = self.exc_type(msg)
|
|
except TypeError:
|
|
# If the exception takes multiple arguments, don't try to
|
|
# instantiate since we don't know how to
|
|
raise RuntimeError(msg) from None
|
|
raise exception
|
|
|
|
|
|
def _get_available_device_type():
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined]
|
|
return "xpu"
|
|
custom_backend_name = torch._C._get_privateuse1_backend_name()
|
|
custom_device_mod = getattr(torch, custom_backend_name, None)
|
|
if custom_device_mod and custom_device_mod.is_available():
|
|
return custom_backend_name
|
|
# add more available device types here
|
|
return None
|
|
|
|
|
|
def _get_device_attr(get_member):
|
|
device_type = _get_available_device_type()
|
|
if device_type and device_type.lower() == "cuda":
|
|
return get_member(torch.cuda)
|
|
if device_type and device_type.lower() == "xpu":
|
|
return get_member(torch.xpu) # type: ignore[attr-defined]
|
|
if device_type == torch._C._get_privateuse1_backend_name():
|
|
return get_member(getattr(torch, device_type))
|
|
# add more available device types here
|
|
return None
|
|
|
|
|
|
def _get_current_device_index():
|
|
# current device index
|
|
return _get_device_attr(lambda m: m.current_device())
|
|
|
|
|
|
def _get_all_device_indices():
|
|
# all device index
|
|
return _get_device_attr(lambda m: list(range(m.device_count())))
|
|
|
|
|
|
def _get_devices_properties(device_ids):
|
|
# all device properties
|
|
return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
|
|
|
|
|
|
def get_current_device_index() -> int:
|
|
r"""Checks if there are CUDA devices available and
|
|
returns the device index of the current default CUDA device.
|
|
Returns -1 in case there are no CUDA devices available.
|
|
Arguments: ``None``
|
|
"""
|
|
if torch.cuda.device_count() > 0:
|
|
return torch.cuda.current_device()
|
|
return -1
|
|
|
|
|
|
def _get_device_index(
|
|
device: Any, optional: bool = False, allow_cpu: bool = False
|
|
) -> int:
|
|
r"""Gets the device index from :attr:`device`, which can be a torch.device
|
|
object, a Python integer, or ``None``.
|
|
|
|
If :attr:`device` is a torch.device object, returns the device index if it
|
|
has index. Note that for a device without a specified index,
|
|
i.e., ``torch.device('xxx')``, this will return the current default
|
|
device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
|
|
CPU devices will be accepted and ``-1`` will be returned in this case.
|
|
|
|
If :attr:`device` is a Python integer, it is returned as is.
|
|
|
|
If :attr:`device` is ``None``, this will return the current default
|
|
device of the supported runtime platform if :attr:`optional` is ``True``.
|
|
i.e., the current default CUDA device will be returned if CUDA runtime is supported.
|
|
"""
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
device_idx: Optional[int] = None
|
|
if isinstance(device, torch.device):
|
|
if not allow_cpu and device.type == "cpu":
|
|
raise ValueError(f"Expected a non cpu device, but got: {device}")
|
|
device_idx = -1 if device.type == "cpu" else device.index
|
|
if isinstance(device, int):
|
|
device_idx = device
|
|
if device_idx is None:
|
|
if optional:
|
|
# The eager API _get_current_device_index uses `lambda` functions which are
|
|
# not supported in JIT and hence not scriptable. The JIT equivalent API to get
|
|
# the current device index is `get_current_device_index()` which can
|
|
# be scripted. We use is_scripting to check the mode we are in and call the
|
|
# appropriate API.
|
|
if torch.jit.is_scripting():
|
|
device_idx = get_current_device_index()
|
|
else:
|
|
device_idx = _get_current_device_index()
|
|
else:
|
|
raise ValueError(
|
|
f"Expected a torch.device with a specified index or an integer, but got:{device}"
|
|
)
|
|
return device_idx
|
|
|
|
|
|
def _handle_complex(tensor):
|
|
"""
|
|
Returns a real view of a tensor if complex dtype else just the tensor
|
|
need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
|
|
"""
|
|
return (
|
|
torch.view_as_real(tensor)
|
|
if not isinstance(tensor, torch.nn.UninitializedParameter)
|
|
and tensor.is_complex()
|
|
else tensor
|
|
)
|
|
|
|
|
|
def _element_size(dtype):
|
|
"""
|
|
Returns the element size for a dtype, in bytes
|
|
"""
|
|
if not isinstance(dtype, torch.dtype):
|
|
raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}")
|
|
|
|
if dtype.is_complex:
|
|
return torch.finfo(dtype).bits >> 2
|
|
elif dtype.is_floating_point:
|
|
return torch.finfo(dtype).bits >> 3
|
|
elif dtype == torch.bool:
|
|
# NOTE: torch.bool is not supported in torch.iinfo()
|
|
return 1
|
|
else:
|
|
return torch.iinfo(dtype).bits >> 3
|
|
|
|
|
|
class _ClassPropertyDescriptor:
|
|
def __init__(self, fget, fset=None):
|
|
self.fget = fget
|
|
|
|
def __get__(self, instance, owner=None):
|
|
if owner is None:
|
|
owner = type(instance)
|
|
return self.fget.__get__(instance, owner)()
|
|
|
|
|
|
def classproperty(func):
|
|
if not isinstance(func, (classmethod, staticmethod)):
|
|
func = classmethod(func)
|
|
return _ClassPropertyDescriptor(func)
|
|
|
|
|
|
# Whether we are compiling with torch.compile or not
|
|
def is_compiling():
|
|
return False
|
|
|
|
|
|
def _functionalize_sync(t):
|
|
# This code lives in python instead of C++ since conditioning on a certain python subclass
|
|
# is much more of a pain in C++.
|
|
from torch._subclasses.functional_tensor import FunctionalTensor
|
|
|
|
if isinstance(t, FunctionalTensor):
|
|
# If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
|
|
# when we sync our inner tensor.
|
|
# Why?
|
|
# (1) If there are input mutations in the graph, then they will be re-applied during
|
|
# AOTAutograd when we call _sync() from inside of our functionalization kernels.
|
|
# (2) _sync() causes us to regenerate our updated the tensor from the updated base,
|
|
# which dispatches to a bunch of view ops
|
|
# (3) The input to these view ops is our inner FunctionalTensorWrapper
|
|
# (since the sync was called from C++), not the python FunctionalTensor
|
|
# (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
|
|
# the view op, since it will see an input that is a C++ FunctionalTensorWrapper
|
|
# (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
|
|
maybe_functional_mode = torch._C._unset_dispatch_mode(
|
|
torch._C._TorchDispatchModeKey.FUNCTIONAL
|
|
)
|
|
try:
|
|
torch._functionalize_sync(t.elem) # type: ignore[attr-defined]
|
|
finally:
|
|
if maybe_functional_mode is not None:
|
|
torch._C._set_dispatch_mode(maybe_functional_mode)
|
|
else:
|
|
torch._functionalize_sync(t) # type: ignore[attr-defined]
|
|
|
|
|
|
@functools.lru_cache(2)
|
|
def _get_device_module(device_type: str):
|
|
device_module = getattr(torch, device_type, None)
|
|
if device_module is None:
|
|
raise RuntimeError(
|
|
f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
|
|
)
|
|
return device_module
|