mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
I've written custom parsers and emitters for everything from docstrings to classes and functions. However, I recently came across an issue when I was parsing/generating from the TensorFlow codebase: inconsistent use of `Args:` and `Arguments:` in its docstrings.
```sh
(pytorch#c348fae)$ for name in 'Args:' 'Arguments:'; do
printf '%-10s %04d\n' "$name" "$(rg -IFtpy --count-matches "$name" | paste -s -d+ -- | bc)"; done
Args: 1095
Arguments: 0336
```
It is easy enough to extend my parsers to support both variants, however it looks like `Arguments:` is wrong anyway, as per:
- https://google.github.io/styleguide/pyguide.html#doc-function-args @ [`ddccc0f`](https://github.com/google/styleguide/blob/ddccc0f/pyguide.md)
- https://chromium.googlesource.com/chromiumos/docs/+/master/styleguide/python.md#describing-arguments-in-docstrings @ [`9fc0fc0`](https://chromium.googlesource.com/chromiumos/docs/+/9fc0fc0/styleguide/python.md)
- https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html @ [`c0ae8e3`](https://github.com/sphinx-contrib/napoleon/blob/c0ae8e3/docs/source/example_google.rst)
Therefore, only `Args:` is valid. This PR replaces them throughout the codebase.
PS: For related PRs, see tensorflow/tensorflow/pull/45420
PPS: The trackbacks automatically appearing below are sending the same changes to other repositories in the [PyTorch](https://github.com/pytorch) organisation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49736
Reviewed By: albanD
Differential Revision: D25710534
Pulled By: soumith
fbshipit-source-id: 61e8ff01abb433e9f78185c2d1d0cbd7c22c1619
147 lines
5.6 KiB
Python
147 lines
5.6 KiB
Python
|
|
import copy
|
|
|
|
import torch.nn as nn
|
|
|
|
from .fuser_method_mappings import get_fuser_method
|
|
# for backward compatiblity
|
|
from .fuser_method_mappings import fuse_conv_bn # noqa: F401
|
|
from .fuser_method_mappings import fuse_conv_bn_relu # noqa: F40
|
|
|
|
from typing import List, Optional
|
|
|
|
# Generalization of getattr
|
|
def _get_module(model, submodule_key):
|
|
tokens = submodule_key.split('.')
|
|
cur_mod = model
|
|
for s in tokens:
|
|
cur_mod = getattr(cur_mod, s)
|
|
return cur_mod
|
|
|
|
# Generalization of setattr
|
|
def _set_module(model, submodule_key, module):
|
|
tokens = submodule_key.split('.')
|
|
sub_tokens = tokens[:-1]
|
|
cur_mod = model
|
|
for s in sub_tokens:
|
|
cur_mod = getattr(cur_mod, s)
|
|
|
|
setattr(cur_mod, tokens[-1], module)
|
|
|
|
def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
|
|
r"""Returns a list of modules that fuses the operations specified
|
|
in the input module list.
|
|
|
|
Fuses only the following sequence of modules:
|
|
conv, bn
|
|
conv, bn, relu
|
|
conv, relu
|
|
linear, relu
|
|
For these sequences, the first element in the output module list performs
|
|
the fused operation. The rest of the elements are set to nn.Identity()
|
|
"""
|
|
types = tuple(type(m) for m in mod_list)
|
|
fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
|
|
if fuser_method is None:
|
|
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
|
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
|
|
fused = fuser_method(*mod_list)
|
|
# NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
|
|
# Move pre forward hooks of the base module to resulting fused module
|
|
for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
|
|
fused.register_forward_pre_hook(pre_hook_fn)
|
|
del mod_list[0]._forward_pre_hooks[handle_id]
|
|
# Move post forward hooks of the last module to resulting fused module
|
|
for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
|
|
fused.register_forward_hook(hook_fn)
|
|
del mod_list[-1]._forward_hooks[handle_id]
|
|
new_mod[0] = fused
|
|
|
|
for i in range(1, len(mod_list)):
|
|
identity = nn.Identity()
|
|
identity.training = mod_list[0].training
|
|
new_mod[i] = identity
|
|
|
|
return new_mod
|
|
|
|
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
|
if fuse_custom_config_dict is None:
|
|
fuse_custom_config_dict = {}
|
|
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
|
mod_list = []
|
|
for item in modules_to_fuse:
|
|
mod_list.append(_get_module(model, item))
|
|
|
|
# Fuse list of modules
|
|
new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
|
|
|
|
# Replace original module list with fused module list
|
|
for i, item in enumerate(modules_to_fuse):
|
|
_set_module(model, item, new_mod_list[i])
|
|
|
|
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
|
r"""Fuses a list of modules into a single module
|
|
|
|
Fuses only the following sequence of modules:
|
|
conv, bn
|
|
conv, bn, relu
|
|
conv, relu
|
|
linear, relu
|
|
bn, relu
|
|
All other sequences are left unchanged.
|
|
For these sequences, replaces the first item in the list
|
|
with the fused module, replacing the rest of the modules
|
|
with identity.
|
|
|
|
Args:
|
|
model: Model containing the modules to be fused
|
|
modules_to_fuse: list of list of module names to fuse. Can also be a list
|
|
of strings if there is only a single list of modules to fuse.
|
|
inplace: bool specifying if fusion happens in place on the model, by default
|
|
a new model is returned
|
|
fuser_func: Function that takes in a list of modules and outputs a list of fused modules
|
|
of the same length. For example,
|
|
fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
|
|
Defaults to torch.quantization.fuse_known_modules
|
|
`fuse_custom_config_dict`: custom configuration for fusion
|
|
|
|
.. code-block:: python
|
|
|
|
# Example of fuse_custom_config_dict
|
|
fuse_custom_config_dict = {
|
|
# Additional fuser_method mapping
|
|
"additional_fuser_method_mapping": {
|
|
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
|
|
},
|
|
}
|
|
|
|
Returns:
|
|
model with fused modules. A new copy is created if inplace=True.
|
|
|
|
Examples::
|
|
|
|
>>> m = myModel()
|
|
>>> # m is a module containing the sub-modules below
|
|
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
|
|
>>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
|
|
>>> output = fused_m(input)
|
|
|
|
>>> m = myModel()
|
|
>>> # Alternately provide a single list of modules to fuse
|
|
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
|
|
>>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
|
|
>>> output = fused_m(input)
|
|
|
|
"""
|
|
if not inplace:
|
|
model = copy.deepcopy(model)
|
|
|
|
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
|
|
# Handle case of modules_to_fuse being a list
|
|
_fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
|
|
else:
|
|
# Handle case of modules_to_fuse being a list of lists
|
|
for module_list in modules_to_fuse:
|
|
_fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
|
|
return model
|