Previously, when we slice out a submesh from a mesh, we assign the mesh as the parent mesh of the submesh. In this case, when we have a 3D mesh topology, the parent mesh of a 1D mesh sliced out from the 3D mesh is different from the parent mesh of the same 1D mesh sliced out from the 2D submesh of the 3D mesh. For example:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]
mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 = mesh_2d["dim0_2"]
# This would evaluate to be True
print(_mesh_resources.get_parent_mesh(mesh_dim0) != _mesh_resources.get_parent_mesh(mesh_dim0))
```
We can always reconstruct the mesh needed from the mesh dim names, as long as two dims come from the same root. For simplicity, we do not see the necessity of building a tree structure to represent child-parent relationship. Therefore, we are replacing the parent mesh concept with a root mesh concept in `_MeshEnv` so we would have:
```
mesh_3d = init_device_mesh("cuda", (2,2,2), ("dim0", "dim1", "dim2"))
mesh_dim0 = mesh_3d["dim0"]
mesh_2d = mesh_2d["dim0", "dim1"]
mesh_dim0_2 = mesh_2d["dim0_2"]
# This would evaluate to be True
print(_mesh_resources.get_root_mesh(mesh_dim0) == _mesh_resources.get_root_mesh(mesh_dim0))
```
With this change, we will have two types of meshes in an environment.
1. `device_mesh != _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is created by slicing.
2. `device_mesh == _mesh_resources.get_root_mesh(device_mesh)` means that the device_mesh is a root mesh not created through slicing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132339
Approved by: https://github.com/wanchaol
ghstack dependencies: #132310, #132311
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127122
Approved by: https://github.com/kit1980
This PR enables DDP + TP using a TP internal API. This should not be the final implementation. A more sound implementation is to inline the TP internal API in DDP. In other words, DDP needs to be aware of DTensor so that we can support 2D state_dict.
This PR adds a compiled DDP + TP test to ensure the new compiled DDP fusion doesn't break TP all_reduce.
**TODOs**
- [x] Implement DDP allreduce fusion algorithm for Inductor post_grad pass.
- [x] Add unit tests to ensure the fusion doesn't DDP + TP.
- [ ] Group different PG and data type of all_reduces.
- [ ] Mixed precision supports and tests
- [ ] Implement the fusions with Inductor IR.
- [ ] Add auto bucketing based on Inductor profiling.
Differential Revision: [D54105050](https://our.internmc.facebook.com/intern/diff/D54105050/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120479
Approved by: https://github.com/wz337
ghstack dependencies: #113209
I prefer to not modify the module if it does not have any of our APIs applied. The side effect of inserting a registry on the module when calling a getter is non-intuitive to me.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113654
Approved by: https://github.com/fegin
This attribute wasn't actually used in tests, add a test ensuring that
if replicate is used on top of FSDP, the replicated parameter names are as
expected.
TODO: there are a few ways to check if module is managed by composable API,
such as replicated param names for replicate, _get_module_state API,
_get_registry_api, etc. We should unify all composable APIs to check in a
unified way (filed an issue)
Differential Revision: [D46236377](https://our.internmc.facebook.com/intern/diff/D46236377/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102401
Approved by: https://github.com/awgu
replicate + trec_shard works if we shard / replicate individually, such as follows:
```
m = TestSparseNN()
shard(m.sparse)
replicate(m.dense)
```
but does not work if users do the following:
```
m = TestSparseNN()
shard(m, sharders=[...])
replicate(m)
```
Many upstream trainers use the latter use case, as sharding is not done on individual module level but rather overall module by specifying planners that contain logic for how to shard different embedding table types.
This diff enables the latter approach (while keeping the former intact), but users need to specify `ignored_modules` to ignore embedding tables in replicate(). This is similar to FSDP (class based and composable) and DDP today.
Differential Revision: [D44899155](https://our.internmc.facebook.com/intern/diff/D44899155/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98890
Approved by: https://github.com/mrshenli, https://github.com/yhcharles
This PR adds FSDP and composable API files to `.lintrunner.toml` so that (1) lintrunner enforces that those files are formatted and (2) `lintrunner f` formats those files for you.
There are two requirements here (see https://github.com/pytorch/pytorch/wiki/lintrunner for details):
1. Install lintrunner:
```
pip install lintrunner
lintrunner init
```
2. `lintrunner f` before you finalize your PR, which would now be enforced by CI after this PR.
The code changes in this PR outside of `.lintrunner.toml` are the result of `lintrunner f`.
---
I only plan to land this PR if all of the composable API developers agree that this is something that makes sense and is not too intrusive to the workflow.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90873
Approved by: https://github.com/yhcharles, https://github.com/mrshenli, https://github.com/rohan-varma
Fixes#91654.
Currently, the `hook` parameters of `nn.Module.register_forward_pre_hook` and `nn.Module.register_forward_hook` are typed as `Callable[..., None]`, which 1) does not enable the validation of the signature of `hook` and 2) incorrectly restricts the return type of `hook`, which the docstrings of these methods themselves state can be non-`None`.
The typing of the first parameter of `hook` as `TypeVar("T", bound="Module")` allows the binding of `Callable` whose first parameter is a subclass of `Module`.
---
Here are some examples of:
1. forward hooks and pre-hook hooks being accepted by mypy according to the new type hints
2. mypy throwing errors d.t. incorrect `hook` signatures
3. false negatives of pre-hooks being accepted as forward hooks
4. false negatives of hooks with kwargs being accepted irrespective of the value provided for `with_kwargs`
```python
from typing import Any, Dict, Tuple
import torch
from torch import nn
def forward_pre_hook(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
) -> None:
...
def forward_pre_hook_return_input(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]:
...
def forward_pre_hook_with_kwargs(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
kwargs: Dict[str, Any],
) -> None:
...
def forward_pre_hook_with_kwargs_return_input(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
kwargs: Dict[str, Any],
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]:
...
def forward_hook(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
output: torch.Tensor,
) -> None:
...
def forward_hook_return_output(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
output: torch.Tensor,
) -> torch.Tensor:
...
def forward_hook_with_kwargs(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
kwargs: Dict[str, Any],
output: torch.Tensor,
) -> None:
...
def forward_hook_with_kwargs_return_output(
module: nn.Linear,
args: Tuple[torch.Tensor, ...],
kwargs: Dict[str, Any],
output: torch.Tensor,
) -> torch.Tensor:
...
model = nn.Module()
# OK
model.register_forward_pre_hook(forward_pre_hook)
model.register_forward_pre_hook(forward_pre_hook_return_input)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=True)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=True)
model.register_forward_hook(forward_hook)
model.register_forward_hook(forward_hook_return_output)
model.register_forward_hook(forward_hook_with_kwargs, with_kwargs=True)
model.register_forward_hook(forward_hook_with_kwargs_return_output, with_kwargs=True)
# mypy(error): [arg-type]
model.register_forward_pre_hook(forward_hook)
model.register_forward_pre_hook(forward_hook_return_output)
model.register_forward_pre_hook(forward_hook_with_kwargs)
model.register_forward_pre_hook(forward_hook_with_kwargs_return_output)
model.register_forward_hook(forward_pre_hook)
model.register_forward_hook(forward_pre_hook_return_input)
# false negatives
model.register_forward_hook(forward_pre_hook_with_kwargs)
model.register_forward_hook(forward_pre_hook_with_kwargs_return_input)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=False)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=False)
...
```
---
Though it is not functional as of mypy 0.991, the ideal typing of these methods would use [`typing.Literal`](https://mypy.readthedocs.io/en/stable/literal_types.html#literal-types):
```python
T = TypeVar("T", bound="Module")
class Module:
@overload
def register_forward_hook(
self,
hook: Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
*,
prepend: bool = ...,
with_kwargs: Literal[False] = ...,
) -> RemovableHandle:
...
@overload
def register_forward_hook(
self,
hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
*,
prepend: bool = ...,
with_kwargs: Literal[True] = ...,
) -> RemovableHandle:
...
def register_forward_hook(...):
...
```
which would:
1. validate the signature of `hook` according to the corresponding literal value provided for `with_kwargs` (and fix the false negative examples above)
2. implicitly define the [fallback `bool` signature](https://github.com/python/mypy/issues/6113#issuecomment-1266186192) e.g. to handle if a non-literal is provided for `with_kwargs`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92061
Approved by: https://github.com/albanD
**Why this PR?**
For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.
It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.
**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
Continuation after https://github.com/pytorch/pytorch/pull/90163.
Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators):
_Edit:_
I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script.
``` python
import ast
import os
import docstring_parser
for root, dirs, files in os.walk('.'):
for name in files:
if root.startswith("./.git/") or root.startswith("./third_party/"):
continue
if name.endswith(".py"):
full_name = os.path.join(root, name)
with open(full_name, "r") as source:
tree = ast.parse(source.read())
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
all_node_args = node.args.args
if node.args.vararg is not None:
all_node_args.append(node.args.vararg)
if node.args.kwarg is not None:
all_node_args.append(node.args.kwarg)
if node.args.posonlyargs is not None:
all_node_args.extend(node.args.posonlyargs)
if node.args.kwonlyargs is not None:
all_node_args.extend(node.args.kwonlyargs)
args = [a.arg for a in all_node_args]
docstring = docstring_parser.parse(ast.get_docstring(node))
doc_args = [a.arg_name for a in docstring.params]
clean_doc_args = []
for a in doc_args:
clean_a = ""
for c in a.split()[0]:
if c.isalnum() or c == '_':
clean_a += c
if clean_a:
clean_doc_args.append(clean_a)
doc_args = clean_doc_args
for a in doc_args:
if a not in args:
print(full_name, node.lineno, args, doc_args)
break
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505
Approved by: https://github.com/malfet, https://github.com/ZainRizvi
This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is:
- create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API.
- create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object
- install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87649
Approved by: https://github.com/zhaojuanmao