Commit Graph

9 Commits

Author SHA1 Message Date
Chien-Chin Huang
d52f121dba [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)
**Why this PR?**

For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs  are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.

It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.

**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
2022-12-13 23:58:01 +00:00
Sergii Dymchenko
f51f6aa387 Fix non-existing parameters in docstrings (#90505)
Continuation after https://github.com/pytorch/pytorch/pull/90163.

Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators):

_Edit:_
I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script.

``` python
import ast
import os
import docstring_parser

for root, dirs, files in os.walk('.'):
    for name in files:
        if root.startswith("./.git/") or root.startswith("./third_party/"):
            continue
        if name.endswith(".py"):
            full_name = os.path.join(root, name)
            with open(full_name, "r") as source:
                tree = ast.parse(source.read())
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        all_node_args = node.args.args
                        if node.args.vararg is not None:
                            all_node_args.append(node.args.vararg)
                        if node.args.kwarg is not None:
                            all_node_args.append(node.args.kwarg)
                        if node.args.posonlyargs is not None:
                            all_node_args.extend(node.args.posonlyargs)
                        if node.args.kwonlyargs is not None:
                            all_node_args.extend(node.args.kwonlyargs)
                        args = [a.arg for a in all_node_args]
                        docstring = docstring_parser.parse(ast.get_docstring(node))
                        doc_args = [a.arg_name for a in docstring.params]
                        clean_doc_args = []
                        for a in doc_args:
                            clean_a = ""
                            for c in a.split()[0]:
                                if c.isalnum() or c == '_':
                                    clean_a += c
                            if clean_a:
                                clean_doc_args.append(clean_a)
                        doc_args = clean_doc_args
                        for a in doc_args:
                            if a not in args:
                                print(full_name, node.lineno, args, doc_args)
                            break

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505
Approved by: https://github.com/malfet, https://github.com/ZainRizvi
2022-12-09 21:43:09 +00:00
Charlie Yan
99fb39f508 reland #89243: [Composable API] replicate: add support for DDP args (#90255)
reland https://github.com/pytorch/pytorch/pull/89243
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90255
Approved by: https://github.com/zhaojuanmao
2022-12-07 15:22:33 +00:00
Charlie Yan
e818c36647 reland #89222: [Composable API] replicate: change to per module call, remove mark_root_module() (#90254)
reland https://github.com/pytorch/pytorch/pull/89222
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90254
Approved by: https://github.com/zhaojuanmao
2022-12-06 21:17:53 +00:00
PyTorch MergeBot
0d8e53dfe7 Revert "[Composable API] replicate: change to per module call, remove mark_root_module() (#89222)"
This reverts commit 65a0dcffd8.

Reverted https://github.com/pytorch/pytorch/pull/89222 on behalf of https://github.com/malfet due to Included unintended submodule updates
2022-12-06 03:26:28 +00:00
PyTorch MergeBot
3749b9dc73 Revert "[Composable API] replicate: add support for DDP args (#89243)"
This reverts commit 0f274ed385.

Reverted https://github.com/pytorch/pytorch/pull/89243 on behalf of https://github.com/malfet due to Depends on https://github.com/pytorch/pytorch/pull/89222 that introduced spurious module updates
2022-12-06 03:22:18 +00:00
Charlie Yan
0f274ed385 [Composable API] replicate: add support for DDP args (#89243)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89243
Approved by: https://github.com/zhaojuanmao
2022-12-05 21:38:23 +00:00
Charlie Yan
65a0dcffd8 [Composable API] replicate: change to per module call, remove mark_root_module() (#89222)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89222
Approved by: https://github.com/zhaojuanmao
2022-12-05 17:54:55 +00:00
Charlie Yan
f3af5ba48e [WIP] Composable API: replicate and DistributedState (#87649)
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is:
- create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API.
- create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object
- install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87649
Approved by: https://github.com/zhaojuanmao
2022-11-17 03:06:31 +00:00