Commit Graph

232 Commits

Author SHA1 Message Date
Chien-Chin Huang
1d2382f141 [DDP] Use compiled_autograd to trace DDP backward allreduce (#110662)
**Summary**
The reducer of `DistributedDataParallel`  is implemented with C++ and it is not easy to trace the allreduce launched in the reducer. This PR modifies `DistributedDataParallel` to launch one allreduce per gradient when `compiled_autograd` is enabled. The changes allow us to use `compiled_autograd` to trace the allreduce and later be optimized (fused) in the Inductor.

**Key Logic**
1. If `ddp_python_hook` is True, we assume `compiled_autograd` is used. `DistributedDataParallel` registers `compiled_accum_grad_hook` for all parameters.
2. In the first forward() call, if `DistributedDataParallel` is not compiled, all  `compiled_accum_grad_hook` are deregistered. If `DistributedDataParallel` is compiled, all `compiled_accum_grad_hook` will be compiled by `compiled_autograd`.
3.  `compiled_accum_grad_hook` launches an allreduce to reduce the gradient of the parameter.

**Bucketing**
The compiled backward is slow because there is no bucketing for the allreduces. We rely on Inductor to bucket the allreduces.

The bucketing is done in a separate PR.

Differential Revision: [D49428482](https://our.internmc.facebook.com/intern/diff/D49428482/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110662
Approved by: https://github.com/wconstab
2024-02-08 03:03:15 +00:00
Edward Z. Yang
46712b019d Enable local_partial_types (#118467)
When using dmypy, this setting is enabled and cannot be turned off. Force it for regular mypy too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118467
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432
2024-01-28 13:38:22 +00:00
mingxzhao
1859895ffa Docs: fix docstring errors in model_averaging (#117038)
pydocstyle check

averagers.py

Pre
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:1 at module level:
        D100: Missing docstring in public module
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:20 in public method `__init__`:
        D107: Missing docstring in __init__
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:27 in public method `average_parameters`:
        D102: Missing docstring in public method
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:84 in public method `__init__`:
        D107: Missing docstring in __init__
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:106 in public method `average_parameters`:
        D205: 1 blank line required between summary line and description (found 0)
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:106 in public method `average_parameters`:
        D400: First line should end with a period (not '`')
6

Post
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:1 at module level:
        D100: Missing docstring in public module
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:20 in public method `__init__`:
        D107: Missing docstring in __init__
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:27 in public method `average_parameters`:
        D102: Missing docstring in public method
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:84 in public method `__init__`:
        D107: Missing docstring in __init__
4

utils.py

Pre
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/utils.py:1 at module level:
        D100: Missing docstring in public module
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/utils.py:17 in public function `average_parameters`:
        D205: 1 blank line required between summary line and description (found 0)
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/utils.py:45 in public function `get_params_to_average`:
        D205: 1 blank line required between summary line and description (found 0)
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/utils.py:45 in public function `get_params_to_average`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/utils.py:68 in public function `average_parameters_or_parameter_groups`:
        D200: One-line docstring should fit on one line with quotes (found 3)
5

Post
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/utils.py:1 at module level:
        D100: Missing docstring in public module
1

hierarchical_model_averager.py

Pre
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:1 at module level:
        D100: Missing docstring in public module
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:16 in public class `HierarchicalModelAverager`:
        D205: 1 blank line required between summary line and description (found 0)
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:98 in public method `__init__`:
        D107: Missing docstring in __init__
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:137 in private method `_find_process_group`:
        D205: 1 blank line required between summary line and description (found 0)
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:137 in private method `_find_process_group`:
        D400: First line should end with a period (not ',')
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:137 in private method `_find_process_group`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:151 in public method `average_parameters`:
        D205: 1 blank line required between summary line and description (found 0)
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:151 in public method `average_parameters`:
        D400: First line should end with a period (not '`')
8

Post /workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:1 at module level:
        D100: Missing docstring in public module
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:99 in public method `__init__`:
        D107: Missing docstring in __init__
2

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117038
Approved by: https://github.com/H-Huang
2024-01-18 04:12:51 +00:00
mingxzhao
b4f1ab4505 Docs: fix docstring errors in ddp_comm_hooks (#116866)
Reopens #115272
Fixes ddp_comm_hooks errors in https://github.com/pytorch/pytorch/issues/112644

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116866
Approved by: https://github.com/awgu
2024-01-10 01:24:06 +00:00
Ashvanth.S
e6bffc6b87 Fix docstring errors in default_hooks.py, optimizer_overlap.py, checkpoint_wrapper.py, copy.py, benchmark_ddp_rpc.py, utils.py, dependency.py, phony.py, pipeline.py, checkpoint.py, worker.py, batchnorm.py, quantization.py (#113511)
Fixes #112645

Updated the files by fixing the docstring errors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113511
Approved by: https://github.com/weifengpy
2023-11-14 18:52:41 +00:00
ChanBong
c0b57d4e3b fix docstring issues in torch.distributed (#113337)
Fixes #112643

Fixes all the issues listed

### Error Count

|File | Count Before | Count now|
|---- | ---- | ---- |
|`torch/distributed/optim/named_optimizer.py` | 13 | 1|
|`torch/distributed/nn/functional.py` | 7 | 1|
|`torch/distributed/nn/api/remote_module.py` | 25 | 3|
|`torch/distributed/algorithms/join.py` | 43 | 4|

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113337
Approved by: https://github.com/ezyang
2023-11-13 19:37:29 +00:00
Fabrice Pont
053367b1ed fix: flake8-bugbear code B024 (#107265)
See #106571 item B024

This fix concerns the addition of `abstractmethod` to methods declared inside abstract classes.

Should I also include PEP8 compliant reformatting on the files I had to modify ?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107265
Approved by: https://github.com/kit1980
2023-10-04 23:52:52 +00:00
redwrasse
ba4782e3c0 cleanup typos; redundant parentheses (#109003)
- minor spelling fixes in `aten/src/ATen/core/TransformationHelper.h`
- remove redundant parentheses in control statements in `torch/distributed/algorithms/_quantization/quantization.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109003
Approved by: https://github.com/davidradl, https://github.com/H-Huang
2023-09-11 17:09:17 +00:00
Rohan Varma
208fd1cb84 [RFC] Somewhat BC breaking: make checkpoint_wrapper default to NO_REENTRANT (#108435)
We should use no_reentrant. There are a lot of users of this API, but
it is in a prototype state so should be fine to change.

Differential Revision: [D48898148](https://our.internmc.facebook.com/intern/diff/D48898148/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108435
Approved by: https://github.com/awgu
ghstack dependencies: #108032, #108033
2023-09-05 21:43:41 +00:00
Andrew Gu
15953fdf35 [FSDP][8/N] Replace _FSDPPolicy.policy with _Policy._run_policy (#104969)
This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.

This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999
2023-08-03 12:42:14 +00:00
Rohan Varma
2ec7cd2db2 [CheckpointWrapper] Test for kwarg propagation, remove checkpoint_fn_arg support (#102679)
Closes https://github.com/pytorch/pytorch/issues/100576

Differential Revision: [D46342398](https://our.internmc.facebook.com/intern/diff/D46342398/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102679
Approved by: https://github.com/awgu
2023-07-28 21:18:35 +00:00
Andrew Gu
800287fb56 [FSDP] Optimize away intermediate div_ for HSDP (#106034)
### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.

$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$

Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.

### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$

This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.

### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.

Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma
2023-07-28 18:36:26 +00:00
Justin Chu
232b96b6e2 [BE] Enable ruff's UP rules and autoformat distributed/ (#105433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105433
Approved by: https://github.com/albanD
2023-07-19 14:27:11 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
b4d91b1c5b Revert "[Typing] Fix PEP 484 Violation (#105022)"
This reverts commit 4148b7bada.

Reverted https://github.com/pytorch/pytorch/pull/105022 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/105022#issuecomment-1635967734))
2023-07-14 14:45:09 +00:00
Nikita Shulga
4148b7bada [Typing] Fix PEP 484 Violation (#105022)
Not sure, how it worked before, but if arguments must be annotated is optional if they are defaulted to None

Towards enabling mypy-1.4.1 in lintrunner

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 5e1b9f4</samp>

> _We annotate the arguments of doom_
> _To show the `None` values of gloom_
> _We improve the type checking and readability_
> _With `Optional` annotations of metal-ity_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105022
Approved by: https://github.com/izaitsevfb, https://github.com/huydhn, https://github.com/Skylion007
2023-07-12 10:20:48 +00:00
Rohan Varma
a748be93df [CheckpointWrapper] Warn on reentrant use (#102890)
We'd like to encourage users to try non-reentrant as much as possible,
and identify any gaps this way.

Differential Revision: [D46397786](https://our.internmc.facebook.com/intern/diff/D46397786/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102890
Approved by: https://github.com/awgu
2023-06-04 18:31:22 +00:00
Rohan Varma
957ea485c4 [FSDP/AC] checkpoint_wrapper acccept auto_wrap_policy (#102672)
Some feedback for this API is that folks would like to use
auto_wrap_policy similar to FSDP instead of having to adapt to the signature of
``check_fn``.

Differential Revision: [D46340320](https://our.internmc.facebook.com/intern/diff/D46340320/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102672
Approved by: https://github.com/awgu
2023-06-04 18:31:19 +00:00
Dirk Groeneveld
75945d54f7 Properly propagates checkpoint wrapper args and kwargs (#99791)
It looks like passing `*args` and `**kwargs` to `checkpoint_wrapper()` does not work because someone forgot some `*`s. This adds them back in.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99791
Approved by: https://github.com/awgu
2023-05-03 23:19:21 +00:00
Rohan Varma
bba2090831 Enable fused optimizer for DP (#98270)
Differential Revision: [D42714482](https://our.internmc.facebook.com/intern/diff/D42714482/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D42714482/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98270
Approved by: https://github.com/awgu
2023-04-13 20:16:32 +00:00
Edward Z. Yang
5a7aad9681 Convert logging f-strings to use % format, part four (#98705)
This does multi-line concatenated string literals.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98705
Approved by: https://github.com/voznesenskym
2023-04-11 13:17:59 +00:00
Kazuaki Ishizaki
6514d71add Fix typos under torch/distributed directory (#98225)
This PR fixes typos in comments and messages of `.py` files under `torch/distributed` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98225
Approved by: https://github.com/soulitzer, https://github.com/kit1980
2023-04-05 00:21:33 +00:00
Edward Z. Yang
5df59f957f Fix G001,G002,G003 in logs to % syntax (#97812)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97812
Approved by: https://github.com/Skylion007, https://github.com/kiukchung, https://github.com/malfet, https://github.com/mlazos
2023-04-01 01:43:33 +00:00
Kazuaki Ishizaki
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
Rohan Varma
32f11f58c9 DDP native mixed precision (#92882)
Implements native mixed precision support for DDP in a similar fashion to how it is enabled for FSDP. The implementation works as follows:

1. In DDP init, we save `_mp_param` and `_fp_param` variables to manage mixed precision parameter usage. In particular, _mp_param will represent the parameter in the reduced precision, while _fp_param will represent the param in regular precision. During forward/backward, we swap back and forth as needed.
2. The root module gets a root pre-forward hook that kicks off copies to the reduced precision for all submodules. An event is recorded for each submodule to allow for waiting, as we run these asynchronously.
3. Each module gets a pre-forward hook that waits on its corresponding event. note that modules might be reused during training, in this case the wait is only done for the first module. After this wait, the module's parameters are in reduced precision.
4. In the pre-forward hook, we register a backward hook on the lower precision parameters in order to run reduced precision allreduce + parameter upcast. We can't rely on the Reducer's constructor setting up these hooks because the gradient is accumulated on the low precision param, so we need to register them ourselves.
5. In the backward pass, when the hook runs, we first run allreduce + divide in the reduced precision. Next, we upcast parameters and gradients back to fp32 asynchronously. We also queue a callback at the end of backward to wait on these upcasts so that the upcast is complete before optim.step() runs.
6. Parameters that don't require grad are also cast since they may be used in computation, they are upcast back in the final autograd callback.
7. DDP Ignored parameters are not touched.

Follow-ups:

1. Unify comm hooks and make it work with apply optimizer in backward
2. implement keep_low_precision_grads,
3. allow BN, LN, or custom units to run in reduced precision,
4. support for cast_forward_inputs
5. Unify certain APIs / helpers with FSDP where possible, such as for _cast_forward_inputs
6. Integrate this with replicate() API.
7. The order in which we kick off copies and wait for them is set by the iteration order of module.modules(), but this might not be how the modules are used in the actual training. In the worst case, the last module in module.modules() could be used first which would result in waiting for all copies unnecessarily. For static graphs, we should record the module execution order and copy / wait in this order.
8. Entirely unused modules probably don't need to be cast.

Differential Revision: [D42515803](https://our.internmc.facebook.com/intern/diff/D42515803/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92882
Approved by: https://github.com/zhaojuanmao
2023-03-13 14:10:31 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
Vasiliy Kuznetsov
f15ab8a7f2 AO migration: replace torch internal callsites (#94170)
Summary:

Do the following renames:
`torch.quantization` -> `torch.ao.quantization`
`torch.nn.quantized` -> `torch.ao.nn.quantized`
`torch.nn.quantizable` -> `torch.ao.nn.quantizable`
`torch.nn.qat` -> `torch.ao.nn.qat`
`torch.nn.intrinsic` -> `torch.ao.nn.intrinsic`

And then, do
`torch.ao.nn.quantized._reference` -> `torch.ao.nn.quantized.reference` to clean up the aftermath of https://github.com/pytorch/pytorch/pull/84974

Then, manually update `test/test_module_init.py` to fix hanging whitespace due to the replace.

Run this script to do the replacements: https://gist.github.com/vkuzo/7f7afebf8c31b9ba48306223e68a1c82

This is for https://github.com/pytorch/pytorch/issues/81667

Test plan: CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94170
Approved by: https://github.com/jerryzh168
2023-02-07 02:32:23 +00:00
Rohan Varma
e8bf7c21e4 Integrate apply_optim_in_backward with DDP (#89194)
Allow _apply_optim_in_backward to work with DDP.

Example:

```
dist.init_process_group("nccl", rank=rank, world_size=2)
    torch.cuda.set_device(rank)
    e = enc().cuda(rank)
    _apply_optimizer_in_backward(
        optimizer_class=torch.optim.SGD,
        params=e.parameters(),
        optimizer_kwargs={"lr": 0.03},
    )
    e = DDP(e, device_ids=[rank])
    inp = torch.randn(1, 10, device=rank)
    e(inp).sum().backward()
```

Constraints:

1. Custom communication hook not yet supported
2. _apply_optim_in_backward needs to be called _before_ wrapping model in DDP.
3. DDP will remove the gradient hooks _apply_optim_in_backward registers, so these gradient hooks will not be fired and cannot be used.
4. All DDP managed parameters have grads set to None by default once optimizer is applied. There is no support for setting only some parameter grads to None, this must be done manually by user (and DDP_OVERLAPPED_OPTIM_SET_GRADS_TO_NONE=0 needs to be set.)

Differential Revision: [D41329694](https://our.internmc.facebook.com/intern/diff/D41329694/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41329694/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89194
Approved by: https://github.com/zhaojuanmao
2022-12-21 07:35:19 +00:00
Andrew Gu
fc429512d5 [FSDP] Clean up FlatParamHandle dtypes, post-backward hook (#90660)
This PR reworks the internal handling of parameter and gradient reduction mixed precision, cleans up the post-backward hook logic, and adds some minor changes to the communication hooks.

**Overview**
This PR addresses everything in https://github.com/pytorch/pytorch/issues/90657 except renaming `keep_low_precision_grads` to `keep_grads_in_reduce_dtype` since that is BC breaking. I recommend reading the issue before preceding.

For `MixedPrecision(param_dtype, reduce_dtype, ...)`, the exact rule for parameter and gradient reduction mixed precision that we are following is:
> If `param_dtype is not None` and `reduce_dtype is None`, then we infer `reduce_dtype = param_dtype`. Otherwise, we take `param_dtype` and `reduce_dtype` as is.

This PR enforces that, at the `FlatParamHandle` level, `handle._config.fwd_bwd_param_dtype` and `handle._config.reduce_dtype` are never `None`. The way to check if mixed precision is enabled is to compare against the original parameter dtype, which is now stored in `handle._orig_param_dtype`. It is no longer to check against `None`.

This avoids ambiguous cases such as when the user passes `MixedPrecision(param_dtype=torch.float32)`. In that case, our existing implementation mistakenly thinks that parameter mixed precision is enabled and either relies on no-ops silently or errors (such as one case reported by MosaicML).

**Additional Details**
- We remove `FullyShardedDataParallel._mixed_precision_enabled_for_params`, `FullyShardedDataParallel._mixed_precision_enabled_for_reduce`, and `FullyShardedDataParallel._mixed_precision_keep_low_precision_grads` since they are not used.
- The unit test `test_meta_device_with_mixed_precision()` exercises a tricky edge case with meta device initialization, `apply()` (calling into `summon_full_params()`), and `param_dtype=torch.float32` for a nested wrapping case, where each nested instance has parameters.
- We include some minor fixes/improvements to the communication hook implementation.

**Follow-Ups**
- We should get rid of `HandleConfig` and store its fields as attributes on `FlatParamHandle` directly.
- Rename `keep_low_precision_grads` to `keep_grads_in_reduce_dtype`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90660
Approved by: https://github.com/zhaojuanmao
2022-12-13 07:34:59 +00:00
Rohan Varma
a5532929da Remove DDP import (#89982)
This import is only used for typing, removing it to avoid circular ref
in next diffs

Differential Revision: [D41636897](https://our.internmc.facebook.com/intern/diff/D41636897/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89982
Approved by: https://github.com/zhaojuanmao
2022-12-01 14:56:48 +00:00
Kazuaki Ishizaki
2ddefbdc3c Fix typos used in documents under torch directory (#88300)
This PR fixes typos, in comments of Python files, that are found from a search box at https://pytorch.org/docs/master/search.html

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88300
Approved by: https://github.com/lezcano
2022-11-02 09:38:13 +00:00
Andrew Gu
a2ffc3be97 [AC] Add trailing "." to _CHECKPOINT_PREFIX like FSDP (#87951)
This is for consistency with FSDP.
- `_FSDP_WRAPPED_MODULE` and `_CHECKPOINT_WRAPPED_MODULE` are exactly the wrapped module variable name, meaning you can call `getattr(module, _FSDP_WRAPPED_MODULE)` or `getattr(module, _CHECKPOINT_WRAPPED_MODULE)`.
- `_FSDP_PREFIX` and `_CHECKPOINT_PREFIX` include the trailing `"."` and are only used for FQNs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87951
Approved by: https://github.com/zhaojuanmao
2022-10-28 22:05:29 +00:00
Andrew Gu
1c37119a1f [FSDP] New fix for composing with other module wrappers (#87950)
We change `.module` to pass through `ActivationWrapper` directly to the inner wrapped module. This should fix the state dict issues.

Given the invariant that `.module` always returns the inner wrapped module, FSDP always registers the `FlatParameter` on the inner wrapped module, regardless of if there is an intermediate `ActivationWrapper` or not. This avoids casing on whether `ActivationWrapper` is added before or after FSDP construction.

This PR removes the added unit test in `test_fsdp_misc.py` for changing the wrapped module because I would rather not complicated `_lazy_init()` logic just to support that kind of adversarial behavior. The user should not be swapping out the wrapped module arbitrarily or deleting the `FlatParameter`. I mainly had those tests to make sure that all branches of the code I added was correct.

Differential Revision: [D40799961](https://our.internmc.facebook.com/intern/diff/D40799961)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87950
Approved by: https://github.com/zhaojuanmao
2022-10-28 21:11:40 +00:00
Andrew Gu
f967918411 [AC] Return None from apply_activation_checkpointing() (#87871)
`_recursive_wrap()` returns `Tuple[nn.Module, int]`, where the `nn.Module` is the in-place modified module and the `int` is the numel wrapped. In that sense, the return value is not meant to be publicly used. The `apply_activation_checkpointing()` docs already suggest that the function returns `None`, so this PR simply follows that.

**Test Plan**
CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87871
Approved by: https://github.com/zhaojuanmao
2022-10-28 02:00:39 +00:00
Andrew Gu
04ad0134ae [FSDP] Use reduce_scatter_tensor() (#87240)
Let us silence some more warnings 👍🏼
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87240
Approved by: https://github.com/rohan-varma
2022-10-24 11:29:23 +00:00
Rohan Varma
bdefa260b2 [RFC] Separate CPU offload activation to its own wrapper (#85459)
Passing in `offload_to_cpu=True` to checkpoint_wrapper is a bit confusing, because this causes the activation checkpoint args to be ignored and we do CPU offloading. This isn't ideal from API design perspective, so proposing to make `offload_wrapper` its own concept.

Now, offload to CPU + checkpoint can be composed together, such as

```
# apply AC to transformer layers
apply_ac_wrapper(model, checkpoint_wrapper, check_fn=lambda mod: isinstance(mod, TransformerLayer))
# offload the rest of activations to CPU
model = offload_wrapper(model)
```

Will polish / add tests if this proposal sounds good.

Differential Revision: [D39719854](https://our.internmc.facebook.com/intern/diff/D39719854/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85459
Approved by: https://github.com/awgu
2022-10-15 05:19:28 +00:00
Rohan Varma
2b5625a726 Update hierarchical_model_averager.py (#85648)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85648
Approved by: https://github.com/wayi1, https://github.com/H-Huang
2022-10-03 06:15:20 +00:00
Rohan Varma
a8074a1a0b [Checkpoint] rename apply_ac_wrapper (#85449)
Per title

Differential Revision: [D39714855](https://our.internmc.facebook.com/intern/diff/D39714855/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85449
Approved by: https://github.com/awgu
2022-09-23 21:15:08 +00:00
Rohan Varma
cc64f64670 [Docs] Minor fix to apply_ac doc (#85448)
Per title

Created from CodeHub with https://fburl.com/edit-in-codehub

Differential Revision: [D39714530](https://our.internmc.facebook.com/intern/diff/D39714530/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85448
Approved by: https://github.com/awgu
2022-09-23 21:15:08 +00:00
anjali411
85073b8ddc Add __all__ to fx, fistributed and cuda submodules (#85080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85080
Approved by: https://github.com/albanD
2022-09-21 18:04:58 +00:00
Rohan Varma
8cb7826889 [CheckpointWrapper] Reentrant kwarg support (#84908)
A temporary patch to support keyword args when reentrant checkpoint wrapper is used. This is need to unblock some crucial workloads, the ideal fix would be checking this directly into torch.utils.checkpoint.

Differential Revision: [D39453453](https://our.internmc.facebook.com/intern/diff/D39453453/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84908
Approved by: https://github.com/awgu
2022-09-15 00:30:23 +00:00
Rohan Varma
55ca6901a7 [CheckpointWrapper] Decouple CPU offload (#84907)
This fixes the activation offload for checkpoint wrapper, which was previously broken. It was broken because it was tightly coupled with activation checkpoint, i.e. we did:

```
with save_on_cpu:
    checkpoint(module_forward())
```

which would not offload any activation tensors to CPU, as those activations would already be not saved by autograd due to the checkpoint implementation taking priority.

Now, if `offload_to_cpu` is specified, we only do `save_on_cpu` and no checkpoint, so all intermediate tensors are offloaded to CPU instead of checkpointed.

These wrappers can be composed, i.e. if we have

`(Linear, Linear) -> (Linear, Linear) -> (Linear, Linear)`

we can do

`Offload( checkpoint(Linear, Linear) -> checkpoint(Linear, Linear) -> checkpoint(Linear, Linear))`

and inner tensors would be checkpointed while outers will be offloaded.

Differential Revision: [D39448882](https://our.internmc.facebook.com/intern/diff/D39448882/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84907
Approved by: https://github.com/awgu
2022-09-15 00:30:23 +00:00
Rodrigo Kumpera
38192f63cd Add __all__ for a few distributed modules plus a little typing (reland) (#84872)
This handles distributed_c10d, which is massive and ddp_comm_hooks.

This relands #84119 with the required fixes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84872
Approved by: https://github.com/rohan-varma
2022-09-13 21:57:49 +00:00
PyTorch MergeBot
219ff26172 Revert "Add __all__ for a few distributed modules plus a little typing (#84119)"
This reverts commit 6f21680563.

Reverted https://github.com/pytorch/pytorch/pull/84119 on behalf of https://github.com/izaitsevfb due to breaking internal builds, see D39386448
2022-09-09 20:01:07 +00:00
Rodrigo Kumpera
6f21680563 Add __all__ for a few distributed modules plus a little typing (#84119)
This handles distributed_c10d, which is massive and ddp_comm_hooks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84119
Approved by: https://github.com/rohan-varma
2022-09-08 23:28:31 +00:00
Rodrigo Kumpera
65dc5dd3f3 [c10d] Introduce dist.get_local_rank, dist.get_global_rank and dist.get_global_ranks (#82134)
Those functions enable membership introspection into a ProcessGroup. A common scenario
that needs this is library code that consumes a PG but doesn't create it, which means
it likely doesn't know the global ranks used to create it.

Translating from local to global is necessary when using c10d collectives like broadcast
so if your library code adopts the convention of using local rank 0, it needs
to the following:

```python
import torch.distributed as dist

my_pg: dist.ProcessGroup = ...

def my_library_bcast(tensor)
    dist.broadcast(tensor, src=dist.get_global_rank(my_pg, local_rank=0), my_pg)

```

This implements some of the helpers needed to implement the `clone` API from: https://github.com/pytorch/pytorch/issues/81291
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82134
Approved by: https://github.com/rohan-varma
2022-08-30 17:45:00 +00:00
Rohan Varma
1a53e35b9d Enforce explicit ProcessGroup passed into DefaultState (#84105)
Would prefer to enforce that users pass in explicit PG into these state objects when using comm hooks with FSDP, so that it is clear and easy debugable over which processes communication is taking place.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84105
Approved by: https://github.com/mrshenli, https://github.com/zhaojuanmao
2022-08-29 14:52:58 +00:00