pytorch/torch/utils/weak.py
zabboud 7f9fafed53 Resolve docstring errors in throughput_benchmark.py, weak.py, _traceback.py, file_baton.py, _contextlib.py, _device.py, cpp_backtrace.py, bundled_inputs.py, run_cpu.py, hooks.py, mobile_optimizer.py, _freeze.py, __init__.py, mkldnn.py, dlpack.py (#113311)
Fixes #112633

Fixed errors relating to pydocstyle in the following files. The remaining errors are not covered in this issue. `torch/utils/dlpack.py` was not modified as the errors are relating to the function signature in the first line in the docstring which must be maintained as is for proper Sphinx interpretation.

```python
def from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
    """from_dlpack(ext_tensor) -> Tensor
         .....
    """
```

pydocstyle torch/utils/_contextlib.py --count
before: 4
after: 0

pydocstyle torch/backends/mps/__init__.py --count
before: 8
after: 1

**remaining errors**
```
torch/backends/mps/__init__.py:1 at module level:
        D104: Missing docstring in public package
```

pydocstyle torch/backends/xeon/run_cpu.py --count
before: 13
after: 1

**remaining errors**
```
torch/backends/xeon/run_cpu.py:864 in public function `main`:
        D103: Missing docstring in public function
```

pydocstyle torch/backends/cpu/__init__.py --count
before: 2
after: 1

**remaining errors**
```
torch/backends/cpu/__init__.py:1 at module level:
        D104: Missing docstring in public package
```

pydocstyle torch/utils/cpp_backtrace.py --count
before: 4
after: 1

**remaining errors**
```
torch/utils/cpp_backtrace.py:1 at module level:
        D100: Missing docstring in public module
```

pydocstyle torch/utils/bundled_inputs.py --count
before: 8
after: 1

**remaining errors**
```
torch/utils/bundled_inputs.py:1 at module level:
        D100: Missing docstring in public module
```

pydocstyle torch/utils/file_baton.py --count
before: 8
after: 1

**remaining errors**
```
torch/utils/file_baton.py:1 at module level:
        D100: Missing docstring in public module
```

pydocstyle torch/utils/mobile_optimizer.py --count
before: 6
after: 1

**remaining errors**
```
torch/utils/mobile_optimizer.py:8 in public class `LintCode`:
        D101: Missing docstring in public class
```

pydocstyle torch/backends/opt_einsum/__init__.py --count
before: 7
after: 5

**remaining errors**
```
torch/backends/opt_einsum/__init__.py:1 at module level:
        D104: Missing docstring in public package
torch/backends/opt_einsum/__init__.py:67 in public function `set_flags`:
        D103: Missing docstring in public function
torch/backends/opt_einsum/__init__.py:77 in public function `flags`:
        D103: Missing docstring in public function
torch/backends/opt_einsum/__init__.py:93 in public class `OptEinsumModule`:
        D101: Missing docstring in public class
torch/backends/opt_einsum/__init__.py:94 in public method `__init__`:
        D107: Missing docstring in __init__
```

pydocstyle torch/utils/_device.py --count
before:  9
after: 6

**remaining errors**
```
torch/utils/_device.py:58 in public class `DeviceContext`:
        D101: Missing docstring in public class
torch/utils/_device.py:59 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/_device.py:62 in public method `__enter__`:
        D105: Missing docstring in magic method
torch/utils/_device.py:68 in public method `__exit__`:
        D105: Missing docstring in magic method
torch/utils/_device.py:73 in public method `__torch_function__`:
        D105: Missing docstring in magic method
torch/utils/_device.py:80 in public function `device_decorator`:
        D103: Missing docstring in public function

```

pydocstyle torch/utils/_freeze.py --count
before: 15
after: 7

**remaining errors**
```
torch/utils/_freeze.py:77 in public function `indent_msg`:
        D103: Missing docstring in public function
torch/utils/_freeze.py:89 in public class `FrozenModule`:
        D101: Missing docstring in public class
torch/utils/_freeze.py:100 in public class `Freezer`:
        D101: Missing docstring in public class
torch/utils/_freeze.py:101 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/_freeze.py:106 in public method `msg`:
        D102: Missing docstring in public method
torch/utils/_freeze.py:185 in public method `get_module_qualname`:
        D102: Missing docstring in public method
torch/utils/_freeze.py:206 in public method `compile_string`:
        D102: Missing docstring in public method

```

pydocstyle torch/utils/throughput_benchmark.py --count
before: 25
after: 8
**remaining errors**
```
torch/utils/throughput_benchmark.py:1 at module level:
        D100: Missing docstring in public module
torch/utils/throughput_benchmark.py:27 in public class `ExecutionStats`:
        D101: Missing docstring in public class
torch/utils/throughput_benchmark.py:28 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/throughput_benchmark.py:33 in public method `latency_avg_ms`:
        D102: Missing docstring in public method
torch/utils/throughput_benchmark.py:37 in public method `num_iters`:
        D102: Missing docstring in public method
torch/utils/throughput_benchmark.py:46 in public method `total_time_seconds`:
        D102: Missing docstring in public method
torch/utils/throughput_benchmark.py:50 in public method `__str__`:
        D105: Missing docstring in magic method
torch/utils/throughput_benchmark.py:94 in public method `__init__`:
        D107: Missing docstring in __init__

```

pydocstyle torch/utils/hooks.py --count

before: 14
after: 11

**remaining errors**
```
torch/utils/hooks.py:1 at module level:
        D100: Missing docstring in public module
torch/utils/hooks.py:23 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/hooks.py:34 in public method `remove`:
        D102: Missing docstring in public method
torch/utils/hooks.py:44 in public method `__getstate__`:
        D105: Missing docstring in magic method
torch/utils/hooks.py:50 in public method `__setstate__`:
        D105: Missing docstring in magic method
torch/utils/hooks.py:64 in public method `__enter__`:
        D105: Missing docstring in magic method
torch/utils/hooks.py:67 in public method `__exit__`:
        D105: Missing docstring in magic method
torch/utils/hooks.py:82 in public function `warn_if_has_hooks`:
        D103: Missing docstring in public function
torch/utils/hooks.py:103 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/hooks.py:188 in public method `setup_input_hook`:
        D102: Missing docstring in public method
torch/utils/hooks.py:197 in public method `setup_output_hook`:
        D102: Missing docstring in public method
```

pydocstyle torch/utils/_traceback.py --count
before: 19
after: 14

**remaining errors**
```
torch/utils/_traceback.py:47 in public function `report_compile_source_on_error`:
        D103: Missing docstring in public function
torch/utils/_traceback.py:160 in public class `CapturedTraceback`:
        D101: Missing docstring in public class
torch/utils/_traceback.py:163 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/_traceback.py:167 in public method `cleanup`:
        D102: Missing docstring in public method
torch/utils/_traceback.py:170 in public method `summary`:
        D102: Missing docstring in public method
torch/utils/_traceback.py:182 in public method `__getstate__`:
        D105: Missing docstring in magic method
torch/utils/_traceback.py:190 in public method `extract`:
        D205: 1 blank line required between summary line and description (found 0)
torch/utils/_traceback.py:190 in public method `extract`:
        D400: First line should end with a period (not 't')
torch/utils/_traceback.py:213 in public method `format`:
        D205: 1 blank line required between summary line and description (found 0)
torch/utils/_traceback.py:213 in public method `format`:
        D400: First line should end with a period (not 'f')
torch/utils/_traceback.py:213 in public method `format`:
        D401: First line should be in imperative mood (perhaps 'Format', not 'Formats')
torch/utils/_traceback.py:224 in public method `format_all`:
        D200: One-line docstring should fit on one line with quotes (found 3)
torch/utils/_traceback.py:247 in private function `_extract_symbolized_tb`:
        D205: 1 blank line required between summary line and description (found 0)
torch/utils/_traceback.py:247 in private function `_extract_symbolized_tb`:
        D400: First line should end with a period (not 'f')
```

pydocstyle torch/utils/mkldnn.py --count
before: 28
after: 26

**remaining errors**
```
torch/utils/mkldnn.py:1 at module level:
        D100: Missing docstring in public module
torch/utils/mkldnn.py:4 in public class `MkldnnLinear`:
        D101: Missing docstring in public class
torch/utils/mkldnn.py:5 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/mkldnn.py:19 in public method `__getstate__`:
        D105: Missing docstring in magic method
torch/utils/mkldnn.py:23 in public method `__setstate__`:
        D105: Missing docstring in magic method
torch/utils/mkldnn.py:29 in public method `forward`:
        D102: Missing docstring in public method
torch/utils/mkldnn.py:75 in public class `MkldnnConv1d`:
        D101: Missing docstring in public class
torch/utils/mkldnn.py:76 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/mkldnn.py:82 in public method `__setstate__`:
        D105: Missing docstring in magic method
torch/utils/mkldnn.py:88 in public class `MkldnnConv2d`:
        D101: Missing docstring in public class
torch/utils/mkldnn.py:89 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/mkldnn.py:100 in public method `__setstate__`:
        D105: Missing docstring in magic method
torch/utils/mkldnn.py:110 in public class `MkldnnConv3d`:
        D101: Missing docstring in public class
torch/utils/mkldnn.py:111 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/mkldnn.py:122 in public method `__setstate__`:
        D105: Missing docstring in magic method
torch/utils/mkldnn.py:133 in public class `MkldnnBatchNorm`:
        D101: Missing docstring in public class
torch/utils/mkldnn.py:136 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/mkldnn.py:155 in public method `__getstate__`:
        D105: Missing docstring in magic method
torch/utils/mkldnn.py:163 in public method `__setstate__`:
        D105: Missing docstring in magic method
torch/utils/mkldnn.py:171 in public method `forward`:
        D102: Missing docstring in public method
torch/utils/mkldnn.py:184 in public class `MkldnnPrelu`:
        D101: Missing docstring in public class
torch/utils/mkldnn.py:185 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/mkldnn.py:190 in public method `__getstate__`:
        D105: Missing docstring in magic method
torch/utils/mkldnn.py:194 in public method `__setstate__`:
        D105: Missing docstring in magic method
torch/utils/mkldnn.py:199 in public method `forward`:
        D102: Missing docstring in public method
torch/utils/mkldnn.py:205 in public function `to_mkldnn`:
        D103: Missing docstring in public function
```

pydocstyle torch/utils/weak.py --count
before: 32
after: 30

**remaining errors**
```
torch/utils/weak.py:1 at module level:
        D100: Missing docstring in public module
torch/utils/weak.py:42 in public class `WeakIdRef`:
        D101: Missing docstring in public class
torch/utils/weak.py:45 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/weak.py:54 in public method `__call__`:
        D102: Missing docstring in public method
torch/utils/weak.py:61 in public method `__hash__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:64 in public method `__eq__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:84 in public class `WeakIdKeyDictionary`:
        D101: Missing docstring in public class
torch/utils/weak.py:87 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/weak.py:131 in public method `__delitem__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:135 in public method `__getitem__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:138 in public method `__len__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:145 in public method `__repr__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:148 in public method `__setitem__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:151 in public method `copy`:
        D102: Missing docstring in public method
torch/utils/weak.py:162 in public method `__deepcopy__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:172 in public method `get`:
        D102: Missing docstring in public method
torch/utils/weak.py:175 in public method `__contains__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:182 in public method `items`:
        D102: Missing docstring in public method
torch/utils/weak.py:189 in public method `keys`:
        D102: Missing docstring in public method
torch/utils/weak.py:198 in public method `values`:
        D102: Missing docstring in public method
torch/utils/weak.py:216 in public method `popitem`:
        D102: Missing docstring in public method
torch/utils/weak.py:224 in public method `pop`:
        D102: Missing docstring in public method
torch/utils/weak.py:228 in public method `setdefault`:
        D102: Missing docstring in public method
torch/utils/weak.py:231 in public method `update`:
        D102: Missing docstring in public method
torch/utils/weak.py:241 in public method `__ior__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:245 in public method `__or__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:252 in public method `__ror__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:262 in public method `__eq__`:
        D105: Missing docstring in magic method
torch/utils/weak.py:276 in public method `__init__`:
        D107: Missing docstring in __init__
torch/utils/weak.py:280 in public method `__call__`:
        D102: Missing docstring in public method

```

@mikaylagawarecki @jbschlosser @svekars
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113311
Approved by: https://github.com/ezyang
2023-11-15 17:40:04 +00:00

288 lines
9.4 KiB
Python

from __future__ import annotations
import weakref
from weakref import ref
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
from collections.abc import MutableMapping, Mapping
from torch import Tensor
import collections.abc as _collections_abc
WeakRef = ref
__all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
# This file defines a variant of WeakKeyDictionary that overrides the hashing
# behavior of the key to use object identity, rather than the builtin
# __eq__/__hash__ functions. This is useful for Tensor weak keys, as their
# __eq__ implementation return a Tensor (elementwise equality), which means
# you can't use them directly with the WeakKeyDictionary in standard library.
#
# Our implementation strategy is to create a wrapper weak key object, which we
# use as a key in a stock Python dictionary. This is similar to how weakref
# implements WeakKeyDictionary, but instead of using weakref.ref as the
# wrapper, we use a custom wrapper that has different __eq__ and __hash__
# behavior. Note that we subsequently store this weak key directly in an
# ORDINARY dictionary, since the newly constructed WeakIdKey's only use would
# be a dictionary so it would have no strong references. Ensuring that
# only live WeakIdKeys are in the map is handled by putting finalizers on the
# original key object.
# It is simpler to implement this with composition, but if we want to
# directly reuse the callback mechanism on weakref, we need the weakref
# and the key to be exactly the same object. Reusing the callback mechanism
# minimizes the divergence between our implementation and Lib/weakref.py
#
# NB: Prefer using this when working with weakrefs of Tensors; e.g., do
# WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of
# easy to get wrong cases transparently for you.
class WeakIdRef(weakref.ref):
__slots__ = ['_id']
def __init__(self, key, callback=None):
# Unlike stock weakref, which preserves hash semantics of the
# original object but lazily defers hash calls until the first
# time the user attempts to hash the weakref, we can eagerly
# cache the id of the key as we know this is definitely the hash
# method
self._id = id(key)
super().__init__(key, callback) # type: ignore[call-arg]
def __call__(self):
r = super().__call__()
# Special logic for Tensor PyObject resurrection
if hasattr(r, '_fix_weakref'):
r._fix_weakref() # type: ignore[union-attr]
return r
def __hash__(self):
return self._id
def __eq__(self, other):
# An attractive but wrong alternate implementation is to only test if
# the stored _ids match. This can lead to an ABA problem if you have:
#
# a1 = A()
# w1 = WeakIdRef(a)
# del a1
# a2 = A() # suppose it gets the same ID as a1
# w2 = WeakIdRef(a2)
# print(w1 == w2)
#
# This should be False, as a1 and a2 are unrelated (and a1 is
# dead anyway)
a = self()
b = other()
if a is not None and b is not None:
return a is b
return self is other
# This is directly adapted from cpython/Lib/weakref.py
class WeakIdKeyDictionary(MutableMapping):
data: dict[WeakIdRef, object]
def __init__(self, dict=None):
self.data = {}
def remove(k, selfref=ref(self)):
self = selfref()
if self is not None:
if self._iterating:
self._pending_removals.append(k)
else:
try:
del self.data[k]
except KeyError:
pass
self._remove = remove
# A list of dead weakrefs (keys to be removed)
self._pending_removals = []
self._iterating = set()
self._dirty_len = False
if dict is not None:
self.update(dict)
def _commit_removals(self):
# NOTE: We don't need to call this method before mutating the dict,
# because a dead weakref never compares equal to a live weakref,
# even if they happened to refer to equal objects.
# However, it means keys may already have been removed.
pop = self._pending_removals.pop
d = self.data
while True:
try:
key = pop()
except IndexError:
return
try:
del d[key]
except KeyError:
pass
def _scrub_removals(self):
d = self.data
self._pending_removals = [k for k in self._pending_removals if k in d]
self._dirty_len = False
def __delitem__(self, key):
self._dirty_len = True
del self.data[WeakIdRef(key)] # CHANGED
def __getitem__(self, key):
return self.data[WeakIdRef(key)] # CHANGED
def __len__(self):
if self._dirty_len and self._pending_removals:
# self._pending_removals may still contain keys which were
# explicitly removed, we have to scrub them (see issue #21173).
self._scrub_removals()
return len(self.data) - len(self._pending_removals)
def __repr__(self):
return f"<{self.__class__.__name__} at {id(self):#x}>"
def __setitem__(self, key, value):
self.data[WeakIdRef(key, self._remove)] = value # CHANGED
def copy(self):
new = WeakIdKeyDictionary()
with _IterationGuard(self):
for key, value in self.data.items():
o = key()
if o is not None:
new[o] = value
return new
__copy__ = copy
def __deepcopy__(self, memo):
from copy import deepcopy
new = self.__class__()
with _IterationGuard(self):
for key, value in self.data.items():
o = key()
if o is not None:
new[o] = deepcopy(value, memo)
return new
def get(self, key, default=None):
return self.data.get(WeakIdRef(key), default) # CHANGED
def __contains__(self, key):
try:
wr = WeakIdRef(key)
except TypeError:
return False
return wr in self.data
def items(self):
with _IterationGuard(self):
for wr, value in self.data.items():
key = wr()
if key is not None:
yield key, value
def keys(self):
with _IterationGuard(self):
for wr in self.data:
obj = wr()
if obj is not None:
yield obj
__iter__ = keys
def values(self):
with _IterationGuard(self):
for wr, value in self.data.items():
if wr() is not None:
yield value
def keyrefs(self):
"""Return a list of weak references to the keys.
The references are not guaranteed to be 'live' at the time
they are used, so the result of calling the references needs
to be checked before being used. This can be used to avoid
creating references that will cause the garbage collector to
keep the keys around longer than needed.
"""
return list(self.data)
def popitem(self):
self._dirty_len = True
while True:
key, value = self.data.popitem()
o = key()
if o is not None:
return o, value
def pop(self, key, *args):
self._dirty_len = True
return self.data.pop(WeakIdRef(key), *args) # CHANGED
def setdefault(self, key, default=None):
return self.data.setdefault(WeakIdRef(key, self._remove), default) # CHANGED
def update(self, dict=None, **kwargs):
d = self.data
if dict is not None:
if not hasattr(dict, "items"):
dict = type({})(dict)
for key, value in dict.items():
d[WeakIdRef(key, self._remove)] = value # CHANGED
if len(kwargs):
self.update(kwargs)
def __ior__(self, other):
self.update(other)
return self
def __or__(self, other):
if isinstance(other, _collections_abc.Mapping):
c = self.copy()
c.update(other)
return c
return NotImplemented
def __ror__(self, other):
if isinstance(other, _collections_abc.Mapping):
c = self.__class__()
c.update(other)
c.update(self)
return c
return NotImplemented
# Default Mapping equality will tests keys for equality, but
# we want to test ids for equality
def __eq__(self, other):
if not isinstance(other, Mapping):
return NotImplemented
return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()}
# Convenience alias
WeakTensorKeyDictionary = WeakIdKeyDictionary
class TensorWeakRef:
"""Wrapper around a weak ref of a Tensor that handles the _fix_weakref() call required when unwrapping a Tensor weakref."""
ref: WeakRef[Tensor]
def __init__(self, tensor: Tensor):
assert isinstance(tensor, Tensor)
self.ref = weakref.ref(tensor)
def __call__(self):
out = self.ref()
if out is None:
return out
assert isinstance(out, Tensor)
# TODO, add _fix_weakref type binding
out._fix_weakref() # type: ignore[attr-defined]
return out