mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This PR re-lands - [Typing] Fix PEP 484 Violation (#105022) - Update mypy to 1.4.1 (#91983) That were reverted due to the conflict with internal source repo. Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional) Plus few real fixes: - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi` - Add missing return statement to `torch._export. deserialize_graph` - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights` - Add assert it `torch/optim/optimizer.py` that Optional list is not None TODO (in followup PR): - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py` Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04: - Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh` - Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done) Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227 Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
import contextlib
|
|
from typing import Sequence
|
|
|
|
import torch
|
|
from torch._custom_op.impl import custom_op
|
|
from torch.utils._content_store import ContentStoreReader
|
|
|
|
LOAD_TENSOR_READER = None
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def load_tensor_reader(loc):
|
|
global LOAD_TENSOR_READER
|
|
assert LOAD_TENSOR_READER is None
|
|
# load_tensor is an "op", and we will play merry hell on
|
|
# Inductor's memory planning if we return a tensor that
|
|
# aliases another tensor that we previously returned from
|
|
# an operator. So unlike standard ContentStoreReader use,
|
|
# we disable the cache so that you always get fresh storages
|
|
# (no aliasing for you!)
|
|
LOAD_TENSOR_READER = ContentStoreReader(loc, cache=False)
|
|
try:
|
|
yield
|
|
finally:
|
|
LOAD_TENSOR_READER = None
|
|
|
|
|
|
def register_debug_prims():
|
|
@custom_op("debugprims::load_tensor")
|
|
def load_tensor( # type: ignore[empty-body]
|
|
name: str,
|
|
size: Sequence[int],
|
|
stride: Sequence[int],
|
|
*,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> torch.Tensor:
|
|
...
|
|
|
|
@load_tensor.impl_factory()
|
|
def load_tensor_factory(name, size, stride, dtype, device):
|
|
if LOAD_TENSOR_READER is None:
|
|
from torch._dynamo.testing import rand_strided
|
|
|
|
return rand_strided(size, stride, dtype, device)
|
|
else:
|
|
from torch._dynamo.utils import clone_input
|
|
|
|
# device argument here takes care of coercion
|
|
r = LOAD_TENSOR_READER.read_tensor(name, device=device)
|
|
assert list(r.size()) == size, f"{r.size()} != {size}"
|
|
assert list(r.stride()) == stride, f"{r.stride()} != {stride}"
|
|
assert r.device == device, f"{r.device} != {device}"
|
|
|
|
# Unlike the other properties, we will do coercions for dtype
|
|
# mismatch
|
|
if r.dtype != dtype:
|
|
r = clone_input(r, dtype=dtype)
|
|
return r
|