pytorch/torch/distributed/_composable/replicate.py
Sergii Dymchenko f51f6aa387 Fix non-existing parameters in docstrings (#90505)
Continuation after https://github.com/pytorch/pytorch/pull/90163.

Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators):

_Edit:_
I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script.

``` python
import ast
import os
import docstring_parser

for root, dirs, files in os.walk('.'):
    for name in files:
        if root.startswith("./.git/") or root.startswith("./third_party/"):
            continue
        if name.endswith(".py"):
            full_name = os.path.join(root, name)
            with open(full_name, "r") as source:
                tree = ast.parse(source.read())
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        all_node_args = node.args.args
                        if node.args.vararg is not None:
                            all_node_args.append(node.args.vararg)
                        if node.args.kwarg is not None:
                            all_node_args.append(node.args.kwarg)
                        if node.args.posonlyargs is not None:
                            all_node_args.extend(node.args.posonlyargs)
                        if node.args.kwonlyargs is not None:
                            all_node_args.extend(node.args.kwonlyargs)
                        args = [a.arg for a in all_node_args]
                        docstring = docstring_parser.parse(ast.get_docstring(node))
                        doc_args = [a.arg_name for a in docstring.params]
                        clean_doc_args = []
                        for a in doc_args:
                            clean_a = ""
                            for c in a.split()[0]:
                                if c.isalnum() or c == '_':
                                    clean_a += c
                            if clean_a:
                                clean_doc_args.append(clean_a)
                        doc_args = clean_doc_args
                        for a in doc_args:
                            if a not in args:
                                print(full_name, node.lineno, args, doc_args)
                            break

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505
Approved by: https://github.com/malfet, https://github.com/ZainRizvi
2022-12-09 21:43:09 +00:00

87 lines
2.5 KiB
Python

from typing import List, Tuple
import torch
import torch.nn as nn
from . import _ddp
from .contract import contract
class _ReplicateState:
def __init__(self) -> None:
self.modules: List[nn.Module] = []
self.has_initialized: bool = False
self._param_list: nn.ParameterList = nn.ParameterList()
self.kwargs: dict = {}
def mark_modules(self, *modules: nn.Module, **kwargs) -> None:
for module in modules:
self.modules.append(module)
replicate.state(module)._distributed_state = self
replicate.state(module)._params_collected = False
module.register_forward_pre_hook(self.forward_pre_hook)
# TODO(@yhcharles): fix type error
module.register_forward_hook(self.forward_post_hook) # type: ignore[arg-type]
self.kwargs = kwargs
def _recursive_collect_params(self, module: nn.Module) -> None:
# TODO: skip if managed by other APIs
if hasattr(replicate.state(module), "_params_collected"):
if replicate.state(module)._params_collected:
return
replicate.state(module)._params_collected = True
self._param_list.extend(
param
for param in module.parameters(recurse=False)
# for param in module.parameters()
if param.requires_grad
)
for child in module.children():
self._recursive_collect_params(child)
def init_helper(self) -> None:
if self.has_initialized:
return
self.has_initialized = True
for module in self.modules:
self._recursive_collect_params(module)
self._ddp = _ddp.DistributedDataParallel(
self._param_list, **self.kwargs
)
def forward_pre_hook(
self, module: nn.Module, input: Tuple[torch.Tensor]
) -> None:
self.init_helper()
self._ddp.pre_forward()
def forward_post_hook(
self,
module: nn.Module,
input: Tuple[torch.Tensor],
output: torch.Tensor,
) -> torch.Tensor:
return self._ddp.post_forward(output)
@contract
def replicate(
module: nn.Module, # NOTE: contract now supports single module only
**kwargs,
) -> nn.Module:
r"""Replicates a module
Args:
module (torch.nn.Module): module to replicate
Example::
>>> module = nn.Linear(3, 3)
>>> replicate(module)
"""
_ReplicateState().mark_modules(module, **kwargs)
return module