Commit Graph

2331 Commits

Author SHA1 Message Date
PyTorch MergeBot
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
PyTorch MergeBot
b4d91b1c5b Revert "[Typing] Fix PEP 484 Violation (#105022)"
This reverts commit 4148b7bada.

Reverted https://github.com/pytorch/pytorch/pull/105022 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/105022#issuecomment-1635967734))
2023-07-14 14:45:09 +00:00
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
Rohan Varma
242fc29c96 [FSDP] Refactor optimizer in backward (#104813)
1) Use zero_grad(set_to_none=True) to set grad to None, 2) call
prepare_grad_for_optim() before call to .step, 3) use
_reset_flat_param_grad_info to set flat param gradient back to None. These
changes should just be refactors and equivalent to how gradient memory was
managed  before.

Differential Revision: [D47310761](https://our.internmc.facebook.com/intern/diff/D47310761/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104813
Approved by: https://github.com/awgu
2023-07-13 06:42:53 +00:00
Rohan Varma
f2eed129c4 FSDP optimizer overlap (#98667)
constraints:

1. No support for gradient accumulation
2. CPU offload runs step() on CPU. In future PRs ideally we'd run this on GPU.
3. When CPU offload + optimizer overlap, we have to copy the flat_param grad to CPU with non_blocking=False, otherwise step() might run on invalid data.
4. Step is waited on in post backward final cb, when in theory it can wait until the next forward.

Differential Revision: [D44809582](https://our.internmc.facebook.com/intern/diff/D44809582/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98667
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-13 06:42:53 +00:00
PyTorch MergeBot
5b4aacd691 Revert "[DCP] Add FsspecReader and FsspecWriter to checkpoint __init__.py (#105088)"
This reverts commit 76a053d55c.

Reverted https://github.com/pytorch/pytorch/pull/105088 on behalf of https://github.com/atalman due to broke trunk and  linux-focal-py3.9-clang7-asan ([comment](https://github.com/pytorch/pytorch/pull/105088#issuecomment-1633385350))
2023-07-13 00:59:55 +00:00
Andrew Gu
954bae8e53 [FSDP][Easy] Rename streams; add back stream sharing test (#104966)
Purely out of preference, this PR renames the streams to `_unshard_stream` instead of `_streams_unshard` etc. since the former reads more naturally. The PR also removes some duplicated comments and adds back a unit test that streams are shared.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104966
Approved by: https://github.com/rohan-varma
2023-07-13 00:24:41 +00:00
Iris
4f8ba6f8f6 [DeviceMesh]Add validate mesh flag to DeviceMesh (#104807)
When creating DeviceMesh, _init_process_group() would validate that all calling ranks pass in the same `mesh` argument. In FSDP, we are currently creating the DeviceMesh based on the pg of the root state so the mesh will always be valid. Adding the flag to DeviceMesh, so we can skip the all_gather_tensor of the validation during construction time.

_validate_mesh is default to True, but we manually flip it to False when initializing device mesh in FSDP's  _runtime_utils.py.

Will modify skipping pg creation if existed for both 1D and 2D cases and then delete _init_process_groups flag in a follow up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104807
Approved by: https://github.com/wanchaol
2023-07-12 23:42:13 +00:00
Iris
76a053d55c [DCP] Add FsspecReader and FsspecWriter to checkpoint __init__.py (#105088)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105088
Approved by: https://github.com/kumpera
2023-07-12 23:40:35 +00:00
Nikita Shulga
4148b7bada [Typing] Fix PEP 484 Violation (#105022)
Not sure, how it worked before, but if arguments must be annotated is optional if they are defaulted to None

Towards enabling mypy-1.4.1 in lintrunner

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 5e1b9f4</samp>

> _We annotate the arguments of doom_
> _To show the `None` values of gloom_
> _We improve the type checking and readability_
> _With `Optional` annotations of metal-ity_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105022
Approved by: https://github.com/izaitsevfb, https://github.com/huydhn, https://github.com/Skylion007
2023-07-12 10:20:48 +00:00
Aaron Gokaslan
2f95a3d0fc [BE]: Apply ruff PERF fixes to torch (#104917)
Applies automated ruff fixes in the PERF modules and enables all automatic ones. I also updated ruff which applied some additional fixes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104917
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-07-11 20:45:21 +00:00
Andrew Gu
63d1fb21f5 [FSDP] Default limit_all_gathers=True (#104900)
This PR defaults to `limit_all_gathers=True`.

I included a `record_function()` for the rate limiter synchronization to help with user confusion on the gap in the pre-forward:
<img width="874" alt="Screenshot 2023-07-10 at 3 28 18 PM" src="https://github.com/pytorch/pytorch/assets/31054793/61f55e0e-58d7-4162-9395-bea06d3e8d8a">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104900
Approved by: https://github.com/fegin
2023-07-11 01:04:29 +00:00
Matthew Hoffman
3279f06410 Merge and improve torch optim optimizer type stubs (#102593)
Fixes #102428

Also improves hook registration type hints:

```python
from typing import Any, Dict, Tuple

from torch import nn
from torch.optim import Adam, Adagrad, Optimizer

linear = nn.Linear(2,2)
optimizer = Adam(linear.parameters(), lr=0.001)

def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def pre_hook_fn_return_modified(
    optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
    return inputs, kwargs

def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

optimizer.register_step_post_hook(hook_fn)  # OK

optimizer.register_step_pre_hook(pre_hook_fn_return_none)  # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_modified)  # OK

optimizer.register_step_post_hook(hook_fn_other_optimizer)  # Parameter 1: type "Adam" cannot be assigned to type "Adagrad"

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593
Approved by: https://github.com/janeyx99
2023-07-11 00:07:30 +00:00
fduwjj
aa84078c6c [PTD][TP] Add BWD support for colwise embedding sharding (#104820)
Originally, we didn't enable BWD for colwise embedding because we thought it was just for inference, but it turns out that we do need it for training. So, let's enable it for now and unit test is also added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104820
Approved by: https://github.com/fegin
2023-07-10 22:33:20 +00:00
Iris Zhang (PyTorch)
7b538d8987 [DCP][fsspec] Consolidate OSS FsspecWriter/Reader and internal FsspecWriter/Reader (#104724)
Summary:
This diff does the following:
1. re-enable single_file_per_rank for FsspecWriter, as the issue of file slicing error is resolved because of [https://github.com/pytorch/pytorch/pull/99167]
2. remove sync_files from FsspecWriter as there is no fsspec equivalence.
3. remove the internal implementation of FsspecWriter/Reader, as it has been upstreamed to PyTorch OSS
4. keep the internal test for manifold inside internal as we can only test it in fb environment
5. consolidate test to remove duplicates
6. remove unnecessary TARGETS

Test Plan:
```
buck test @//mode/dev-nosan  //caffe2/test/distributed/checkpoint/fb:test_fsspec_filesystem -- --print-passing-details

----------------------------------------------------------------------
Ran 1 test in 54.894s

OK
/usr/local/fbcode/platform010/lib/python3.8/tempfile.py:818: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmpzomokvh6'>
  _warnings.warn(warn_message, ResourceWarning)

Buck UI: https://www.internalfb.com/buck2/4cb722a2-3ee7-48f2-a9ef-55ee6fb1a498
Test UI: https://www.internalfb.com/intern/testinfra/testrun/8725724447995201
Network: Up: 8.8 MiB  Down: 1.5 GiB  (reSessionID-04c29f56-ae94-4187-8a1a-c812f432674d)
Jobs completed: 209847. Time elapsed: 1:56.5s.
Cache hits: 100%. Commands: 85687 (cached: 85687, remote: 0, local: 0)
Tests finished: Pass 3. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Differential Revision: D47266068

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104724
Approved by: https://github.com/fegin, https://github.com/fduwjj
2023-07-10 19:31:01 +00:00
Mikayla Gawarecki
1ad435772b Added option to always call nn.Module global/non-global forward hooks (#104278)
Fix #103997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104278
Approved by: https://github.com/albanD
2023-07-10 18:58:07 +00:00
Jane Xu
e25f5732c8 Add meta registrations and distributed decomps: _foreach_div_.Scalar, sqrt_.default (#104779)
This PR unblocks #104780 by resolving spmd tracing test issues and by adding meta registrations for foreach inplace ops (div_ and sqrt_)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104779
Approved by: https://github.com/fegin, https://github.com/albanD
2023-07-10 17:38:46 +00:00
Iris
af52f6b928 [DCP] Add documentation for HSDP saving using DCP (#104810)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104810
Approved by: https://github.com/fduwjj
2023-07-10 17:33:05 +00:00
Chien-Chin Huang
46154c4c35 [FSDP][optim_state_dict] The correct way to initialize optimizer states if the corresponding param is empty (#104765)
When using KeyedOptimizer.init_state(), some optimizers initializes the states even if the param is empty (size() == 0) while some optimizer avoid initializing the states. There is no way FSDP can tell. Instead, FSDP should look up `optim.state`. Fortunatelly, `optim.state` does not rely on FQNs which some internal users change the FQNs.

Differential Revision: [D47285562](https://our.internmc.facebook.com/intern/diff/D47285562/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104765
Approved by: https://github.com/fduwjj
2023-07-10 08:00:55 +00:00
Andrew Gu
e600505e32 [FSDP][5/N] Unblock ignored_states + auto wrap (for now) (#104418)
The "for now" is because we still have the issue that when using the parameter `ignored_states` path, we do not recover the ignored modules, so FSDP still wraps those as empty shells (no managed parameters), which is not ideal. This is not a blocking issue as far as I know.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104418
Approved by: https://github.com/rohan-varma
2023-07-08 12:40:14 +00:00
Andrew Gu
610f74627e [FSDP][4/N] Remove _get_fully_sharded_module_to_states (#104409)
`_get_fully_sharded_module_to_states()` was used to emulate auto wrapping without actually calling `fully_shard`. Since we committed to unifying (see previous PR), we can remove this function and its helpers/tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104409
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-07-08 12:40:14 +00:00
Andrew Gu
d9be0366d3 [FSDP][3/N] Unify fully_shard auto wrap (#104408)
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104408
Approved by: https://github.com/rohan-varma
2023-07-08 12:40:12 +00:00
Andrew Gu
6d71b4f9f1 [FSDP][2/N][Easy] Prepare _auto_wrap for fully_shard (#104407)
This mainly just changes the `_auto_wrap()` function signature and generalizes the `_check_nested_wrapping()` to both wrapper and composable paths (though the composable path will not hit in this PR).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104407
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-07-08 12:40:09 +00:00
Andrew Gu
d58f75be8b [FSDP][1/N] Move wrapper ModuleWrapPolicy to new path (#104346)
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support:
- Checking valid frozen parameters (uniform frozenness per FSDP)
- Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher)
- Writing auto wrapping policies that may take multiple passes over the module tree
- Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy)

The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping.

I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple.

The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)

To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104346
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-07-08 12:40:07 +00:00
Rohan Varma
0bf39d5663 [FSDP] Option for eval in fp32/bf16 (#104682)
In https://github.com/pytorch/pytorch/pull/97645 and some follow up diffs, we made FSDP run in full precision in eval mode, even if mixed precision was specified.

However, this is probably not the best idea and we should provide a flag for users to have control over this a bit more. Adding an env var FSDP_FULL_PREC_IN_EVAL and defaulting it to off, users who want to run eval in fp32 can toggle this before wrapping model in FSDP:

os.environ["FSDP_FULL_PREC_IN_EVAL"] = "1"

Verified that unittests, APS workflow, TNT workloads can run eval appropriately with this change.

Differential Revision: [D47246556](https://our.internmc.facebook.com/intern/diff/D47246556/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104682
Approved by: https://github.com/awgu
2023-07-07 08:14:23 +00:00
Will Constable
d64bada876 Refactor funcol for readability and dynamo tracing (#104387)
Move eager kernel impls to separate file, which is eaiser to read
(since users may be confused about 2 versions of each kernel in the same file)
and easier to set a dynamo policy to trace only the first file currently.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104387
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/kumpera
2023-07-06 23:29:49 +00:00
Andrew Gu
6c1d959889 [FSDP] Annotate modules for fully_shard (#104363)
This annotates modules managed by `fully_shard` for TorchDynamo to treat them specially.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104363
Approved by: https://github.com/fegin
2023-07-06 16:56:59 +00:00
Rodrigo Kumpera
17ab4f85e9 [c10d] Adopt allgather_into_tensor_coalesced for NCCL. (#103086)
This is done by adding c10d::_allgather_into_tensor_coalesced wrapper.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103086
Approved by: https://github.com/rohan-varma
2023-07-06 15:05:55 +00:00
Wanchao Liang
db1ac4e29b fix functional collective's allgather for gloo (#104681)
Summary: We should explicitly check for the gloo backend instead of relying on the shard's device, because user might pass a GPU tensor as input and a process group gloo as the pg, and expect that should work.

Differential Revision: D47249172

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104681
Approved by: https://github.com/rohan-varma, https://github.com/fduwjj
2023-07-06 09:52:48 +00:00
Iris
434fcffa21 [6/n][FSDP] Update _sharded_pre_load_state_dict_hook to use DTensor when use_dtensor=True in ShardedStateDictConfig (#104087)
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.

Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
2023-07-06 05:36:19 +00:00
PyTorch MergeBot
fcb53c1394 Revert "[6/n][FSDP] Update _sharded_pre_load_state_dict_hook to use DTensor when use_dtensor=True in ShardedStateDictConfig (#104087)"
This reverts commit 49af83cf44.

Reverted https://github.com/pytorch/pytorch/pull/104087 on behalf of https://github.com/huydhn due to This is failing in trunk 49af83cf44, probably due to a land race ([comment](https://github.com/pytorch/pytorch/pull/104087#issuecomment-1615608189))
2023-07-01 07:50:31 +00:00
Iris
49af83cf44 [6/n][FSDP] Update _sharded_pre_load_state_dict_hook to use DTensor when use_dtensor=True in ShardedStateDictConfig (#104087)
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.

Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
2023-07-01 01:02:59 +00:00
Andrew Gu
d982fdb5d5 [FSDP] Rework meta device init (#104189)
This addresses https://github.com/pytorch/pytorch/issues/104187.

After this PR, the contract with the user is that:
- If passing `param_init_fn=None`, each `nn.Module.reset_parameters()` should only initialize its own parameters/buffers (like `parameters(recurse=False)`/`buffers(recurse=False)`).
- If passing `param_init_fn` not equal to `None`, then similarly, one call to `param_init_fn(module)` should only initialize `module`'s own parameters/buffers.

With this contract and this PR's changes, meta device initialization through either `reset_parameters()` or `param_init_fn` should be correct. Those functions will run on the original parameter/buffer shapes allowing for correct shape-dependent computations like for fan-in/fan-out, and there will not be any re-initialization of any module.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104189
Approved by: https://github.com/rohan-varma
2023-07-01 00:25:12 +00:00
Xilun Wu
e799f565eb [DTensor][TP][Random] Introduce TensorParallelRNGTracker to integrate parallel RNG state with Tensor Parallel (#103910)
This PR enables the automatic use of `TensorParallelRNGTracker` in Tensor Parallel api. Some unit tests are going to be added to cover.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103910
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-06-30 08:06:41 +00:00
Wanchao Liang
da06920f47 Replace all_gather in device mesh with functional collective equivalent (#104056)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104056
Approved by: https://github.com/kumpera, https://github.com/wanchaol
2023-06-30 05:30:02 +00:00
Wanchao Liang
8457703e8d lazy init device mesh in fsdp (#104447)
since fsdp state is lazy init, we also need to lazy init device mesh
otherwise devicemesh allgather check would trigger some mismatch in
allgather counts in fsdp tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104447
Approved by: https://github.com/wconstab
2023-06-30 04:40:16 +00:00
Will Constable
d0509fe32d Document how functional collectives work under eager/dynamo (#104386)
Move user facing apis to the top for best visibility
(strictly code-motion in this PR, besides adding comments)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104386
Approved by: https://github.com/voznesenskym, https://github.com/wanchaol
2023-06-30 01:12:55 +00:00
Rohan Varma
60e2a4a4a0 [2D parallel] workaround for FSDP init issue (#104398)
Closes https://github.com/pytorch/pytorch/issues/96491 and does so by relaxing FSDP's assumption that the entire input module must be on the same device. Now, FSDP can accept a module partially on CPU and GPU and just emits a warning.

Differential Revision: [D47117256](https://our.internmc.facebook.com/intern/diff/D47117256/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104398
Approved by: https://github.com/fegin
2023-06-29 16:07:07 +00:00
Rohan Varma
c866446d6c [FSDP] Check module.training for _root_cast_forward_inputs (#104223)
We might erroneously cast forward inputs for the root if it doesn't
manage any handles (FSDP parameters). As a fix, pass in the module and check
its training attribute to ensure we don't cast inputs in eval mode.

Differential Revision: [D47041673](https://our.internmc.facebook.com/intern/diff/D47041673/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104223
Approved by: https://github.com/fegin
2023-06-28 16:38:01 +00:00
Andrew Gu
6493519fff [Easy][FSDP] Remove misleading asserts (#104274)
Since we do not call `_FSDPState.__init__()` and only use it for typing, it is not possible for these attributes to be `None`. The purpose of these `assert`s is to make sure that these attributes are set by `_init_process_group_state_for_hybrid_shard()`. If we care to make that explicit, I would posit that we should be using `hasattr` checks, not `is not None` checks, because if indeed `_init_process_group_state_for_hybrid_shard()` did not set these attributes, then even checking that it is not `None` would lead to an `AttributeError`. I do not include these `hasattr` checks for now since `_init_process_group_state_for_hybrid_shard()` is short enough that we can quickly tell by inspection that it sets the desired attributes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104274
Approved by: https://github.com/rohan-varma
2023-06-28 11:08:47 +00:00
Andrew Gu
ba9f6e6e92 [FSDP] Validate ignored_modules, ignored_states (#104273)
This checks that `ignored_modules` and `ignored_states` have the expected type and provides a reasonable error message if not. Otherwise, if someone passes a mix of modules and parameters to `ignored_states` for example, then our code may be silently incorrect.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104273
Approved by: https://github.com/rohan-varma
2023-06-28 11:08:47 +00:00
Andrew Gu
cc27e6c0f9 [FSDP] Fix ignored_states doc (#104253)
This fixes https://github.com/pytorch/pytorch/issues/104246.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104253
Approved by: https://github.com/rohan-varma
2023-06-28 11:08:45 +00:00
Andrew Gu
9db8ad7f1d [FSDP] Support unfreezing params for reshard-only hook (#104186)
This fixes https://github.com/pytorch/pytorch/issues/104148 (unfreezing parameters after `n` steps).

- This fixes a bug where we did not delete the post-backward hook state properly for the `requires_grad=False` case.
- This makes the `already_resharded` correct for `SHARD_GRAD_OP`.
- This generalizes `_clear_grads_if_needed()` to `_reset_flat_param_grad_info_if_needed()` to additionally include propagating the original parameters' `requires_grad` to the flat parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104186
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-06-28 11:04:57 +00:00
shibo19
c2095af3f8 make funcs argument type from torch.cuda.stream as torch.Stream (#104156)
Fixes #ISSUE_NUMBER
1. we want to support fsdp for custom device, so we make funcs argument type from torch.cuda.stream as torch.Stream
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104156
Approved by: https://github.com/awgu
2023-06-28 06:02:56 +00:00
Xilun Wu
a66107a30c [DTensor][Random] Introduce CudaRNGStateTracker to maintain parallel RNG state for DTensor (#103235)
# Change
This PR adds two classes to DTensor:

1. `CudaRNGStateTracker`:  `CudaRNGStateTracker` stores Random Number Generator (RNG) state (a `ByteTensor` object) in a `dict`, mapping from a corresponding tag to each state tensor. It also provides a set of convenient utility methods to help access/modify the state tensors. The most important interface is `_distribute_region` which will be used when DTensor executes a random op (an operator that calls RNG).

2. `OffsetBasedRNGTracker`: This subclass of `CudaRNGStateTracker` defines the default policy of how RNG states should be shared and synchronized among all ranks to respect the semantics of DTensor random operators.

# Warning

- With `Multi-threaded ProcessGroup`, the global variable `_rng_tracker` will be shared among threads(ranks) and cause issue. We need to figure out a compatible solution for that.

- The RNG state may be asynchronous outside of participating ranks. It is harmless in our current use case of submesh though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103235
Approved by: https://github.com/wanchaol
2023-06-27 19:00:25 +00:00
Amr Elshennawy
968b7b5e0f Initial commit of collective_utils (#101037)
Summary:
Details in T133020932
First commit of collective utils library. Ported over from model store, removed scuba logging, error_trait and all dependencies on modelstore.

Test Plan: In the following diffs.

Differential Revision: D45545970

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101037
Approved by: https://github.com/H-Huang
2023-06-27 02:15:16 +00:00
Rodrigo Kumpera
c17bdb3247 [C10D] Add functional collective reduce_scatter_into_tensor_coalesced. (#101023)
Implementation uses a fallback that does no coalescing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101023
Approved by: https://github.com/wanchaol
2023-06-23 19:24:11 +00:00
fduwjj
23b7035b3c [TP] Add an input resharding wrapper for TP and unit test for 2D + AC (#103334)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103334
Approved by: https://github.com/kumpera
2023-06-23 04:05:01 +00:00
Chien-Chin Huang
1c33c398c7 [FSDP][state_dict] Add a summary log when finishing state_dict (#103784)
Add a summary log when finishing state_dict

Differential Revision: [D46807103](https://our.internmc.facebook.com/intern/diff/D46807103/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103784
Approved by: https://github.com/fduwjj
2023-06-22 16:29:24 +00:00
Iris
613970eb05 [5/n][FSDP] Update _sharded_post_state_dict_hook to use DTensor when use_dtensor=True in state_dict_config (#103921)
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.state_dict().

load_state_dict hooks updates will be in next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103921
Approved by: https://github.com/fduwjj, https://github.com/fegin
2023-06-22 08:32:19 +00:00
Andrew Gu
ec8aa6e592 [Easy][FSDP] Fix "column" -> "row" in PG example (#103975)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103975
Approved by: https://github.com/fduwjj
2023-06-21 20:41:50 +00:00
Chien-Chin Huang
a2d001d4dd [FSDP][state_dict] Use _get_module_fsdp_state_if_fully_sharded_module for state_dict (#103783)
Fix https://github.com/pytorch/pytorch/issues/90788
Use a consistent implementation as optim_state_dict

Differential Revision: [D46807090](https://our.internmc.facebook.com/intern/diff/D46807090/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103783
Approved by: https://github.com/awgu, https://github.com/fduwjj
2023-06-21 20:31:30 +00:00
Rodrigo Kumpera
0beec88c93 Inductor support for all_gather_into_tensor_coalesced. (#98643)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98643
Approved by: https://github.com/wanchaol
2023-06-21 19:25:03 +00:00
Chien-Chin Huang
6b1d6750b9 [FSDP][state_dict][BE] Remove outdated and fixed TODOs (#103782)
Remove outdated and fixed TODOs

Differential Revision: [D46807071](https://our.internmc.facebook.com/intern/diff/D46807071/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103782
Approved by: https://github.com/rohan-varma
2023-06-21 05:41:19 +00:00
Chien-Chin Huang
1192f5ac46 [FSDP][optim_state_dict] Cleanup the unused optimizer state_dict APIs (#103781)
Cleanup the unused optimizer state_dict APIs

Differential Revision: [D46803955](https://our.internmc.facebook.com/intern/diff/D46803955/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103781
Approved by: https://github.com/rohan-varma
2023-06-21 05:38:48 +00:00
Michael Voznesensky
02f28de408 [dynamo x fsdp] Simplify stream logic handling (#103902)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103902
Approved by: https://github.com/awgu
2023-06-21 01:34:19 +00:00
Chien-Chin Huang
0ae4c4d417 [FSDP][optim_state_dict] Avoid calling optim.state_dict() to get the initial
empty states (#103609)

Users may prefix the keys optim state_dict. Using`optim.state_dict()` to get the initial states is brittle. This PR removes the call to `optim.state_dict()` and directly infers the empty states from the input states.

Differential Revision: [D46729119](https://our.internmc.facebook.com/intern/diff/D46729119/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103609
Approved by: https://github.com/awgu
2023-06-20 22:11:58 +00:00
Rodrigo Kumpera
f83ebfe1bb [FSDP] Improve support for CPU tensors. (#103171)
Don't emit device index when using CPU devices.
Don't call Tensor::record_stream as it's CUDA only op.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103171
Approved by: https://github.com/rohan-varma, https://github.com/wz337
2023-06-20 21:08:19 +00:00
Ke Wen
22e8a61d9b Implement coalesced reduce_scatter_tensor (#103561)
Map of #101157.

This PR adds support for coalesced `reduce_scatter_tensor` calls in the following syntax:

Sync communication style:
```
with dist._coalescing_manager():
     for i in range(num_coll):
         dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
```

Async communication style:
```
with dist._coalescing_manager(async_ops=True) as cm:
     for i in range(num_coll):
         dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])

# do a bunch of other things
cm.wait()
# do things that depend on the reduce-scatters' results
```
Each `reduce_scatter_tensor` call can be independent in terms of their data and buffer locations. But could be executed in parallel by supported backends (like NCCL).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103561
Approved by: https://github.com/fegin
2023-06-15 20:11:12 +00:00
Mikayla Gawarecki
d1cecd9c32 Add assign kwarg to module.load_state_dict (#102212)
Fixes #64601 and #98906

Adds an `assign` argument to `load_state_dict` that loads params/buffers by assignment instead of doing `param.copy_(param_from_state_dict)`.

Primarily intended to remove the need for the `.to_empty()` in

```
with torch.device('meta'):
    m = SomeModule()
m.to_empty()
state_dict = torch.load('...pth')
m.load_state_dict(state_dict)
```

so we can instead do

```
with torch.device('meta'):
    m = SomeModule()
state_dict = torch.load('...pth')
m.load_state_dict(state_dict, assign=True)
```

**A problem with this PR for the case where the model is initialized on meta is what happens to nonpersistent buffers/params corresponding to keys missing from the state dict?**
What happens in the case where `load_state_dict(state_dict, strict=False, assign=True)` and the state_dict is missing some keys? The corresponding params missing from the `state_dict` and nonpersistent buffers would still be on `meta` and need to be manually initialized. However, I don't think we offer an API that would initialize these.

One solution would be to make these empty tensors but it might not be semantically correct...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102212
Approved by: https://github.com/albanD
2023-06-15 18:41:00 +00:00
Andrew Gu
2eea3cb19d Fix composable checkpoint(use_reentrant=True) with multi args (#103590)
The `_ModuleHookCheckpointFunction.backward()` should take in `*output_grads` instead of `output_grads`. Otherwise, we may see an error like:
```
TypeError: backward() takes 2 positional arguments but 5 were given
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103590
Approved by: https://github.com/rohan-varma, https://github.com/fduwjj, https://github.com/fegin
2023-06-14 21:53:30 +00:00
Iris
7dd0f525b5 [FSDP][4/n]Update use_dtensor option for _optim_utils.py (#103599)
Same as https://github.com/pytorch/pytorch/pull/103069 (this branch is corrupted so have to re-submit).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103599
Approved by: https://github.com/fegin
2023-06-14 20:18:33 +00:00
Iris
d991ce6da3 [FSDP][3/N]_shard_utils update for dtensor state_dict support (#103479)
Same as https://github.com/pytorch/pytorch/pull/102545 (this branch is corrupted so have to re-submit).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103479
Approved by: https://github.com/fegin
2023-06-14 06:45:28 +00:00
Iris
51d21ffd8a [FSDP][2/n] add use_dtensor flag to both StateDictConfig and OptimStateDictConfig (#103477)
Same as #102552 (this branch is corrupted so have to re-submit).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103477
Approved by: https://github.com/fegin
2023-06-13 19:09:56 +00:00
Andrew Gu
71b560208c [FSDP] Fix device_id when buffer-only module (#103504)
There was an issue reported internally that with `sync_module_states=True`, if the model had buffers on CPU, even with `device_id` specified, FSDP would try to broadcast CPU buffers, leading to an error like:
```
RuntimeError: No backend type associated with device type cpu
```

After some investigation, I determined that we should _not_ fix this by moving the buffers to GPU just for the broadcast and then back to CPU. Instead, we should fix our `device_id` logic.

The issue is that we always used the _parameters_ as the proxy to tell whether we should move module states to the device specified by `device_id`. However, a module (often the root) may not have any parameters but have some buffers! In that case, the buffers are left on CPU even if `device_id` is specified. This PR fixes this by considering both parameters and buffers for movement to `device_id`.

Note that this PR preserves the logic that `ignored_modules` / `ignored_parameters` are not considered for this movement, meaning that ignored parameters are moved to `device_id`.

Note also that I had to move the unit test back from using MTPG to the normal PG since otherwise, I could not repro the original error. (It seems like MTPG does not complain if we try to use `dist._broadcast_coalesced()` with CPU tensors.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103504
Approved by: https://github.com/rohan-varma
2023-06-13 18:33:26 +00:00
Rodrigo Kumpera
5b33d39114 [FSDP] Workaround for GLOO's lack of all_gather_into_tensor. (#103170)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103170
Approved by: https://github.com/rohan-varma
2023-06-13 17:21:41 +00:00
Rodrigo Kumpera
63fe26809d Implement all_gather_into_tensor_coalesced. (#98642)
The implementation is suboptimal since it uses c10d's group coalescing which
is known to be inneficient.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98642
Approved by: https://github.com/wanchaol
2023-06-13 15:06:52 +00:00
zhuhong61
50c972bfd2 [c10d] Add xpu to the default device supported by user specified backend (#103410)
**Motivation:**
For collective dispatching, we want to provide a more user friendly usage for xpu device and CCL backend (user specified backend) mapping.

**Solution:**
We add xpu to the default device list, and it can construct the mapping between xpu and the user specified backend directly.
Usage:
When using xpu device, user can specify backend name only:
`dist.init_process_group(backend='ccl')`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103410
Approved by: https://github.com/jgong5, https://github.com/ezyang
2023-06-12 19:46:33 +00:00
PyTorch MergeBot
caecb55223 Revert "Log functional_collectives apis to distributed logger (#103288)"
This reverts commit 37359c36fd.

Reverted https://github.com/pytorch/pytorch/pull/103288 on behalf of https://github.com/malfet due to Broke test_inductor_collectives, see 37359c36fd ([comment](https://github.com/pytorch/pytorch/pull/103288#issuecomment-1587677705))
2023-06-12 16:37:57 +00:00
Will Constable
37359c36fd Log functional_collectives apis to distributed logger (#103288)
This logs functional collectives API calls with debug log level only.

(the `+` in the TORCH_LOGS cmdline enables debug level, otherwise only info level)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103288
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-06-12 06:33:26 +00:00
Wanchao Liang
4cc474dec4 [dtensor] support torch.save/load with DTensor (#103106)
This PR actually enables DTensor to be pickable and add tests to test
torch.save/load works correctly for DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103106
Approved by: https://github.com/kumpera
2023-06-09 04:11:15 +00:00
Wanchao Liang
d31707a257 Get rid of dim_groups attribute from DeviceMesh (#103105)
This PR get rids of the dim_groups attribute from DeviceMesh, the main
motivation behind this is that we should let c10d store the process
groups during its creation instead of DeviceMesh, DeviceMesh should just
handle ranks correctly.

This could enable DTensor becomes picklable! (torch.save/load could be
possible), which I will give it a try in the next PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103105
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
2023-06-09 04:11:15 +00:00
Andrew Gu
48056b168f [FSDP] Reshard frozen params in backward (#101982)
This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass.
- Without this, frozen parameters involved in gradient computation are kept as unsharded through the entire backward pass.
- The approach is to register a multi-grad ~~post~~-hook on the _input_ activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard).

~~This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" from `register_multi_grad_hook()`. I find that with `register_multi_grad_hook()`, sometimes the unit test counting the number of times `_post_backward_reshard()` is called fails (due to it not being called).~~ This was resolved in https://github.com/pytorch/pytorch/pull/102859.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101982
Approved by: https://github.com/rohan-varma
2023-06-08 21:12:45 +00:00
Xilun Wu
675f2597fa [reland][DTensor][3/N] add DTensor constructor function: full (#101436) (#103165)
This is a reland attempt of reverted PR #101436 .

Differential Revision: [D46537531](https://our.internmc.facebook.com/intern/diff/D46537531)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103165
Approved by: https://github.com/wanchaol
2023-06-08 16:18:33 +00:00
Rodrigo Kumpera
4833dc10b8 [DCP] Rewrite read slicing to use a wrapper. (#99167)
Moved SlicedBufferedReader to utils and renamed to _ReaderView.

It no longer depends on file handles and is a pure wrapper. This makes it general enought to handle non io stream objects like fsspec's.

Should help with #98386
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99167
Approved by: https://github.com/wz337
2023-06-08 13:52:13 +00:00
Wanchao Liang
8585784a34 [dtensor] fix allgather unpadding logic (#103219)
This PR fixes allgather unpadding logic so that we only need to unpad
the full tensor instead of first chunking it to small tensors and unpad
individually, as we know how our padding algorithm works
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103219
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-06-08 03:31:24 +00:00
Iris
d5142c52d3 [FSDP]Remove dim_group from device_mesh init (#103218)
1) remove dim_group
2) don't init device_mesh if not using default_pg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103218
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-06-08 03:29:19 +00:00
shaoyf42
17737f9d0e [DTensor] Allow DTensor support cuda-like device (#102468)
Allow DTensor support cuda-like device, fix https://github.com/pytorch/pytorch/issues/102442

Currently, DTensor supports cuda and cpu. There are other efforts to make DTensor support third-party devices, for example https://github.com/pytorch/pytorch/pull/101914 and https://github.com/pytorch/pytorch/issues/101911. However, this support only extends a portion of third-party devices and is no good support for third-party cuda-like devices. Therefore, we would like to extend DTensor to support cuda-like devices, after all, cuda is so popular!

1. Similar to what is done here, we need to initialize the communication backend for the device set by DeviceMesh. So `_default_backend_for_device` is added to `Backend`. It is worth noting that when we register a new backend for a device other than cpu and cuda, we also need to add a new default backend for this device.
2. Adding `_device_handle` to `DeviceMesh` for cuda-like devices, similar to what is set in FSDP. When `_device_handle` is not None, the device has similar behavior to `cuda`. In this way, functions like `torch.cuda.device_count()` need to be modified to `device_mesh._device_handle.device_count()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102468
Approved by: https://github.com/wanchaol
2023-06-07 23:13:53 +00:00
Ke Wen
07104ca99c [c10d] Make it default that PG do not perform barrier after init (#103033)
Both internal and OSS users trying https://github.com/pytorch/pytorch/pull/99937 report that their workloads perform normally even with the barrier removed and see a scalability win. Thus in this PR, we decide to make it default that PG do not perform a barrier after init.

In the discussion of #99937, people point out that such barrier might be needed for c10d + RPC cases. IMO, this need originates from RPC's programming model and should be RPC or RPC user's responsibility to deal with. That is, with other functions/libraries, it can happen too. So the need for c10d to do so big a favor is not justified IMO. Also good to remove it before users become reliant on this barrier.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103033
Approved by: https://github.com/XilunWu
2023-06-07 06:11:14 +00:00
Iris
a02a58d862 [FSDP][1/N]Add device_mesh to FSDPstate (#102317) (#102551)
This PR creates a device_mesh and share it across all FSDP state. The device_mesh will later be used to test out dtensor state_dict (1d device_mesh).
Approved by: https://github.com/awgu

Add device mesh to fsdp state
skip dist.get_world_size(pg) != dist.get_world_size()
address test_fake_pg.py test failure
fix test_fake_py.py failure

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102551
Approved by: https://github.com/fegin
2023-06-07 04:14:00 +00:00
Rohan Varma
dfa64fddeb [FSDP] Fix for optim state dict (#102901)
Fix for HSDP + use_orig_params where we need to pass in the PG that
might not be the default.

Differential Revision: [D46417327](https://our.internmc.facebook.com/intern/diff/D46417327/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102901
Approved by: https://github.com/wz337
2023-06-06 20:21:23 +00:00
Chao Yang
367b0ad062 enforce dtype (reland) (#102996)
Summary: The original diff didn't break the test.

Test Plan: N/A

Differential Revision: D46448488

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102996
Approved by: https://github.com/malfet, https://github.com/wanchaol
2023-06-06 00:35:04 +00:00
PyTorch MergeBot
ecb191683e Revert "enforece dtype (#102802)"
This reverts commit 8e2a86c2a5.

Reverted https://github.com/pytorch/pytorch/pull/102802 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/102802#issuecomment-1577099676))
2023-06-05 16:21:28 +00:00
Samuel Eisenhandler
9cabdff8bd Update documentation to read FileSystemReader instead of FileSystemLoader (#102795)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102795
Approved by: https://github.com/wz337
2023-06-05 15:22:49 +00:00
Chao Yang
8e2a86c2a5 enforece dtype (#102802)
Summary: Add a flag to enforce the gather data dtype. In case backward compatibility, make the default as False

Test Plan: local and mast

Reviewed By: zyan0, strisunshinewentingwang

Differential Revision: D46295190

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102802
Approved by: https://github.com/mrshenli
2023-06-05 02:04:09 +00:00
Rohan Varma
a748be93df [CheckpointWrapper] Warn on reentrant use (#102890)
We'd like to encourage users to try non-reentrant as much as possible,
and identify any gaps this way.

Differential Revision: [D46397786](https://our.internmc.facebook.com/intern/diff/D46397786/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102890
Approved by: https://github.com/awgu
2023-06-04 18:31:22 +00:00
Rohan Varma
88ce6215f5 [FSDP/DDP] Unify _cast_forward_inputs (#102680)
Closes https://github.com/pytorch/pytorch/issues/96380

Differential Revision: [D46342814](https://our.internmc.facebook.com/intern/diff/D46342814/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102680
Approved by: https://github.com/awgu
2023-06-04 18:31:21 +00:00
Rohan Varma
957ea485c4 [FSDP/AC] checkpoint_wrapper acccept auto_wrap_policy (#102672)
Some feedback for this API is that folks would like to use
auto_wrap_policy similar to FSDP instead of having to adapt to the signature of
``check_fn``.

Differential Revision: [D46340320](https://our.internmc.facebook.com/intern/diff/D46340320/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102672
Approved by: https://github.com/awgu
2023-06-04 18:31:19 +00:00
Rohan Varma
df40ec82dc [FSDP][Docs] Document get_state_dict_type (#102658)
Per title

Differential Revision: [D46335317](https://our.internmc.facebook.com/intern/diff/D46335317/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102658
Approved by: https://github.com/fegin, https://github.com/awgu
2023-06-04 18:31:18 +00:00
Rohan Varma
c6d0fe39ec [FSDP] Document optim_state_dict_config in method (#102657)
Per title

Differential Revision: [D46335318](https://our.internmc.facebook.com/intern/diff/D46335318/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102657
Approved by: https://github.com/fegin
2023-06-04 18:31:16 +00:00
Rohan Varma
beb7131c64 [FSDP] Use INFO instead of DETAIL for warning logs (#102639)
Since these are just logs and don't introduce any big perf slowdowns,
I think we should just enable them in info mode.

Differential Revision: [D46328510](https://our.internmc.facebook.com/intern/diff/D46328510/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102639
Approved by: https://github.com/awgu
2023-06-04 18:31:15 +00:00
Rohan Varma
4d516f44a1 [FSDP][ez] Type optimizer correctly (#102637)
In shardedgradscaler, the optimizer doesn't have to be SGD.

Differential Revision: [D46327103](https://our.internmc.facebook.com/intern/diff/D46327103/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102637
Approved by: https://github.com/Skylion007, https://github.com/awgu, https://github.com/fegin
2023-06-04 18:31:13 +00:00
Rohan Varma
e66c498d2d Log modules FSDP hooks fire for (#102508)
Under torch_distributed_debug >= INFO and use_orig_params=True, log post backward hook firing to debug things like FSDP + AC integration.

Differential Revision: [D46172916](https://our.internmc.facebook.com/intern/diff/D46172916/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102508
Approved by: https://github.com/awgu, https://github.com/fegin
2023-06-04 18:31:12 +00:00
PyTorch MergeBot
0f672e8c67 Revert "[DTensor][3/N] add DTensor constructor function: full (#101436)"
This reverts commit 2ca75d49a8.

Reverted https://github.com/pytorch/pytorch/pull/101436 on behalf of https://github.com/malfet due to Caused internal SEV ([comment](https://github.com/pytorch/pytorch/pull/101436#issuecomment-1575076672))
2023-06-03 17:09:08 +00:00
shaoyf42
fc218a8a13 Fix typos in README of DTensor (#102813)
Fix typos in README of DTensor. But there is still a problem to be fixed. I reported an error when I tried to use distribute_module with  shard_params. I show the specific error message in issue https://github.com/pytorch/pytorch/issues/102812.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102813
Approved by: https://github.com/wanchaol
2023-06-02 19:27:23 +00:00
Ashwin Hari
cf0aa38005 Allow ORT backend for DTensor (#101914)
fixes #101911

Currently, `DTensor` supports cuda and cpu. This PR makes some changes for easier integration with the ort backend.

* `Backend.NAME`  attribute now has value `name` instead of `NAME` for backends registered through `register_backend(name)`; this matches the pattern for backends with built-in support like nccl.
* remove unused `_check_for_nccl_backend` function
* add test case that moves parameters to device in the `partition_fn` - a scenario that's useful for big models
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101914
Approved by: https://github.com/wanchaol
2023-06-01 22:37:09 +00:00
fduwjj
92923aca61 [TP] Use Stride inferred from local tensor in to_local bwd (#102630)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102630
Approved by: https://github.com/wanchaol
2023-06-01 04:30:24 +00:00
Wanchao Liang
c5d4ee2d73 [dtensor][simple] fix some comments (#102661)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102661
Approved by: https://github.com/fduwjj, https://github.com/XilunWu
2023-06-01 03:23:19 +00:00
shaoyf42
8d7e082300 [c10d] Add is_backend_available for c10d backend. (#101945)
Add is_backend_available for c10d backend, either the built-in backends or third-party backends through function ``Backend.register_backend``.

There is a related discussion in https://github.com/pytorch/pytorch/pull/101775#discussion_r1199253553
> For example in python constructor for their backend they should explicitly add the is_X_available. Or if defining in C++ they should modify pybind like this https://github.com/H-Huang/torch_collective_extension/blob/main/custom_backend/include/dummy.hpp#L98-L101
to also add their own is_available property

It is a natural choice for users to add their own `is_available` when they create a backend. We think it might be a possible way for the user to use `is_X_available` in the same way as the native, for example by dynamically adding`torch.distributed.is_dummpy_available()` function.  This is why we want to dynamically add the `is_X_available` to `torch.distributed` in `register_backend`.

> Or we could add an Is_available(backend) function, that checks for the backend.

Providing a public function is indeed another good approach. We have implemented an `is_backend_available` in https://github.com/pytorch/pytorch/pull/101945  that supports both built-in backends and third-party backends.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101945
Approved by: https://github.com/H-Huang
2023-05-31 22:51:51 +00:00
Rohan Varma
0ecca122e7 [Replicate] Add unit test with replicate param names (#102401)
This attribute wasn't actually used in tests, add a test ensuring that
if replicate is used on top of FSDP, the replicated parameter names are as
expected.

TODO: there are a few ways to check if module is managed by composable API,
such as replicated param names for replicate, _get_module_state API,
_get_registry_api, etc. We should unify all composable APIs to check in a
unified way (filed an issue)

Differential Revision: [D46236377](https://our.internmc.facebook.com/intern/diff/D46236377/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102401
Approved by: https://github.com/awgu
2023-05-31 18:41:03 +00:00
Yanli Zhao
f47ee87765 Fix ignored_states when they are passed as generators (#102575)
This PR fixed the case where ignored_states are passed as generators, not List/Set

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102575
Approved by: https://github.com/awgu
2023-05-31 15:58:55 +00:00
Matthew Hoffman
c28f8e314d Add type hints in torch/distributed/utils.py (#102262)
Fixes #77190

Pretty similar to the typing in `torch/nn/parallel`, which was also improved recently: #102194

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102262
Approved by: https://github.com/Skylion007, https://github.com/Neilblaze
2023-05-30 19:57:45 +00:00
Wanchao Liang
ff58d19c89 DeviceMesh use dispatchable PG to support custom backend (#102336)
This PR switches DeviceMesh to use dispatchable process group instead,
this could enable easier backend integration as user only need to
integrate with c10d process group custom backend, without needing to
change DeviceMesh to plug in the backend
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102336
Approved by: https://github.com/fduwjj
2023-05-30 19:22:37 +00:00
Wanchao Liang
3ef4d697df [c10d] default backend need to check for nccl availability (#102470)
As titled, we can only initialize nccl backend when NCCL is available
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102470
Approved by: https://github.com/Skylion007, https://github.com/XilunWu
2023-05-30 19:22:37 +00:00
Will Constable
77f97019b7 Dynamo remaps legacy allgather to traceable one (#102232)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102232
Approved by: https://github.com/voznesenskym
2023-05-30 16:45:25 +00:00
PyTorch MergeBot
81ac076bce Revert "[FSDP]Add device_mesh to FSDPstate (#102317)"
This reverts commit 4c584acc5d.

Reverted https://github.com/pytorch/pytorch/pull/102317 on behalf of https://github.com/malfet due to Broke test_fake_pg, see https://github.com/pytorch/pytorch/actions/runs/5100633726/jobs/9173277369  ([comment](https://github.com/pytorch/pytorch/pull/102317#issuecomment-1566129496))
2023-05-28 12:53:28 +00:00
Iris
4c584acc5d [FSDP]Add device_mesh to FSDPstate (#102317)
This PR creates a device_mesh and share it across all FSDP state. The device_mesh will later be used to test out dtensor state_dict (1d device_mesh).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102317
Approved by: https://github.com/awgu
2023-05-27 20:25:30 +00:00
Matthew Hoffman
0ed22fce97 Merge type stubs torch nn parallel (#102194)
Fixes merge issue for #101528

In the above PR, `torch.nn.parallel.parallel_apply.get_a_var` was marked private to appease the [public interface linter](https://github.com/pytorch/pytorch/actions/runs/4999216467/jobs/8955582204#step:14:21666): ceeb242bc7

This broke CI pipelines running external dependencies that expected `get_a_var`'s name to not change. In this PR, we change the name back to `get_a_var` and include it in the `__all__` instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102194
Approved by: https://github.com/ezyang
2023-05-26 20:10:47 +00:00
Rohan Varma
3dfa755a1f [MTPG] Enable for some tests in test_fsdp_misc (#102043)
Enables MTPG for some FSDP tests in this file. Tests that need the
backward pass and warning logging are left as follow up work.

Backward pass issue: It seems that there is a hang with all_gather. Will sync with @kumpera on this.

Warning issue: We have a couple tests that regex check on warnings, but in the
multithreaded scenario these warnings are somehow not logged.

Differential Revision: [D43209769](https://our.internmc.facebook.com/intern/diff/D43209769/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102043
Approved by: https://github.com/awgu
2023-05-26 06:21:25 +00:00
Iris
080d86acfb [DCP] Add API logging for checkpoint high level API (#102278)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102278
Approved by: https://github.com/fduwjj
2023-05-25 21:13:29 +00:00
Wanchao Liang
7b47cd0a6c [c10d] add fake pg necessary collectives (#102238)
This PR adds fake pg necessary collectives to enable e2e FSDP run
with out multiprocess or multithreading
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102238
Approved by: https://github.com/ezyang
2023-05-25 05:01:16 +00:00
Wanchao Liang
9a19262556 [c10d] conslidate barrier after init logic (#102237)
This PR consolidates the barrier after init logic to allow custom
backend to set the env var when creating the pg, so that
`init_process_group` would skip barrier
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102237
Approved by: https://github.com/ezyang
2023-05-25 05:01:16 +00:00
fduwjj
d4380edb9b [TP] Add API logging for TP high level API (#102209)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102209
Approved by: https://github.com/wz337, https://github.com/wanchaol
2023-05-25 03:33:00 +00:00
Rohan Varma
f3e42f15e9 [FSDP] Start to generalize modules to ignore for mixed precision (#102010)
The main use case here is that folks would like to ignore layer norm for mixed precision. This can now be enabled with:

```
mp_config = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
            _mixed_precision_module_classes_to_ignore=[_BatchNorm, nn.LayerNorm],
        )
```

This is done by classes of types in `_mixed_precision_module_classes_to_ignore` being wrapped in their own FSDP unit with mixed preicsion disabled. This is only enabled for auto wrapping.

We also add module pre and post hooks to cast / downcast inputs to the appropriate full precision.

Differential Revision: [D46079957](https://our.internmc.facebook.com/intern/diff/D46079957/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102010
Approved by: https://github.com/awgu
2023-05-25 00:45:54 +00:00
Edward Z. Yang
c903b12cb8 Add fake process group (#102180)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102180
Approved by: https://github.com/wanchaol
2023-05-24 23:27:40 +00:00
Yeonju Ro
06f656c5d1 [distributed] implemented find_all_descendants (#102138)
Fixes #100397

Implemented find_all_descendants function that identifies the list of nodes that need to be moved. Added unit test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102138
Approved by: https://github.com/fegin
2023-05-24 21:47:59 +00:00
Yanli Zhao
956bd03808 add ignored_states to FSDP/fully_shard (#102056)
Add 'ignored_states' that accepts either a list of ignored_parameters or a list of nn modules for FSDP model wrapper and fully_shard composable APIs, it is recommended to use 'ignored_states' over 'ignored_modules' moving forward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102056
Approved by: https://github.com/awgu
2023-05-24 18:36:48 +00:00
Wanchao Liang
d316a2dd5c [spmd] Enable data parallel to work with non 0 batch dim (#100073)
This PR enables data parallel to work with non 0 batch dim, the only
thing we need to do is to expose the input_batch_dim to DataParallelMode
and the data parallel expansion automatically works as we have done
things correctly in batch dim analysis.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100073
Approved by: https://github.com/mrshenli
2023-05-24 17:55:10 +00:00
Wanchao Liang
d378837039 [spmd] add more decomp and fix a sharding bug (#100938)
This PR adds native_layernorm_backward op to the decomp table and fixes
a sharding bug to not automatically do padding
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100938
Approved by: https://github.com/mrshenli
2023-05-24 17:55:10 +00:00
Wanchao Liang
dd1f295201 [spmd] Improve activation handling, factory ops and batch dim reduction (#100853)
This PR improves the activation handling logic of data parallel, to
support the cases where there're tensor factory ops that does not depend
on any input node, it would still produce activation, with either
sharded act (i.e. if output shape have batch size) or replcate act

It also significantly simplify the full reduction logic, now we don't
need the full reduction detection, we only need to ensure that when
compute the batch dim, we detected full reduction and mark it as sharded
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100853
Approved by: https://github.com/mrshenli
2023-05-24 17:55:09 +00:00
Wanchao Liang
4d55ea8548 [spmd] enhance batch dim analysis of data parallel (#100852)
This PR enhances batch dim analysis of data parallel to understand
more on the cases where batch dim get flattened or split, using
dtensor's view ops, we could be able to track the batch dim that got
transformed in non-trival ways.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100852
Approved by: https://github.com/mrshenli
2023-05-24 17:55:07 +00:00
Wanchao Liang
b2eaba6b62 [spmd] by default average gradients for nccl backend (#99964)
This PR by default average gradient for NCCL backend, this allows
SPMD's data parallel match with DDP/FSDP results.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99964
Approved by: https://github.com/mrshenli
2023-05-24 17:55:06 +00:00
Wanchao Liang
942cd12d55 [spmd] add option to preserve node types (#100072)
This PR adds a option to preserve node types for the entire graph,
this could allow some exploration about using those node types to do
things like act checkpoint, etc.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100072
Approved by: https://github.com/mrshenli
2023-05-24 17:55:05 +00:00
medivh-xp
8b7bd81902 determined collective device by _get_pg_default_device rather than explicit cuda (#101533)
There are many communication operations for shardedTensor in the state dict of fsdp. They use the external passed-in pg (or the default pg), which currently supports cuda devices. Before communication, the memory will be moved to cuda, which is implicit (because it is essentially moving data to the memory type required by pg, not the computing device type). Similarly, when users use fsdp on a custom backend, they will pass in a custom pg (which does not support cuda devices), which may cause fsdp to not work properly in some cases. This PR obtains the memory type supported by the pg through _get_pg_default_device during communication, and moves the data to it when needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101533
Approved by: https://github.com/awgu
2023-05-24 13:48:43 +00:00
Iris
ee95e37a69 [c10d] Record time spent for init_process_group, new_group, _store_based_barrier (#101912)
1. Record time spent for init_process_group, new_group, _store_based_barrier
2. Rename c10d_error_logger to c10d_logger for generalization.
3. Refactor to move logger wrappers in distributed_c10d.py to logger to c10d_logger.py.
4. Rename the logger wrappers (bc breaking). Exception_handler is renamed to exception_logger to avoid confusion with logging handler.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101912
Approved by: https://github.com/fduwjj
2023-05-24 09:36:34 +00:00
Edward Z. Yang
3318a832b3 Tighten FakeTensor reentrancy asserts, add debugging (#102091)
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.

In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.

This PR does a number of things:

* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
2023-05-24 05:37:51 +00:00
Edward Z. Yang
f65732552e Support FakeTensor with FlatParameter (#101987)
In this PR we turn FlatParameter into a virtual tensor subclass
which doesn't actually ever get instantiated: __new__ will create
a Parameter instead (or a FakeTensor, if necessary).

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101987
Approved by: https://github.com/awgu, https://github.com/eellison
2023-05-23 23:12:08 +00:00
Wanchao Liang
6e0c741105 [dtensor] hide mesh validation check under init_process_group flag (#101996)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101996
Approved by: https://github.com/wz337
2023-05-23 18:17:54 +00:00
Wanchao Liang
70eccdbf92 [dtensor] add necessary logging to APIs and components (#101994)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101994
Approved by: https://github.com/wz337
2023-05-23 18:17:54 +00:00
Xilun Wu
2ca75d49a8 [DTensor][3/N] add DTensor constructor function: full (#101436)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101436
Approved by: https://github.com/wanchaol
2023-05-23 06:05:40 +00:00
Wanchao Liang
38a29324b0 [dtensor][2/N] more tensor ops to use strategy propagation (#101203)
As titled, this PR adapts a few more tensor ops to use strategy based
sharding prop
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101203
Approved by: https://github.com/XilunWu
2023-05-22 17:16:14 +00:00
Aaron Gokaslan
3e2ea32dab [BE]: Enable ruff rule TRY302 and apply fixes (#101874)
Removes useless try statements and unreachable code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101874
Approved by: https://github.com/malfet
2023-05-19 17:30:52 +00:00
medivh-xp
e06bd8f3b1 fsdp support create hybrid-sharded process group for custom backend (#100622)
FSDP creates communication groups for intra-node communication through dist.new_subgroups. Previously, dist.new_subgroups only supported creation based on the number of CUDA devices. However, issue #99706 removed the avaliable-check for CUDA devices, allowing for custom backend create group based on num of custom devices per node.

This PR allows FSDP to explicitly pass device num within the node when creating communication groups for intra-node communication, instead of defaulting to the number of CUDA devices.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100622
Approved by: https://github.com/awgu
2023-05-19 06:08:55 +00:00
shaoyf42
97180aca5e Enables barrier to support the specified device (#99589)
Enables barrier to support the specified device, e.g cuda/custom device. There is some discussion here: https://github.com/pytorch/pytorch/issues/97938#issue-1646833919

Today, there are two limitations of barrier:
One is that barrier does not support custom  #device:
fbdb86c174/torch/csrc/distributed/c10d/ProcessGroup.hpp (L512-L522)

The second is that there is a special valid for nccl when device_id is not None, which is an assumption for cuda and nccl bindings, and also hinders custom device.
789070986c/torch/distributed/distributed_c10d.py (L3504-L3508)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99589
Approved by: https://github.com/kwen2501
2023-05-17 05:26:04 +00:00
Thibaut Durand
01da732691 Fix type annotation of torch.split (#100655)
The type annotation indicates `list` but the returned type is `tuple`
```python
>>> import torch
>>> type(torch.arange(10).split(4))
<class 'tuple'>
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100655
Approved by: https://github.com/kit1980
2023-05-16 21:35:41 +00:00
Xilun Wu
010763be9a [DTensor][2/N] add DTensor constructor function: empty (#101022)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101022
Approved by: https://github.com/wanchaol
2023-05-16 16:50:54 +00:00
Xilun Wu
5cc361c736 [DTensor][1/N] add DTensor constructor function: ones (#100933)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100933
Approved by: https://github.com/wanchaol
2023-05-16 16:50:54 +00:00
albanD
59dff01319 Add top level function to check if running with deploy (#101420)
Also not sure if this should be a public function or not. Leaving it private for now but let me know if you prefer for it to be public.

FYI @nikitaved this will logically conflict with your triton kernel PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101420
Approved by: https://github.com/malfet
2023-05-16 16:05:49 +00:00
Xuehai Pan
05f6250815 Add missing torch.distributed.ReduceOp.AVG in type stubs (#101534)
Add missing `AVG` to `torch.distributed.ReduceOp` enum for type annotation.

Ref:

88b6a4577b/torch/csrc/distributed/c10d/Types.hpp (L35-L47)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101534
Approved by: https://github.com/Skylion007
2023-05-16 15:51:21 +00:00
fduwjj
9d858642af [PTD] Make input contiguous for _ReduceScatter (#101373)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101373
Approved by: https://github.com/wz337
2023-05-15 22:08:21 +00:00
Shen Li
af841f38bd [SPMD] Allow Override.replacement to have a global view (#101427)
It's easier for users to implement one Override that takes care of
all target submodules of different types, instead of specifying one
mapping pair for each FQN/type. For example, when calculating
sharding for sparse layers, the decision needs to be make globally.
In this, case it's helpful to allow user Override to get access to
all submodules and make replacement decisions accordingly.

Differential Revision: [D45879732](https://our.internmc.facebook.com/intern/diff/D45879732)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101427
Approved by: https://github.com/fegin
2023-05-15 21:27:41 +00:00
Aaron Gokaslan
dfe484a3b3 [BE]: Bugfix functorch and some generic typing improvements (#101337)
Fixes some typing bugs found with newer versions of mypy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101337
Approved by: https://github.com/ezyang
2023-05-14 14:20:56 +00:00
Iris
568db1b464 [dtensor] Relax condition for _split_tensor() (#101218)
When tensor.size(self.dim) < num_chunks, we will fill empty chunk with empty tensor (https://github.com/pytorch/pytorch/pull/98722). Therefore, we no longer needs this assert.

For example, when sharding a tensor with 1 element on 2 ranks along dim 0, results would be as follows:
```
rank:0, dtensor:DTensor(local_tensor=tensor([0.4963], device='cuda:0'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)])
rank:1, dtensor:DTensor(local_tensor=tensor([], device='cuda:1'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101218
Approved by: https://github.com/wanchaol
2023-05-14 07:39:27 +00:00
Yanli Zhao
5ac48eb353 [FSDP]Skip unshard call during checkpointing for NO_SHARD sharding strategy (#101095)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101095
Approved by: https://github.com/fegin
2023-05-12 18:19:18 +00:00
Wanchao Liang
3ae612ba7f [dtensor] remove assertions about submesh checks (#101229)
This PR removes assertions from submesh checks to directly return local
tensor, this is so that all the other APIs can work with submesh
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101229
Approved by: https://github.com/fduwjj
2023-05-12 04:20:35 +00:00
Aaron Gokaslan
738ba13b35 [BE]: enable PLE error codes in ruff and fix bugs (#101079)
Enables PyLint error codes implemented in ruff. These are un-opinionated static analysis checks on Python code that finds common bugs. After running all the PLE error codes that are implemented in ruff, I fixed the bugs, added a few ignores for malformed Python code that is part of our JIT test script, and finally added a few ignores for a false positive on PLE0605 and submitted an issue upstream to fix in ruff https://github.com/charliermarsh/ruff/issues/4345 .

Common bugs found here include analysis for malformed logging format calls, bad string format calls, invalid escape sequences, and more.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101079
Approved by: https://github.com/malfet
2023-05-11 23:57:25 +00:00
Wanchao Liang
599ae95d1a [dtensor] use stack to manage mesh resources (#101202)
This PR changes the context manager behavior of device mesh, now we use
a mesh env to track the current mesh and save the mesh to a stack so
that we can allow nested context manager
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101202
Approved by: https://github.com/wz337
2023-05-11 23:48:36 +00:00
Chien-Chin Huang
49c8a0cad0 [SPMD][BE] Remove the legacy tracing code (#100858)
Remove the legacy tracing code as it cause several test and benchmark issues.

Differential Revision: [D45649123](https://our.internmc.facebook.com/intern/diff/D45649123/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100858
Approved by: https://github.com/wanchaol
2023-05-11 23:08:27 +00:00
Ke Wen
daed3bf8f9 Implement coalesced all_gather_into_tensor (#101157)
This PR adds support for the following use cases:
- Sync style:
```
with dist._coalescing_manager():
     for i in range(num_coll):
         dist.all_gather_into_tensor(output_tensors[i], input_tensors[i])
```
- Async style:
```
with dist._coalescing_manager(async_ops=True) as cm:
     for i in range(num_coll):
         dist.all_gather_into_tensor(output_tensors[i], input_tensors[i])

# do a bunch of other things
cm.wait()
# do things that depend on the all-gather's
```
Each `all_gather_into_tensor` would be independent in terms of data and their buffer location. But could be executed in parallel by supported backends (like NCCL).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101157
Approved by: https://github.com/kumpera, https://github.com/wanchaol
2023-05-11 20:58:47 +00:00
Wanchao Liang
a1aa32e204 [dtensor] tensor ops to use strategy based sharding prop (#100607)
This is the first series of PR that adopts operator impls to use a
strategy based approach, each op utilizes OpStrategy and PlacementStrategy
to generate their own strategy. By utilizing the strategy based
approach along with the op graph, we could enable more advanced op
implementation (decomp is possible), and turn the sharding prop to be
more like a contraint satisfication problem.

This PR alone only adds some basic tensor op strategies, and it directly
works on the op graph that was used for metadata propagation. The tensor ops
added in this PR mainly follows one of the arg strategy. The next set of
PRs would add more op strategies to other ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100607
Approved by: https://github.com/XilunWu
2023-05-11 02:47:20 +00:00
Ke Wen
0848ed21b8 [c10d] Figure out device to use for object collectives (#100954)
Fixes https://github.com/pytorch/pytorch/issues/97938

this pr is clone from https://github.com/pytorch/pytorch/pull/100238, which is important to me. But
@kwen2501 has not resolved the confliction. So, this pr is submitted to resolve the confliction.
the only confliction is `distributed_c10d.py:2653`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100954
Approved by: https://github.com/kwen2501
2023-05-11 01:49:09 +00:00
Chien-Chin Huang
0fbe55ea8f [FSDP][state_dict] Make sharded_state_dict work with composable fully_shard (#100856)
The current implementation of sharded_state_dict only works with wrapper based FSDP (both use_orig_params and not use_orig_params work) but not with fully_shard. This PR changes the implementation of sharded_state_dict when loading to fix the incompatibility.

Differential Revision: [D45626856](https://our.internmc.facebook.com/intern/diff/D45626856/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D45626856/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100856
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao
2023-05-10 15:32:45 +00:00
Chien-Chin Huang
55844dfdbc [FSDP][state_dict] Restore the state_dict_config for NO_SHARD (#100855)
Any change to the user configurations should be temporary. This PR fixes the issue when NO_SHARD state_dict/load_state_dict is called, the state_dict_config and state_dict_type are changed permanently.

Differential Revision: [D45593313](https://our.internmc.facebook.com/intern/diff/D45593313/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D45593313/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100855
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao, https://github.com/rohan-varma
2023-05-10 10:01:21 +00:00
Aaron Gokaslan
8769fb854d [BE] Fix flake8 B027 errors - missing abstractmethod decorator (#100715)
Enables B027 and applies fixes by adding abstract method decorators. Autofix generated by ruff master.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100715
Approved by: https://github.com/ezyang
2023-05-09 17:28:48 +00:00
Will Constable
793bd6993a Work around torchdynamo import error with functional collectives (#100901)
Summary:
Currently there are build configs where the torchdynamo import trips over a
strange SystemError related to some module's __dict__.items() returning NULL,
while torchdynamo tries to iterate all torch modules and process them for
its allowed functions list.

While this is hard to repro, we should be able to work around it and then fix
it properly.

Test Plan: Rely on others to test this, assuming CI passes.

Reviewed By: anijain2305

Differential Revision: D45663313

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100901
Approved by: https://github.com/yanboliang, https://github.com/malfet
2023-05-09 16:09:42 +00:00
lessw2020
ec144b9412 handle new param from torch.compile (Inductor pattern matcher), enable_log (#100814)
This PR puts a placeholder param handler for a new param being passed in from Inductor, enable log.
Fixes this error below, where I've been unable to run torch.compile on NanoGPT due to this error:

~~~
File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/_inductor/fx_passes/fuse_attention.py", line 219, in _sfdp_init
    register_replacement(
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/_inductor/pattern_matcher.py", line 658, in register_replacement
    search_gm = trace_fn(search_fn, example_inputs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/_inductor/pattern_matcher.py", line 828, in training_graph
    aot_function(
torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
TypeError: patched_aot_function() got an unexpected keyword argument 'enable_log'
~~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100814
Approved by: https://github.com/fegin
2023-05-08 18:34:45 +00:00
Xing Liu
0731420645 [PyTorch/Distributed]Only sync buffers when broadcast_buffers is True (#100729)
Summary: Disable buffers sync in _sync_module_states(...) when broadcast_buffers is False. This change will memory usage when a model has huge buffers and does not need broadcast buffers.

Test Plan: .

Differential Revision: D45610709

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100729
Approved by: https://github.com/mrshenli
2023-05-08 16:34:29 +00:00
fduwjj
953aa6d90e [TP] Enable more generic attn in Tensor Parallelism (#100508)
To make TP more generic for Attention module, we come up with this new col/rowwise parallel style.

Basically, the idea behind is that:
We only do DTensor op for Col/Rowwise sharded part. For the rest of ATen ops, we will leave it to Tensor ops.

And we set this behavior as default for Colwise and Rowwise parallel style. If people want to customize it, they can always pass in different prepare_input or prepare_output

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100508
Approved by: https://github.com/wanchaol
2023-05-07 18:15:49 +00:00
Rodrigo Kumpera
7a15e82388 Fix tensor registration to work with coalescing collectives. (#99763)
We do it by making it possible to register multiple tensors for the same
worker and coordinate waiting/cleanup among them.

This ensures waiting on any number the output tensors will result in a
single stream sync. This simplifies codegen by inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99763
Approved by: https://github.com/wanchaol
2023-05-05 14:25:35 +00:00
Rohan Varma
8869897ebe [replicate] support simpler device_id (#100217)
Allow passing in `device_id=[device]` regardless of CPU or GPU. We
modify the kwarg as needed to pass to DDP.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100217
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao
2023-05-04 21:06:04 +00:00
Rodrigo Kumpera
a204f7f518 [c10d] Fix subprocess group handlig in scatter_object_list. (#100552)
scatter_object_list assumed src was a group rank while all collectives use global ranks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100552
Approved by: https://github.com/fduwjj
2023-05-04 10:04:21 +00:00
Will Constable
2dca418112 Reland basic dynamo support for traceable collectives (#100476)
Relative to the original land, this also contains:
- Fix torchdeploy import of functional collectives
- Can't import torchdynamo utils due to torch._refs being missing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100476
Approved by: https://github.com/kumpera
2023-05-04 04:25:35 +00:00
Xiaodong Wang
c29ab84115 Fix bug in process_group_name when there is duplicate pgs (#100518)
Summary: with the new c10d API, we don't need all ranks to call new_group. Integrate with the new API, so that every rank just call new_group 3 times, with a local barrier with the members within the group.

Reviewed By: xunnanxu, eeggl

Differential Revision: D45315615

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100518
Approved by: https://github.com/kumpera
2023-05-04 02:12:28 +00:00
Rohan Varma
253b9d3247 [replicate] input casting support (#100216)
Supports input casting by doing this in the pre hook.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100216
Approved by: https://github.com/awgu
2023-05-04 01:46:15 +00:00
Dirk Groeneveld
75945d54f7 Properly propagates checkpoint wrapper args and kwargs (#99791)
It looks like passing `*args` and `**kwargs` to `checkpoint_wrapper()` does not work because someone forgot some `*`s. This adds them back in.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99791
Approved by: https://github.com/awgu
2023-05-03 23:19:21 +00:00
Shabab Ayub
287f74c4fc Revert D45387167: Multisect successfully blamed D45387167 for test or build failures (#100424)
Summary:
This diff is reverting D45387167
D45387167: Basic dynamo support for traceable collectives (#94440) by wconstab has been identified to be causing the following test or build failures (internal)

If you believe this diff has been generated in error you may Commandeer and Abandon it.

Test Plan: NA

Reviewed By: s4ayub

Differential Revision: D45448312

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100424
Approved by: https://github.com/rohan-varma, https://github.com/kumpera
2023-05-03 16:10:54 +00:00
Shen Li
2ebb48ff28 [SPMD] add FQN argument to Override.replacement (#100473)
Differential Revision: [D45486089](https://our.internmc.facebook.com/intern/diff/D45486089)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100473
Approved by: https://github.com/wanchaol
2023-05-03 14:20:01 +00:00
Shen Li
9439cb0e11 Avoid using einsum for torch.cat DTensor propogation (#100251)
DTensor was reusing `einop_rule` to propagate sharding for torch.cat.
However, einsum only supports up to 52 subscripts (i.e., input tensors).
We have encountered use cases where one cat operator has more than 60
input tensors. Therefore, this commit reimplements sharding prop
rule for cat without using einsum.

Differential Revision: [D45435232](https://our.internmc.facebook.com/intern/diff/D45435232)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100251
Approved by: https://github.com/wanchaol
2023-05-03 01:56:18 +00:00
Animesh Jain
5fbb40669f [dynamo][moco] Disallow_in_graph distributed APIs (#100071)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100071
Approved by: https://github.com/jansel, https://github.com/H-Huang
2023-05-02 20:09:25 +00:00
Ivan Kobzarev
4582ceb2c4 [distributed][sharded_tensor] Move local_shards check from ShardedTensorBase to ShardedTensor (#100197)
Differential Revision: [D45369211](https://our.internmc.facebook.com/intern/diff/D45369211)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100197
Approved by: https://github.com/fduwjj
2023-05-02 12:42:24 +00:00
Wanchao Liang
123be4b694 [dtensor] add debug tool to track op coverage (#100124)
This PR adds a debug tool to track the op coverage needed in DTensor.

Note that we specifically target ops after decomp table in inductor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100124
Approved by: https://github.com/XilunWu
2023-05-02 01:45:55 +00:00
Yanli Zhao
dc9c79d3cf Allow each fully_shard unit to cast foward inputs for mixed precision config (#100290)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100290
Approved by: https://github.com/rohan-varma
2023-05-02 00:03:48 +00:00
Andrew Gu
a014d1b18c [Easy][FSDP] Clarify _use_unsharded_grad_views comment (#100359)
This is an easy follow-up to the previous PR to (1) clarify that `view` is the original parameter's gradient and (2) that after `reshard()` the gradient is on CPU only if offloading parameters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100359
Approved by: https://github.com/rohan-varma
2023-05-01 12:58:43 +00:00
Andrew Gu
83b803c2b5 [FSDP] Fix use_orig_params=True, CPU offload, no_sync() (#100180)
This should fix https://github.com/pytorch/pytorch/issues/98494. We follow a similar approach as in past PRs for mismatched dtype or size from running in `no_sync()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100180
Approved by: https://github.com/rohan-varma
2023-05-01 05:15:51 +00:00
Justin Chu
01abbfbaae [BE] Fix all B022 useless-contextlib-suppress (#100335)
No arguments passed to contextlib.suppress. No exceptions will be suppressed and therefore this context manager is redundant

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100335
Approved by: https://github.com/Skylion007
2023-04-30 18:47:40 +00:00
Chien-Chin Huang
e0a2b49f0b [SPMD] Introduce prerequisites to graph_optimization_pass (#99970)
Some optimizations require prerequisite passes. It is hard to debug why a optimization pass because of the prerequisites condition does not match. Adding this check makes it easier to discover the error.

Differential Revision: [D45255377](https://our.internmc.facebook.com/intern/diff/D45255377/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99970
Approved by: https://github.com/lessw2020
2023-04-28 18:38:01 +00:00
Iris
a23365885f [FSDP] Make set_state_type to SHARDED_STATE_DICT compatible with NO_SHARD sharding_strategy (#100208)
Currently, if we use NO_SHARD strategy for fully_shard and set state_dict_type to be SHARDED_STATE_DICT, a runtime error would be raised ("``sharded_state_dict`` can only be used when parameters are flatten and sharded.").

This PR updates pre_state_dict_hook, post_state_dict_hook, pre_load_state_dict_hook, and post_load_state_dict_hook to set state_dict_type and state_dict_config to full state when using NO_SHARD, even if the state_dict_type and state_dict_config of the root module is set to sharded state.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100208
Approved by: https://github.com/rohan-varma
2023-04-28 04:37:58 +00:00
fduwjj
89b1e67d0a [Tensor Parallel] Add a new Colwise Parallel style when Pairwise cannot directly used (#100137)
Some use cases, users cannot directly `PairwiseParallelStyle` and they might need to specify colwise and rowwise separately.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100137
Approved by: https://github.com/wz337
2023-04-28 03:27:51 +00:00
zhouzaida
b51f92ebda [Docs] Fix docstring format (#99396)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99396
Approved by: https://github.com/awgu
2023-04-28 01:10:07 +00:00
Yanli Zhao
ca1cf434e7 Not flatten states when use_orig_param is True and sharding is NO_SHARD (#100189)
When use_orig_param is True and sharding is NO_SHARD, parameters and states are not flattened, so optimizer states should not be flattened as well. The unit test will fail without the fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100189
Approved by: https://github.com/awgu
2023-04-27 23:47:01 +00:00
Daniel Dale
477ca1789c Avoid elementwise dispatch of gradient unscaling/validation ops in _foreach_non_finite_check_and_unscale_cpu_ (#100108)
Fixes [#82206](https://github.com/pytorch/pytorch/issues/82206)

When executing a `ShardedGradScaler` step in the context of `cpu_offload`, [the function](ecd2c71871/torch/distributed/fsdp/sharded_grad_scaler.py (L151-L152)) `_foreach_non_finite_check_and_unscale_cpu_` is grindingly slow. This issue is due to the elementwise op dispatching/redispatching/execution that is engendered by the current approach to gradient tensor validation:
ecd2c71871/torch/distributed/fsdp/sharded_grad_scaler.py (L159-L163)

The subsequent `isinf` and `isnan` checks with associated `any` checks result in unscalable elementwise op dispatches:
ecd2c71871/torch/distributed/fsdp/sharded_grad_scaler.py (L173-L181)

This inefficency is of course hidden in the current FSDP tests given their (appropriately) trivial parameter dimensionality. In the perf analysis below, the example test configures only the final `Linear(4, 8)` module parameters to require grad, so there are 40 elements to iterate through. However, if one increases the dimensionality to a still-modest 320008 elements (changing the final module to `Linear(40000,8)`), the execution time/cpu cost of the test is dominated by the elementwise op dispatching/redispatching/execution of the `any` validation ops in this function.

To characterize the current behavior, I use a slightly modified version of an existing `ShardedGradScaler` test [^1]. The following modifications to the test are made to allow the analysis:

1. Run just `CUDAInitMode.CUDA_BEFORE` for clarity instead of additional scenarios
2. Increase the final module to `Linear(40000, 8)` (along with modifying the preceding module to make the dimensions work) ,
3. For the cProfile run (but not valgrind or perf) the test runs just a single [`_train_for_several_steps`](ecd2c71871/torch/testing/_internal/common_fsdp.py (L926-L934)) step per rank (instead of 2 steps)
4. I temporarily reduce `init_scale` further to ensure we don't hit any `infs`, short-circuiting our analysis

### Current behavior

The most relevant call subgraph:
![callgrind_subgraph_elementwise_dispatch](https://user-images.githubusercontent.com/7462936/234656744-b7ca81b2-ce5b-4035-9918-0ad57d3689d3.png)

Note that:

1. Instead of dispatching to the relevant autograd op and then redispatching to the relevant CPU op implementation 8 times per test, (2 train steps x 2 any calls per parameter per step x 2 orig parameters) we (I believe unnecessarily) call the relevant dispatch flow elementwise, so 640016 times! (only 1 node in this trace so 320008 elements/2 X 2 train steps x 2 calls per element per step).
2. Nearly 50% of the relative (inclusive) instruction reads for the entire test in `callgrind` are executed by the `isnan` (320008 execs), `isinf` (320008 execs) and `any` (640016 execs) calls.
3. The `any` pre-dispatch entry point IRs (`torch::autograd::THPVariable_any`) vs actual op implementation IRs (`at::native::structured_any_all_out::impl`) are below to give one a sense of the relative dispatch and op execution cost in an elementwise context[^3].
![THPVariable_any_op_elementwise_dispatch_absolute_IR](https://user-images.githubusercontent.com/7462936/234656886-3c017ee3-8a04-4a7d-bdf8-6c690de42c92.png)
![structured_any_all_out_impl_absolute_IR](https://user-images.githubusercontent.com/7462936/234656915-0b203bb7-bd05-4ceb-a38b-67b0d4862aa7.png)

Using cprofile stats:

```bash
python -c "import pstats; stats=pstats.Stats('/tmp/fsdp_cprofile_8wa9uw39.stats'); stats.print_stats()"
...
ncalls  tottime  	percall  	cumtime  	percall  filename:lineno(function)
1   	20.159   	20.159   	66.805   	66.805 	 torch/distributed/fsdp/sharded_grad_scaler.py:151(_foreach_non_finite_check_and_unscale_cpu_)
160004  18.427    	0.000   	18.427    	0.000 	 {built-in method torch.isinf}
160004  6.026    	0.000    	6.026    	0.000 	 {built-in method torch.isnan}
```
We see that a single step of the scaler runs for more than a minute. Though there is non-trivial cprofile overhead, we can infer from this that per-element op dispatches/executions are on the order of a 100ns.

On the order of 100 nanoseconds per dispatch is acceptable if we're using typical tensor access patterns, but if we're dispatching each element for each op, obviously everything is going to come to a grinding halt for many practical use cases.

(Given the cost of this function is currently O(n) in the number of gradient elements, feel free to set `TORCH_SHOW_DISPATCH_TRACE=1` if you want to make this function cry 🤣)

I've attached a flamegraph at the bottom of the PR[^2] that more intuitively demonstrates the manner and extent of resource consumption attributable to this function with just a modest number of gradient elements.

### After the loop refactor in this PR:

The most relevant call subgraph:
![callgrind_subgraph_elementwise_dispatch_fix](https://user-images.githubusercontent.com/7462936/234657001-0a448756-b4ce-468e-9f91-1d21597df057.png)

Note that:

1. Less than 0.4% of the relative (inclusive) instruction reads for the entire test in `callgrind` are executed by the `isnan` (4 execs), `isinf` (4 execs) and `any` (8 execs) calls (versus ~50% and 320008, 320008, 640016 respectively above)
2. The `any` pre-dispatch entry point IRs (`torch::autograd::THPVariable_any`) vs actual op implementation IRs (`at::native::structured_any_all_out::impl`) reflect far less overhead (of secondary importance to item number 1)
![THPVariable_any_op_elementwise_dispatch_absolute_IR_fix](https://user-images.githubusercontent.com/7462936/234659454-b1e262cf-d291-4d44-aff2-e27efe284e9c.png)
![structured_any_all_out_impl_absolute_IR_fix](https://user-images.githubusercontent.com/7462936/234657154-91fa7cb8-e39e-48c7-abf0-cc58f06c0ae1.png)

Using cprofile stats:

```bash
python -c "import pstats; stats=pstats.Stats('/tmp/fsdp_cprofile_pfap7nwk.stats'); stats.print_stats()"
...
ncalls  tottime  	percall  	cumtime  	percall  	filename:lineno(function)
1    	0.013    	0.013    	0.109    	0.109 		torch/distributed/fsdp/sharded_grad_scaler.py:151(_foreach_non_finite_check_and_unscale_cpu_)
2    	0.022    	0.011    	0.022    	0.011 		{built-in method torch.isinf}
2    	0.018    	0.009    	0.018    	0.009 		{built-in method torch.isnan}
```
We can see our function runtime has dropped from more than a minute to ~100ms.

### Assumptions associated with this loop refactor:

The key assumptions here are:

1. The grads are always on CPU in this function so any MTA-safe constraints ([`can_use_fast_route`](efc3887ea5/aten/src/ATen/native/cuda/AmpKernels.cu (L110-L111)) relating to the relevant CUDA kernel path selection, i.e. slower `TensorIterator` gpu kernel vs `multi_tensor_apply_kernel`) do not apply in this context

2. We've already filtered by dtype and device and can assume the presence of a single CPU device. Unless manually creating separate CPU devices with manually set non-default indexes (which I don't think FSDP supports and should be validated prior to this function), device equality should always be `True` for `cpu` type devices so we should just need to check that the current device is of `cpu` type. [^4].

![elementwise_dispatch](https://user-images.githubusercontent.com/7462936/234660413-8c96ef90-7a23-4307-b8ed-c1fbf932f1e9.svg)

[^1]: `TestShardedGradScalerParityWithDDP.test_fsdp_ddp_parity_with_grad_scaler_offload_true_none_mixed_precision_use_orig_params` test in `test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py`
[^2]: Note the native frame stacks for `torch::autograd::THPVariable_isinf`, `torch::autograd::THPVariable_isnan`, `torch::autograd::THPVariable_any` in particular.
[^3]: There's more `TensorIterator` etc. setup overhead further up the stack beyond `structured_any_all_out`, but roughly speaking
[^4]: Device equality is based on [type and index combination](efc3887ea5/c10/core/Device.h (L47-L51)), CPU device type is -1 by default (`None` on the python side) and is intended to [always be 0](cf21240f67/c10/core/Device.h (L29)) if set explicitly. Though technically, unless in debug mode, this constraint isn't [actually validated](bb4e9e9124/c10/core/Device.h (L171-L184)), so one can actually manually create separate `cpu` devices with invalid indices. I suspect it's safe to ignore that potential incorrect/unusual configuration in this context but let me know if you'd like to add another `cpu` device equality check.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100108
Approved by: https://github.com/awgu
2023-04-27 21:33:27 +00:00
Chien-Chin Huang
01de8ee845 [SPMD][Easy] Add time counter in graph_optimization_pass (#99969)
This can give the idea how expensive the pass is.

Differential Revision: [D45255366](https://our.internmc.facebook.com/intern/diff/D45255366/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99969
Approved by: https://github.com/lessw2020
2023-04-27 17:56:07 +00:00
Ke Wen
ae0eb2342d [Experimental] Remove store barrier after PG init (#99937)
Store based barrier is not scalable.
Experimenting to see if removing it breaks any CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99937
Approved by: https://github.com/kumpera, https://github.com/H-Huang
2023-04-27 17:23:10 +00:00
Iris
fad2f6edab [PTD][Checkpoint] Upstream fsspec storage read/write to PT (#98387)
Remove sync_files.
Remove single_file_per_rank and will add it back once we resolve the issue. https://github.com/pytorch/pytorch/issues/98386

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98387
Approved by: https://github.com/fegin
2023-04-27 16:47:28 +00:00
Chien-Chin Huang
b94a0ba5bb [SPMD] Add embedding dense backward prop rule for postional embedding (#100038)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100038
Approved by: https://github.com/mrshenli
2023-04-27 16:31:51 +00:00
Rodrigo Kumpera
ad21890f8f [c10d] Scalable PG initiation. (#99931)
Add use_local_synchronization argument to new_group.

When this argument is True, is change new_group to do a store_barrier only on the ranks that are park of the group and not the whole cluster.

This addressess both scalability and composability problems associated with new_group.

Fixes #81291.

This is relanding #84224
As part of the original PR I did a quick benchmark of creating 3 PGs per rank using both functions and perf is the following:

new_group use_local_synchronization=False:
| World Size | Time (in secs) |
| --- | ----------- |
| 4 | 0.12 |
| 8 | 0.25 |
| 16 | 0.51 |
| 32 | 0.87 |
| 64 | 1.50 |
| 128 | 2.87 |

new_group use_local_synchronization=True:
| World Size | Time (in secs) |
| --- | ----------- |
| 4 | 0.05 |
| 8 | 0.04 |
| 16 | 0.03 |
| 32 | 0.03 |
| 64 | 0.04 |
| 128 | 0.04 |

Scaling for `use_local_synchronization=False` is sub linear because the number of process groups created as a multiple of world_size decreases as we go up. It's 6 with world_size 4 and 192 with world_size 128.

Scaling for `use_local_synchronization=True` is constant as the number of store barriers executed per rank remains constant at 3.

Setup:

1 AWS host, backend gloo.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99931
Approved by: https://github.com/xw285cornell
2023-04-27 13:44:02 +00:00
Will Constable
100a25d021 Basic dynamo support for traceable collectives (#94440)
Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.

Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.

Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
  in eager vs compiled.  In eager, there will be work-obj registration and
  a wrapper subclass will insert a 'wait' call at the appropriate time.
  In compile/trace mode, wait will be immetiately called, and work obj
  registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
  api, such as '_expand_group' which is essentially a constant transformation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94440
Approved by: https://github.com/kumpera
2023-04-27 05:38:36 +00:00
medivh-xp
859e82a7a9 Making fsdp device-agnostic for custom-backend which implement cuda-semantics (#99024)
Custom backend implementation based on privateuse1 with semantics identical to CUDA (CUDA is so popular), named for example 'my_device', and registered as the same module name torch.my_device.

This PR aims to satisfy the constraints of such a backend, which can be directly integrated into the current FSDP implementation.

The main issues addressed are:

#### 1. Device decision for FSDP wrapping of Modules without Parameters

Users typically organize FSDP code as follows:
```python
m = Module().to('my_device:0')
fsdp_m = FSDP(m)
```
or like this:
```python
m = Module()
fsdp_m = FSDP(m, device_id=torch.device('my_device', 0))
```
If the model has Parameters, everything works fine because FSDP will prioritize the device where the Parameters are located. However, for Modules without Parameters, the to() call has no side effects, and FSDP will assume the current CUDA device, which prevents the use of devices other than the current CUDA device for Modules without Parameters. Therefore, when FSDP is called with a device_id argument, this configuration takes top priority.

#### 2. Abstraction of a cuda-like device

Now, in addition to compute_device, _FSDPState includes a device_handler member. In fact, this device_handler is now just a reference to either torch.cuda or torch.my_device. From now on, code that works based on _FSDPState should use state.device_handler to operate streams create, wait or sync, just like using torch.cuda previously.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99024
Approved by: https://github.com/awgu
2023-04-27 04:13:28 +00:00
Rohan Varma
be8c7c06b6 [Tensor Parallel] Simplify distribute for MHA (#100046)
This function is only called for nn.MHA or the custom MHA we use, and
if it is the former it is converted to the latter. So this check can actually
be an assert.

Differential Revision: [D45300396](https://our.internmc.facebook.com/intern/diff/D45300396/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100046
Approved by: https://github.com/wanchaol
2023-04-27 00:54:21 +00:00
Rodrigo Kumpera
5b4a523583 Add all_reduce_coalesced to functional collectives (#98640)
This adds all_reduce_coalesced to MTPG to ease testing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98640
Approved by: https://github.com/wanchaol
2023-04-26 17:05:54 +00:00
Chien-Chin Huang
33fba6ef07 [SPMD] Add arange and zeros to default factory ops (#100037)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100037
Approved by: https://github.com/mrshenli, https://github.com/wanchaol
2023-04-26 16:32:10 +00:00
Shrikant Nagori
676a23f452 [RFC] Allow elastic agent to fail fast (#99051)
Summary: Today, on a segfault on a single trainer , we end up keeping the gpu on all ranks blocked for 5 minutes due to elastic agents barrier timeouts

Test Plan: Rely on existing test to validate . Looking to get some feedback on adding UTs

Differential Revision: D44929488

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99051
Approved by: https://github.com/kurman, https://github.com/kiukchung
2023-04-25 23:51:20 +00:00
Wanchao Liang
0901b41a5e [spmd] Add a few more loss ops to the reduction op list (#99900)
This PR adds a few more loss ops to the reduction op list
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99900
Approved by: https://github.com/mrshenli
2023-04-25 19:31:00 +00:00
Wanchao Liang
932ed333f7 [spmd] expose input_batch_dim to DataParallelMode (#99899)
This PR exposes the input batch dim to the DataParallelMode so that
we could have explicit control of which input dim is batch dim
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99899
Approved by: https://github.com/awgu, https://github.com/mrshenli
2023-04-25 19:30:58 +00:00
Wanchao Liang
c6949db481 [spmd] enable fully_shard fused_adam test (#99898)
This PR enables fully_shard fused adam tests with some additional tweaks
about how to handle scalar tensor. Now we treat scalar tensors as if
it's just a scalar value, we don't distribute it as there's no need to
shard a scalar tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99898
Approved by: https://github.com/mrshenli
2023-04-25 19:30:55 +00:00
Wanchao Liang
ad882c5210 [spmd] Use TupleStrategy and enable replicate fused_adam (#99374)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99374
Approved by: https://github.com/mrshenli
2023-04-25 19:30:53 +00:00
Aaron Gokaslan
e2a3817dfd [BE] Enable C419 rule for any all shortcircuiting (#99890)
Apparently https://github.com/pytorch/pytorch/pull/78142 made torch.JIT allow for simple generator expressions which allows us to enable rules that replace unnecessary list comprehensions with generators in any/all. This was originally part of #99280 but I split it off into this PR so that it can be easily reverted should anything break.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99890
Approved by: https://github.com/justinchuby, https://github.com/kit1980, https://github.com/malfet
2023-04-25 15:02:13 +00:00
Chien-Chin Huang
3de7fd461a [FSDP][Reland] Include duplicate parameters and modules when calling named_parameters and named_modules (#99448)
The default option of `named_parameters` and `named_modules` is to remove the duplicated parameters and modules. However, in FSDP, we need to know what parameters are shared. As a result, setting `remove_duplicate` to False is required in FSDP. Without setting `remove_duplicate` to False, FSDP won't be able to discover shared weights in some cases (e.g., the shared weights are in the same module or there are shared modules).

The previous PR is reverted due to some modules overwriting the signature of `named_parameters()`. This new PR adds a workaround for the case.

Differential Revision: [D45065973](https://our.internmc.facebook.com/intern/diff/D45065973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99448
Approved by: https://github.com/zhaojuanmao
2023-04-25 00:27:07 +00:00
Wanchao Liang
855f611baf [spmd] skip gradient copying for fused adam (#99489)
gradients does not need to be copy back as it's not useful
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99489
Approved by: https://github.com/mrshenli
2023-04-24 22:50:02 +00:00
Iris
7398b5650d [Lint] Fix wrong docstring for dcp save_state_dict() (#99778)
``no_dist=True`` mean not saving in SPMD style.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99778
Approved by: https://github.com/H-Huang
2023-04-24 22:18:42 +00:00
Ke Wen
3a09aa5977 [c10d] Faster coalescing (#98793)
### Description
The PR aims at reducing CPU overhead of context manager style coalescing.

By "context manager style coalescing", we mean:
Sync style:
```
with _coalescing_manager():
     for i in range(num_coll):
         dist.all_reduce(tensors[i])
```
Async style:
```
with _coalescing_manager(async_ops=True) as cm:
     for i in range(num_coll):
         dist.all_reduce(tensors[i])
cm.wait()
```
In previous implementation, each collective in the `num_coll` loop actually calls into the C++ backend, accumulating pybind overhead.

In the new implementation, we capture the collectives at Python level, and only fire towards C++ at the exit of the coalescing manager.

### Tests
In current PR, the "fast path" only applies to all-reduce.
- Flattened 512M: 16.38 ms, including CPU time 131.21 us
- Old _coalescing_manager 64 x 8M: 22.19 ms, including CPU time 2865 us
- New _coalescing_manager 64 x 8M: 16.93 ms, including CPU time 635 us

Hence a 4x reduction in CPU overhead (dependent on `num_coll`).

Cc @mrshenli @kumpera @wanchaol @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98793
Approved by: https://github.com/kumpera
2023-04-24 21:27:26 +00:00
medivh-xp
39590d06c5 Make new_subgroups avaliable for non-cuda depend backend (#99706)
The `new_subgroups` allows for the easy creation of sub-communication groups, but it currently requires CUDA availability. For communications that do not rely on CUDA, such as the CPU-based gloo or custom communication backends, I still hope to be able to use it, such as with the CPU-based gloo (which is also the case when using a custom backend):
```python
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def gloo_process(rank_id, world_size, group_size, mp_lock):
    assert not torch.cuda.is_available()
    def lock_print(*args, **kwargs):
        with mp_lock:
            print(*args, **kwargs, flush=True)

    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group('gloo', rank=rank_id, world_size=world_size)

    subgroup, _ = dist.new_subgroups(group_size)
    subgroup_ranks = list(range(subgroup.rank() * group_size, (subgroup.rank() + 1) * group_size))
    lock_print(f"Rank {rank_id} initialized in subgroup_{subgroup.rank()}: {subgroup_ranks}")

    tensor = torch.Tensor([rank_id + 1])
    subgroup.broadcast(tensor, root=0)

    lock_print(f"After broadcast, rank {rank_id} in subgroup_{subgroup.rank()}:{subgroup_ranks} got {tensor}")

if __name__ == "__main__":
    world_size = 4
    group_size = 2
    processes = []
    mp.set_start_method("spawn")
    mp_lock = mp.Lock()
    for rank in range(world_size):
        p = mp.Process(target=gloo_process, args=(rank, world_size, group_size, mp_lock))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()
```

```bash
Rank 0 assigned to subgroup_0: [0, 1]
Rank 1 assigned to subgroup_1: [2, 3]
Rank 2 assigned to subgroup_0: [0, 1]
Rank 3 assigned to subgroup_1: [2, 3]
After broadcast, rank 2 in subgroup_0:[0, 1] got tensor([3.])
After broadcast, rank 3 in subgroup_1:[2, 3] got tensor([3.])
After broadcast, rank 1 in subgroup_1:[2, 3] got tensor([1.])
After broadcast, rank 0 in subgroup_0:[0, 1] got tensor([1.])
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99706
Approved by: https://github.com/kumpera
2023-04-24 18:22:59 +00:00
Justin Chu
7d2a18da0b Enable ruff in lintrunner (#99785)
### This change

- Implements the ruff linter in pytorch lintrunner. It is adapted from https://github.com/justinchuby/lintrunner-adapters/blob/main/lintrunner_adapters/adapters/ruff_linter.py. It does **both linting and fixing**. 🔧
- Migrated all flake8 configs to the ruff config and enabled it for the repo. 
- **`ruff` lints the whole repo in under 2s** 🤯

Fixes https://github.com/pytorch/pytorch/issues/94737 Replaces #99280

@huydhn @Skylion007

<!--
copilot:all
-->
### <samp>🤖 Generated by Copilot at 6b982dd</samp>

### Summary
🧹🛠️🎨

<!--
1.  🧹 This emoji represents cleaning or tidying up, which is what `ruff` does by formatting and linting the code. It also suggests improving the code quality and removing unnecessary or redundant code.
2.  🛠️ This emoji represents tools or fixing, which is what `ruff` is as a code formatter and linter. It also suggests enhancing the code functionality and performance, and resolving potential issues or bugs.
3.  🎨 This emoji represents art or creativity, which is what `ruff` allows by providing a consistent and configurable style for the code. It also suggests adding some flair or personality to the code, and making it more readable and enjoyable.
-->
Add `[tool.ruff]` section to `pyproject.toml` to configure `ruff` code formatter and linter. This change aims to improve code quality and consistency with a single tool.

> _`ruff` cleans the code_
> _like a spring breeze in the fields_
> _`pyproject.toml`_

### Walkthrough
*  Configure `ruff` code formatter and linter for the whole project ([link](https://github.com/pytorch/pytorch/pull/99785/files?diff=unified&w=0#diff-50c86b7ed8ac2cf95bd48334961bf0530cdc77b5a56f852c5c61b89d735fd711R22-R79))

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99785
Approved by: https://github.com/malfet, https://github.com/Skylion007
2023-04-24 16:18:44 +00:00
Wanchao Liang
9db6920635 [spmd] Add list handling to data parallel and add foreach tests (#99373)
This PR adds list handling logic to the new DataParallel expansion and
add foreach optimizer tests, currently current testing sgd optimizers
in foreach mode, for both replicate and fully shard

Next step:

Add fused optim tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99373
Approved by: https://github.com/mrshenli
2023-04-22 05:39:20 +00:00
Wanchao Liang
c1e2fa8189 [dtensor] add StrategyType and TupleStrategy (#99435)
This PR refactors the current StrategyList. It introduces a
StrategyType, which is the base class of Strategy, and it have
two sub strategies:

1. Refactor the previous StrategyList to OpStrategy
2. Add TupleStrategy, the new strategy added to deal with tuple cases where
it could return multiple different OpStrategy for an op.

This would help support a more complicated op and unblocks compile mode
FSDP
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99435
Approved by: https://github.com/mrshenli
2023-04-22 05:39:20 +00:00
Iris
ca8625f456 [BE][1/N]Add sharding spec logger for ShardedTensor (#99748)
Set up a nullHandler() on the OSS side.
Next step is to set up the counterpart in internal.

This is part of the effort for ShardedTensor deprecation. We want to log internal use cases for different sharding spec.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99748
Approved by: https://github.com/H-Huang, https://github.com/fegin
2023-04-22 04:05:21 +00:00
Wanchao Liang
e9bf94149e [spmd] Introduce Compile Mode FSDP with DTensor (#99062)
This PR introduces compile mode Data Parallel (FSDP/DDP) using DTensor sharding.

Along with the algorithm, it also introduces a new DataParallelMode so that `compile` API can take it
and apply data parallel. This PR trys to preserve the DTensorExpand
approach first to avoid BC, we shall discuss steps to remove
DTensorExpand.

The data parallel mode uses heuristics to determine node types in the
graphs and assign the corresponding sharding. The detailed algorithm
described in the design doc.

The benefits of this approach:
- Model parameters and optimizer states are all DTensors after  `spmd.compile`, which is necessary for FSDP, and also makes it super easier for checkpointing
- As model parameter/optim states are sharding in a per-parameter approach, it would be able to compose with sophisticated second order optimizer (i.e. Shampoo) in a easier way.
- We leverage the model parameter/grads information to derive data parallel pattern. In this way we don't need to worry about DTensor op coverage anymore! As data parallel is just a special case of DTensor operation.
- Use dtensor_expand might work for DDP but aren't going to work for FSDP as dtensor might choose to allgather activation, which might violate native fsdp algorithm.
- The approach is general enough to support both DDP/FSDP and a mixed mode

Follow ups:
- Add the "default" data parallel mode which supports mixing of
replicate/fully shard
- Test more e2e models with more different types of optimizers, etc
- migrate the existing stack from the DTensorExpand mode
- build optimizations on top of this prototype

Differential Revision: [D45174400](https://our.internmc.facebook.com/intern/diff/D45174400)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99062
Approved by: https://github.com/mrshenli
2023-04-22 03:13:05 +00:00
Xilun Wu
ce60997376 [BE][DTensor] validate the mesh argument in DeviceMesh construction (#99094)
## What's in this PR
DeviceMesh's __init__ function now requires all calling ranks to pass the same `mesh` argument.

## Why
We want to enforce SPMD style of programs using DTensor. Before this PR, 2-D Parallel API (e.g. _create_1d_device_mesh) defines different DeviceMesh on different ranks. After this PR, it defines each sub-meshes and simply perform communications on the one that it is associated with.

Differential Revision: [D45165511](https://our.internmc.facebook.com/intern/diff/D45165511)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99094
Approved by: https://github.com/wanchaol
2023-04-21 23:47:51 +00:00
Horace He
547bef11ee tweak heuristic for sdpa selection based off of *data* (and a decision tree) (#99644)
High level approach:
1. I generated a bunch of data comparing FlashAttention and Cutlass implementations (https://pastebin.com/pe0j3YeK)
2. I trained a decision tree using standard train/val split methodology and hyperparameter sweeps (https://pastebin.com/fjYX1HjR).
2a. I did a bunch of feature augmentation to capture interactions between features.

The heuristic I ended up with is:
```
use_flash = seq_len / (num_heads * batch_size) > 6
```

TL;DR: On my dataset, where FlashAttention and Cutlass differ by more than 10%, the existing heuristic achieves 69% accuracy.  My new heuristic achieves 94% accuracy.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99644
Approved by: https://github.com/ngimel, https://github.com/drisspg
2023-04-21 23:28:44 +00:00
Chien-Chin Huang
4f62e7cb10 [FSDP][BE] Remove unused code (#99731)
Remove the unused code. https://github.com/pytorch/pytorch/pull/99675 is duplicated and we should land this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99731
Approved by: https://github.com/wz337
2023-04-21 23:11:37 +00:00
Daniel Dale
363d530035 Fix decision logic for should_cast_forward_inputs in _root_pre_forward() and _pre_forward() (#99546)
Fixes #99545

There is currently no topological constraint dictating FSDP instances own ``FlatParamHandle`` s directly. If all parameters are managed by descendant FSDP instances leaving an FSDP instance with no direct ``state._handles``, the  ``should_cast_forward_inputs`` decisions below in both ``_root_pre_forward()`` and ``_pre_forward()`` respectively can return incorrect decisions [^1].

For [``_root_pre_forward()``](436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L514)):

436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L602-L604)

For [``_pre_forward``](436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L384)):

436edc5ac3/torch/distributed/fsdp/_runtime_utils.py (L420-L422)

See the [related issue](https://github.com/pytorch/pytorch/issues/99545) for reproduction.

### Remediation

In this PR, I amend the two decision statements referenced above (in both `_root_pre_forward()` and `_pre_forward()`) to account for FSDP instances without direct handles:
```python
should_cast_forward_inputs = len(state._handles) > 0 and all(
    not handle._force_full_precision for handle in state._handles
)
```

If one configures ``MixedPrecision`` in the example above with ``cast_forward_inputs=True`` and the ``should_cast_forward_inputs`` adjustment above, FSDP returns to the expected behavior and produces no error.

Though the check is the same in both ``_root_pre_forward()`` and ``_pre_forward()`` and hence could be refactored into a separate function, I figured it may make sense to retain separate statements to preserve the ability for root-specific behavior in the future. Whichever approach the team prefers I can update this PR with.

### Implementation considerations and questions:

1. Rather than write a test that would arguably have a poor utility/resource usage profile, I have not added any tests associated with this PR. The new decision logic is exercised by all existing tests (which continue to pass after this PR of course) so I think the utility of new tests is fairly modest. Let me know if you think new tests should be added and I'm happy to do so.
2. As discussed above, the decision statement shared among ``_pre_forward()`` and ``_root_pre_forward()`` could be factored out into a separate function. Given the simplicity of the statement and to retain current flexibility for root-specific decisions it might not be worth the refactor so I haven't done it yet. Let me know if you'd like me to do so.
3. The note below could be updated to indicate the utility of setting ``cast_forward_inputs=True`` for the situations addressed with this PR but I haven't done so since I'm not sure it's worth complicating the current usage guidance. I'd be happy to add verbiage describing the use case if the team wants it.
cde35b4069/torch/distributed/fsdp/api.py (L175-L181)

Thanks again to the PyTorch distributed team for your immensely valuable contributions to the open-source ML community!

[^1]: Though one could keep the existing decision logic and impose a new topological constraint requiring all FSDP instances have direct `_handles`, I think retaining the current wrapping flexibility is both convenient and useful enough (e.g. programmatic wrapping of modules that may or may not already have all parameters handled by descendant FSDP instances) to update the decision logic as discussed here instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99546
Approved by: https://github.com/awgu
2023-04-21 22:49:50 +00:00
Chien-Chin Huang
7876c503b7 [FSDP][optim_state_dict] Consolidate rank0_only load logic (#99647)
Follow up https://github.com/pytorch/pytorch/pull/99624, this PR consolidate the logic of `use_orig_params=False` with `use_orig_params=True` to use the same logic to load optimizer checkpoint when rank0_only is True.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99647
Approved by: https://github.com/wz337
2023-04-21 20:29:54 +00:00
Chien-Chin Huang
dd07dab1c7 [FSDP][optim_state_dict] Support rank0_only when use_orig_params is on (#99624)
This PR makes `use_orig_params=True` case support rank0_only loading for optim state_dict. The implementation is different from `use_orig_params=False`. The `use_orig_params=False` implementation first flatten the parameters on rank0 and then broadcast the states while this implementation broadcast the state when doing the flattening. The implementation is slower as it broadcast the original parameters instead of the flattened ones. However, the implementation introduced by this PR is simpler. As loading is usually happen once per training life, the performance difference can be ignored. In next PR, we will consolidate the implementations in favor of the simpleness.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99624
Approved by: https://github.com/wz337
2023-04-21 20:09:19 +00:00
Wanchao Liang
b96bb2f1a6 [spmd] Introduce ParallelMode and add DTensorExpandMode (#98452)
This PR introduces a ParallelMode interface to define how to do
SPMD expansion and optimize the captured graph. This would be
beneifical for different parallelisms to expand differently
and apply different optimization passes

Put DTensorExpandMode as the first parallel mode that does the
existing dtensor_expand functionality.

Differential Revision: [D45174399](https://our.internmc.facebook.com/intern/diff/D45174399)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98452
Approved by: https://github.com/mrshenli
2023-04-21 17:24:54 +00:00
PyTorch MergeBot
9861ec9785 Revert "[c10d] Faster coalescing (#98793)"
This reverts commit db456ab83d.

Reverted https://github.com/pytorch/pytorch/pull/98793 on behalf of https://github.com/DanilBaibak due to Break internal build
2023-04-21 09:15:04 +00:00
Iris
0d2b55c459 [DTensor] Change Sharding algorithm to be in line with `torch.chunk()` (#98722)
As functional collective being updated, using tensor_split() as the underlying sharding algorithm would require padding and unpadding on multiple ranks. Therefore, we are changing the sharding algorithm to be in line with ``torch.chunk()`` to allow padding on the last two ranks in most of the scenarios.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98722
Approved by: https://github.com/wanchaol
2023-04-21 02:05:22 +00:00
Edward Z. Yang
abdd1f4a38 Reuse tracing context and fake tensors from backwards in forwards (#99619)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99619
Approved by: https://github.com/wanchaol
2023-04-20 22:39:48 +00:00
Chien-Chin Huang
88c45a1954 [SPMD] Allow users to dynamically pass the last_iter to IterGraphModule (#99575)
The current design of IterGraphModule requires users to specify the concrete iteration count which is not always possible and not very precise. This PR introduce `last_iter` to IterGraphModule.forward() which allows users to dynamically specify the last iteration.

Differential Revision: [D45129585](https://our.internmc.facebook.com/intern/diff/D45129585/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99575
Approved by: https://github.com/lessw2020
2023-04-20 16:49:34 +00:00
Iris
a2a4144256 [FSDP]Make param_groups optional for FSDP optim state dict (#99117)
Make param_groups optional for FSDP optim state dict and add corresponding test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99117
Approved by: https://github.com/fegin, https://github.com/zhaojuanmao
2023-04-20 06:34:40 +00:00
ZhongYingMatrix
af7fed1d92 fix osd rank0_only in fsdp (#99136)
Fixes #99135

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99136
Approved by: https://github.com/fegin
2023-04-19 21:50:38 +00:00
Shen Li
ca89e7942a [SPMD][Easy] switch to tree_map_only to simplify code (#99547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99547
Approved by: https://github.com/fegin
2023-04-19 20:40:09 +00:00
Ke Wen
db456ab83d [c10d] Faster coalescing (#98793)
### Description
The PR aims at reducing CPU overhead of context manager style coalescing.

By "context manager style coalescing", we mean:
Sync style:
```
with _coalescing_manager():
     for i in range(num_coll):
         dist.all_reduce(tensors[i])
```
Async style:
```
with _coalescing_manager(async_ops=True) as cm:
     for i in range(num_coll):
         dist.all_reduce(tensors[i])
cm.wait()
```
In previous implementation, each collective in the `num_coll` loop actually calls into the C++ backend, accumulating pybind overhead.

In the new implementation, we capture the collectives at Python level, and only fire towards C++ at the exit of the coalescing manager.

### Tests
In current PR, the "fast path" only applies to all-reduce.
- Flattened 512M: 16.38 ms, including CPU time 131.21 us
- Old _coalescing_manager 64 x 8M: 22.19 ms, including CPU time 2865 us
- New _coalescing_manager 64 x 8M: 16.93 ms, including CPU time 635 us

Hence a 4x reduction in CPU overhead (dependent on `num_coll`).

Cc @mrshenli @kumpera @wanchaol @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98793
Approved by: https://github.com/kumpera
2023-04-19 20:17:58 +00:00
Yanli Zhao
6ca991cacf [Composable API] Add fully_shard debug function to print sharded tree structure, module names and managed param fqns (#99133)
Adding a fully_shard debug function to print sharded tree structure like following format, return module names and their managed parameter fqns as well.

![Screenshot 2023-04-18 at 5 14 54 PM](https://user-images.githubusercontent.com/48731194/232931628-169a63a9-b4d5-4902-9cfd-f40113f3ec98.png)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99133
Approved by: https://github.com/rohan-varma
2023-04-19 19:27:43 +00:00
Shen Li
e605b5df74 [SPMD] Add sym_stride to DSymInt (#99504)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99504
Approved by: https://github.com/fegin
2023-04-19 14:55:40 +00:00
Shen Li
2cb8a8d4cc [SPMD] Support DSymInt for slice_backward in SPMD expansion (#99501)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99501
Approved by: https://github.com/fegin
2023-04-19 14:55:40 +00:00
Shen Li
292296141a [SPMD] Support SymInt with non-op call_function nodes (#99420)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99420
Approved by: https://github.com/fegin
2023-04-19 14:55:37 +00:00
Shen Li
7c0c663a4c [SPMD] Add aten.stack and aten.select to DTensor prop (#99417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99417
Approved by: https://github.com/fegin
2023-04-19 14:55:34 +00:00
Shen Li
301be37091 Avoid import * from experimental_ops (#99363)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99363
Approved by: https://github.com/fegin
2023-04-19 14:55:30 +00:00
Chien-Chin Huang
41d7969590 [SPMD] Upstream iter_move_grads_and_optimizers (#98785)
This PR upstreams `iter_move_grads_and_optimizer` which delay some of the gradients and the corresponding optimizer to the next iteration. D44512863(credit to @lessw2020 ) is the internal implementation, which is only good for the old _SPMD expansion.  This PR changes the implmentation to use the new APIs.

Differential Revision: [D44836486](https://our.internmc.facebook.com/intern/diff/D44836486/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98785
Approved by: https://github.com/mrshenli
2023-04-19 06:40:33 +00:00
Rodrigo Kumpera
38e964056b Reland python ops (#99170)
Waiting for the revert to land.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99170
Approved by: https://github.com/albanD
2023-04-18 15:15:46 +00:00
PyTorch MergeBot
1c042a2137 Revert "Reland python ops (#99170)"
This reverts commit d4de64ae8d.

Reverted https://github.com/pytorch/pytorch/pull/99170 on behalf of https://github.com/DanilBaibak due to Break internal build
2023-04-18 11:37:43 +00:00
Xilun Wu
964c7e3e85 [BE][DTensor] fix DTensor equal op (#99014)
## What problem this PR solves?
#97170 fixed `equal` operator return type (old: Tensor, now: bool) by giving it the correct sharding propagation. This is consistent with the `aten::equal` op. However, the correctness only stays at the local result level:
* `equal` op returns True if the local copy of dtensor A equals to the the local copy of dtensor B

This is not the correct semantic of `equal` which should return True if all local copies of A are equal to the corresponding local copies of B.

## What is this PR?

1. For non-participating ranks, if the return type is scalar, `local_results` is set to `None` which means the default value is a reduced result of participating ranks only.
2. For all ranks, if the return type is scalar and the `op_call` is `aten::equal`(because `aten::equal` is the only function that returns scalar value and needs communication), all gather the `local_results` within the `default pg` and reduce on them with `operator.and_`. The result will be the new `local_result`.

## Result/Impact
For non-participating ranks and the return type is scalar:

1. op is `aten::equal`, the return value is same with all other ranks
2. op is not `aten::equal`, the return value is None. Before this PR, this will raise "NotImplementedError" but has not been tested.

For participating ranks and the return type is scalar:

1. op is `aten::equal`, the return value is the equality of two dtensor operands - True if all copies are equal, False otherwise.
2. op is not `aten::equal`, simply the local computation result.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99014
Approved by: https://github.com/wanchaol
2023-04-18 03:22:44 +00:00
Chien-Chin Huang
bdaf32261f [FSDP] Ensure that customized non tensor optimizer state can be saved (#99214)
The current logic does not actually handle all different non-tensor optimizer states correctly. This PR fixes the issue and adds a test.

This PR will solve https://github.com/pytorch/pytorch/issues/99079

Differential Revision: [D45021331](https://our.internmc.facebook.com/intern/diff/D45021331/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99214
Approved by: https://github.com/awgu, https://github.com/awaelchli
2023-04-17 21:54:16 +00:00
Rodrigo Kumpera
d4de64ae8d Reland python ops (#99170)
Waiting for the revert to land.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99170
Approved by: https://github.com/albanD
2023-04-17 21:53:41 +00:00
Nikita Shulga
ccc5d1daec Revert D44897935: Multisect successfully blamed D44897935 for test or build failures (#99353)
Summary:
This diff is reverting D44897935
D44897935: [FSDP] Include duplicate parameters and modules when calling named_parameters and named_modules (#98912) by fegin has been identified to be causing the following test or build failures:

Tests affected:
- [caffe2/torch/fb/module_factory/sync_sgd/tests:test_pyper_data_parallel_wrapper - caffe2.torch.fb.module_factory.sync_sgd.tests.test_pyper_data_parallel_wrapper.PyPerDataParallelWrapperTest: test_fsdp_submodules_pyper](https://www.internalfb.com/intern/test/562950025957458/)

Here's the Multisect link:
https://www.internalfb.com/multisect/1893714
Here are the tasks that are relevant to this breakage:

We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it.

If you believe this diff has been generated in error you may Commandeer and Abandon it.

Test Plan: NA

Reviewed By: fegin

Differential Revision: D45027286

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99353
Approved by: https://github.com/izaitsevfb, https://github.com/fegin
2023-04-17 20:53:10 +00:00
Shen Li
62a6d81143 [SPMD][Easy] Making typing consistent by replacing object with Any (#99332)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99332
Approved by: https://github.com/dracifer
2023-04-17 19:33:45 +00:00
Shen Li
19c2804614 [SPMD][EASY] Remove unnecessary torch.ops prefix (#99331)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99331
Approved by: https://github.com/dracifer
2023-04-17 19:33:45 +00:00
Chien-Chin Huang
148d49260a [SPMD] Implement split_fused_optimizer to split one fused_optimizer node to two (#98784)
Several optimization passes requires the ability to split the fused_optimizer.  This PR adds the API to support the use cases.

Differential Revision: [D44806450](https://our.internmc.facebook.com/intern/diff/D44806450/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98784
Approved by: https://github.com/mrshenli
2023-04-17 10:02:07 +00:00
Shen Li
c69d54885a [SPMD][BE] Generalize factory ops support in SPMD expansion (#99233)
Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D45028740](https://our.internmc.facebook.com/intern/diff/D45028740)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99233
Approved by: https://github.com/yifuwang
2023-04-16 00:07:27 +00:00
Shen Li
6bb20822f5 [SPMD][BE] Remove deprecated aten.sym_numel branch (#99232)
Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D45028732](https://our.internmc.facebook.com/intern/diff/D45028732)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99232
Approved by: https://github.com/yifuwang
2023-04-16 00:07:27 +00:00
Shen Li
39be994913 [SPMD][BE] Consolidate DSymInt Branches (#99231)
Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D45028726](https://our.internmc.facebook.com/intern/diff/D45028726)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99231
Approved by: https://github.com/yifuwang
2023-04-16 00:07:24 +00:00
Shen Li
544cd8e134 [SPMD] Refactor DSize to DSymInt to enable sym_numel (#99206)
This commit uses `aten.arange.default` and `aten.arange.start` to
test `aten.sym_numel`.

Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)

Differential Revision: [D45028715](https://our.internmc.facebook.com/intern/diff/D45028715)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99206
Approved by: https://github.com/yifuwang
2023-04-16 00:07:21 +00:00
Shen Li
bafb984022 [SPMD] Enable aten.full.default with SymInt on sharded dims (#99190)
Differential Revision: [D45028686](https://our.internmc.facebook.com/intern/diff/D45028686)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99190
Approved by: https://github.com/yifuwang
2023-04-16 00:07:18 +00:00
Daniel Dale
157c869026 Enable FSDP `use_orig_params=True` mixed precision training when some ranks have no (non-zero sized) parameter shards (#99175)
Fixes #99174

## Enable FSDP ``use_orig_params=True`` mixed precision training when some ranks have no (non-zero sized) parameter shards

### The issue

Now that ``use_orig_params=True`` allows non-uniform ``requires_grad`` (🎉 🚀 thanks @awgu!!!) with [#98221](https://github.com/pytorch/pytorch/pull/98221), there will be circumstances wherein some ranks have no (non-zero sized) local shards of the original parameters (and hence no associated gradients).

### Use Cases
For a simple Transformer case, imagine a user wraps all encoder layers in separate FSDP instances but allows the classifier head to be wrapped in the same FSDP instance as the relatively large embeddings layers. While this is a sub-optimal wrapping strategy for most use-cases, I believe it is expected to be supported (full precision training works in that context).

I originally encountered this issue while extending a package I maintain, leveraging the relaxed ``requires_grad`` contstraint to simplify multi-phase scheduled fine-tuning FSDP configuration, so a [concrete example is there](https://finetuning-scheduler.readthedocs.io/en/latest/advanced/fsdp_scheduled_fine_tuning.html#basic-scheduled-fine-tuning-with-fsdp).

### Reproduction and Remediation
Currently, ``ShardedGradScaler`` does not accommodate these situations, failing to initialize ``optimizer_state["found_inf_per_device"]`` when ``unscale_`` is called.

In this PR, I extend the existing ``ShardedGradScaler`` tests with an ``use_orig_params=True`` dimension added to the parameterization and test scenarios wherein one rank possesses no (non-zero sized) parameter shards.

The relevant issue can be reproduced with the tests I'm adding in this PR. The current (pre-PR) execution of these tests fail in ``use_orig_params=True`` mode with this error:

```python
./test_fsdp_sharded_grad_scaler.py::TestShardedGradScalerParityWithDDP::test_fsdp_ddp_parity_with_grad_scaler_offload_false_none_mixed_precision_use_orig_params Failed with Error: Process 0 exited with error code 10 and exception:
Traceback (most recent call last):
  File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_distributed.py", line 657, in run_test
    getattr(self, test_name)()
  File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_distributed.py", line 543, in wrapper
    fn()
  File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_utils.py", line 259, in instantiated_test
    test(self, **param_kwargs)
  File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_distributed.py", line 174, in wrapper
    return func(*args, **kwargs)
  File "/home/speediedan/repos/pytorch/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py", line 187, in test_fsdp_ddp_parity_with_grad_scaler
    self._test_fsdp_parity(
  File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_fsdp.py", line 1152, in _test_fsdp_parity
    fsdp_loss = self._train_for_several_steps(
  File "/home/speediedan/repos/pytorch/torch/testing/_internal/common_fsdp.py", line 1016, in _train_for_several_steps
    sharded_grad_scaler.step(optim)
  File "/home/speediedan/repos/pytorch/torch/distributed/fsdp/sharded_grad_scaler.py", line 291, in step
    return super().step(optimizer, *args, **kwargs)
  File "/home/speediedan/repos/pytorch/torch/cuda/amp/grad_scaler.py", line 368, in step
    assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.
```

A few implementation notes/considerations and questions:

1. Rather than just initialize  ``per_device_found_inf``, one could disable the grad scalar altogether for relevant ranks, altering ``unscale_`` to reduce with a subgroup or some rank mask construct to avoid the ``all_reduce`` s in ``distributed/fsdp/sharded_grad_scaler.py:unscale_()`` from hanging. Given that users may subsequently add parameter groups to an optimizer that would require re-enabling the scaler and the complexity associated with maintaining a separate mask construct or process subgroup, I thought this implementation was cleaner.
2. I extended ``_train_for_several_steps`` and ``_test_fsdp_parity`` in ``/torch/testing/_internal/common_fsdp.py`` with the ability to configure ``sharded_grad_scaler_kwargs`` for future testing flexibility.
3. Should the user be warned that no parameter shards were associated with a given rank? My initial thought is that this should be considered an implementation detail, part of supporting ``use_orig_params`` with heterogeneous ``requires_grad``, and therefore should be transparently handled by PyTorch. Should a DEBUG level message be added? If so, likely further upstream rather than at the scaler step level.
4. Rather than extend the existing ``ShardedGradScaler`` tests with an ``use_orig_params=True`` dimension added to the parameterization, let me know if you prefer that I instead narrow the scope of the new testing to a single additional test, e.g.:
	```python
	# from typing import Optional
	from typing import Optional, List
	# ...
	# use_orig_params = ["enable_use_orig_params", None]
	use_orig_params: List[Optional[str]] = [None]
	# ...
	configs = list(itertools.product(cpu_offload_config, sharding_strategy_config, mixed_precision, use_orig_params))
	configs.append((CPUOffload(offload_params=False), None, "enable_mixed_precision", "enable_use_orig_params"))
	```
Thanks as always to the PyTorch distributed team for your astonishingly impressive and valuable contributions to the open-source ML engineering community!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99175
Approved by: https://github.com/awgu
2023-04-15 05:13:23 +00:00
Rohan Varma
ef11966aff [composable] Enable replicate + trec_shard overall (#98890)
replicate + trec_shard works if we shard / replicate individually, such as follows:

```
m = TestSparseNN()
shard(m.sparse)
replicate(m.dense)
```

but does not work if users do the following:
```
m = TestSparseNN()
shard(m, sharders=[...])
replicate(m)
```

Many upstream trainers use the latter use case, as sharding is not done on individual module level but rather overall module by specifying planners that contain logic for how to shard different embedding table types.

This diff enables the latter approach (while keeping the former intact), but users need to specify `ignored_modules` to ignore embedding tables in replicate(). This is similar to FSDP (class based and composable) and DDP today.

Differential Revision: [D44899155](https://our.internmc.facebook.com/intern/diff/D44899155/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98890
Approved by: https://github.com/mrshenli, https://github.com/yhcharles
2023-04-15 01:09:00 +00:00
Rodrigo Kumpera
a910045add [PATCH] Back out "Move functional collectives implementation to python. (#98595) (#99168)
Summary:
Original commit changeset: ba36f8751adc

Original Phabricator Diff: D44788697

Test Plan: model loading is fine after reverting the diff

Reviewed By: zyan0, sayitmemory

Differential Revision: D44921259
---

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99168
Approved by: https://github.com/izaitsevfb
2023-04-14 23:48:19 +00:00
feifan
bd07f8d2e0 DDP forward support custom stream accelerated copy. (#98723)
At present, DDP forward uses `_get_stream` to get a stream,which is cudaStream.
If the custom module already registered to torch, I can use `getattr` to get it and it's stream. Then, the custom stream is used to copy the tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98723
Approved by: https://github.com/ezyang
2023-04-14 20:19:56 +00:00
Chien-Chin Huang
286b618b7d [SPMD] Move some functions to IterGraphModule.setup() (#99076)
Since users will have to call these steps before calling `setup()`, moving these steps to `setup()` can reduce the API usage complexity.

Differential Revision: [D44973726](https://our.internmc.facebook.com/intern/diff/D44973726/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99076
Approved by: https://github.com/lessw2020
2023-04-14 14:41:10 +00:00
Chien-Chin Huang
d863876545 [SPMD] Remove the unused code (#99075)
Remove the unused code

Differential Revision: [D44973692](https://our.internmc.facebook.com/intern/diff/D44973692/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99075
Approved by: https://github.com/lessw2020
2023-04-14 14:35:55 +00:00
Shen Li
40aaacd4fa Respect sharded dimensions when aten expaned/view consumes SymInt values (#99058)
Currently, aten.expand always expands to the global dimension. Then, it
introduces additional slice and clone ops before running compute on
the expanded tensor with a local tensor.

In this commit, if we detect the op consumes a SymInt size, it respects
both local size and the dimension placements from where the SymInt was
extracted.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99058
Approved by: https://github.com/wanchaol
2023-04-14 13:54:05 +00:00