mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Fixes #112633 Fixed errors relating to pydocstyle in the following files. The remaining errors are not covered in this issue. `torch/utils/dlpack.py` was not modified as the errors are relating to the function signature in the first line in the docstring which must be maintained as is for proper Sphinx interpretation. ```python def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': """from_dlpack(ext_tensor) -> Tensor ..... """ ``` pydocstyle torch/utils/_contextlib.py --count before: 4 after: 0 pydocstyle torch/backends/mps/__init__.py --count before: 8 after: 1 **remaining errors** ``` torch/backends/mps/__init__.py:1 at module level: D104: Missing docstring in public package ``` pydocstyle torch/backends/xeon/run_cpu.py --count before: 13 after: 1 **remaining errors** ``` torch/backends/xeon/run_cpu.py:864 in public function `main`: D103: Missing docstring in public function ``` pydocstyle torch/backends/cpu/__init__.py --count before: 2 after: 1 **remaining errors** ``` torch/backends/cpu/__init__.py:1 at module level: D104: Missing docstring in public package ``` pydocstyle torch/utils/cpp_backtrace.py --count before: 4 after: 1 **remaining errors** ``` torch/utils/cpp_backtrace.py:1 at module level: D100: Missing docstring in public module ``` pydocstyle torch/utils/bundled_inputs.py --count before: 8 after: 1 **remaining errors** ``` torch/utils/bundled_inputs.py:1 at module level: D100: Missing docstring in public module ``` pydocstyle torch/utils/file_baton.py --count before: 8 after: 1 **remaining errors** ``` torch/utils/file_baton.py:1 at module level: D100: Missing docstring in public module ``` pydocstyle torch/utils/mobile_optimizer.py --count before: 6 after: 1 **remaining errors** ``` torch/utils/mobile_optimizer.py:8 in public class `LintCode`: D101: Missing docstring in public class ``` pydocstyle torch/backends/opt_einsum/__init__.py --count before: 7 after: 5 **remaining errors** ``` torch/backends/opt_einsum/__init__.py:1 at module level: D104: Missing docstring in public package torch/backends/opt_einsum/__init__.py:67 in public function `set_flags`: D103: Missing docstring in public function torch/backends/opt_einsum/__init__.py:77 in public function `flags`: D103: Missing docstring in public function torch/backends/opt_einsum/__init__.py:93 in public class `OptEinsumModule`: D101: Missing docstring in public class torch/backends/opt_einsum/__init__.py:94 in public method `__init__`: D107: Missing docstring in __init__ ``` pydocstyle torch/utils/_device.py --count before: 9 after: 6 **remaining errors** ``` torch/utils/_device.py:58 in public class `DeviceContext`: D101: Missing docstring in public class torch/utils/_device.py:59 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/_device.py:62 in public method `__enter__`: D105: Missing docstring in magic method torch/utils/_device.py:68 in public method `__exit__`: D105: Missing docstring in magic method torch/utils/_device.py:73 in public method `__torch_function__`: D105: Missing docstring in magic method torch/utils/_device.py:80 in public function `device_decorator`: D103: Missing docstring in public function ``` pydocstyle torch/utils/_freeze.py --count before: 15 after: 7 **remaining errors** ``` torch/utils/_freeze.py:77 in public function `indent_msg`: D103: Missing docstring in public function torch/utils/_freeze.py:89 in public class `FrozenModule`: D101: Missing docstring in public class torch/utils/_freeze.py:100 in public class `Freezer`: D101: Missing docstring in public class torch/utils/_freeze.py:101 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/_freeze.py:106 in public method `msg`: D102: Missing docstring in public method torch/utils/_freeze.py:185 in public method `get_module_qualname`: D102: Missing docstring in public method torch/utils/_freeze.py:206 in public method `compile_string`: D102: Missing docstring in public method ``` pydocstyle torch/utils/throughput_benchmark.py --count before: 25 after: 8 **remaining errors** ``` torch/utils/throughput_benchmark.py:1 at module level: D100: Missing docstring in public module torch/utils/throughput_benchmark.py:27 in public class `ExecutionStats`: D101: Missing docstring in public class torch/utils/throughput_benchmark.py:28 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/throughput_benchmark.py:33 in public method `latency_avg_ms`: D102: Missing docstring in public method torch/utils/throughput_benchmark.py:37 in public method `num_iters`: D102: Missing docstring in public method torch/utils/throughput_benchmark.py:46 in public method `total_time_seconds`: D102: Missing docstring in public method torch/utils/throughput_benchmark.py:50 in public method `__str__`: D105: Missing docstring in magic method torch/utils/throughput_benchmark.py:94 in public method `__init__`: D107: Missing docstring in __init__ ``` pydocstyle torch/utils/hooks.py --count before: 14 after: 11 **remaining errors** ``` torch/utils/hooks.py:1 at module level: D100: Missing docstring in public module torch/utils/hooks.py:23 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/hooks.py:34 in public method `remove`: D102: Missing docstring in public method torch/utils/hooks.py:44 in public method `__getstate__`: D105: Missing docstring in magic method torch/utils/hooks.py:50 in public method `__setstate__`: D105: Missing docstring in magic method torch/utils/hooks.py:64 in public method `__enter__`: D105: Missing docstring in magic method torch/utils/hooks.py:67 in public method `__exit__`: D105: Missing docstring in magic method torch/utils/hooks.py:82 in public function `warn_if_has_hooks`: D103: Missing docstring in public function torch/utils/hooks.py:103 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/hooks.py:188 in public method `setup_input_hook`: D102: Missing docstring in public method torch/utils/hooks.py:197 in public method `setup_output_hook`: D102: Missing docstring in public method ``` pydocstyle torch/utils/_traceback.py --count before: 19 after: 14 **remaining errors** ``` torch/utils/_traceback.py:47 in public function `report_compile_source_on_error`: D103: Missing docstring in public function torch/utils/_traceback.py:160 in public class `CapturedTraceback`: D101: Missing docstring in public class torch/utils/_traceback.py:163 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/_traceback.py:167 in public method `cleanup`: D102: Missing docstring in public method torch/utils/_traceback.py:170 in public method `summary`: D102: Missing docstring in public method torch/utils/_traceback.py:182 in public method `__getstate__`: D105: Missing docstring in magic method torch/utils/_traceback.py:190 in public method `extract`: D205: 1 blank line required between summary line and description (found 0) torch/utils/_traceback.py:190 in public method `extract`: D400: First line should end with a period (not 't') torch/utils/_traceback.py:213 in public method `format`: D205: 1 blank line required between summary line and description (found 0) torch/utils/_traceback.py:213 in public method `format`: D400: First line should end with a period (not 'f') torch/utils/_traceback.py:213 in public method `format`: D401: First line should be in imperative mood (perhaps 'Format', not 'Formats') torch/utils/_traceback.py:224 in public method `format_all`: D200: One-line docstring should fit on one line with quotes (found 3) torch/utils/_traceback.py:247 in private function `_extract_symbolized_tb`: D205: 1 blank line required between summary line and description (found 0) torch/utils/_traceback.py:247 in private function `_extract_symbolized_tb`: D400: First line should end with a period (not 'f') ``` pydocstyle torch/utils/mkldnn.py --count before: 28 after: 26 **remaining errors** ``` torch/utils/mkldnn.py:1 at module level: D100: Missing docstring in public module torch/utils/mkldnn.py:4 in public class `MkldnnLinear`: D101: Missing docstring in public class torch/utils/mkldnn.py:5 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/mkldnn.py:19 in public method `__getstate__`: D105: Missing docstring in magic method torch/utils/mkldnn.py:23 in public method `__setstate__`: D105: Missing docstring in magic method torch/utils/mkldnn.py:29 in public method `forward`: D102: Missing docstring in public method torch/utils/mkldnn.py:75 in public class `MkldnnConv1d`: D101: Missing docstring in public class torch/utils/mkldnn.py:76 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/mkldnn.py:82 in public method `__setstate__`: D105: Missing docstring in magic method torch/utils/mkldnn.py:88 in public class `MkldnnConv2d`: D101: Missing docstring in public class torch/utils/mkldnn.py:89 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/mkldnn.py:100 in public method `__setstate__`: D105: Missing docstring in magic method torch/utils/mkldnn.py:110 in public class `MkldnnConv3d`: D101: Missing docstring in public class torch/utils/mkldnn.py:111 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/mkldnn.py:122 in public method `__setstate__`: D105: Missing docstring in magic method torch/utils/mkldnn.py:133 in public class `MkldnnBatchNorm`: D101: Missing docstring in public class torch/utils/mkldnn.py:136 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/mkldnn.py:155 in public method `__getstate__`: D105: Missing docstring in magic method torch/utils/mkldnn.py:163 in public method `__setstate__`: D105: Missing docstring in magic method torch/utils/mkldnn.py:171 in public method `forward`: D102: Missing docstring in public method torch/utils/mkldnn.py:184 in public class `MkldnnPrelu`: D101: Missing docstring in public class torch/utils/mkldnn.py:185 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/mkldnn.py:190 in public method `__getstate__`: D105: Missing docstring in magic method torch/utils/mkldnn.py:194 in public method `__setstate__`: D105: Missing docstring in magic method torch/utils/mkldnn.py:199 in public method `forward`: D102: Missing docstring in public method torch/utils/mkldnn.py:205 in public function `to_mkldnn`: D103: Missing docstring in public function ``` pydocstyle torch/utils/weak.py --count before: 32 after: 30 **remaining errors** ``` torch/utils/weak.py:1 at module level: D100: Missing docstring in public module torch/utils/weak.py:42 in public class `WeakIdRef`: D101: Missing docstring in public class torch/utils/weak.py:45 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/weak.py:54 in public method `__call__`: D102: Missing docstring in public method torch/utils/weak.py:61 in public method `__hash__`: D105: Missing docstring in magic method torch/utils/weak.py:64 in public method `__eq__`: D105: Missing docstring in magic method torch/utils/weak.py:84 in public class `WeakIdKeyDictionary`: D101: Missing docstring in public class torch/utils/weak.py:87 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/weak.py:131 in public method `__delitem__`: D105: Missing docstring in magic method torch/utils/weak.py:135 in public method `__getitem__`: D105: Missing docstring in magic method torch/utils/weak.py:138 in public method `__len__`: D105: Missing docstring in magic method torch/utils/weak.py:145 in public method `__repr__`: D105: Missing docstring in magic method torch/utils/weak.py:148 in public method `__setitem__`: D105: Missing docstring in magic method torch/utils/weak.py:151 in public method `copy`: D102: Missing docstring in public method torch/utils/weak.py:162 in public method `__deepcopy__`: D105: Missing docstring in magic method torch/utils/weak.py:172 in public method `get`: D102: Missing docstring in public method torch/utils/weak.py:175 in public method `__contains__`: D105: Missing docstring in magic method torch/utils/weak.py:182 in public method `items`: D102: Missing docstring in public method torch/utils/weak.py:189 in public method `keys`: D102: Missing docstring in public method torch/utils/weak.py:198 in public method `values`: D102: Missing docstring in public method torch/utils/weak.py:216 in public method `popitem`: D102: Missing docstring in public method torch/utils/weak.py:224 in public method `pop`: D102: Missing docstring in public method torch/utils/weak.py:228 in public method `setdefault`: D102: Missing docstring in public method torch/utils/weak.py:231 in public method `update`: D102: Missing docstring in public method torch/utils/weak.py:241 in public method `__ior__`: D105: Missing docstring in magic method torch/utils/weak.py:245 in public method `__or__`: D105: Missing docstring in magic method torch/utils/weak.py:252 in public method `__ror__`: D105: Missing docstring in magic method torch/utils/weak.py:262 in public method `__eq__`: D105: Missing docstring in magic method torch/utils/weak.py:276 in public method `__init__`: D107: Missing docstring in __init__ torch/utils/weak.py:280 in public method `__call__`: D102: Missing docstring in public method ``` @mikaylagawarecki @jbschlosser @svekars Pull Request resolved: https://github.com/pytorch/pytorch/pull/113311 Approved by: https://github.com/ezyang
250 lines
9.2 KiB
Python
250 lines
9.2 KiB
Python
import torch
|
|
from collections import OrderedDict
|
|
import weakref
|
|
import warnings
|
|
from typing import Any, Tuple
|
|
|
|
__all__ = ["RemovableHandle", "unserializable_hook", "warn_if_has_hooks", "BackwardHook"]
|
|
|
|
class RemovableHandle:
|
|
r"""
|
|
A handle which provides the capability to remove a hook.
|
|
|
|
Args:
|
|
hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``.
|
|
extra_dict (Union[dict, List[dict]]): An additional dictionary or list of
|
|
dictionaries whose keys will be deleted when the same keys are
|
|
removed from ``hooks_dict``.
|
|
"""
|
|
|
|
id: int
|
|
next_id: int = 0
|
|
|
|
def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None:
|
|
self.hooks_dict_ref = weakref.ref(hooks_dict)
|
|
self.id = RemovableHandle.next_id
|
|
RemovableHandle.next_id += 1
|
|
|
|
self.extra_dict_ref: Tuple = ()
|
|
if isinstance(extra_dict, dict):
|
|
self.extra_dict_ref = (weakref.ref(extra_dict),)
|
|
elif isinstance(extra_dict, list):
|
|
self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict)
|
|
|
|
def remove(self) -> None:
|
|
hooks_dict = self.hooks_dict_ref()
|
|
if hooks_dict is not None and self.id in hooks_dict:
|
|
del hooks_dict[self.id]
|
|
|
|
for ref in self.extra_dict_ref:
|
|
extra_dict = ref()
|
|
if extra_dict is not None and self.id in extra_dict:
|
|
del extra_dict[self.id]
|
|
|
|
def __getstate__(self):
|
|
if self.extra_dict_ref is None:
|
|
return (self.hooks_dict_ref(), self.id)
|
|
else:
|
|
return (self.hooks_dict_ref(), self.id, tuple(ref() for ref in self.extra_dict_ref))
|
|
|
|
def __setstate__(self, state) -> None:
|
|
if state[0] is None:
|
|
# create a dead reference
|
|
self.hooks_dict_ref = weakref.ref(OrderedDict())
|
|
else:
|
|
self.hooks_dict_ref = weakref.ref(state[0])
|
|
self.id = state[1]
|
|
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
|
|
|
|
if len(state) < 3 or state[2] is None:
|
|
self.extra_dict_ref = ()
|
|
else:
|
|
self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2])
|
|
|
|
def __enter__(self) -> "RemovableHandle":
|
|
return self
|
|
|
|
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
|
|
self.remove()
|
|
|
|
|
|
def unserializable_hook(f):
|
|
"""
|
|
Mark a function as an unserializable hook with this decorator.
|
|
|
|
This suppresses warnings that would otherwise arise if you attempt
|
|
to serialize a tensor that has a hook.
|
|
"""
|
|
f.__torch_unserializable__ = True
|
|
return f
|
|
|
|
|
|
def warn_if_has_hooks(tensor):
|
|
if tensor._backward_hooks:
|
|
for k in tensor._backward_hooks:
|
|
hook = tensor._backward_hooks[k]
|
|
if not hasattr(k, "__torch_unserializable__"):
|
|
warnings.warn(f"backward hook {repr(hook)} on tensor will not be "
|
|
"serialized. If this is expected, you can "
|
|
"decorate the function with @torch.utils.hooks.unserializable_hook "
|
|
"to suppress this warning")
|
|
|
|
class BackwardHook:
|
|
"""
|
|
A wrapper class to implement nn.Module backward hooks.
|
|
|
|
It handles:
|
|
- Ignoring non-Tensor inputs and replacing them by None before calling the user hook
|
|
- Generating the proper Node to capture a set of Tensor's gradients
|
|
- Linking the gradients captures for the outputs with the gradients captured for the input
|
|
- Calling the user hook once both output and input gradients are available
|
|
"""
|
|
|
|
def __init__(self, module, user_hooks, user_pre_hooks):
|
|
self.user_hooks = user_hooks
|
|
self.user_pre_hooks = user_pre_hooks
|
|
self.module = module
|
|
|
|
self.grad_outputs = None
|
|
self.n_outputs = -1
|
|
self.output_tensors_index = None
|
|
self.n_inputs = -1
|
|
self.input_tensors_index = None
|
|
|
|
def _pack_with_none(self, indices, values, size):
|
|
res = [None] * size
|
|
for idx, val in zip(indices, values):
|
|
res[idx] = val
|
|
|
|
return tuple(res)
|
|
|
|
def _unpack_none(self, indices, values):
|
|
res = []
|
|
for idx in indices:
|
|
res.append(values[idx])
|
|
|
|
return tuple(res)
|
|
|
|
def _set_user_hook(self, grad_fn):
|
|
def hook(grad_input, _):
|
|
if self.grad_outputs is None:
|
|
# This happens because the gradient in your nn.Module flows to
|
|
# the Module's input without " passing through the Module's
|
|
# output, e.g. when you're doing double backward.
|
|
return
|
|
res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs)
|
|
|
|
for hook in self.user_hooks:
|
|
out = hook(self.module, res, self.grad_outputs)
|
|
|
|
if out is None:
|
|
continue
|
|
|
|
if len(out) != len(res):
|
|
raise RuntimeError("Backward hook returned an invalid number of grad_input, "
|
|
f"got {len(out)}, but expected {len(res)}")
|
|
|
|
res = out
|
|
|
|
self.grad_outputs = None
|
|
|
|
return self._unpack_none(self.input_tensors_index, res)
|
|
|
|
grad_fn.register_hook(hook)
|
|
|
|
def _apply_on_tensors(self, fn, args):
|
|
# Can be used to apply the given function to the tensors contained in the
|
|
# args. Will return updated args and the tensors indices
|
|
tensors_idx = []
|
|
tensors = []
|
|
|
|
requires_grad = False
|
|
for i, arg in enumerate(args):
|
|
if isinstance(arg, torch.Tensor):
|
|
tensors_idx.append(i)
|
|
tensors.append(arg)
|
|
requires_grad |= arg.requires_grad
|
|
|
|
if not (requires_grad and torch.is_grad_enabled()):
|
|
return args, None
|
|
|
|
new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors)
|
|
if len(new_tensors) == 0:
|
|
raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.")
|
|
|
|
grad_fns = [t.grad_fn for t in new_tensors if t.grad_fn is not None and t.grad_fn.name() == "BackwardHookFunctionBackward"]
|
|
if len(grad_fns) == 0:
|
|
raise RuntimeError("Error while setting up backward hooks. Please open "
|
|
"an issue with a code sample to reproduce this.")
|
|
|
|
fn(grad_fns[0])
|
|
|
|
arg_list = list(args)
|
|
for idx, val in zip(tensors_idx, new_tensors):
|
|
arg_list[idx] = val
|
|
|
|
if type(args) is tuple:
|
|
out = tuple(arg_list)
|
|
else:
|
|
out = type(args)(*arg_list)
|
|
return out, tensors_idx
|
|
|
|
def setup_input_hook(self, args):
|
|
def fn(grad_fn):
|
|
self._set_user_hook(grad_fn)
|
|
|
|
res, input_idx = self._apply_on_tensors(fn, args)
|
|
self.n_inputs = len(args)
|
|
self.input_tensors_index = input_idx
|
|
return res
|
|
|
|
def setup_output_hook(self, args):
|
|
def fn(grad_fn):
|
|
def hook(_, grad_output):
|
|
self.grad_outputs = self._pack_with_none(self.output_tensors_index,
|
|
grad_output,
|
|
self.n_outputs)
|
|
|
|
if self.user_pre_hooks:
|
|
expected_len = len(self.grad_outputs)
|
|
for user_pre_hook in self.user_pre_hooks:
|
|
hook_grad_outputs = user_pre_hook(self.module, self.grad_outputs)
|
|
if hook_grad_outputs is None:
|
|
continue
|
|
|
|
actual_len = len(hook_grad_outputs)
|
|
if actual_len != expected_len:
|
|
raise RuntimeError("Backward pre hook returned an invalid number of grad_output, "
|
|
f"got {actual_len}, but expected {expected_len}")
|
|
self.grad_outputs = hook_grad_outputs
|
|
|
|
# Special case if no input required gradients, this hook should call the user
|
|
# hook directly
|
|
if self.input_tensors_index is None:
|
|
grad_inputs = self._pack_with_none([], [], self.n_inputs)
|
|
for user_hook in self.user_hooks:
|
|
res = user_hook(self.module, grad_inputs, self.grad_outputs)
|
|
if res is not None and not (isinstance(res, tuple) and all(el is None for el in res)):
|
|
raise RuntimeError("Backward hook for Modules where no input requires "
|
|
"gradient should always return None or None for all gradients.")
|
|
self.grad_outputs = None
|
|
|
|
if self.grad_outputs is not None:
|
|
assert self.output_tensors_index is not None # mypy
|
|
return tuple(self.grad_outputs[i] for i in self.output_tensors_index)
|
|
|
|
grad_fn.register_hook(hook)
|
|
|
|
is_tuple = True
|
|
if not isinstance(args, tuple):
|
|
args = (args,)
|
|
is_tuple = False
|
|
|
|
res, output_idx = self._apply_on_tensors(fn, args)
|
|
self.n_outputs = len(args)
|
|
self.output_tensors_index = output_idx
|
|
|
|
if not is_tuple:
|
|
res = res[0]
|
|
return res
|