Commit Graph

38 Commits

Author SHA1 Message Date
Wei Feng
472b0daeaa [DDP][FSDP2] keep DTensor params for replicate(fully_shard) (#133059)
current status: for `replicate(fully_shard)`, DDP lazy_init will convert DTensor into local tensor, and that breaks FSDP unshard

this PR keeps FSDP params untouched during DDP lazy_init
I came across it because of a CI error in FSDP2's unit test #132978
thanks @awgu for fix proposal

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133059
Approved by: https://github.com/Skylion007, https://github.com/fegin
2024-08-09 18:38:05 +00:00
wz337
0ff0bf3d31 [Replicate] Fix replicate with DeviceMesh initialization (#133024)
A follow up on https://github.com/pytorch/pytorch/pull/132339.

`get_parent_mesh` is replaced by `get_root_mesh`. In addition, modify a few places that parent mesh is mentioned in test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133024
Approved by: https://github.com/Skylion007, https://github.com/fegin
2024-08-09 00:45:47 +00:00
wz337
87053132ea [DeviceMesh] Remove parent mesh concept from _MeshEnv and replace by root mesh (#132339)
Previously, when we slice out a submesh from a mesh, we assign the mesh as the parent mesh of the submesh. In this case, when we have a 3D mesh topology, the parent mesh of a 1D mesh sliced out from the 3D mesh is different from the parent mesh of the same 1D mesh sliced out from the 2D submesh of the 3D mesh. For example:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]

mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 =  mesh_2d["dim0_2"]

# This would evaluate to be True
print(_mesh_resources.get_parent_mesh(mesh_dim0) != _mesh_resources.get_parent_mesh(mesh_dim0))
```

We can always reconstruct the mesh needed from the mesh dim names, as long as two dims come from the same root. For simplicity, we do not see the necessity of building a tree structure to represent child-parent relationship. Therefore, we are replacing the parent mesh concept with a root mesh concept in `_MeshEnv` so we would have:

```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]

mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 =  mesh_2d["dim0_2"]

# This would evaluate to be True
print(_mesh_resources.get_root_mesh(mesh_dim0) == _mesh_resources.get_root_mesh(mesh_dim0))
```
With this change, we will have two types of meshes in an environment.
1. `device_mesh != _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is created by slicing.
2. `device_mesh == _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is a root mesh not created through slicing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132339
Approved by: https://github.com/wanchaol
ghstack dependencies: #132310, #132311
2024-08-07 07:01:12 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Xuehai Pan
973037be6a [BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() (#130199)
This PR changes the empty collection factory call to Python literals:

- `list()` -> `[]`
- `tuple()` -> `()`
- `dict()` -> `{}`

The Python literals are more performant and safer. For example, the bytecode for building an empty dictionary:

```bash
$ python3 -m dis - <<EOS
import collections

d1 = {}
d2 = dict()

dict = collections.OrderedDict
d3 = dict()
EOS
```

```text
  0           0 RESUME                   0

  1           2 LOAD_CONST               0 (0)
              4 LOAD_CONST               1 (None)
              6 IMPORT_NAME              0 (collections)
              8 STORE_NAME               0 (collections)

  3          10 BUILD_MAP                0
             12 STORE_NAME               1 (d1)

  4          14 PUSH_NULL
             16 LOAD_NAME                2 (dict)
             18 CALL                     0
             26 STORE_NAME               3 (d2)

  6          28 LOAD_NAME                0 (collections)
             30 LOAD_ATTR                8 (OrderedDict)
             50 STORE_NAME               2 (dict)

  7          52 PUSH_NULL
             54 LOAD_NAME                2 (dict)
             56 CALL                     0
             64 STORE_NAME               5 (d3)
             66 RETURN_CONST             1 (None)
```

The dict literal `{}` only has one bytecode `BUILD_MAP`, while the factory call `dict()` has three `PUSH_NULL + LOAD_NAME + CALL`. Also, the factory call is not safe if users override the `dict` name in `locals` or `globals` (see the example of replacing with `OrderedDict` above).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130199
Approved by: https://github.com/malfet
2024-07-11 17:30:28 +00:00
Xuehai Pan
94dc3253a0 [BE][Easy] enable UFMT for torch/distributed/ (#128870)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870
Approved by: https://github.com/fegin, https://github.com/wconstab
2024-06-22 18:53:28 +00:00
PyTorch MergeBot
9c929f6ce9 Revert "[BE][Easy] enable UFMT for torch/distributed/ (#128870)"
This reverts commit a0e1e20c41.

Reverted https://github.com/pytorch/pytorch/pull/128870 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128870#issuecomment-2181780356))
2024-06-21 00:38:28 +00:00
Xuehai Pan
a0e1e20c41 [BE][Easy] enable UFMT for torch/distributed/ (#128870)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870
Approved by: https://github.com/fegin
ghstack dependencies: #128868, #128869
2024-06-18 21:49:08 +00:00
Aaron Orenstein
3a0d088517 Flip default value for mypy disallow_untyped_defs [5/11] (#127842)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127842
Approved by: https://github.com/oulgen
2024-06-08 18:49:18 +00:00
Xuehai Pan
ba3b05fdf3 [1/N][Easy] fix typo for usort config in pyproject.toml (kown -> known): sort stdlib (#127122)
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127122
Approved by: https://github.com/kit1980
2024-05-25 08:25:50 +00:00
Chien-Chin Huang
1eb7b8eb60 [PT2D] Ensure the trace rules are correct with distributed (#125333)
Summary:
1. Avoid using `torch._dynamo.disable`.
2. Clear the LRU cache of the trace rules. This won't do anything if rules are not evluated before PG initilization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125333
Approved by: https://github.com/yanboliang
2024-05-02 16:28:38 +00:00
Chien-Chin Huang
f3af049b88 [DDP][PT2D] Fix the import issue (#124846)
As title

Differential Revision: [D56521582](https://our.internmc.facebook.com/intern/diff/D56521582/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124846
Approved by: https://github.com/LucasLLC, https://github.com/wz337
ghstack dependencies: #124421, #124422, #123424
2024-04-25 11:08:27 +00:00
Chien-Chin Huang
290bfbe01f [DDP][PT2D] Lazy Initialization of DDP Module for Replicate API (#123424)
In order to make replicate work with Meta tensor, we need to do lazy Initialization for the replicate API. This PR impelements the lazy initialization and ensures that replicate still work with the new DDP compilation.

Differential Revision: [D55787340](https://our.internmc.facebook.com/intern/diff/D55787340/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123424
Approved by: https://github.com/yf225
ghstack dependencies: #124421, #124422
2024-04-24 06:30:19 +00:00
Chien-Chin Huang
c7193f4099 [DDP][PT2D][2D] Enable DDP + TP and add test for compiled DDP + TP (#120479)
This PR enables DDP + TP using a TP internal API. This should not be the final implementation. A more sound implementation is to inline the TP internal API in DDP. In other words, DDP needs to be aware of DTensor so that we can support 2D state_dict.

This PR adds a compiled DDP + TP test to ensure the new compiled DDP fusion doesn't break TP all_reduce.

**TODOs**

- [x] Implement DDP allreduce fusion algorithm for Inductor post_grad pass.
- [x] Add unit tests to ensure the fusion doesn't DDP + TP.
- [ ] Group different PG and data type of all_reduces.
- [ ] Mixed precision supports and tests
- [ ] Implement the fusions with Inductor IR.
- [ ] Add auto bucketing based on Inductor profiling.

Differential Revision: [D54105050](https://our.internmc.facebook.com/intern/diff/D54105050/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120479
Approved by: https://github.com/wz337
ghstack dependencies: #113209
2024-03-13 21:41:22 +00:00
Chien-Chin Huang
b1fb591272 [replicate] Simplify replicate() init logic and remove unnecessary variables in _ReplicateState (#113679)
Many variables _ReplicateState are created because replicate() was lazy initialized. This PR removes these variables and simplifes the logic.y

Differential Revision: [D51317874](https://our.internmc.facebook.com/intern/diff/D51317874/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113679
Approved by: https://github.com/awgu
2023-11-28 00:55:36 +00:00
Andrew Gu
20eaa49dde [PT-D] Made _get_registry return None if no APIs applied (#113654)
I prefer to not modify the module if it does not have any of our APIs applied. The side effect of inserting a registry on the module when calling a getter is non-intuitive to me.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113654
Approved by: https://github.com/fegin
2023-11-14 20:28:11 +00:00
Chien-Chin Huang
b35279dfac [DDP] Make _ReplicateState inherit from _State and make replicate eagerly initialized (#109647)
Follow how fully_shard store the _FSDPState, this PR makes _ReplicateState inherit from _State.  This PR also makes replicate eagerly initialize the internal DDP instance so that users can access the required methods/functions before the first forward().

Differential Revision: [D49428291](https://our.internmc.facebook.com/intern/diff/D49428291/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109647
Approved by: https://github.com/wz337, https://github.com/rohan-varma
ghstack dependencies: #110688
2023-10-12 07:58:39 +00:00
Rohan Varma
0ecca122e7 [Replicate] Add unit test with replicate param names (#102401)
This attribute wasn't actually used in tests, add a test ensuring that
if replicate is used on top of FSDP, the replicated parameter names are as
expected.

TODO: there are a few ways to check if module is managed by composable API,
such as replicated param names for replicate, _get_module_state API,
_get_registry_api, etc. We should unify all composable APIs to check in a
unified way (filed an issue)

Differential Revision: [D46236377](https://our.internmc.facebook.com/intern/diff/D46236377/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102401
Approved by: https://github.com/awgu
2023-05-31 18:41:03 +00:00
Rohan Varma
8869897ebe [replicate] support simpler device_id (#100217)
Allow passing in `device_id=[device]` regardless of CPU or GPU. We
modify the kwarg as needed to pass to DDP.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100217
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao
2023-05-04 21:06:04 +00:00
Rohan Varma
253b9d3247 [replicate] input casting support (#100216)
Supports input casting by doing this in the pre hook.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100216
Approved by: https://github.com/awgu
2023-05-04 01:46:15 +00:00
Rohan Varma
ef11966aff [composable] Enable replicate + trec_shard overall (#98890)
replicate + trec_shard works if we shard / replicate individually, such as follows:

```
m = TestSparseNN()
shard(m.sparse)
replicate(m.dense)
```

but does not work if users do the following:
```
m = TestSparseNN()
shard(m, sharders=[...])
replicate(m)
```

Many upstream trainers use the latter use case, as sharding is not done on individual module level but rather overall module by specifying planners that contain logic for how to shard different embedding table types.

This diff enables the latter approach (while keeping the former intact), but users need to specify `ignored_modules` to ignore embedding tables in replicate(). This is similar to FSDP (class based and composable) and DDP today.

Differential Revision: [D44899155](https://our.internmc.facebook.com/intern/diff/D44899155/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98890
Approved by: https://github.com/mrshenli, https://github.com/yhcharles
2023-04-15 01:09:00 +00:00
Rohan Varma
51ff9ce997 [Replicate] Simplify code a bit (#98889)
Simplifies the code, such as making self.modules not a list and only a
single module.

Differential Revision: [D44899281](https://our.internmc.facebook.com/intern/diff/D44899281/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98889
Approved by: https://github.com/mrshenli, https://github.com/yhcharles
2023-04-13 03:21:06 +00:00
Charlie Yan
721260e966 [3/n] Consolidate replicate and DDP: update replicate to reuse functions in DDP (#96660)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96660
Approved by: https://github.com/rohan-varma
2023-03-30 03:54:34 +00:00
Colin Taylor
e5496ebcac [torch] [composable] [analytics] add analytics logging to PT-D composable APIs (#95016)
Summary: as title

Test Plan: N/A

Differential Revision: D43376274

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95016
Approved by: https://github.com/awgu, https://github.com/rohan-varma, https://github.com/fegin
2023-02-17 02:49:16 +00:00
Andrew Gu
0d4bbd1996 [Lint] Add FSDP/composable API files to ufmt include (#90873)
This PR adds FSDP and composable API files to `.lintrunner.toml` so that (1) lintrunner enforces that those files are formatted and (2) `lintrunner f` formats those files for you.

There are two requirements here (see https://github.com/pytorch/pytorch/wiki/lintrunner for details):
1. Install lintrunner:
```
pip install lintrunner
lintrunner init
```
2. `lintrunner f` before you finalize your PR, which would now be enforced by CI after this PR.

The code changes in this PR outside of `.lintrunner.toml` are the result of `lintrunner f`.

---

I only plan to land this PR if all of the composable API developers agree that this is something that makes sense and is not too intrusive to the workflow.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90873
Approved by: https://github.com/yhcharles, https://github.com/mrshenli, https://github.com/rohan-varma
2023-01-18 05:33:34 +00:00
Matthew Hoffman
a26e5e21b5 Improve type hints for Module forward hooks (#92061)
Fixes #91654.

Currently, the `hook` parameters of `nn.Module.register_forward_pre_hook` and `nn.Module.register_forward_hook` are typed as `Callable[..., None]`, which 1) does not enable the validation of the signature of `hook` and 2) incorrectly restricts the return type of `hook`, which the docstrings of these methods themselves state can be non-`None`.

The typing of the first parameter of `hook` as `TypeVar("T", bound="Module")` allows the binding of `Callable` whose first parameter is a subclass of `Module`.

---

Here are some examples of:
1. forward hooks and pre-hook hooks being accepted by mypy according to the new type hints
2. mypy throwing errors d.t. incorrect `hook` signatures
3. false negatives of pre-hooks being accepted as forward hooks
4. false negatives of hooks with kwargs being accepted irrespective of the value provided for `with_kwargs`

```python
from typing import Any, Dict, Tuple

import torch
from torch import nn

def forward_pre_hook(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
) -> None:
    ...

def forward_pre_hook_return_input(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]:
    ...

def forward_pre_hook_with_kwargs(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
) -> None:
    ...

def forward_pre_hook_with_kwargs_return_input(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]:
    ...

def forward_hook(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    output: torch.Tensor,
) -> None:
    ...

def forward_hook_return_output(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    output: torch.Tensor,
) -> torch.Tensor:
    ...

def forward_hook_with_kwargs(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
    output: torch.Tensor,
) -> None:
    ...

def forward_hook_with_kwargs_return_output(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
    output: torch.Tensor,
) -> torch.Tensor:
    ...

model = nn.Module()

# OK
model.register_forward_pre_hook(forward_pre_hook)
model.register_forward_pre_hook(forward_pre_hook_return_input)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=True)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=True)

model.register_forward_hook(forward_hook)
model.register_forward_hook(forward_hook_return_output)
model.register_forward_hook(forward_hook_with_kwargs, with_kwargs=True)
model.register_forward_hook(forward_hook_with_kwargs_return_output, with_kwargs=True)

# mypy(error): [arg-type]
model.register_forward_pre_hook(forward_hook)
model.register_forward_pre_hook(forward_hook_return_output)
model.register_forward_pre_hook(forward_hook_with_kwargs)
model.register_forward_pre_hook(forward_hook_with_kwargs_return_output)

model.register_forward_hook(forward_pre_hook)
model.register_forward_hook(forward_pre_hook_return_input)

# false negatives
model.register_forward_hook(forward_pre_hook_with_kwargs)
model.register_forward_hook(forward_pre_hook_with_kwargs_return_input)

model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=False)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=False)
...
```

---

Though it is not functional as of mypy 0.991, the ideal typing of these methods would use [`typing.Literal`](https://mypy.readthedocs.io/en/stable/literal_types.html#literal-types):

```python
T = TypeVar("T", bound="Module")

class Module:

    @overload
    def register_forward_hook(
        self,
        hook: Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
        *,
        prepend: bool = ...,
        with_kwargs: Literal[False] = ...,
    ) -> RemovableHandle:
        ...

    @overload
    def register_forward_hook(
        self,
        hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
        *,
        prepend: bool = ...,
        with_kwargs: Literal[True] = ...,
    ) -> RemovableHandle:
        ...

    def register_forward_hook(...):
        ...

```

which would:

1. validate the signature of `hook` according to the corresponding literal value provided for `with_kwargs` (and fix the false negative examples above)
2. implicitly define the [fallback `bool` signature](https://github.com/python/mypy/issues/6113#issuecomment-1266186192) e.g. to handle if a non-literal is provided for `with_kwargs`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92061
Approved by: https://github.com/albanD
2023-01-13 15:45:42 +00:00
joncrall
ad782ff7df Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
2022-12-29 05:32:42 +00:00
Chien-Chin Huang
d08e3d2304 [Composable API] Apply ufmt to _composable and the corresponding test folders (#91255)
This PR apply ufmt to format `_composable` related code. This is a request from https://github.com/pytorch/pytorch/pull/91234 to separate formatting changes as a new PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91255
Approved by: https://github.com/awgu
2022-12-23 16:08:27 +00:00
Charlie Yan
a1a2f548a9 [Composable API] Enable composable fully_shard submodules in replicate parent module (#90711)
To make sure `fully_shard` and `replicate` can work together, we need to check for each other in the implementation. This change adds the check in `replicate()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90711
Approved by: https://github.com/mrshenli
2022-12-17 09:28:38 +00:00
Chien-Chin Huang
d52f121dba [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)
**Why this PR?**

For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs  are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.

It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.

**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
2022-12-13 23:58:01 +00:00
Sergii Dymchenko
f51f6aa387 Fix non-existing parameters in docstrings (#90505)
Continuation after https://github.com/pytorch/pytorch/pull/90163.

Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators):

_Edit:_
I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script.

``` python
import ast
import os
import docstring_parser

for root, dirs, files in os.walk('.'):
    for name in files:
        if root.startswith("./.git/") or root.startswith("./third_party/"):
            continue
        if name.endswith(".py"):
            full_name = os.path.join(root, name)
            with open(full_name, "r") as source:
                tree = ast.parse(source.read())
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        all_node_args = node.args.args
                        if node.args.vararg is not None:
                            all_node_args.append(node.args.vararg)
                        if node.args.kwarg is not None:
                            all_node_args.append(node.args.kwarg)
                        if node.args.posonlyargs is not None:
                            all_node_args.extend(node.args.posonlyargs)
                        if node.args.kwonlyargs is not None:
                            all_node_args.extend(node.args.kwonlyargs)
                        args = [a.arg for a in all_node_args]
                        docstring = docstring_parser.parse(ast.get_docstring(node))
                        doc_args = [a.arg_name for a in docstring.params]
                        clean_doc_args = []
                        for a in doc_args:
                            clean_a = ""
                            for c in a.split()[0]:
                                if c.isalnum() or c == '_':
                                    clean_a += c
                            if clean_a:
                                clean_doc_args.append(clean_a)
                        doc_args = clean_doc_args
                        for a in doc_args:
                            if a not in args:
                                print(full_name, node.lineno, args, doc_args)
                            break

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505
Approved by: https://github.com/malfet, https://github.com/ZainRizvi
2022-12-09 21:43:09 +00:00
Charlie Yan
99fb39f508 reland #89243: [Composable API] replicate: add support for DDP args (#90255)
reland https://github.com/pytorch/pytorch/pull/89243
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90255
Approved by: https://github.com/zhaojuanmao
2022-12-07 15:22:33 +00:00
Charlie Yan
e818c36647 reland #89222: [Composable API] replicate: change to per module call, remove mark_root_module() (#90254)
reland https://github.com/pytorch/pytorch/pull/89222
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90254
Approved by: https://github.com/zhaojuanmao
2022-12-06 21:17:53 +00:00
PyTorch MergeBot
0d8e53dfe7 Revert "[Composable API] replicate: change to per module call, remove mark_root_module() (#89222)"
This reverts commit 65a0dcffd8.

Reverted https://github.com/pytorch/pytorch/pull/89222 on behalf of https://github.com/malfet due to Included unintended submodule updates
2022-12-06 03:26:28 +00:00
PyTorch MergeBot
3749b9dc73 Revert "[Composable API] replicate: add support for DDP args (#89243)"
This reverts commit 0f274ed385.

Reverted https://github.com/pytorch/pytorch/pull/89243 on behalf of https://github.com/malfet due to Depends on https://github.com/pytorch/pytorch/pull/89222 that introduced spurious module updates
2022-12-06 03:22:18 +00:00
Charlie Yan
0f274ed385 [Composable API] replicate: add support for DDP args (#89243)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89243
Approved by: https://github.com/zhaojuanmao
2022-12-05 21:38:23 +00:00
Charlie Yan
65a0dcffd8 [Composable API] replicate: change to per module call, remove mark_root_module() (#89222)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89222
Approved by: https://github.com/zhaojuanmao
2022-12-05 17:54:55 +00:00
Charlie Yan
f3af5ba48e [WIP] Composable API: replicate and DistributedState (#87649)
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is:
- create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API.
- create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object
- install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87649
Approved by: https://github.com/zhaojuanmao
2022-11-17 03:06:31 +00:00