mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Continuation after https://github.com/pytorch/pytorch/pull/90163. Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators): _Edit:_ I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script. ``` python import ast import os import docstring_parser for root, dirs, files in os.walk('.'): for name in files: if root.startswith("./.git/") or root.startswith("./third_party/"): continue if name.endswith(".py"): full_name = os.path.join(root, name) with open(full_name, "r") as source: tree = ast.parse(source.read()) for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): all_node_args = node.args.args if node.args.vararg is not None: all_node_args.append(node.args.vararg) if node.args.kwarg is not None: all_node_args.append(node.args.kwarg) if node.args.posonlyargs is not None: all_node_args.extend(node.args.posonlyargs) if node.args.kwonlyargs is not None: all_node_args.extend(node.args.kwonlyargs) args = [a.arg for a in all_node_args] docstring = docstring_parser.parse(ast.get_docstring(node)) doc_args = [a.arg_name for a in docstring.params] clean_doc_args = [] for a in doc_args: clean_a = "" for c in a.split()[0]: if c.isalnum() or c == '_': clean_a += c if clean_a: clean_doc_args.append(clean_a) doc_args = clean_doc_args for a in doc_args: if a not in args: print(full_name, node.lineno, args, doc_args) break ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505 Approved by: https://github.com/malfet, https://github.com/ZainRizvi
87 lines
2.5 KiB
Python
87 lines
2.5 KiB
Python
from typing import List, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from . import _ddp
|
|
from .contract import contract
|
|
|
|
|
|
class _ReplicateState:
|
|
def __init__(self) -> None:
|
|
self.modules: List[nn.Module] = []
|
|
self.has_initialized: bool = False
|
|
self._param_list: nn.ParameterList = nn.ParameterList()
|
|
self.kwargs: dict = {}
|
|
|
|
def mark_modules(self, *modules: nn.Module, **kwargs) -> None:
|
|
for module in modules:
|
|
self.modules.append(module)
|
|
replicate.state(module)._distributed_state = self
|
|
replicate.state(module)._params_collected = False
|
|
module.register_forward_pre_hook(self.forward_pre_hook)
|
|
# TODO(@yhcharles): fix type error
|
|
module.register_forward_hook(self.forward_post_hook) # type: ignore[arg-type]
|
|
self.kwargs = kwargs
|
|
|
|
def _recursive_collect_params(self, module: nn.Module) -> None:
|
|
# TODO: skip if managed by other APIs
|
|
|
|
if hasattr(replicate.state(module), "_params_collected"):
|
|
if replicate.state(module)._params_collected:
|
|
return
|
|
replicate.state(module)._params_collected = True
|
|
|
|
self._param_list.extend(
|
|
param
|
|
for param in module.parameters(recurse=False)
|
|
# for param in module.parameters()
|
|
if param.requires_grad
|
|
)
|
|
for child in module.children():
|
|
self._recursive_collect_params(child)
|
|
|
|
def init_helper(self) -> None:
|
|
if self.has_initialized:
|
|
return
|
|
|
|
self.has_initialized = True
|
|
for module in self.modules:
|
|
self._recursive_collect_params(module)
|
|
|
|
self._ddp = _ddp.DistributedDataParallel(
|
|
self._param_list, **self.kwargs
|
|
)
|
|
|
|
def forward_pre_hook(
|
|
self, module: nn.Module, input: Tuple[torch.Tensor]
|
|
) -> None:
|
|
self.init_helper()
|
|
self._ddp.pre_forward()
|
|
|
|
def forward_post_hook(
|
|
self,
|
|
module: nn.Module,
|
|
input: Tuple[torch.Tensor],
|
|
output: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
return self._ddp.post_forward(output)
|
|
|
|
|
|
@contract
|
|
def replicate(
|
|
module: nn.Module, # NOTE: contract now supports single module only
|
|
**kwargs,
|
|
) -> nn.Module:
|
|
r"""Replicates a module
|
|
|
|
Args:
|
|
module (torch.nn.Module): module to replicate
|
|
|
|
Example::
|
|
>>> module = nn.Linear(3, 3)
|
|
>>> replicate(module)
|
|
"""
|
|
_ReplicateState().mark_modules(module, **kwargs)
|
|
return module
|