Commit Graph

192 Commits

Author SHA1 Message Date
Rohan Varma
a8074a1a0b [Checkpoint] rename apply_ac_wrapper (#85449)
Per title

Differential Revision: [D39714855](https://our.internmc.facebook.com/intern/diff/D39714855/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85449
Approved by: https://github.com/awgu
2022-09-23 21:15:08 +00:00
Rohan Varma
cc64f64670 [Docs] Minor fix to apply_ac doc (#85448)
Per title

Created from CodeHub with https://fburl.com/edit-in-codehub

Differential Revision: [D39714530](https://our.internmc.facebook.com/intern/diff/D39714530/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85448
Approved by: https://github.com/awgu
2022-09-23 21:15:08 +00:00
anjali411
85073b8ddc Add __all__ to fx, fistributed and cuda submodules (#85080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85080
Approved by: https://github.com/albanD
2022-09-21 18:04:58 +00:00
Rohan Varma
8cb7826889 [CheckpointWrapper] Reentrant kwarg support (#84908)
A temporary patch to support keyword args when reentrant checkpoint wrapper is used. This is need to unblock some crucial workloads, the ideal fix would be checking this directly into torch.utils.checkpoint.

Differential Revision: [D39453453](https://our.internmc.facebook.com/intern/diff/D39453453/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84908
Approved by: https://github.com/awgu
2022-09-15 00:30:23 +00:00
Rohan Varma
55ca6901a7 [CheckpointWrapper] Decouple CPU offload (#84907)
This fixes the activation offload for checkpoint wrapper, which was previously broken. It was broken because it was tightly coupled with activation checkpoint, i.e. we did:

```
with save_on_cpu:
    checkpoint(module_forward())
```

which would not offload any activation tensors to CPU, as those activations would already be not saved by autograd due to the checkpoint implementation taking priority.

Now, if `offload_to_cpu` is specified, we only do `save_on_cpu` and no checkpoint, so all intermediate tensors are offloaded to CPU instead of checkpointed.

These wrappers can be composed, i.e. if we have

`(Linear, Linear) -> (Linear, Linear) -> (Linear, Linear)`

we can do

`Offload( checkpoint(Linear, Linear) -> checkpoint(Linear, Linear) -> checkpoint(Linear, Linear))`

and inner tensors would be checkpointed while outers will be offloaded.

Differential Revision: [D39448882](https://our.internmc.facebook.com/intern/diff/D39448882/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84907
Approved by: https://github.com/awgu
2022-09-15 00:30:23 +00:00
Rodrigo Kumpera
38192f63cd Add __all__ for a few distributed modules plus a little typing (reland) (#84872)
This handles distributed_c10d, which is massive and ddp_comm_hooks.

This relands #84119 with the required fixes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84872
Approved by: https://github.com/rohan-varma
2022-09-13 21:57:49 +00:00
PyTorch MergeBot
219ff26172 Revert "Add __all__ for a few distributed modules plus a little typing (#84119)"
This reverts commit 6f21680563.

Reverted https://github.com/pytorch/pytorch/pull/84119 on behalf of https://github.com/izaitsevfb due to breaking internal builds, see D39386448
2022-09-09 20:01:07 +00:00
Rodrigo Kumpera
6f21680563 Add __all__ for a few distributed modules plus a little typing (#84119)
This handles distributed_c10d, which is massive and ddp_comm_hooks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84119
Approved by: https://github.com/rohan-varma
2022-09-08 23:28:31 +00:00
Rodrigo Kumpera
65dc5dd3f3 [c10d] Introduce dist.get_local_rank, dist.get_global_rank and dist.get_global_ranks (#82134)
Those functions enable membership introspection into a ProcessGroup. A common scenario
that needs this is library code that consumes a PG but doesn't create it, which means
it likely doesn't know the global ranks used to create it.

Translating from local to global is necessary when using c10d collectives like broadcast
so if your library code adopts the convention of using local rank 0, it needs
to the following:

```python
import torch.distributed as dist

my_pg: dist.ProcessGroup = ...

def my_library_bcast(tensor)
    dist.broadcast(tensor, src=dist.get_global_rank(my_pg, local_rank=0), my_pg)

```

This implements some of the helpers needed to implement the `clone` API from: https://github.com/pytorch/pytorch/issues/81291
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82134
Approved by: https://github.com/rohan-varma
2022-08-30 17:45:00 +00:00
Rohan Varma
1a53e35b9d Enforce explicit ProcessGroup passed into DefaultState (#84105)
Would prefer to enforce that users pass in explicit PG into these state objects when using comm hooks with FSDP, so that it is clear and easy debugable over which processes communication is taking place.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84105
Approved by: https://github.com/mrshenli, https://github.com/zhaojuanmao
2022-08-29 14:52:58 +00:00
PyTorch MergeBot
5cf4542f86 Revert "Enforce explicit ProcessGroup passed into DefaultState (#84105)"
This reverts commit adc9a1e2fb.

Reverted https://github.com/pytorch/pytorch/pull/84105 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-08-28 14:30:18 +00:00
Rohan Varma
adc9a1e2fb Enforce explicit ProcessGroup passed into DefaultState (#84105)
Would prefer to enforce that users pass in explicit PG into these state objects when using comm hooks with FSDP, so that it is clear and easy debugable over which processes communication is taking place.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84105
Approved by: https://github.com/mrshenli, https://github.com/zhaojuanmao
2022-08-27 03:12:20 +00:00
Olga Andreeva
f204afc2bb Added communication hook for sharded cases (#83254)
Fixes https://github.com/pytorch/pytorch/issues/79114

An implementation of a FSDP communication hook interface for a sharded strategies:

- Added `reduce_scatter_hook` to default hooks. Note the difference of `reduce_scatter` from `all_reduce`, it requires 2 tensors:`input_gradient` and `output` variables and stores result in `output`, which is further used as a summed gradient shard.
- Adjusted FSDP logic to return `reduce_scatter_hook` as a default communication hook for sharded strategies, `DefaultState` is the same for sharded and non-sharded strategies.
- Adjusted low-precision hooks to work with both `all_reduce` and `reduce_scatter` depending on whether `output` tensor is provided or not.

Test plan:

Added all existing sharded strategies as an input parameters to existing tests.
For`test_default_communication_hook_behaviour` double checked how a linear layer is sharded across workers. This test creates a simple net ``1 X N``, where ``N`` - is the number of workers. For sharded cases, ``N`` parameters are sharded across ``N`` workers. This test checks that after backward, each worker has a proper value in it's chunk of the gradient, or the whole gradient on every worker is equal to an expected value.

Checked that low-precision tests work for sharded cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83254
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-08-18 18:41:14 +00:00
joncrall
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
Rohan Varma
5b2c03823d Generalize CheckpointWrapper (#83035)
Allow checkpoint_wrapper to take in the checkpoint functional impl. This decouples it from torch.utils.checkpoint and allows other checkpoint implementations to be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83035
Approved by: https://github.com/awgu
2022-08-09 23:35:39 +00:00
ProGamerGov
71d50f4f89 Change docstring type callable to Callable for consistency (#82487)
### Description

Across PyTorch's docstrings, both `callable` and `Callable` for variable types. The Callable should be capitalized as we are referring to the `Callable` type, and not the Python `callable()` function.

### Testing

There shouldn't be any testing required.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82487
Approved by: https://github.com/albanD
2022-08-01 17:26:09 +00:00
Olga Andreeva
a60907ec11 Adding fsdp fp16 and bf16 hooks (#81711)
Recently, `register_comm_hook` was introduced to `FSDP`, which at the moment supports only `NO_SHARD` strategy and has a default `all_reduce` hook implemented. This PR adds two lower precision hooks to an existing default hook.

I've also made slight adjustments to existing implementation of an `all_reduce` hook including:

`AllReduceState` ->` DefaultState `, motivation: `AllReduceState` is not specific to all_reduce. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. `fp16_hook` and `bf16_hook`.
I've put all 3 hooks into `default_hooks.py`
Additionally, `FSDP` supports `MixedPrecision` and, theoretically, it is possible to specify MixedPrecision for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks to `fully_sharded_data_parallel`, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook and MixedPrecision(reduce_dtype=<precision>) specified, but I am happy to discuss this and adjust current implementation.

As a test, I create two models: one with a lower precision hook and one with a `MixedPrecision(reduce_dtype=<precision>)` specified, perform one forward/backward and optimizer step and compare gradients.

PS. first version of this PR was reverted, because added unittests didn't include NCCL version checks for `bf16_hook` (thus failed on trunk). In this version, I've added appropriate checks for tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81711
Approved by: https://github.com/rohan-varma
2022-07-19 23:54:51 +00:00
PyTorch MergeBot
a8f4011e90 Revert "Adding fsdp fp16 and bf16 hooks (#80557)"
This reverts commit f7d6828467.

Reverted https://github.com/pytorch/pytorch/pull/80557 on behalf of https://github.com/aovladi due to broke distributed tests on trunk
2022-07-19 03:11:19 +00:00
Olga Andreeva
f7d6828467 Adding fsdp fp16 and bf16 hooks (#80557)
Recently, `register_comm_hook` was introduced to `FSDP`, which at the moment supports only `NO_SHARD` strategy and has a default `all_reduce` hook implemented. This PR adds two lower precision hooks to an existing default hook.

I've also made slight adjustments to existing implementation of an `all_reduce` hook including:

- `AllReduceState` ->  `DefaultState` , motivation: `AllReduceState` is not specific to `all_reduce`. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. fp16_hook and bf16_hook.
- I've put all 3 hooks into `default_hooks.py`

Additionally, `FSDP` supports `MixedPrecision` and, theoretically, it is possible to specify `MixedPrecision` for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks to `fully_sharded_data_parallel`, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook and `MixedPrecision(reduce_dtype=<precision>)` specified, but I am happy to discuss this and adjust current implementation.

As a test, I create two models: one with a lower precision hook and one with a `MixedPrecision(reduce_dtype=<precision>)` specified, perform one forward/backward and optimizer step and compare gradients.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80557
Approved by: https://github.com/rohan-varma
2022-07-18 22:40:56 +00:00
Jerome
547e499731 Enable Zero1's ddp_with_overlap for hpu backend (#80438)
Enable zero with ddp overlap feature along with a simple interface to insert functional optimizer to the map

Signed-off-by: Jerome <janand@habana.ai>

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80438
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-07-18 15:05:27 +00:00
Rohan Varma
0c5fdfd95f Revert "Revert "[FSDP Optim State] Remove checkpoint prefix (#80480)"" (#80936)
This reverts commit fe361dede4.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80936
Approved by: https://github.com/awgu
2022-07-06 22:21:07 +00:00
PyTorch MergeBot
fe361dede4 Revert "[FSDP Optim State] Remove checkpoint prefix (#80480)"
This reverts commit 04c50fec1c.

Reverted https://github.com/pytorch/pytorch/pull/80480 on behalf of https://github.com/suo due to Broke master 04c50fec1c, the test failures were not unrelated
2022-07-06 02:43:27 +00:00
Rohan Varma
04c50fec1c [FSDP Optim State] Remove checkpoint prefix (#80480)
Remove `_checkpoint_wrapped_module` prefixes when creating keys for optimizer state_dict.

Having these does not actually create an issue for optim_state_dict save / load, but we'd like to strip these keys out for downstream code that consumes these APIs typically expecting checkpointing prefixes to not exist (as checkpointing should be a transparent operation which should not change module / parameter names).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80480
Approved by: https://github.com/awgu, https://github.com/fegin
2022-07-06 01:17:58 +00:00
Chien-Chin Huang
e0eeb06ec6 Consolidate the naming of named_parameter and state_dict for CheckpointWrapper (#80089)
named_parameter() should return the same parameter names as state_dict() but the current CheckpointWrapper does not enforce this naming rule. This PR resolves this issue.

Differential Revision: [D37344200](https://our.internmc.facebook.com/intern/diff/D37344200/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80089
Approved by: https://github.com/rohan-varma
2022-07-05 22:11:59 +00:00
Charlie Yan
ffae7308c9 Enable test: distributed/algorithms/quantization/test_quantization (#80097)
fixes  https://github.com/pytorch/pytorch/issues/69017
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80097
Approved by: https://github.com/wanchaol
2022-07-01 01:32:33 +00:00
PyTorch MergeBot
f667aaed1d Revert "Added serialization to postlocal_SGD. (#80435)"
This reverts commit dfdf4e79df.

Reverted https://github.com/pytorch/pytorch/pull/80435 on behalf of https://github.com/suo due to broke distributed tests on trunk, see: dfdf4e79df
2022-06-30 01:34:10 +00:00
Olga Andreeva
dfdf4e79df Added serialization to postlocal_SGD. (#80435)
Fixes #75666

Current PR adds the functionality for `PostLocalSGD` communication hook and tests that communication hook can be properly saved and restored. Similar to https://github.com/pytorch/pytorch/pull/79334, where serialization was added to `PowerSGD`.

``__getstate__``

 Returns:
```
        ``Dict[str, Any]`` which will be pickled and saved.
        ``process_group`` and ``subgroup`` are not serializable and excluded from
        a returned state.
```
``__setstate__``
```
          Takes provided ``state`` and retrieves ``PostLocalSGDState``.
          ``process_group`` and ``subgroup`` are set to default process_group and subgroup respectively.
           Default subgroup is equivalent to the subgroup on each node.
```

Small adjustment to `PowerSGD`'s warning message.

Refactored unittest, i.e. separated parity and log checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80435
Approved by: https://github.com/awgu
2022-06-29 23:59:46 +00:00
Rohan Varma
5fc2d45a3a Remove unneeded TODO (#80453)
This TODO is no longer needed, as we use `_register_fused_optim` to register the overlapped optimizer in DDP.  Also, remove comment about API being experimental, as this API is no longer going to be used by end user.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80453
Approved by: https://github.com/awgu
2022-06-29 01:19:48 +00:00
Olga Andreeva
5fc209ed11 FSDP communication hook interface for NO_SHARD strategy (#79833)
Fixes #79114

An implementation of a FSDP communication hook interface for a NO_SHARD strategy:
- `FullyShardedDataParallel.register_comm_hook(self, state: object, hook: callable)` checks current sharding strategy. If it is other that NO_SHARD, raises a runtime error. Otherwise, sets and shares a specified hook and its state with all submodules
- When FSDP is ready to communicate a gradient, checks if there is a registered hook, and calls it instead of all_reduce. Additionally, gradient pre and post devision are not performed if a hook is registered.

To test the interface, I've implemented a communication hook, that calls for `all_reduce`.

A  unittest:
- checks that is a sharding strategy is anything but NO_SHARD, a runtime error is raised
- checks that for a NO_SHARD case, model with registered all_reduce hook and without a hook work the same.
- checks for 2 types of FSDP models: with the wrapped first layer and without. (to make sure submodules have a hook registered)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79833
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-06-28 08:03:11 +00:00
anjali411
3bcc19b29a Add __all__ to various submodules in torch.fx, distributions, distributed, package (#80367)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80367
Approved by: https://github.com/albanD
2022-06-27 21:27:30 +00:00
Rohan Varma
2ede28724d [CheckpointWrapper] Replace generic mod prefix (#79830)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79830
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao
2022-06-21 16:01:59 +00:00
Olga Andreeva
8a6d83079c Functionality/pickling for commhooks (#79334)
This PR addresses issue address #75666.
Stateful communication hook now can be saved and reloaded to resume training.

Current PR adds the functionality for PowerSGD communication hook and tests that communication hook can be properly saved and restored.

PowerSGD implementation uses ``__slots__``, as a result introduced __getstate__ and __setstate__ methods are implemented to work with `__slots__` and not` __dict__`.

`__getstate__ `

	 Returns:
           A dictionary that represents a ``PowerSGDState`` which will be pickled and saved.
          ``process_group`` is non-serializable and excluded from a returned state.

`__setstate__`

	Takes a provided ``state`` and retrieves ``PowerSGDState``.
        ``process_group`` is set to default with a proper warning issued to a user.

Unit test

A hook-independent `_test_hook_pickling` is added with this PR, as well as `test_ddp_hook_pickling_powerSGD`, which tests `powerSGD`’s ability to be saved and reloaded.

Currently, the test creates a ddp model with a provided hook, trains it for 10 epochs and saves model’s state and hook’s state.
During reloading, unit test makes sure that a warning was logged (only one warning and the proper one). It then proceeds to check that reloaded hook and original hook are the same. Finally, it checks that a hook’s state was properly initialized:
	- it compares slot values (all, but 2: `process_group` and `rng`) for original and reloaded state
	- it checks that process group was set to a default group
	- it checks that a random state was restored properly with np.testing.assert_array_equal, because `rng` is an instance of `np.random.RandomState`, represented by a tuple. One of entries is of `ndarray dtype[uint32]` type and `np.testing.assert_array_equal` is used for assertion.

Future To-Do:
	- Implement similar __getstate__ and __setstate__ for other stateful communication hooks
	- Add appropriate tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79334
Approved by: https://github.com/rohan-varma, https://github.com/awgu
2022-06-16 23:15:34 +00:00
Rohan Varma
543919cfc8 Forward attributes to wrapped module
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78854

Approved by: https://github.com/albanD
2022-06-14 01:13:33 +00:00
Rohan Varma
44fe851feb [WIP] Fix non-reentrant hooks based checkpointing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78752

Approved by: https://github.com/albanD
2022-06-14 01:13:33 +00:00
Rohan Varma
ec86070922 Checkpoint util
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78704

Approved by: https://github.com/zhaojuanmao
2022-06-10 18:37:36 +00:00
Rohan Varma
f9f8127414 CheckpointWrapper state_dict fix (#77224)
- Uses state dict / load state dict hooks to ensure that modules wrapped with `CheckpointWrapper` can be loaded into non-checkpointed wrapped module.

This is because a training run can use activation checkpointing, then we can recover `state_dict`, and a future run may not want to wrap modules with activation checkpointing or decide to change activation checkpoint wrapping structure. To support this, we add hooks to remove / add the relevant prefix as needed.

Tests are added to ensure we can load into CheckpointWrapper module as well as local module from CheckpointWrapper-wrapped module. state_dict with FSDP is also verified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77224
Approved by: https://github.com/zhaojuanmao
2022-05-17 03:39:31 +00:00
wayi1
5ab8afe487 [Model Averaging] Support disabling post-local gradient sync (#76723)
I find that sometimes disabling intra-subgroup gradient allreduce can still give a satisfying accuracy for some cases, so better to make such gradient averaging configurable. This does not take into account the saving in the communication of allreducing gradients.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76723
Approved by: https://github.com/rohan-varma
2022-05-16 18:09:09 +00:00
Yi Wang
25fa6235f4 [Model Averaging] Make an error message more clear in hierarchical_model_averager.py
As title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75832
Approved by: https://github.com/mrshenli
2022-04-26 15:20:51 +00:00
wayi1
e90580390d [Model Averaging] Make the error message more informative in hierarchical_model_averager.py
As title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76277
Approved by: https://github.com/rohan-varma
2022-04-24 15:10:19 +00:00
magialiao
7c8c8cc248 Use batched operations for PowerSGD
This PR is a rebased version of #75157 which fixes CI issues
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76041
Approved by: https://github.com/albanD, https://github.com/rohan-varma
2022-04-21 03:25:09 +00:00
Alban Desmaison
da3c848dfa Make distributed raise ImportError when not available
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75975

Approved by: https://github.com/mrshenli
2022-04-20 13:05:18 +00:00
PyTorch MergeBot
c5d57e7be9 Revert "Use batched operations for PowerSGD"
This reverts commit 5654e63398.

Reverted https://github.com/pytorch/pytorch/pull/75157 on behalf of https://github.com/albanD
2022-04-18 13:10:29 +00:00
magialiao
5654e63398 Use batched operations for PowerSGD
This implements method proposed in #74907

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75157
Approved by: https://github.com/wayi1, https://github.com/rohan-varma
2022-04-18 04:34:17 +00:00
Haijunlv
08f3b95857 fix PostLocalSGDOptimizer and ModelAverager average bug
Fixes #74157

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74894
Approved by: https://github.com/rohan-varma, https://github.com/wayi1
2022-04-13 11:41:27 +00:00
wayi1
4fb7fa081e [Model Averaging] Code simplification for _find_process_group function (#75007)
Summary:
Previously the highest-level process group in `period_process_group_dict` could be `None`, indicating the global group. Now `period_process_group_dict` cannot contain `None` as a process group, so the function `_find_process_group` can just return a process group instead of a tuple -- when not found, just return `None`, because now the returned process group cannot be `None`.

Proposal: https://github.com/pytorch/pytorch/issues/71325

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75007

Reviewed By: awgu

Differential Revision: D35357816

Pulled By: rohan-varma

fbshipit-source-id: 4522dba49797df7140227bfd822d668b7e118a66
(cherry picked from commit 77ca01b555d52685283c969176b08de4ff46c32d)
2022-04-04 20:31:22 +00:00
Yi Wang
2aebece625 [Model Averaging] Remove unused variable world_size in post_localSGD_hook.py (#74803)
Summary:
As title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74803

Reviewed By: albanD

Differential Revision: D35175613

Pulled By: mrshenli

fbshipit-source-id: 881933656ed214554b8acb4c5756349cea0af51d
(cherry picked from commit 033efb2eea856d00d5e78c8a99d726c6cf69d714)
2022-03-28 17:41:26 +00:00
wayi1
5fbe8b1966 [Model Averaging] Make HierarchicalModelAverager a subclass of averagers.ModelAverager
Make `HierarchicalModelAverager` a subclass of `averagers.ModelAverager` is a preparation step for incorporating hierarchical SGD into `PostLocalSGDOptimizer`.

Proposal: https://github.com/pytorch/pytorch/issues/73382
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74564
Approved by: https://github.com/rohan-varma
2022-03-24 21:52:00 +00:00
wayi1
5993f48711 [Model Averaging] Add a reference to hierarchical SGD (#73823)
Summary:
Add a reference.

Also fix the comment: unlike `averagers.py`, currently this is not a base class that can inherit many subclasses.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73823

Reviewed By: ejguan

Differential Revision: D34684366

Pulled By: rohan-varma

fbshipit-source-id: e253ed39ba0783ad73bfd889e9a2e7d0c9214a3a
(cherry picked from commit a9fec3585078881ccd5886ebb27e52b15f7181b1)
2022-03-08 05:56:17 +00:00
wayi1
0bb3b0652c [Model Averaging] Support hierarchical model averaging (#73285)
Summary:
Implement hierarchical model averaging proposed in https://github.com/pytorch/pytorch/issues/71325.

Unit tests are added. Since I don't have access to 4-GPU machines in open-source environment, expect that the branch with the prefix of `ci-all` can run the test that requires 4 GPUs.

In the future, the internals of `PeriodicModelAveraging` can be simplified as an implementation of a specialized hierarchical model averaging, where `period_group_size_dict` only has a pair of period and world size.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73285

Reviewed By: mrshenli

Differential Revision: D34457792

Pulled By: rohan-varma

fbshipit-source-id: 39a6c5bf8a2852b6394a56abbad17b8a909b9fba
(cherry picked from commit 5f543d46103edb515db199dbb80db43c85665f29)
2022-03-04 18:29:36 +00:00
Andrew Gu
59dd84cab6 [Join][BE] Fix typo; remove obsolete method (#72886)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72886

**Test Plan**
Searching for `_schedule_shadow_all_reduce_for_fwd_pass` shows that it is defined but never used.

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D34255651

Pulled By: awgu

fbshipit-source-id: 205a0325c2cdc05e127a183cb86fa2fc2e0db99d
(cherry picked from commit 4492f03a3f)
2022-02-16 15:03:09 +00:00