pytorch/torch/_prims/debug_prims.py
Nikita Shulga 5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00

60 lines
1.9 KiB
Python

import contextlib
from typing import Sequence
import torch
from torch._custom_op.impl import custom_op
from torch.utils._content_store import ContentStoreReader
LOAD_TENSOR_READER = None
@contextlib.contextmanager
def load_tensor_reader(loc):
global LOAD_TENSOR_READER
assert LOAD_TENSOR_READER is None
# load_tensor is an "op", and we will play merry hell on
# Inductor's memory planning if we return a tensor that
# aliases another tensor that we previously returned from
# an operator. So unlike standard ContentStoreReader use,
# we disable the cache so that you always get fresh storages
# (no aliasing for you!)
LOAD_TENSOR_READER = ContentStoreReader(loc, cache=False)
try:
yield
finally:
LOAD_TENSOR_READER = None
def register_debug_prims():
@custom_op("debugprims::load_tensor")
def load_tensor( # type: ignore[empty-body]
name: str,
size: Sequence[int],
stride: Sequence[int],
*,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
...
@load_tensor.impl_factory()
def load_tensor_factory(name, size, stride, dtype, device):
if LOAD_TENSOR_READER is None:
from torch._dynamo.testing import rand_strided
return rand_strided(size, stride, dtype, device)
else:
from torch._dynamo.utils import clone_input
# device argument here takes care of coercion
r = LOAD_TENSOR_READER.read_tensor(name, device=device)
assert list(r.size()) == size, f"{r.size()} != {size}"
assert list(r.stride()) == stride, f"{r.stride()} != {stride}"
assert r.device == device, f"{r.device} != {device}"
# Unlike the other properties, we will do coercions for dtype
# mismatch
if r.dtype != dtype:
r = clone_input(r, dtype=dtype)
return r