Commit Graph

265 Commits

Author SHA1 Message Date
Aaron Orenstein
6cb186e279 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-20 00:19:01 +00:00
Yu, Guangye
176cde6240 Use torch with statement in torch distributed module (#144951)
# Motivation
In https://github.com/pytorch/pytorch/pull/137678, we help use the device-agnostic APIs to generalize distributed module. As this [comment](https://github.com/pytorch/pytorch/pull/137678#discussion_r1828645683) said, we will use the with statement of `torch.Stream` once https://github.com/pytorch/pytorch/pull/140138 is landed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144951
Approved by: https://github.com/kwen2501, https://github.com/albanD
2025-01-17 01:49:28 +00:00
bobrenjc93
08be9ec312 Migrate from Tuple -> tuple in torch/distributed (#144258)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144258
Approved by: https://github.com/aorenste
2025-01-10 08:34:54 +00:00
Xuehai Pan
e1196dfe51 Deprecate torch._utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-12-08 22:55:36 +00:00
Aaron Gokaslan
08db735629 [BE]: Update mypy to 1.13.0 (#140808)
Update mypy to 1.13.0 . Should hopefully reduce linting time. Has support for orjson cache serialization which should improve mypy cache perf if orjson is installed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140808
Approved by: https://github.com/ezyang, https://github.com/malfet
2024-12-03 02:50:10 +00:00
PyTorch MergeBot
daa77f3d9f Revert "[BE]: Update mypy to 1.13.0 (#140808)"
This reverts commit 00134d68af.

Reverted https://github.com/pytorch/pytorch/pull/140808 on behalf of https://github.com/huydhn due to This is failing a distributed test in trunk, target determination missed this test and did not run it on PR ([comment](https://github.com/pytorch/pytorch/pull/140808#issuecomment-2512788426))
2024-12-02 20:47:43 +00:00
Aaron Gokaslan
00134d68af [BE]: Update mypy to 1.13.0 (#140808)
Update mypy to 1.13.0 . Should hopefully reduce linting time. Has support for orjson cache serialization which should improve mypy cache perf if orjson is installed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140808
Approved by: https://github.com/ezyang, https://github.com/malfet
2024-12-02 18:47:54 +00:00
Will Constable
c82c46ccc7 [C10D] support group_src/dst in broadcast/reduce ops (#140843)
Also add mypy annotations

Partially addresses RFC 0042 (pytorch/rfcs#71)
See more details/motivation in #140460
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140843
Approved by: https://github.com/kwen2501
2024-11-19 01:23:08 +00:00
lzhang2
1886e33f60 Use device-agnostic runtime API in distributed DDP/FSDP instead of cuda device specific. (#137678)
# Motivation
This PR targets to use device-agnostic runtime API in distributed DDP/FSDP instead of `cuda` device specific.

cc cc [@jgong5](https://github.com/jgong5) [@gujinghui](https://github.com/gujinghui) [@EikanWang](https://github.com/EikanWang) [@fengyuan14](https://github.com/fengyuan14) [@guangyey](https://github.com/guangyey)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137678
Approved by: https://github.com/kwen2501, https://github.com/guangyey, https://github.com/jgong5
2024-11-13 05:32:19 +00:00
PyTorch MergeBot
1d28b8b6d5 Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit e84d1121ad.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. More details in D65483292 ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2458381056))
2024-11-05 23:10:38 +00:00
Xuehai Pan
e84d1121ad Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-11-05 10:44:56 +00:00
Vishwa Raj Singh
9d7a0869f0 Make DDP Quantization hooks backend Agnostic (#138816)
Current ddp hooks quantization code use .cuda() API to move tensors and parameter on backend devices. This limits only cuda backend to work with ddp quantization hooks.
Change is to make code backend agnostic and move tensors/parameters based on **tensor.device.**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138816
Approved by: https://github.com/kwen2501
2024-10-29 15:02:45 +00:00
Tom Ritchford
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
zeshengzong
82443798aa [Distributed] Refactor compress hook to remove duplicated code (#138182)
Fix TODO in code

```python
# TODO: create an internal helper function and extract the duplicate code in FP16_compress and BF16_compress.
```

1. Extract common logic in `fp16_compress_hook` and `bf16_compress_hook` to `_compress_hook` method
2. Let `fp16_compress_hook` and `bf16_compress_hook` invoke  `_compress_hook` with difference `dtype`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138182
Approved by: https://github.com/awgu
2024-10-18 06:01:15 +00:00
Aaron Gokaslan
31715be72a [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-16 19:44:11 +00:00
PyTorch MergeBot
3117f2cf67 Revert "[BE]: Update mypy to 1.11.2 (#133816)"
This reverts commit 55299cfc22.

Reverted https://github.com/pytorch/pytorch/pull/133816 on behalf of https://github.com/jeanschmidt due to seems to have broken https://github.com/pytorch/pytorch/actions/runs/10865710499/job/30155699792 on main ([comment](https://github.com/pytorch/pytorch/pull/133816#issuecomment-2352377684))
2024-09-16 09:11:16 +00:00
Aaron Gokaslan
55299cfc22 [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-14 21:40:36 +00:00
Xuehai Pan
758a0a88a2 [BE][Easy] enable ruff rule PIE790: unnecessary pass statement (#133200)
This PR removes unnecessary `pass` statement. This is semanticly safe because the bytecode for the Python code does not change.

Note that if there is a docstring in the function, a empty function does not need a `pass` statement as placeholder.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133200
Approved by: https://github.com/malfet, https://github.com/eqy, https://github.com/kit1980
2024-08-15 15:50:19 +00:00
PyTorch MergeBot
cbee9c1fd2 Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit 0e7e61f7ce.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2272370386))
2024-08-07 00:05:20 +00:00
Xuehai Pan
0e7e61f7ce Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007, https://github.com/malfet
2024-08-03 09:43:38 +00:00
Oguz Ulgen
72d2dba992 Add None return type to init (#132335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335
Approved by: https://github.com/albanD
2024-08-01 15:26:45 +00:00
Mikayla Gawarecki
1dd10ac802 [BE] [Reland] Make nn.Module state_dict load_state_dict pre-hook and state_dict post-hook public (#131690)
Reland https://github.com/pytorch/pytorch/pull/126704

#### Fixes the issue with type of `nn.Module._state_dict_hooks` being changed in that PR which was problematic:
Instead of using `Tuple(Callable, bool)` to keep track of whether the private `_register_state_dict_hook` or the public `register_state_dict_post_hook` API was used to register the hook and toggle the behavior accordingly, I set an attribute on the Callable in the private API, which is never cleaned up.

If a callable previously registered using the private API is registered via the public API, a RuntimeError will be raised

#### Copied from previous PR description
Fixes https://github.com/pytorch/pytorch/issues/75287 and https://github.com/pytorch/pytorch/issues/117437

- `nn.Module._register_state_dict_hook` --> add public `nn.Module.register_state_dict_post_hook`
   - Add a test as this API was previously untested
- `nn.Module._register_load_state_dict_pre_hook` --> add public `nn.Module.register_load_state_dict_pre_hook` (remove the `with_module` flag, default it to `True`
    ~- For consistency with optimizer `load_state_dict_pre_hook` raised by @janeyx99, allow the pre-hook to return a new `state_dict`~
 - For issuet by https://github.com/pytorch/pytorch/issues/117437 regarding `_register_state_dict_hook` semantic of returning a new state_dict only being respected for the root for private hook
       - Document this for private `_register_state_dict_hook`
       - Remove this for the public `register_state_dict_post_hook`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131690
Approved by: https://github.com/albanD
2024-07-26 18:14:07 +00:00
Jovian Anthony Jaison
f6edd1f7c9 [BE] Make ActivationWrapper an abstract class (#129808)
Fixes #95481

Test Plan:
Unit tested checkpoint_wrapper.py by instantizing ActivationWrapper and got TypeError as expected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129808
Approved by: https://github.com/Skylion007
2024-07-02 04:29:43 +00:00
Xuehai Pan
e6d4451ae8 [BE][Easy] enable UFMT for torch/distributed/{algorithms,autograd,benchmarks,checkpoint,elastic}/ (#128866)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128866
Approved by: https://github.com/fegin
2024-06-18 13:51:53 +00:00
PyTorch MergeBot
1d233b8f50 Revert "Make nn.Module state_dict load_state_dict pre-hook and state_dict post hook public (#126704)"
This reverts commit c38b3381a1.

Reverted https://github.com/pytorch/pytorch/pull/126704 on behalf of https://github.com/clee2000 due to broke internal typecheck D58394110 (which probably means the code wouldn't work either but I guess it didn't run on the diff). Probably an easy fix? ([comment](https://github.com/pytorch/pytorch/pull/126704#issuecomment-2161299193))
2024-06-11 17:45:20 +00:00
Mikayla Gawarecki
c38b3381a1 Make nn.Module state_dict load_state_dict pre-hook and state_dict post hook public (#126704)
Fixes https://github.com/pytorch/pytorch/issues/75287 and https://github.com/pytorch/pytorch/issues/117437

- `nn.Module._register_state_dict_hook` --> add public `nn.Module.register_state_dict_post_hook`
   - Add a test as this API was previously untested
- `nn.Module._register_load_state_dict_pre_hook` --> add public `nn.Module.register_load_state_dict_pre_hook` (remove the `with_module` flag, default it to `True`
    ~- For consistency with optimizer `load_state_dict_pre_hook` raised by @janeyx99, allow the pre-hook to return a new `state_dict`~
 - Document issue pointed out by https://github.com/pytorch/pytorch/issues/117437 regarding `_register_state_dict_hook` semantic of returning a new state_dict only being respected for the root for private hook
       - Remove this for the public `register_state_dict_post_hook`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126704
Approved by: https://github.com/albanD
ghstack dependencies: #126906
2024-06-10 21:50:17 +00:00
PyTorch MergeBot
90bb510ece Revert "Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)"
This reverts commit 348b181a97.

Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/clee2000 due to sorry I think https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456 is still relevant, I will reach out to them to see what needs to be done in internal to get this remerged ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2159248859))
2024-06-10 20:44:42 +00:00
Aaron Orenstein
3a0d088517 Flip default value for mypy disallow_untyped_defs [5/11] (#127842)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127842
Approved by: https://github.com/oulgen
2024-06-08 18:49:18 +00:00
Xuehai Pan
348b181a97 Deprecate torch._utils.is_compiling() and torch._dynamo.external_utils.is_compiling() (#127690)
This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127690
Approved by: https://github.com/Skylion007
2024-06-08 15:25:03 +00:00
Xuehai Pan
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
PyTorch MergeBot
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
Xuehai Pan
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
Jeeja
1110edb94b Fix stream type to generic in comms default hooks (#120069)
In comms default_hooks - decompress stream is hardcoded to cuda type. fix this to use generic type based on the grad tensor device.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120069
Approved by: https://github.com/jgong5, https://github.com/fegin
2024-05-27 10:27:30 +00:00
Chien-Chin Huang
1d2382f141 [DDP] Use compiled_autograd to trace DDP backward allreduce (#110662)
**Summary**
The reducer of `DistributedDataParallel`  is implemented with C++ and it is not easy to trace the allreduce launched in the reducer. This PR modifies `DistributedDataParallel` to launch one allreduce per gradient when `compiled_autograd` is enabled. The changes allow us to use `compiled_autograd` to trace the allreduce and later be optimized (fused) in the Inductor.

**Key Logic**
1. If `ddp_python_hook` is True, we assume `compiled_autograd` is used. `DistributedDataParallel` registers `compiled_accum_grad_hook` for all parameters.
2. In the first forward() call, if `DistributedDataParallel` is not compiled, all  `compiled_accum_grad_hook` are deregistered. If `DistributedDataParallel` is compiled, all `compiled_accum_grad_hook` will be compiled by `compiled_autograd`.
3.  `compiled_accum_grad_hook` launches an allreduce to reduce the gradient of the parameter.

**Bucketing**
The compiled backward is slow because there is no bucketing for the allreduces. We rely on Inductor to bucket the allreduces.

The bucketing is done in a separate PR.

Differential Revision: [D49428482](https://our.internmc.facebook.com/intern/diff/D49428482/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110662
Approved by: https://github.com/wconstab
2024-02-08 03:03:15 +00:00
Edward Z. Yang
46712b019d Enable local_partial_types (#118467)
When using dmypy, this setting is enabled and cannot be turned off. Force it for regular mypy too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118467
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432
2024-01-28 13:38:22 +00:00
mingxzhao
1859895ffa Docs: fix docstring errors in model_averaging (#117038)
pydocstyle check

averagers.py

Pre
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:1 at module level:
        D100: Missing docstring in public module
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:20 in public method `__init__`:
        D107: Missing docstring in __init__
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:27 in public method `average_parameters`:
        D102: Missing docstring in public method
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:84 in public method `__init__`:
        D107: Missing docstring in __init__
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:106 in public method `average_parameters`:
        D205: 1 blank line required between summary line and description (found 0)
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:106 in public method `average_parameters`:
        D400: First line should end with a period (not '`')
6

Post
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:1 at module level:
        D100: Missing docstring in public module
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:20 in public method `__init__`:
        D107: Missing docstring in __init__
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:27 in public method `average_parameters`:
        D102: Missing docstring in public method
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/averagers.py:84 in public method `__init__`:
        D107: Missing docstring in __init__
4

utils.py

Pre
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/utils.py:1 at module level:
        D100: Missing docstring in public module
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/utils.py:17 in public function `average_parameters`:
        D205: 1 blank line required between summary line and description (found 0)
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/utils.py:45 in public function `get_params_to_average`:
        D205: 1 blank line required between summary line and description (found 0)
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/utils.py:45 in public function `get_params_to_average`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/utils.py:68 in public function `average_parameters_or_parameter_groups`:
        D200: One-line docstring should fit on one line with quotes (found 3)
5

Post
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/utils.py:1 at module level:
        D100: Missing docstring in public module
1

hierarchical_model_averager.py

Pre
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:1 at module level:
        D100: Missing docstring in public module
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:16 in public class `HierarchicalModelAverager`:
        D205: 1 blank line required between summary line and description (found 0)
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:98 in public method `__init__`:
        D107: Missing docstring in __init__
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:137 in private method `_find_process_group`:
        D205: 1 blank line required between summary line and description (found 0)
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:137 in private method `_find_process_group`:
        D400: First line should end with a period (not ',')
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:137 in private method `_find_process_group`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:151 in public method `average_parameters`:
        D205: 1 blank line required between summary line and description (found 0)
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:151 in public method `average_parameters`:
        D400: First line should end with a period (not '`')
8

Post /workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:1 at module level:
        D100: Missing docstring in public module
/workspaces/pytorch/torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py:99 in public method `__init__`:
        D107: Missing docstring in __init__
2

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117038
Approved by: https://github.com/H-Huang
2024-01-18 04:12:51 +00:00
mingxzhao
b4f1ab4505 Docs: fix docstring errors in ddp_comm_hooks (#116866)
Reopens #115272
Fixes ddp_comm_hooks errors in https://github.com/pytorch/pytorch/issues/112644

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116866
Approved by: https://github.com/awgu
2024-01-10 01:24:06 +00:00
Ashvanth.S
e6bffc6b87 Fix docstring errors in default_hooks.py, optimizer_overlap.py, checkpoint_wrapper.py, copy.py, benchmark_ddp_rpc.py, utils.py, dependency.py, phony.py, pipeline.py, checkpoint.py, worker.py, batchnorm.py, quantization.py (#113511)
Fixes #112645

Updated the files by fixing the docstring errors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113511
Approved by: https://github.com/weifengpy
2023-11-14 18:52:41 +00:00
ChanBong
c0b57d4e3b fix docstring issues in torch.distributed (#113337)
Fixes #112643

Fixes all the issues listed

### Error Count

|File | Count Before | Count now|
|---- | ---- | ---- |
|`torch/distributed/optim/named_optimizer.py` | 13 | 1|
|`torch/distributed/nn/functional.py` | 7 | 1|
|`torch/distributed/nn/api/remote_module.py` | 25 | 3|
|`torch/distributed/algorithms/join.py` | 43 | 4|

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113337
Approved by: https://github.com/ezyang
2023-11-13 19:37:29 +00:00
Fabrice Pont
053367b1ed fix: flake8-bugbear code B024 (#107265)
See #106571 item B024

This fix concerns the addition of `abstractmethod` to methods declared inside abstract classes.

Should I also include PEP8 compliant reformatting on the files I had to modify ?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107265
Approved by: https://github.com/kit1980
2023-10-04 23:52:52 +00:00
redwrasse
ba4782e3c0 cleanup typos; redundant parentheses (#109003)
- minor spelling fixes in `aten/src/ATen/core/TransformationHelper.h`
- remove redundant parentheses in control statements in `torch/distributed/algorithms/_quantization/quantization.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109003
Approved by: https://github.com/davidradl, https://github.com/H-Huang
2023-09-11 17:09:17 +00:00
Rohan Varma
208fd1cb84 [RFC] Somewhat BC breaking: make checkpoint_wrapper default to NO_REENTRANT (#108435)
We should use no_reentrant. There are a lot of users of this API, but
it is in a prototype state so should be fine to change.

Differential Revision: [D48898148](https://our.internmc.facebook.com/intern/diff/D48898148/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108435
Approved by: https://github.com/awgu
ghstack dependencies: #108032, #108033
2023-09-05 21:43:41 +00:00
Andrew Gu
15953fdf35 [FSDP][8/N] Replace _FSDPPolicy.policy with _Policy._run_policy (#104969)
This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.

This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999
2023-08-03 12:42:14 +00:00
Rohan Varma
2ec7cd2db2 [CheckpointWrapper] Test for kwarg propagation, remove checkpoint_fn_arg support (#102679)
Closes https://github.com/pytorch/pytorch/issues/100576

Differential Revision: [D46342398](https://our.internmc.facebook.com/intern/diff/D46342398/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102679
Approved by: https://github.com/awgu
2023-07-28 21:18:35 +00:00
Andrew Gu
800287fb56 [FSDP] Optimize away intermediate div_ for HSDP (#106034)
### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.

$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$

Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.

### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$

This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.

### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.

Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma
2023-07-28 18:36:26 +00:00
Justin Chu
232b96b6e2 [BE] Enable ruff's UP rules and autoformat distributed/ (#105433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105433
Approved by: https://github.com/albanD
2023-07-19 14:27:11 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
b4d91b1c5b Revert "[Typing] Fix PEP 484 Violation (#105022)"
This reverts commit 4148b7bada.

Reverted https://github.com/pytorch/pytorch/pull/105022 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/105022#issuecomment-1635967734))
2023-07-14 14:45:09 +00:00