Commit Graph

2331 Commits

Author SHA1 Message Date
Andrew Gu
031ce0fadc [FSDP][7/N] Add warning about frozen params (#104967)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104967
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427
2023-08-02 21:50:38 +00:00
Andrew Gu
a8c52863dd [FSDP][6/N] Check valid param freezing for ModuleWrapPolicy (#104427)
This PR adds improved error/warning messaging when auto wrapping with `ModuleWrapPolicy` in the presence of frozen parameters.
- For `use_orig_params=False`, FSDP requires uniform `requires_grad` for each FSDP instance. This PR adds a `ValueError` at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names.
- For `use_orig_params=True`, FSDP allows non-uniform `requires_grad` for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a `UserWarning` at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names.
    - There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain.

<details>
<summary> Why DFS via named_children() vs. Using named_modules()</summary>

```
LoraModel(
  (embed_tokens): Embedding(100, 32)
  (layers): ModuleList(
    (0-3): 4 x LoraDecoder(
      (attn): LoraAttention(
        (q_proj): Linear(in_features=32, out_features=32, bias=False)
        (lora_A): Linear(in_features=32, out_features=8, bias=False)
        (lora_B): Linear(in_features=8, out_features=32, bias=False)
        (k_proj): Linear(in_features=32, out_features=32, bias=False)
        (v_proj): Linear(in_features=32, out_features=32, bias=False)
        (o_proj): Linear(in_features=32, out_features=32, bias=False)
      )
      (mlp): LoraMLP(
        (proj1): Linear(in_features=32, out_features=128, bias=False)
        (proj2): Linear(in_features=128, out_features=32, bias=False)
      )
      (inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    )
  )
  (norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)
```
Reverse topological order with stack-based DFS via `named_children()`:
```
[
  'embed_tokens',
  'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0',
  'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1',
  'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2',
  'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3',
  'layers', 'norm', ''
]
```
Reverse topological order with `named_modules()`:
```
[
  'norm',
  'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3',
  'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2',
  'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0',
  'layers', 'embed_tokens', ''
]
```
With the stack-based DFS via `named_children()`, reversing the topological order gives us each level in the module tree in the registered order, wheres with `named_modules()`, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the _first-registered_ module that violates the frozen/non-frozen condition.

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104427
Approved by: https://github.com/ezyang
2023-08-02 21:44:44 +00:00
Jane Xu
7e47343d64 [BE] document more of FSDP checkpointing logic with a sprinkle of cleaning (#106069)
This PR should not make any functional difference. It:
- adds clearer documentation
- clarifies a type
- revises minor typos
- swaps a .keys for a .items call on a dictionary

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106069
Approved by: https://github.com/awgu
2023-08-02 17:19:04 +00:00
Iris
0cba33e176 [DTensor]Minor Docstring Update (#106250)
Fix docstring to reflect change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106250
Approved by: https://github.com/wanchaol
2023-08-02 00:27:29 +00:00
Andrew Gu
506b55fc29 [FSDP][Easy] Move _FSDPState attrs to avoid comment confusion (#106392)
Resubmit of https://github.com/pytorch/pytorch/pull/106333 after rebasing (I lost the original branch locally)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106392
Approved by: https://github.com/kwen2501
2023-08-01 20:39:22 +00:00
shibo19
0af3203c72 fix torchrun script for custom device (#105443)
Fixes #ISSUE_NUMBER
as the title,add torchrun support for custom device

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105443
Approved by: https://github.com/kumpera
2023-07-31 05:46:23 +00:00
Rohan Varma
5d4e170d58 [Optim in backward] API to retrieve in-backward optimizers (#105991)
API to retrieve in backward optimizer for checkpointing purposes

Differential Revision: [D47782225](https://our.internmc.facebook.com/intern/diff/D47782225/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105991
Approved by: https://github.com/awgu
2023-07-29 01:36:25 +00:00
Rohan Varma
2ec7cd2db2 [CheckpointWrapper] Test for kwarg propagation, remove checkpoint_fn_arg support (#102679)
Closes https://github.com/pytorch/pytorch/issues/100576

Differential Revision: [D46342398](https://our.internmc.facebook.com/intern/diff/D46342398/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102679
Approved by: https://github.com/awgu
2023-07-28 21:18:35 +00:00
Andrew Gu
800287fb56 [FSDP] Optimize away intermediate div_ for HSDP (#106034)
### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.

$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$

Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.

### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$

This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.

### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.

Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma
2023-07-28 18:36:26 +00:00
Albert Chen
7c8efc9049 [PT][FSDP] Combine _utils.py into _common_utils.py [2/2] (#106181)
Summary:
https://github.com/pytorch/pytorch/issues/97813
This diffs moves `_no_dispatch_record_stream` and `_same_storage_as_data_ptr`

Test Plan: CI

Differential Revision: D47706114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106181
Approved by: https://github.com/awgu
2023-07-28 17:15:25 +00:00
fduwjj
487ebcac3b Clean up unsed MHA code to avoid confusion (#105956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105956
Approved by: https://github.com/wz337, https://github.com/ezyang, https://github.com/wanchaol
2023-07-27 17:10:17 +00:00
Wanchao Liang
f026b32008 [device_mesh][BE] reduce_scatter fallback to funcol and remove from DM (#105642)
For the reason similar to https://github.com/pytorch/pytorch/pull/105605
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105642
Approved by: https://github.com/kumpera, https://github.com/wz337, https://github.com/fduwjj
2023-07-27 01:33:05 +00:00
Wanchao Liang
2fa063e1e0 [device_mesh][BE] remove allgather from DM (#105614)
For the reason similar to https://github.com/pytorch/pytorch/pull/105605
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105614
Approved by: https://github.com/rohan-varma, https://github.com/wz337, https://github.com/fduwjj
2023-07-27 01:33:05 +00:00
Wanchao Liang
4a49f1f46e [device mesh][BE] remove allreduce from DM (#105605)
This PR removes allreduce from DM and use functional collective instead,
the rationle is that we don't want to maintain yet another set of
collective apis, and since the DM's collective is now a thin wrapper to functional collective so we
don't really need these collective to live in DM
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105605
Approved by: https://github.com/kumpera, https://github.com/wz337, https://github.com/fduwjj
2023-07-27 01:33:02 +00:00
Rohan Varma
4137d6e499 [Composable FSDP] Enable HSDP (#105206)
Need to pass in strategy to _init_process_group_state to enable hsdp
for composable.

Differential Revision: [D47462394](https://our.internmc.facebook.com/intern/diff/D47462394/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105206
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-26 21:03:55 +00:00
Andrew Gu
841b4acf1e [FSDP][Easy] Rename to _comm_hook, _comm_hook_state (#106033)
This is just out of preference to make the naming convention consistent with `register_comm_hook()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106033
Approved by: https://github.com/fegin
2023-07-26 19:59:11 +00:00
Andrew Gu
035704e88d [FSDP][Easy] Move post-bwd hook logging to own func (#106032)
This is to help make `_post_backward_hook()` easier to read. I plan to refactor some other parts in future PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106032
Approved by: https://github.com/fegin
2023-07-26 19:59:11 +00:00
FFFrog
9a1cdcb8a0 Format: fixing multiple string concatenation in single line (#106013)
Fixing multiple string concatenation in single line
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106013
Approved by: https://github.com/albanD
2023-07-26 18:39:18 +00:00
Daniel Dale
6b6702f506 Enhance no_grad-context FSDP backward handling (#105374)
Fixes #105369
Fixes #105371

Addressing two somewhat distinct issues that involve the same test in this PR:

1. To fix #105369:
    - Add a `no_grad` guard to [`_register_post_backward_reshard_only_hooks`](93f852f201/torch/distributed/fsdp/_runtime_utils.py (L1406)) to avoid registering post-backward hooks that would not be removed in that context.

2. To fix #105371:
    - Add a `grad` context condition to [`_use_sharded_flat_param`](93f852f201/torch/distributed/fsdp/flat_param.py (L1645C9-L1645C32)) logic to trigger post-forward `_use_sharded_views` in a `no_grad` context for `NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105374
Approved by: https://github.com/awgu
2023-07-26 14:12:13 +00:00
Andrew Gu
c099b80073 [FSDP] Add record_function for explicit prefetching (#105985)
Example:
<img width="568" alt="Screenshot 2023-07-25 at 7 41 43 PM" src="https://github.com/pytorch/pytorch/assets/31054793/5f3f07b3-97f4-4493-9cab-5619484e2f6d">

This can be particularly help when `with_stack=False`, in which case it is harder to tell the prefetch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105985
Approved by: https://github.com/fegin
2023-07-26 12:16:35 +00:00
Andrew Gu
a9a3c45649 Revert "Simplify handle indexing (#105006)" (#105984)
This reverts commit 429d45f91a.

Unfortunately, https://github.com/pytorch/pytorch/pull/105006 broke backward prefetching (where backward prefetching working correctly was not captured in our unit tests).

I need more time to dig into this (tomorrow), but I think the issue is related to:
429d45f91a (diff-9a6937168d232432c34c2c4605b96f3147afa2786e287f74b6074b20aa5980e6R143-R146)

Follow-ups:
1. Investigate this thoroughly
2. Add unit tests to capture backward prefetch functionality
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105984
Approved by: https://github.com/fegin
2023-07-26 12:12:14 +00:00
Matthew Hoffman
0616952d13 Merge and improve torch optim optimizer type stubs (#102593)
Fixes #102428

Also improves hook registration type hints:

```python
from typing import Any, Dict, Tuple

from torch import nn
from torch.optim import Adam, Adagrad, Optimizer

linear = nn.Linear(2,2)
optimizer = Adam(linear.parameters(), lr=0.001)

def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def pre_hook_fn_return_modified(
    optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
    return inputs, kwargs

def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

optimizer.register_step_post_hook(hook_fn)  # OK

optimizer.register_step_pre_hook(pre_hook_fn_return_none)  # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_modified)  # OK

optimizer.register_step_post_hook(hook_fn_other_optimizer)  # Parameter 1: type "Adam" cannot be assigned to type "Adagrad"

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593
Approved by: https://github.com/janeyx99, https://github.com/malfet
2023-07-26 11:56:42 +00:00
Rohan Varma
a326f5621e composable fsdp, checkpoint, + compile test (#105180)
Test to ensure that composable FSDP, checkpoint, and compile all work
together. Includes a change from https://github.com/pytorch/pytorch/pull/105090
which we can land in that PR first.

Differential Revision: [D47452973](https://our.internmc.facebook.com/intern/diff/D47452973/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105180
Approved by: https://github.com/awgu
2023-07-26 07:03:09 +00:00
Rohan Varma
5d70fe0165 [Composable] Use non-reentrant generator, remove reentrant (#105176)
Removes reentrant support for the composable checkpoint, as
non-reentrant is the recommended approach and we should use this when rolling
out composable checkpoint API.

Also removes the standalone implementation for non-reentrant and instead uses
the generator from below diff to reuse the original implemenetation.

Differential Revision: [D47451375](https://our.internmc.facebook.com/intern/diff/D47451375/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105176
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-26 07:03:03 +00:00
fduwjj
0003d5135d [TP] Enable partial tensor add without redistribute (#105939)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105939
Approved by: https://github.com/wanchaol
2023-07-26 03:12:39 +00:00
Albert Chen
b65b9e6ff4 [PT][FSDP] Combine _utils.py into _common_utils.py [1/3] (#105857)
Summary:
https://github.com/pytorch/pytorch/issues/97813

This diffs moves `_override_module_mixed_precision`

Test Plan: CI

Differential Revision: D47706059

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105857
Approved by: https://github.com/awgu
2023-07-25 17:37:08 +00:00
Andrew Gu
c9edf11073 [FSDP][Docs] Make model/optim state dict configs visible in docs (#105848)
This closes https://github.com/pytorch/pytorch/issues/104717.

Rendered docs:
![Screenshot 2023-07-25 at 11 15 23 AM](https://github.com/pytorch/pytorch/assets/31054793/3c38166a-70c0-472c-805d-452d3bd9c700)
![Screenshot 2023-07-25 at 11 15 30 AM](https://github.com/pytorch/pytorch/assets/31054793/6d275d94-020a-44a2-a64c-0eeba083d47f)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105848
Approved by: https://github.com/rohan-varma
2023-07-25 16:23:53 +00:00
Michael Voznesensky
487a33e38a [FSDP x dynamo] simplify registry keys (#104209)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104209
Approved by: https://github.com/wconstab, https://github.com/fegin
2023-07-25 07:16:22 +00:00
Jon Bolin
1032a2541e Add option to disable rewriting index hints in default global save plan (#105861)
With distributed checkpointing in PyTorch/XLA SPMD, the WriteItem index hints should not be modified when creating the global plan. In order to reuse the default planner logic for checkpoint metadata creation, we need to make the behavior of rewriting index hints optional.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105861
Approved by: https://github.com/kumpera
2023-07-25 06:00:13 +00:00
Louis Feng
3a01c056f5 [PyTorch][ET] Collect Process Groups Mapping Info (#104373)
Summary: Add the logics and interface to log ProcessGroup comms configuration (unique ID, type, and ranks info).

Test Plan:
Testing in HPC:
```
TORCH_LOGS=all ../buck-out/v2/gen/fbcode/c8344b52091f4f7f/hpc/models/ads/__ads_10x_launcher__/ads_10x_launcher.par  +launcher=local launcher.num_trainers=4 +data_loader=random data_loader.num_batches=2000
```
Example output in ET:
```
    {
      "name": "## process_group:init ##", "id": 3, "rf_id": 1, "parent": 2, "fw_parent": 0, "seq_id": -1, "scope": 7, "tid": 1, "fw_tid": 0, "op_schema": "",
      "inputs": ["[{'pg_id': 140538064364672, 'backend_id': 140538060772480, 'backend_config': 'cuda:nccl', 'ranks': {0: 0, 1: 1, 2: 2, 3: 3}}, {'pg_id': 140538064363904, 'backend_id': 140538042628864, 'backend_config': 'cuda:nccl', 'ranks': {0: 0, 1: 1, 2: 2, 3: 3}}]"], "input_shapes": [[]], "input_types": ["String"],
      "outputs": [], "output_shapes": [], "output_types": []
    },
```

Differential Revision: D46321690

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104373
Approved by: https://github.com/kwen2501
2023-07-25 03:34:53 +00:00
Andrew Gu
6655b6527a [FSDP][Docs] Tidy up FSDP ctor/api docs (#105847)
- This PR rewords the `BackwardPrefetch` docs to make the tradeoffs clear in the first sentence of each with more technical details after.
- The only supported `_FSDPPolicy` is `ModuleWrapPolicy` at the time of writing this PR. We may add others in the future such as in my other PR stack. This PR removes `_FSDPPolicy` from the public docs.
- This provides some more details around `MixedPrecision` such as explaining that layer norm and batch norm accumulate in fp32.

Follow-ups:
- Why do we force batch norm modules to have FSDP applied separately? (E.g. was this because before batch norm kernels did not support fp16/bf16?) Like layer norm, this just means that the affine parameters are in fp32. Both already accumulate in fp32 even with fp16/bf16 inputs.
- Check the `param_init_fn` + `sync_module_states=True` usage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105847
Approved by: https://github.com/rohan-varma
2023-07-25 00:19:08 +00:00
Howard Huang
0ab74044c2 [BE] remove deprecated attributes from distributed_c10d (#105753)
Removing these attributes as they were introduced 5 years ago and before pytorch 1.0. `Backend` is the only support use now.

Differential Revision: [D47683717](https://our.internmc.facebook.com/intern/diff/D47683717)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105753
Approved by: https://github.com/rohan-varma
2023-07-24 16:35:08 +00:00
Wanchao Liang
e3539a0e54 [dtensor] forward fix for dynamo import with deploy (#105760)
Summary: forward fix to avoid revert

Differential Revision: D47679598

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105760
Approved by: https://github.com/atalman
2023-07-23 07:13:38 +00:00
Aaron Gokaslan
6d43c89f37 [BE]: Update Ruff to 0.0.280 (#105724)
Removes unusued loop values in python dictionary iteration. Automated fix from Ruff master

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105724
Approved by: https://github.com/ezyang, https://github.com/janeyx99
2023-07-22 23:03:34 +00:00
Andrew Gu
221853af23 [FSDP][Easy] nit follow-ups to handle refactor (#105738)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105738
Approved by: https://github.com/fegin, https://github.com/voznesenskym
2023-07-21 22:00:14 +00:00
Iris
6b2d48e78c [8/n][FSDP] make use_dtensor=True work with offload_to_cpu=True for optim.load_state_dict() (#105690)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105690
Approved by: https://github.com/fegin
2023-07-21 18:55:01 +00:00
Michael Voznesensky
429d45f91a Simplify handle indexing (#105006)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105006
Approved by: https://github.com/awgu
2023-07-21 05:53:23 +00:00
Michael Voznesensky
a832967627 Migrate tuple(handle) -> handle (#104488)
We strengthen the invariant that one FSDP managed module has one flatparameter, and remove unused code that would have supported 1:many module to flatparam mapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104488
Approved by: https://github.com/awgu
2023-07-19 22:33:35 +00:00
Iris
c54f630201 [7/n][FSDP] make use_dtensor=True work with offload_to_cpu=True for load_state_dict (#105378)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105378
Approved by: https://github.com/fegin
2023-07-19 21:36:37 +00:00
Mo Mo
7b56238551 fix typo (#105507)
Differential Revision: D47568928

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105507
Approved by: https://github.com/awgu, https://github.com/fduwjj
2023-07-19 20:34:43 +00:00
Wanchao Liang
f139aab2f4 [dynamo] add initial dynamo support for DTensor (#103146)
This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`

`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.

Captured graph:
```
    def forward(self, L_x_ : torch.Tensor):
        l_x_ = L_x_

        # File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
        prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False);  l_x_ = None

        # File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
        prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local);  prim_from_local = None
        to_local = prim_redistribute.to_local();  prim_redistribute = None
        add = to_local + 2;  to_local = None
        return (add,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym
2023-07-19 16:01:12 +00:00
Justin Chu
232b96b6e2 [BE] Enable ruff's UP rules and autoformat distributed/ (#105433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105433
Approved by: https://github.com/albanD
2023-07-19 14:27:11 +00:00
Andrew Gu
e983625f22 [FSDP] Fix skip-sharded-views + mixed precision (#105346)
This fixes https://github.com/pytorch/pytorch/issues/104504.

- When not using full-precision eval, the relevant fix is to force `_use_sharded_views()` calls if needed in `SUMMON_FULL_PARAMS` training state.
- When using full-precision in eval, the relevant fix is tracking what was the unsharded flat parameter from which the unsharded views were computed and using that instead of determining the unsharded flat parameter from the calling context via `_get_padded_unsharded_flat_param()`.

This also fixes https://github.com/pytorch/pytorch/issues/104770.
<details>
<summary> Print output showing parity </summary>

```
Key: 0
Model 1: [-1.5, 6.40625, -0.9453125, -0.3828125, 0.16015625, -1.5078125]
Model 2: [-1.5, 6.40625, -0.9453125, -0.3828125, 0.16015625, -1.5078125]

Key: 1
Model 1: [0.0157470703125, -0.8828125, 5.65625, 1.1328125, 0.275390625, 0.11181640625]
Model 2: [0.0157470703125, -0.8828125, 5.65625, 1.1328125, 0.275390625, 0.11181640625]

Key: 2
Model 1: [0.1689453125, -0.00567626953125, -0.09375, 7.34375, -0.18359375, -0.09521484375]
Model 2: [0.1689453125, -0.00567626953125, -0.09375, 7.34375, -0.18359375, -0.09521484375]

Key: 3
Model 1: [0.546875, -0.8984375, 0.228515625, 0.7578125, 6.0625, 0.435546875]
Model 2: [0.546875, -0.8984375, 0.228515625, 0.7578125, 6.0625, 0.435546875]

Key: 4
Model 1: [-0.66796875, -0.88671875, 0.30078125, 0.06494140625, 0.412109375, 6.9375]
Model 2: [-0.66796875, -0.88671875, 0.30078125, 0.06494140625, 0.412109375, 6.9375]

Key: 5
Model 1: [0.07763671875, 0.8671875, -0.43359375, 0.5703125, 0.76171875, -0.0089111328125]
Model 2: [0.07763671875, 0.8671875, -0.43359375, 0.5703125, 0.76171875, -0.0089111328125]

Key: 6
Model 1: [-0.283203125, -0.361328125, 0.474609375, 0.10205078125, 1.125, -0.0859375]
Model 2: [-0.283203125, -0.361328125, 0.474609375, 0.10205078125, 1.125, -0.0859375]

Key: 7
Model 1: [1.140625, 0.62890625, -0.07568359375, -1.0390625, -0.2578125, -0.053955078125]
Model 2: [1.140625, 0.62890625, -0.07568359375, -1.0390625, -0.2578125, -0.053955078125]

Key: 8
Model 1: [0.68359375, -1.09375, 0.59375, 1.0, -0.23828125, 0.578125]
Model 2: [0.68359375, -1.09375, 0.59375, 1.0, -0.23828125, 0.578125]

Key: 9
Model 1: [0.515625, 0.296875, -0.1826171875, -0.12890625, -0.51953125, -0.3359375]
Model 2: [0.515625, 0.296875, -0.1826171875, -0.12890625, -0.51953125, -0.3359375]
```

</details>

Follow-ups:
- I suspect that for `SHARD_GRAD_OP`, train forward -> eval forward when using full-precision in eval will not free the low-precision unsharded parameters from the train forward, resulting in 1.5x unsharded parameter memory.

Differential Revision: [D47527597](https://our.internmc.facebook.com/intern/diff/D47527597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105346
Approved by: https://github.com/fegin, https://github.com/rohan-varma
2023-07-18 23:13:53 +00:00
Wanchao Liang
cb23373264 [dynamo] allow tensor subclass fakification in dynamo (#105308)
This PR adds necessary plumbing through torchdynamo to allow tensor
subclasses with certain contract (i.e. with `__tensor_flatten__` and
`__tensor_unflatten__`) to goes through the dynamo fakification pass by
fakifying the tensor subclass internal components.

Some of the tensor subclass contract logic mostly borrowed from
https://github.com/pytorch/pytorch/pull/97540

Added some tests to verify simply passing through a tensor subclass
(i.e. DTensor) through dynamo eager works as expected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105308
Approved by: https://github.com/ezyang
2023-07-18 17:28:04 +00:00
Wanchao Liang
bcb9ca4e5a [dtensor] canonicalize detach callsites and use view_as when appropriate (#105239)
This PR canonicalize the detach callsite to only call the detach
from `distribute_tensor`. Change other callsite to view_as and remove the
tensor constructor detach call

This is so that we don't detach local tensor for every op run when
rewrapping the DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105239
Approved by: https://github.com/albanD
2023-07-18 17:13:37 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
Richard Barnes
15ea0a00cb Fix RRef type annotations (#104876)
Test Plan: Sandcastle

Reviewed By: H-Huang

Differential Revision: D47334579

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104876
Approved by: https://github.com/H-Huang
2023-07-14 17:31:51 +00:00
PyTorch MergeBot
1646d6f939 Revert "Merge and improve torch optim optimizer type stubs (#102593)"
This reverts commit 3279f06410.

Reverted https://github.com/pytorch/pytorch/pull/102593 on behalf of https://github.com/malfet due to There is nothing wrong with this PR, but it fails some internal builds that depend on outdated typing_extensions, will reland when update is done ([comment](https://github.com/pytorch/pytorch/pull/102593#issuecomment-1636062515))
2023-07-14 16:04:54 +00:00
PyTorch MergeBot
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
PyTorch MergeBot
b4d91b1c5b Revert "[Typing] Fix PEP 484 Violation (#105022)"
This reverts commit 4148b7bada.

Reverted https://github.com/pytorch/pytorch/pull/105022 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/105022#issuecomment-1635967734))
2023-07-14 14:45:09 +00:00
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
Rohan Varma
242fc29c96 [FSDP] Refactor optimizer in backward (#104813)
1) Use zero_grad(set_to_none=True) to set grad to None, 2) call
prepare_grad_for_optim() before call to .step, 3) use
_reset_flat_param_grad_info to set flat param gradient back to None. These
changes should just be refactors and equivalent to how gradient memory was
managed  before.

Differential Revision: [D47310761](https://our.internmc.facebook.com/intern/diff/D47310761/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104813
Approved by: https://github.com/awgu
2023-07-13 06:42:53 +00:00
Rohan Varma
f2eed129c4 FSDP optimizer overlap (#98667)
constraints:

1. No support for gradient accumulation
2. CPU offload runs step() on CPU. In future PRs ideally we'd run this on GPU.
3. When CPU offload + optimizer overlap, we have to copy the flat_param grad to CPU with non_blocking=False, otherwise step() might run on invalid data.
4. Step is waited on in post backward final cb, when in theory it can wait until the next forward.

Differential Revision: [D44809582](https://our.internmc.facebook.com/intern/diff/D44809582/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98667
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-13 06:42:53 +00:00
PyTorch MergeBot
5b4aacd691 Revert "[DCP] Add FsspecReader and FsspecWriter to checkpoint __init__.py (#105088)"
This reverts commit 76a053d55c.

Reverted https://github.com/pytorch/pytorch/pull/105088 on behalf of https://github.com/atalman due to broke trunk and  linux-focal-py3.9-clang7-asan ([comment](https://github.com/pytorch/pytorch/pull/105088#issuecomment-1633385350))
2023-07-13 00:59:55 +00:00
Andrew Gu
954bae8e53 [FSDP][Easy] Rename streams; add back stream sharing test (#104966)
Purely out of preference, this PR renames the streams to `_unshard_stream` instead of `_streams_unshard` etc. since the former reads more naturally. The PR also removes some duplicated comments and adds back a unit test that streams are shared.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104966
Approved by: https://github.com/rohan-varma
2023-07-13 00:24:41 +00:00
Iris
4f8ba6f8f6 [DeviceMesh]Add validate mesh flag to DeviceMesh (#104807)
When creating DeviceMesh, _init_process_group() would validate that all calling ranks pass in the same `mesh` argument. In FSDP, we are currently creating the DeviceMesh based on the pg of the root state so the mesh will always be valid. Adding the flag to DeviceMesh, so we can skip the all_gather_tensor of the validation during construction time.

_validate_mesh is default to True, but we manually flip it to False when initializing device mesh in FSDP's  _runtime_utils.py.

Will modify skipping pg creation if existed for both 1D and 2D cases and then delete _init_process_groups flag in a follow up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104807
Approved by: https://github.com/wanchaol
2023-07-12 23:42:13 +00:00
Iris
76a053d55c [DCP] Add FsspecReader and FsspecWriter to checkpoint __init__.py (#105088)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105088
Approved by: https://github.com/kumpera
2023-07-12 23:40:35 +00:00
Nikita Shulga
4148b7bada [Typing] Fix PEP 484 Violation (#105022)
Not sure, how it worked before, but if arguments must be annotated is optional if they are defaulted to None

Towards enabling mypy-1.4.1 in lintrunner

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 5e1b9f4</samp>

> _We annotate the arguments of doom_
> _To show the `None` values of gloom_
> _We improve the type checking and readability_
> _With `Optional` annotations of metal-ity_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105022
Approved by: https://github.com/izaitsevfb, https://github.com/huydhn, https://github.com/Skylion007
2023-07-12 10:20:48 +00:00
Aaron Gokaslan
2f95a3d0fc [BE]: Apply ruff PERF fixes to torch (#104917)
Applies automated ruff fixes in the PERF modules and enables all automatic ones. I also updated ruff which applied some additional fixes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104917
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-07-11 20:45:21 +00:00
Andrew Gu
63d1fb21f5 [FSDP] Default limit_all_gathers=True (#104900)
This PR defaults to `limit_all_gathers=True`.

I included a `record_function()` for the rate limiter synchronization to help with user confusion on the gap in the pre-forward:
<img width="874" alt="Screenshot 2023-07-10 at 3 28 18 PM" src="https://github.com/pytorch/pytorch/assets/31054793/61f55e0e-58d7-4162-9395-bea06d3e8d8a">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104900
Approved by: https://github.com/fegin
2023-07-11 01:04:29 +00:00
Matthew Hoffman
3279f06410 Merge and improve torch optim optimizer type stubs (#102593)
Fixes #102428

Also improves hook registration type hints:

```python
from typing import Any, Dict, Tuple

from torch import nn
from torch.optim import Adam, Adagrad, Optimizer

linear = nn.Linear(2,2)
optimizer = Adam(linear.parameters(), lr=0.001)

def pre_hook_fn_return_none(optimizer: Adam, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def pre_hook_fn_return_modified(
    optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
    return inputs, kwargs

def hook_fn(optimizer: Optimizer, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

def hook_fn_other_optimizer(optimizer: Adagrad, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
    return None

optimizer.register_step_post_hook(hook_fn)  # OK

optimizer.register_step_pre_hook(pre_hook_fn_return_none)  # OK
optimizer.register_step_pre_hook(pre_hook_fn_return_modified)  # OK

optimizer.register_step_post_hook(hook_fn_other_optimizer)  # Parameter 1: type "Adam" cannot be assigned to type "Adagrad"

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102593
Approved by: https://github.com/janeyx99
2023-07-11 00:07:30 +00:00
fduwjj
aa84078c6c [PTD][TP] Add BWD support for colwise embedding sharding (#104820)
Originally, we didn't enable BWD for colwise embedding because we thought it was just for inference, but it turns out that we do need it for training. So, let's enable it for now and unit test is also added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104820
Approved by: https://github.com/fegin
2023-07-10 22:33:20 +00:00
Iris Zhang (PyTorch)
7b538d8987 [DCP][fsspec] Consolidate OSS FsspecWriter/Reader and internal FsspecWriter/Reader (#104724)
Summary:
This diff does the following:
1. re-enable single_file_per_rank for FsspecWriter, as the issue of file slicing error is resolved because of [https://github.com/pytorch/pytorch/pull/99167]
2. remove sync_files from FsspecWriter as there is no fsspec equivalence.
3. remove the internal implementation of FsspecWriter/Reader, as it has been upstreamed to PyTorch OSS
4. keep the internal test for manifold inside internal as we can only test it in fb environment
5. consolidate test to remove duplicates
6. remove unnecessary TARGETS

Test Plan:
```
buck test @//mode/dev-nosan  //caffe2/test/distributed/checkpoint/fb:test_fsspec_filesystem -- --print-passing-details

----------------------------------------------------------------------
Ran 1 test in 54.894s

OK
/usr/local/fbcode/platform010/lib/python3.8/tempfile.py:818: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmpzomokvh6'>
  _warnings.warn(warn_message, ResourceWarning)

Buck UI: https://www.internalfb.com/buck2/4cb722a2-3ee7-48f2-a9ef-55ee6fb1a498
Test UI: https://www.internalfb.com/intern/testinfra/testrun/8725724447995201
Network: Up: 8.8 MiB  Down: 1.5 GiB  (reSessionID-04c29f56-ae94-4187-8a1a-c812f432674d)
Jobs completed: 209847. Time elapsed: 1:56.5s.
Cache hits: 100%. Commands: 85687 (cached: 85687, remote: 0, local: 0)
Tests finished: Pass 3. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Differential Revision: D47266068

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104724
Approved by: https://github.com/fegin, https://github.com/fduwjj
2023-07-10 19:31:01 +00:00
Mikayla Gawarecki
1ad435772b Added option to always call nn.Module global/non-global forward hooks (#104278)
Fix #103997

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104278
Approved by: https://github.com/albanD
2023-07-10 18:58:07 +00:00
Jane Xu
e25f5732c8 Add meta registrations and distributed decomps: _foreach_div_.Scalar, sqrt_.default (#104779)
This PR unblocks #104780 by resolving spmd tracing test issues and by adding meta registrations for foreach inplace ops (div_ and sqrt_)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104779
Approved by: https://github.com/fegin, https://github.com/albanD
2023-07-10 17:38:46 +00:00
Iris
af52f6b928 [DCP] Add documentation for HSDP saving using DCP (#104810)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104810
Approved by: https://github.com/fduwjj
2023-07-10 17:33:05 +00:00
Chien-Chin Huang
46154c4c35 [FSDP][optim_state_dict] The correct way to initialize optimizer states if the corresponding param is empty (#104765)
When using KeyedOptimizer.init_state(), some optimizers initializes the states even if the param is empty (size() == 0) while some optimizer avoid initializing the states. There is no way FSDP can tell. Instead, FSDP should look up `optim.state`. Fortunatelly, `optim.state` does not rely on FQNs which some internal users change the FQNs.

Differential Revision: [D47285562](https://our.internmc.facebook.com/intern/diff/D47285562/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104765
Approved by: https://github.com/fduwjj
2023-07-10 08:00:55 +00:00
Andrew Gu
e600505e32 [FSDP][5/N] Unblock ignored_states + auto wrap (for now) (#104418)
The "for now" is because we still have the issue that when using the parameter `ignored_states` path, we do not recover the ignored modules, so FSDP still wraps those as empty shells (no managed parameters), which is not ideal. This is not a blocking issue as far as I know.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104418
Approved by: https://github.com/rohan-varma
2023-07-08 12:40:14 +00:00
Andrew Gu
610f74627e [FSDP][4/N] Remove _get_fully_sharded_module_to_states (#104409)
`_get_fully_sharded_module_to_states()` was used to emulate auto wrapping without actually calling `fully_shard`. Since we committed to unifying (see previous PR), we can remove this function and its helpers/tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104409
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-07-08 12:40:14 +00:00
Andrew Gu
d9be0366d3 [FSDP][3/N] Unify fully_shard auto wrap (#104408)
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.

This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104408
Approved by: https://github.com/rohan-varma
2023-07-08 12:40:12 +00:00
Andrew Gu
6d71b4f9f1 [FSDP][2/N][Easy] Prepare _auto_wrap for fully_shard (#104407)
This mainly just changes the `_auto_wrap()` function signature and generalizes the `_check_nested_wrapping()` to both wrapper and composable paths (though the composable path will not hit in this PR).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104407
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-07-08 12:40:09 +00:00
Andrew Gu
d58f75be8b [FSDP][1/N] Move wrapper ModuleWrapPolicy to new path (#104346)
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support:
- Checking valid frozen parameters (uniform frozenness per FSDP)
- Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher)
- Writing auto wrapping policies that may take multiple passes over the module tree
- Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy)

The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping.

I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple.

The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)

To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104346
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-07-08 12:40:07 +00:00
Rohan Varma
0bf39d5663 [FSDP] Option for eval in fp32/bf16 (#104682)
In https://github.com/pytorch/pytorch/pull/97645 and some follow up diffs, we made FSDP run in full precision in eval mode, even if mixed precision was specified.

However, this is probably not the best idea and we should provide a flag for users to have control over this a bit more. Adding an env var FSDP_FULL_PREC_IN_EVAL and defaulting it to off, users who want to run eval in fp32 can toggle this before wrapping model in FSDP:

os.environ["FSDP_FULL_PREC_IN_EVAL"] = "1"

Verified that unittests, APS workflow, TNT workloads can run eval appropriately with this change.

Differential Revision: [D47246556](https://our.internmc.facebook.com/intern/diff/D47246556/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104682
Approved by: https://github.com/awgu
2023-07-07 08:14:23 +00:00
Will Constable
d64bada876 Refactor funcol for readability and dynamo tracing (#104387)
Move eager kernel impls to separate file, which is eaiser to read
(since users may be confused about 2 versions of each kernel in the same file)
and easier to set a dynamo policy to trace only the first file currently.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104387
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/kumpera
2023-07-06 23:29:49 +00:00
Andrew Gu
6c1d959889 [FSDP] Annotate modules for fully_shard (#104363)
This annotates modules managed by `fully_shard` for TorchDynamo to treat them specially.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104363
Approved by: https://github.com/fegin
2023-07-06 16:56:59 +00:00
Rodrigo Kumpera
17ab4f85e9 [c10d] Adopt allgather_into_tensor_coalesced for NCCL. (#103086)
This is done by adding c10d::_allgather_into_tensor_coalesced wrapper.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103086
Approved by: https://github.com/rohan-varma
2023-07-06 15:05:55 +00:00
Wanchao Liang
db1ac4e29b fix functional collective's allgather for gloo (#104681)
Summary: We should explicitly check for the gloo backend instead of relying on the shard's device, because user might pass a GPU tensor as input and a process group gloo as the pg, and expect that should work.

Differential Revision: D47249172

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104681
Approved by: https://github.com/rohan-varma, https://github.com/fduwjj
2023-07-06 09:52:48 +00:00
Iris
434fcffa21 [6/n][FSDP] Update _sharded_pre_load_state_dict_hook to use DTensor when use_dtensor=True in ShardedStateDictConfig (#104087)
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.

Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
2023-07-06 05:36:19 +00:00
PyTorch MergeBot
fcb53c1394 Revert "[6/n][FSDP] Update _sharded_pre_load_state_dict_hook to use DTensor when use_dtensor=True in ShardedStateDictConfig (#104087)"
This reverts commit 49af83cf44.

Reverted https://github.com/pytorch/pytorch/pull/104087 on behalf of https://github.com/huydhn due to This is failing in trunk 49af83cf44, probably due to a land race ([comment](https://github.com/pytorch/pytorch/pull/104087#issuecomment-1615608189))
2023-07-01 07:50:31 +00:00
Iris
49af83cf44 [6/n][FSDP] Update _sharded_pre_load_state_dict_hook to use DTensor when use_dtensor=True in ShardedStateDictConfig (#104087)
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.

Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
2023-07-01 01:02:59 +00:00
Andrew Gu
d982fdb5d5 [FSDP] Rework meta device init (#104189)
This addresses https://github.com/pytorch/pytorch/issues/104187.

After this PR, the contract with the user is that:
- If passing `param_init_fn=None`, each `nn.Module.reset_parameters()` should only initialize its own parameters/buffers (like `parameters(recurse=False)`/`buffers(recurse=False)`).
- If passing `param_init_fn` not equal to `None`, then similarly, one call to `param_init_fn(module)` should only initialize `module`'s own parameters/buffers.

With this contract and this PR's changes, meta device initialization through either `reset_parameters()` or `param_init_fn` should be correct. Those functions will run on the original parameter/buffer shapes allowing for correct shape-dependent computations like for fan-in/fan-out, and there will not be any re-initialization of any module.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104189
Approved by: https://github.com/rohan-varma
2023-07-01 00:25:12 +00:00
Xilun Wu
e799f565eb [DTensor][TP][Random] Introduce TensorParallelRNGTracker to integrate parallel RNG state with Tensor Parallel (#103910)
This PR enables the automatic use of `TensorParallelRNGTracker` in Tensor Parallel api. Some unit tests are going to be added to cover.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103910
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-06-30 08:06:41 +00:00
Wanchao Liang
da06920f47 Replace all_gather in device mesh with functional collective equivalent (#104056)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104056
Approved by: https://github.com/kumpera, https://github.com/wanchaol
2023-06-30 05:30:02 +00:00
Wanchao Liang
8457703e8d lazy init device mesh in fsdp (#104447)
since fsdp state is lazy init, we also need to lazy init device mesh
otherwise devicemesh allgather check would trigger some mismatch in
allgather counts in fsdp tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104447
Approved by: https://github.com/wconstab
2023-06-30 04:40:16 +00:00
Will Constable
d0509fe32d Document how functional collectives work under eager/dynamo (#104386)
Move user facing apis to the top for best visibility
(strictly code-motion in this PR, besides adding comments)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104386
Approved by: https://github.com/voznesenskym, https://github.com/wanchaol
2023-06-30 01:12:55 +00:00
Rohan Varma
60e2a4a4a0 [2D parallel] workaround for FSDP init issue (#104398)
Closes https://github.com/pytorch/pytorch/issues/96491 and does so by relaxing FSDP's assumption that the entire input module must be on the same device. Now, FSDP can accept a module partially on CPU and GPU and just emits a warning.

Differential Revision: [D47117256](https://our.internmc.facebook.com/intern/diff/D47117256/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104398
Approved by: https://github.com/fegin
2023-06-29 16:07:07 +00:00
Rohan Varma
c866446d6c [FSDP] Check module.training for _root_cast_forward_inputs (#104223)
We might erroneously cast forward inputs for the root if it doesn't
manage any handles (FSDP parameters). As a fix, pass in the module and check
its training attribute to ensure we don't cast inputs in eval mode.

Differential Revision: [D47041673](https://our.internmc.facebook.com/intern/diff/D47041673/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104223
Approved by: https://github.com/fegin
2023-06-28 16:38:01 +00:00
Andrew Gu
6493519fff [Easy][FSDP] Remove misleading asserts (#104274)
Since we do not call `_FSDPState.__init__()` and only use it for typing, it is not possible for these attributes to be `None`. The purpose of these `assert`s is to make sure that these attributes are set by `_init_process_group_state_for_hybrid_shard()`. If we care to make that explicit, I would posit that we should be using `hasattr` checks, not `is not None` checks, because if indeed `_init_process_group_state_for_hybrid_shard()` did not set these attributes, then even checking that it is not `None` would lead to an `AttributeError`. I do not include these `hasattr` checks for now since `_init_process_group_state_for_hybrid_shard()` is short enough that we can quickly tell by inspection that it sets the desired attributes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104274
Approved by: https://github.com/rohan-varma
2023-06-28 11:08:47 +00:00
Andrew Gu
ba9f6e6e92 [FSDP] Validate ignored_modules, ignored_states (#104273)
This checks that `ignored_modules` and `ignored_states` have the expected type and provides a reasonable error message if not. Otherwise, if someone passes a mix of modules and parameters to `ignored_states` for example, then our code may be silently incorrect.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104273
Approved by: https://github.com/rohan-varma
2023-06-28 11:08:47 +00:00
Andrew Gu
cc27e6c0f9 [FSDP] Fix ignored_states doc (#104253)
This fixes https://github.com/pytorch/pytorch/issues/104246.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104253
Approved by: https://github.com/rohan-varma
2023-06-28 11:08:45 +00:00
Andrew Gu
9db8ad7f1d [FSDP] Support unfreezing params for reshard-only hook (#104186)
This fixes https://github.com/pytorch/pytorch/issues/104148 (unfreezing parameters after `n` steps).

- This fixes a bug where we did not delete the post-backward hook state properly for the `requires_grad=False` case.
- This makes the `already_resharded` correct for `SHARD_GRAD_OP`.
- This generalizes `_clear_grads_if_needed()` to `_reset_flat_param_grad_info_if_needed()` to additionally include propagating the original parameters' `requires_grad` to the flat parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104186
Approved by: https://github.com/rohan-varma, https://github.com/fegin
2023-06-28 11:04:57 +00:00
shibo19
c2095af3f8 make funcs argument type from torch.cuda.stream as torch.Stream (#104156)
Fixes #ISSUE_NUMBER
1. we want to support fsdp for custom device, so we make funcs argument type from torch.cuda.stream as torch.Stream
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104156
Approved by: https://github.com/awgu
2023-06-28 06:02:56 +00:00
Xilun Wu
a66107a30c [DTensor][Random] Introduce CudaRNGStateTracker to maintain parallel RNG state for DTensor (#103235)
# Change
This PR adds two classes to DTensor:

1. `CudaRNGStateTracker`:  `CudaRNGStateTracker` stores Random Number Generator (RNG) state (a `ByteTensor` object) in a `dict`, mapping from a corresponding tag to each state tensor. It also provides a set of convenient utility methods to help access/modify the state tensors. The most important interface is `_distribute_region` which will be used when DTensor executes a random op (an operator that calls RNG).

2. `OffsetBasedRNGTracker`: This subclass of `CudaRNGStateTracker` defines the default policy of how RNG states should be shared and synchronized among all ranks to respect the semantics of DTensor random operators.

# Warning

- With `Multi-threaded ProcessGroup`, the global variable `_rng_tracker` will be shared among threads(ranks) and cause issue. We need to figure out a compatible solution for that.

- The RNG state may be asynchronous outside of participating ranks. It is harmless in our current use case of submesh though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103235
Approved by: https://github.com/wanchaol
2023-06-27 19:00:25 +00:00
Amr Elshennawy
968b7b5e0f Initial commit of collective_utils (#101037)
Summary:
Details in T133020932
First commit of collective utils library. Ported over from model store, removed scuba logging, error_trait and all dependencies on modelstore.

Test Plan: In the following diffs.

Differential Revision: D45545970

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101037
Approved by: https://github.com/H-Huang
2023-06-27 02:15:16 +00:00
Rodrigo Kumpera
c17bdb3247 [C10D] Add functional collective reduce_scatter_into_tensor_coalesced. (#101023)
Implementation uses a fallback that does no coalescing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101023
Approved by: https://github.com/wanchaol
2023-06-23 19:24:11 +00:00
fduwjj
23b7035b3c [TP] Add an input resharding wrapper for TP and unit test for 2D + AC (#103334)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103334
Approved by: https://github.com/kumpera
2023-06-23 04:05:01 +00:00
Chien-Chin Huang
1c33c398c7 [FSDP][state_dict] Add a summary log when finishing state_dict (#103784)
Add a summary log when finishing state_dict

Differential Revision: [D46807103](https://our.internmc.facebook.com/intern/diff/D46807103/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103784
Approved by: https://github.com/fduwjj
2023-06-22 16:29:24 +00:00
Iris
613970eb05 [5/n][FSDP] Update _sharded_post_state_dict_hook to use DTensor when use_dtensor=True in state_dict_config (#103921)
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.state_dict().

load_state_dict hooks updates will be in next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103921
Approved by: https://github.com/fduwjj, https://github.com/fegin
2023-06-22 08:32:19 +00:00
Andrew Gu
ec8aa6e592 [Easy][FSDP] Fix "column" -> "row" in PG example (#103975)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103975
Approved by: https://github.com/fduwjj
2023-06-21 20:41:50 +00:00
Chien-Chin Huang
a2d001d4dd [FSDP][state_dict] Use _get_module_fsdp_state_if_fully_sharded_module for state_dict (#103783)
Fix https://github.com/pytorch/pytorch/issues/90788
Use a consistent implementation as optim_state_dict

Differential Revision: [D46807090](https://our.internmc.facebook.com/intern/diff/D46807090/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103783
Approved by: https://github.com/awgu, https://github.com/fduwjj
2023-06-21 20:31:30 +00:00
Rodrigo Kumpera
0beec88c93 Inductor support for all_gather_into_tensor_coalesced. (#98643)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98643
Approved by: https://github.com/wanchaol
2023-06-21 19:25:03 +00:00
Chien-Chin Huang
6b1d6750b9 [FSDP][state_dict][BE] Remove outdated and fixed TODOs (#103782)
Remove outdated and fixed TODOs

Differential Revision: [D46807071](https://our.internmc.facebook.com/intern/diff/D46807071/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103782
Approved by: https://github.com/rohan-varma
2023-06-21 05:41:19 +00:00
Chien-Chin Huang
1192f5ac46 [FSDP][optim_state_dict] Cleanup the unused optimizer state_dict APIs (#103781)
Cleanup the unused optimizer state_dict APIs

Differential Revision: [D46803955](https://our.internmc.facebook.com/intern/diff/D46803955/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103781
Approved by: https://github.com/rohan-varma
2023-06-21 05:38:48 +00:00
Michael Voznesensky
02f28de408 [dynamo x fsdp] Simplify stream logic handling (#103902)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103902
Approved by: https://github.com/awgu
2023-06-21 01:34:19 +00:00
Chien-Chin Huang
0ae4c4d417 [FSDP][optim_state_dict] Avoid calling optim.state_dict() to get the initial
empty states (#103609)

Users may prefix the keys optim state_dict. Using`optim.state_dict()` to get the initial states is brittle. This PR removes the call to `optim.state_dict()` and directly infers the empty states from the input states.

Differential Revision: [D46729119](https://our.internmc.facebook.com/intern/diff/D46729119/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103609
Approved by: https://github.com/awgu
2023-06-20 22:11:58 +00:00
Rodrigo Kumpera
f83ebfe1bb [FSDP] Improve support for CPU tensors. (#103171)
Don't emit device index when using CPU devices.
Don't call Tensor::record_stream as it's CUDA only op.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103171
Approved by: https://github.com/rohan-varma, https://github.com/wz337
2023-06-20 21:08:19 +00:00
Ke Wen
22e8a61d9b Implement coalesced reduce_scatter_tensor (#103561)
Map of #101157.

This PR adds support for coalesced `reduce_scatter_tensor` calls in the following syntax:

Sync communication style:
```
with dist._coalescing_manager():
     for i in range(num_coll):
         dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
```

Async communication style:
```
with dist._coalescing_manager(async_ops=True) as cm:
     for i in range(num_coll):
         dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])

# do a bunch of other things
cm.wait()
# do things that depend on the reduce-scatters' results
```
Each `reduce_scatter_tensor` call can be independent in terms of their data and buffer locations. But could be executed in parallel by supported backends (like NCCL).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103561
Approved by: https://github.com/fegin
2023-06-15 20:11:12 +00:00
Mikayla Gawarecki
d1cecd9c32 Add assign kwarg to module.load_state_dict (#102212)
Fixes #64601 and #98906

Adds an `assign` argument to `load_state_dict` that loads params/buffers by assignment instead of doing `param.copy_(param_from_state_dict)`.

Primarily intended to remove the need for the `.to_empty()` in

```
with torch.device('meta'):
    m = SomeModule()
m.to_empty()
state_dict = torch.load('...pth')
m.load_state_dict(state_dict)
```

so we can instead do

```
with torch.device('meta'):
    m = SomeModule()
state_dict = torch.load('...pth')
m.load_state_dict(state_dict, assign=True)
```

**A problem with this PR for the case where the model is initialized on meta is what happens to nonpersistent buffers/params corresponding to keys missing from the state dict?**
What happens in the case where `load_state_dict(state_dict, strict=False, assign=True)` and the state_dict is missing some keys? The corresponding params missing from the `state_dict` and nonpersistent buffers would still be on `meta` and need to be manually initialized. However, I don't think we offer an API that would initialize these.

One solution would be to make these empty tensors but it might not be semantically correct...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102212
Approved by: https://github.com/albanD
2023-06-15 18:41:00 +00:00
Andrew Gu
2eea3cb19d Fix composable checkpoint(use_reentrant=True) with multi args (#103590)
The `_ModuleHookCheckpointFunction.backward()` should take in `*output_grads` instead of `output_grads`. Otherwise, we may see an error like:
```
TypeError: backward() takes 2 positional arguments but 5 were given
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103590
Approved by: https://github.com/rohan-varma, https://github.com/fduwjj, https://github.com/fegin
2023-06-14 21:53:30 +00:00
Iris
7dd0f525b5 [FSDP][4/n]Update use_dtensor option for _optim_utils.py (#103599)
Same as https://github.com/pytorch/pytorch/pull/103069 (this branch is corrupted so have to re-submit).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103599
Approved by: https://github.com/fegin
2023-06-14 20:18:33 +00:00
Iris
d991ce6da3 [FSDP][3/N]_shard_utils update for dtensor state_dict support (#103479)
Same as https://github.com/pytorch/pytorch/pull/102545 (this branch is corrupted so have to re-submit).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103479
Approved by: https://github.com/fegin
2023-06-14 06:45:28 +00:00
Iris
51d21ffd8a [FSDP][2/n] add use_dtensor flag to both StateDictConfig and OptimStateDictConfig (#103477)
Same as #102552 (this branch is corrupted so have to re-submit).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103477
Approved by: https://github.com/fegin
2023-06-13 19:09:56 +00:00
Andrew Gu
71b560208c [FSDP] Fix device_id when buffer-only module (#103504)
There was an issue reported internally that with `sync_module_states=True`, if the model had buffers on CPU, even with `device_id` specified, FSDP would try to broadcast CPU buffers, leading to an error like:
```
RuntimeError: No backend type associated with device type cpu
```

After some investigation, I determined that we should _not_ fix this by moving the buffers to GPU just for the broadcast and then back to CPU. Instead, we should fix our `device_id` logic.

The issue is that we always used the _parameters_ as the proxy to tell whether we should move module states to the device specified by `device_id`. However, a module (often the root) may not have any parameters but have some buffers! In that case, the buffers are left on CPU even if `device_id` is specified. This PR fixes this by considering both parameters and buffers for movement to `device_id`.

Note that this PR preserves the logic that `ignored_modules` / `ignored_parameters` are not considered for this movement, meaning that ignored parameters are moved to `device_id`.

Note also that I had to move the unit test back from using MTPG to the normal PG since otherwise, I could not repro the original error. (It seems like MTPG does not complain if we try to use `dist._broadcast_coalesced()` with CPU tensors.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103504
Approved by: https://github.com/rohan-varma
2023-06-13 18:33:26 +00:00
Rodrigo Kumpera
5b33d39114 [FSDP] Workaround for GLOO's lack of all_gather_into_tensor. (#103170)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103170
Approved by: https://github.com/rohan-varma
2023-06-13 17:21:41 +00:00
Rodrigo Kumpera
63fe26809d Implement all_gather_into_tensor_coalesced. (#98642)
The implementation is suboptimal since it uses c10d's group coalescing which
is known to be inneficient.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98642
Approved by: https://github.com/wanchaol
2023-06-13 15:06:52 +00:00
zhuhong61
50c972bfd2 [c10d] Add xpu to the default device supported by user specified backend (#103410)
**Motivation:**
For collective dispatching, we want to provide a more user friendly usage for xpu device and CCL backend (user specified backend) mapping.

**Solution:**
We add xpu to the default device list, and it can construct the mapping between xpu and the user specified backend directly.
Usage:
When using xpu device, user can specify backend name only:
`dist.init_process_group(backend='ccl')`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103410
Approved by: https://github.com/jgong5, https://github.com/ezyang
2023-06-12 19:46:33 +00:00
PyTorch MergeBot
caecb55223 Revert "Log functional_collectives apis to distributed logger (#103288)"
This reverts commit 37359c36fd.

Reverted https://github.com/pytorch/pytorch/pull/103288 on behalf of https://github.com/malfet due to Broke test_inductor_collectives, see 37359c36fd ([comment](https://github.com/pytorch/pytorch/pull/103288#issuecomment-1587677705))
2023-06-12 16:37:57 +00:00
Will Constable
37359c36fd Log functional_collectives apis to distributed logger (#103288)
This logs functional collectives API calls with debug log level only.

(the `+` in the TORCH_LOGS cmdline enables debug level, otherwise only info level)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103288
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-06-12 06:33:26 +00:00
Wanchao Liang
4cc474dec4 [dtensor] support torch.save/load with DTensor (#103106)
This PR actually enables DTensor to be pickable and add tests to test
torch.save/load works correctly for DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103106
Approved by: https://github.com/kumpera
2023-06-09 04:11:15 +00:00
Wanchao Liang
d31707a257 Get rid of dim_groups attribute from DeviceMesh (#103105)
This PR get rids of the dim_groups attribute from DeviceMesh, the main
motivation behind this is that we should let c10d store the process
groups during its creation instead of DeviceMesh, DeviceMesh should just
handle ranks correctly.

This could enable DTensor becomes picklable! (torch.save/load could be
possible), which I will give it a try in the next PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103105
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
2023-06-09 04:11:15 +00:00
Andrew Gu
48056b168f [FSDP] Reshard frozen params in backward (#101982)
This PR makes a first attempt at improving FSDP's fine-tuning support by adding hooks to reshard frozen parameters in the backward pass.
- Without this, frozen parameters involved in gradient computation are kept as unsharded through the entire backward pass.
- The approach is to register a multi-grad ~~post~~-hook on the _input_ activations to the FSDP module, where the hook performs the resharding after all gradients for the FSDP module must have been computed (meaning that we are safe to reshard).

~~This PR relies on adding a "multi-grad post-hook" that differs from the existing "multi-grad hook" from `register_multi_grad_hook()`. I find that with `register_multi_grad_hook()`, sometimes the unit test counting the number of times `_post_backward_reshard()` is called fails (due to it not being called).~~ This was resolved in https://github.com/pytorch/pytorch/pull/102859.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101982
Approved by: https://github.com/rohan-varma
2023-06-08 21:12:45 +00:00
Xilun Wu
675f2597fa [reland][DTensor][3/N] add DTensor constructor function: full (#101436) (#103165)
This is a reland attempt of reverted PR #101436 .

Differential Revision: [D46537531](https://our.internmc.facebook.com/intern/diff/D46537531)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103165
Approved by: https://github.com/wanchaol
2023-06-08 16:18:33 +00:00
Rodrigo Kumpera
4833dc10b8 [DCP] Rewrite read slicing to use a wrapper. (#99167)
Moved SlicedBufferedReader to utils and renamed to _ReaderView.

It no longer depends on file handles and is a pure wrapper. This makes it general enought to handle non io stream objects like fsspec's.

Should help with #98386
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99167
Approved by: https://github.com/wz337
2023-06-08 13:52:13 +00:00
Wanchao Liang
8585784a34 [dtensor] fix allgather unpadding logic (#103219)
This PR fixes allgather unpadding logic so that we only need to unpad
the full tensor instead of first chunking it to small tensors and unpad
individually, as we know how our padding algorithm works
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103219
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-06-08 03:31:24 +00:00
Iris
d5142c52d3 [FSDP]Remove dim_group from device_mesh init (#103218)
1) remove dim_group
2) don't init device_mesh if not using default_pg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103218
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-06-08 03:29:19 +00:00
shaoyf42
17737f9d0e [DTensor] Allow DTensor support cuda-like device (#102468)
Allow DTensor support cuda-like device, fix https://github.com/pytorch/pytorch/issues/102442

Currently, DTensor supports cuda and cpu. There are other efforts to make DTensor support third-party devices, for example https://github.com/pytorch/pytorch/pull/101914 and https://github.com/pytorch/pytorch/issues/101911. However, this support only extends a portion of third-party devices and is no good support for third-party cuda-like devices. Therefore, we would like to extend DTensor to support cuda-like devices, after all, cuda is so popular!

1. Similar to what is done here, we need to initialize the communication backend for the device set by DeviceMesh. So `_default_backend_for_device` is added to `Backend`. It is worth noting that when we register a new backend for a device other than cpu and cuda, we also need to add a new default backend for this device.
2. Adding `_device_handle` to `DeviceMesh` for cuda-like devices, similar to what is set in FSDP. When `_device_handle` is not None, the device has similar behavior to `cuda`. In this way, functions like `torch.cuda.device_count()` need to be modified to `device_mesh._device_handle.device_count()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102468
Approved by: https://github.com/wanchaol
2023-06-07 23:13:53 +00:00
Ke Wen
07104ca99c [c10d] Make it default that PG do not perform barrier after init (#103033)
Both internal and OSS users trying https://github.com/pytorch/pytorch/pull/99937 report that their workloads perform normally even with the barrier removed and see a scalability win. Thus in this PR, we decide to make it default that PG do not perform a barrier after init.

In the discussion of #99937, people point out that such barrier might be needed for c10d + RPC cases. IMO, this need originates from RPC's programming model and should be RPC or RPC user's responsibility to deal with. That is, with other functions/libraries, it can happen too. So the need for c10d to do so big a favor is not justified IMO. Also good to remove it before users become reliant on this barrier.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103033
Approved by: https://github.com/XilunWu
2023-06-07 06:11:14 +00:00
Iris
a02a58d862 [FSDP][1/N]Add device_mesh to FSDPstate (#102317) (#102551)
This PR creates a device_mesh and share it across all FSDP state. The device_mesh will later be used to test out dtensor state_dict (1d device_mesh).
Approved by: https://github.com/awgu

Add device mesh to fsdp state
skip dist.get_world_size(pg) != dist.get_world_size()
address test_fake_pg.py test failure
fix test_fake_py.py failure

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102551
Approved by: https://github.com/fegin
2023-06-07 04:14:00 +00:00
Rohan Varma
dfa64fddeb [FSDP] Fix for optim state dict (#102901)
Fix for HSDP + use_orig_params where we need to pass in the PG that
might not be the default.

Differential Revision: [D46417327](https://our.internmc.facebook.com/intern/diff/D46417327/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102901
Approved by: https://github.com/wz337
2023-06-06 20:21:23 +00:00
Chao Yang
367b0ad062 enforce dtype (reland) (#102996)
Summary: The original diff didn't break the test.

Test Plan: N/A

Differential Revision: D46448488

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102996
Approved by: https://github.com/malfet, https://github.com/wanchaol
2023-06-06 00:35:04 +00:00
PyTorch MergeBot
ecb191683e Revert "enforece dtype (#102802)"
This reverts commit 8e2a86c2a5.

Reverted https://github.com/pytorch/pytorch/pull/102802 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/102802#issuecomment-1577099676))
2023-06-05 16:21:28 +00:00
Samuel Eisenhandler
9cabdff8bd Update documentation to read FileSystemReader instead of FileSystemLoader (#102795)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102795
Approved by: https://github.com/wz337
2023-06-05 15:22:49 +00:00
Chao Yang
8e2a86c2a5 enforece dtype (#102802)
Summary: Add a flag to enforce the gather data dtype. In case backward compatibility, make the default as False

Test Plan: local and mast

Reviewed By: zyan0, strisunshinewentingwang

Differential Revision: D46295190

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102802
Approved by: https://github.com/mrshenli
2023-06-05 02:04:09 +00:00
Rohan Varma
a748be93df [CheckpointWrapper] Warn on reentrant use (#102890)
We'd like to encourage users to try non-reentrant as much as possible,
and identify any gaps this way.

Differential Revision: [D46397786](https://our.internmc.facebook.com/intern/diff/D46397786/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102890
Approved by: https://github.com/awgu
2023-06-04 18:31:22 +00:00
Rohan Varma
88ce6215f5 [FSDP/DDP] Unify _cast_forward_inputs (#102680)
Closes https://github.com/pytorch/pytorch/issues/96380

Differential Revision: [D46342814](https://our.internmc.facebook.com/intern/diff/D46342814/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102680
Approved by: https://github.com/awgu
2023-06-04 18:31:21 +00:00
Rohan Varma
957ea485c4 [FSDP/AC] checkpoint_wrapper acccept auto_wrap_policy (#102672)
Some feedback for this API is that folks would like to use
auto_wrap_policy similar to FSDP instead of having to adapt to the signature of
``check_fn``.

Differential Revision: [D46340320](https://our.internmc.facebook.com/intern/diff/D46340320/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102672
Approved by: https://github.com/awgu
2023-06-04 18:31:19 +00:00
Rohan Varma
df40ec82dc [FSDP][Docs] Document get_state_dict_type (#102658)
Per title

Differential Revision: [D46335317](https://our.internmc.facebook.com/intern/diff/D46335317/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102658
Approved by: https://github.com/fegin, https://github.com/awgu
2023-06-04 18:31:18 +00:00
Rohan Varma
c6d0fe39ec [FSDP] Document optim_state_dict_config in method (#102657)
Per title

Differential Revision: [D46335318](https://our.internmc.facebook.com/intern/diff/D46335318/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102657
Approved by: https://github.com/fegin
2023-06-04 18:31:16 +00:00
Rohan Varma
beb7131c64 [FSDP] Use INFO instead of DETAIL for warning logs (#102639)
Since these are just logs and don't introduce any big perf slowdowns,
I think we should just enable them in info mode.

Differential Revision: [D46328510](https://our.internmc.facebook.com/intern/diff/D46328510/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102639
Approved by: https://github.com/awgu
2023-06-04 18:31:15 +00:00
Rohan Varma
4d516f44a1 [FSDP][ez] Type optimizer correctly (#102637)
In shardedgradscaler, the optimizer doesn't have to be SGD.

Differential Revision: [D46327103](https://our.internmc.facebook.com/intern/diff/D46327103/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102637
Approved by: https://github.com/Skylion007, https://github.com/awgu, https://github.com/fegin
2023-06-04 18:31:13 +00:00
Rohan Varma
e66c498d2d Log modules FSDP hooks fire for (#102508)
Under torch_distributed_debug >= INFO and use_orig_params=True, log post backward hook firing to debug things like FSDP + AC integration.

Differential Revision: [D46172916](https://our.internmc.facebook.com/intern/diff/D46172916/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102508
Approved by: https://github.com/awgu, https://github.com/fegin
2023-06-04 18:31:12 +00:00
PyTorch MergeBot
0f672e8c67 Revert "[DTensor][3/N] add DTensor constructor function: full (#101436)"
This reverts commit 2ca75d49a8.

Reverted https://github.com/pytorch/pytorch/pull/101436 on behalf of https://github.com/malfet due to Caused internal SEV ([comment](https://github.com/pytorch/pytorch/pull/101436#issuecomment-1575076672))
2023-06-03 17:09:08 +00:00
shaoyf42
fc218a8a13 Fix typos in README of DTensor (#102813)
Fix typos in README of DTensor. But there is still a problem to be fixed. I reported an error when I tried to use distribute_module with  shard_params. I show the specific error message in issue https://github.com/pytorch/pytorch/issues/102812.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102813
Approved by: https://github.com/wanchaol
2023-06-02 19:27:23 +00:00
Ashwin Hari
cf0aa38005 Allow ORT backend for DTensor (#101914)
fixes #101911

Currently, `DTensor` supports cuda and cpu. This PR makes some changes for easier integration with the ort backend.

* `Backend.NAME`  attribute now has value `name` instead of `NAME` for backends registered through `register_backend(name)`; this matches the pattern for backends with built-in support like nccl.
* remove unused `_check_for_nccl_backend` function
* add test case that moves parameters to device in the `partition_fn` - a scenario that's useful for big models
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101914
Approved by: https://github.com/wanchaol
2023-06-01 22:37:09 +00:00
fduwjj
92923aca61 [TP] Use Stride inferred from local tensor in to_local bwd (#102630)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102630
Approved by: https://github.com/wanchaol
2023-06-01 04:30:24 +00:00
Wanchao Liang
c5d4ee2d73 [dtensor][simple] fix some comments (#102661)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102661
Approved by: https://github.com/fduwjj, https://github.com/XilunWu
2023-06-01 03:23:19 +00:00
shaoyf42
8d7e082300 [c10d] Add is_backend_available for c10d backend. (#101945)
Add is_backend_available for c10d backend, either the built-in backends or third-party backends through function ``Backend.register_backend``.

There is a related discussion in https://github.com/pytorch/pytorch/pull/101775#discussion_r1199253553
> For example in python constructor for their backend they should explicitly add the is_X_available. Or if defining in C++ they should modify pybind like this https://github.com/H-Huang/torch_collective_extension/blob/main/custom_backend/include/dummy.hpp#L98-L101
to also add their own is_available property

It is a natural choice for users to add their own `is_available` when they create a backend. We think it might be a possible way for the user to use `is_X_available` in the same way as the native, for example by dynamically adding`torch.distributed.is_dummpy_available()` function.  This is why we want to dynamically add the `is_X_available` to `torch.distributed` in `register_backend`.

> Or we could add an Is_available(backend) function, that checks for the backend.

Providing a public function is indeed another good approach. We have implemented an `is_backend_available` in https://github.com/pytorch/pytorch/pull/101945  that supports both built-in backends and third-party backends.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101945
Approved by: https://github.com/H-Huang
2023-05-31 22:51:51 +00:00
Rohan Varma
0ecca122e7 [Replicate] Add unit test with replicate param names (#102401)
This attribute wasn't actually used in tests, add a test ensuring that
if replicate is used on top of FSDP, the replicated parameter names are as
expected.

TODO: there are a few ways to check if module is managed by composable API,
such as replicated param names for replicate, _get_module_state API,
_get_registry_api, etc. We should unify all composable APIs to check in a
unified way (filed an issue)

Differential Revision: [D46236377](https://our.internmc.facebook.com/intern/diff/D46236377/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102401
Approved by: https://github.com/awgu
2023-05-31 18:41:03 +00:00
Yanli Zhao
f47ee87765 Fix ignored_states when they are passed as generators (#102575)
This PR fixed the case where ignored_states are passed as generators, not List/Set

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102575
Approved by: https://github.com/awgu
2023-05-31 15:58:55 +00:00
Matthew Hoffman
c28f8e314d Add type hints in torch/distributed/utils.py (#102262)
Fixes #77190

Pretty similar to the typing in `torch/nn/parallel`, which was also improved recently: #102194

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102262
Approved by: https://github.com/Skylion007, https://github.com/Neilblaze
2023-05-30 19:57:45 +00:00
Wanchao Liang
ff58d19c89 DeviceMesh use dispatchable PG to support custom backend (#102336)
This PR switches DeviceMesh to use dispatchable process group instead,
this could enable easier backend integration as user only need to
integrate with c10d process group custom backend, without needing to
change DeviceMesh to plug in the backend
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102336
Approved by: https://github.com/fduwjj
2023-05-30 19:22:37 +00:00
Wanchao Liang
3ef4d697df [c10d] default backend need to check for nccl availability (#102470)
As titled, we can only initialize nccl backend when NCCL is available
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102470
Approved by: https://github.com/Skylion007, https://github.com/XilunWu
2023-05-30 19:22:37 +00:00
Will Constable
77f97019b7 Dynamo remaps legacy allgather to traceable one (#102232)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102232
Approved by: https://github.com/voznesenskym
2023-05-30 16:45:25 +00:00
PyTorch MergeBot
81ac076bce Revert "[FSDP]Add device_mesh to FSDPstate (#102317)"
This reverts commit 4c584acc5d.

Reverted https://github.com/pytorch/pytorch/pull/102317 on behalf of https://github.com/malfet due to Broke test_fake_pg, see https://github.com/pytorch/pytorch/actions/runs/5100633726/jobs/9173277369  ([comment](https://github.com/pytorch/pytorch/pull/102317#issuecomment-1566129496))
2023-05-28 12:53:28 +00:00
Iris
4c584acc5d [FSDP]Add device_mesh to FSDPstate (#102317)
This PR creates a device_mesh and share it across all FSDP state. The device_mesh will later be used to test out dtensor state_dict (1d device_mesh).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102317
Approved by: https://github.com/awgu
2023-05-27 20:25:30 +00:00
Matthew Hoffman
0ed22fce97 Merge type stubs torch nn parallel (#102194)
Fixes merge issue for #101528

In the above PR, `torch.nn.parallel.parallel_apply.get_a_var` was marked private to appease the [public interface linter](https://github.com/pytorch/pytorch/actions/runs/4999216467/jobs/8955582204#step:14:21666): ceeb242bc7

This broke CI pipelines running external dependencies that expected `get_a_var`'s name to not change. In this PR, we change the name back to `get_a_var` and include it in the `__all__` instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102194
Approved by: https://github.com/ezyang
2023-05-26 20:10:47 +00:00
Rohan Varma
3dfa755a1f [MTPG] Enable for some tests in test_fsdp_misc (#102043)
Enables MTPG for some FSDP tests in this file. Tests that need the
backward pass and warning logging are left as follow up work.

Backward pass issue: It seems that there is a hang with all_gather. Will sync with @kumpera on this.

Warning issue: We have a couple tests that regex check on warnings, but in the
multithreaded scenario these warnings are somehow not logged.

Differential Revision: [D43209769](https://our.internmc.facebook.com/intern/diff/D43209769/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102043
Approved by: https://github.com/awgu
2023-05-26 06:21:25 +00:00
Iris
080d86acfb [DCP] Add API logging for checkpoint high level API (#102278)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102278
Approved by: https://github.com/fduwjj
2023-05-25 21:13:29 +00:00
Wanchao Liang
7b47cd0a6c [c10d] add fake pg necessary collectives (#102238)
This PR adds fake pg necessary collectives to enable e2e FSDP run
with out multiprocess or multithreading
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102238
Approved by: https://github.com/ezyang
2023-05-25 05:01:16 +00:00
Wanchao Liang
9a19262556 [c10d] conslidate barrier after init logic (#102237)
This PR consolidates the barrier after init logic to allow custom
backend to set the env var when creating the pg, so that
`init_process_group` would skip barrier
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102237
Approved by: https://github.com/ezyang
2023-05-25 05:01:16 +00:00
fduwjj
d4380edb9b [TP] Add API logging for TP high level API (#102209)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102209
Approved by: https://github.com/wz337, https://github.com/wanchaol
2023-05-25 03:33:00 +00:00
Rohan Varma
f3e42f15e9 [FSDP] Start to generalize modules to ignore for mixed precision (#102010)
The main use case here is that folks would like to ignore layer norm for mixed precision. This can now be enabled with:

```
mp_config = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
            _mixed_precision_module_classes_to_ignore=[_BatchNorm, nn.LayerNorm],
        )
```

This is done by classes of types in `_mixed_precision_module_classes_to_ignore` being wrapped in their own FSDP unit with mixed preicsion disabled. This is only enabled for auto wrapping.

We also add module pre and post hooks to cast / downcast inputs to the appropriate full precision.

Differential Revision: [D46079957](https://our.internmc.facebook.com/intern/diff/D46079957/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102010
Approved by: https://github.com/awgu
2023-05-25 00:45:54 +00:00
Edward Z. Yang
c903b12cb8 Add fake process group (#102180)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102180
Approved by: https://github.com/wanchaol
2023-05-24 23:27:40 +00:00
Yeonju Ro
06f656c5d1 [distributed] implemented find_all_descendants (#102138)
Fixes #100397

Implemented find_all_descendants function that identifies the list of nodes that need to be moved. Added unit test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102138
Approved by: https://github.com/fegin
2023-05-24 21:47:59 +00:00
Yanli Zhao
956bd03808 add ignored_states to FSDP/fully_shard (#102056)
Add 'ignored_states' that accepts either a list of ignored_parameters or a list of nn modules for FSDP model wrapper and fully_shard composable APIs, it is recommended to use 'ignored_states' over 'ignored_modules' moving forward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102056
Approved by: https://github.com/awgu
2023-05-24 18:36:48 +00:00
Wanchao Liang
d316a2dd5c [spmd] Enable data parallel to work with non 0 batch dim (#100073)
This PR enables data parallel to work with non 0 batch dim, the only
thing we need to do is to expose the input_batch_dim to DataParallelMode
and the data parallel expansion automatically works as we have done
things correctly in batch dim analysis.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100073
Approved by: https://github.com/mrshenli
2023-05-24 17:55:10 +00:00
Wanchao Liang
d378837039 [spmd] add more decomp and fix a sharding bug (#100938)
This PR adds native_layernorm_backward op to the decomp table and fixes
a sharding bug to not automatically do padding
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100938
Approved by: https://github.com/mrshenli
2023-05-24 17:55:10 +00:00
Wanchao Liang
dd1f295201 [spmd] Improve activation handling, factory ops and batch dim reduction (#100853)
This PR improves the activation handling logic of data parallel, to
support the cases where there're tensor factory ops that does not depend
on any input node, it would still produce activation, with either
sharded act (i.e. if output shape have batch size) or replcate act

It also significantly simplify the full reduction logic, now we don't
need the full reduction detection, we only need to ensure that when
compute the batch dim, we detected full reduction and mark it as sharded
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100853
Approved by: https://github.com/mrshenli
2023-05-24 17:55:09 +00:00
Wanchao Liang
4d55ea8548 [spmd] enhance batch dim analysis of data parallel (#100852)
This PR enhances batch dim analysis of data parallel to understand
more on the cases where batch dim get flattened or split, using
dtensor's view ops, we could be able to track the batch dim that got
transformed in non-trival ways.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100852
Approved by: https://github.com/mrshenli
2023-05-24 17:55:07 +00:00
Wanchao Liang
b2eaba6b62 [spmd] by default average gradients for nccl backend (#99964)
This PR by default average gradient for NCCL backend, this allows
SPMD's data parallel match with DDP/FSDP results.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99964
Approved by: https://github.com/mrshenli
2023-05-24 17:55:06 +00:00
Wanchao Liang
942cd12d55 [spmd] add option to preserve node types (#100072)
This PR adds a option to preserve node types for the entire graph,
this could allow some exploration about using those node types to do
things like act checkpoint, etc.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100072
Approved by: https://github.com/mrshenli
2023-05-24 17:55:05 +00:00
medivh-xp
8b7bd81902 determined collective device by _get_pg_default_device rather than explicit cuda (#101533)
There are many communication operations for shardedTensor in the state dict of fsdp. They use the external passed-in pg (or the default pg), which currently supports cuda devices. Before communication, the memory will be moved to cuda, which is implicit (because it is essentially moving data to the memory type required by pg, not the computing device type). Similarly, when users use fsdp on a custom backend, they will pass in a custom pg (which does not support cuda devices), which may cause fsdp to not work properly in some cases. This PR obtains the memory type supported by the pg through _get_pg_default_device during communication, and moves the data to it when needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101533
Approved by: https://github.com/awgu
2023-05-24 13:48:43 +00:00
Iris
ee95e37a69 [c10d] Record time spent for init_process_group, new_group, _store_based_barrier (#101912)
1. Record time spent for init_process_group, new_group, _store_based_barrier
2. Rename c10d_error_logger to c10d_logger for generalization.
3. Refactor to move logger wrappers in distributed_c10d.py to logger to c10d_logger.py.
4. Rename the logger wrappers (bc breaking). Exception_handler is renamed to exception_logger to avoid confusion with logging handler.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101912
Approved by: https://github.com/fduwjj
2023-05-24 09:36:34 +00:00
Edward Z. Yang
3318a832b3 Tighten FakeTensor reentrancy asserts, add debugging (#102091)
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.

In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.

This PR does a number of things:

* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
2023-05-24 05:37:51 +00:00
Edward Z. Yang
f65732552e Support FakeTensor with FlatParameter (#101987)
In this PR we turn FlatParameter into a virtual tensor subclass
which doesn't actually ever get instantiated: __new__ will create
a Parameter instead (or a FakeTensor, if necessary).

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101987
Approved by: https://github.com/awgu, https://github.com/eellison
2023-05-23 23:12:08 +00:00
Wanchao Liang
6e0c741105 [dtensor] hide mesh validation check under init_process_group flag (#101996)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101996
Approved by: https://github.com/wz337
2023-05-23 18:17:54 +00:00
Wanchao Liang
70eccdbf92 [dtensor] add necessary logging to APIs and components (#101994)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101994
Approved by: https://github.com/wz337
2023-05-23 18:17:54 +00:00
Xilun Wu
2ca75d49a8 [DTensor][3/N] add DTensor constructor function: full (#101436)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101436
Approved by: https://github.com/wanchaol
2023-05-23 06:05:40 +00:00
Wanchao Liang
38a29324b0 [dtensor][2/N] more tensor ops to use strategy propagation (#101203)
As titled, this PR adapts a few more tensor ops to use strategy based
sharding prop
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101203
Approved by: https://github.com/XilunWu
2023-05-22 17:16:14 +00:00
Aaron Gokaslan
3e2ea32dab [BE]: Enable ruff rule TRY302 and apply fixes (#101874)
Removes useless try statements and unreachable code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101874
Approved by: https://github.com/malfet
2023-05-19 17:30:52 +00:00
medivh-xp
e06bd8f3b1 fsdp support create hybrid-sharded process group for custom backend (#100622)
FSDP creates communication groups for intra-node communication through dist.new_subgroups. Previously, dist.new_subgroups only supported creation based on the number of CUDA devices. However, issue #99706 removed the avaliable-check for CUDA devices, allowing for custom backend create group based on num of custom devices per node.

This PR allows FSDP to explicitly pass device num within the node when creating communication groups for intra-node communication, instead of defaulting to the number of CUDA devices.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100622
Approved by: https://github.com/awgu
2023-05-19 06:08:55 +00:00
shaoyf42
97180aca5e Enables barrier to support the specified device (#99589)
Enables barrier to support the specified device, e.g cuda/custom device. There is some discussion here: https://github.com/pytorch/pytorch/issues/97938#issue-1646833919

Today, there are two limitations of barrier:
One is that barrier does not support custom  #device:
fbdb86c174/torch/csrc/distributed/c10d/ProcessGroup.hpp (L512-L522)

The second is that there is a special valid for nccl when device_id is not None, which is an assumption for cuda and nccl bindings, and also hinders custom device.
789070986c/torch/distributed/distributed_c10d.py (L3504-L3508)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99589
Approved by: https://github.com/kwen2501
2023-05-17 05:26:04 +00:00
Thibaut Durand
01da732691 Fix type annotation of torch.split (#100655)
The type annotation indicates `list` but the returned type is `tuple`
```python
>>> import torch
>>> type(torch.arange(10).split(4))
<class 'tuple'>
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100655
Approved by: https://github.com/kit1980
2023-05-16 21:35:41 +00:00
Xilun Wu
010763be9a [DTensor][2/N] add DTensor constructor function: empty (#101022)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101022
Approved by: https://github.com/wanchaol
2023-05-16 16:50:54 +00:00
Xilun Wu
5cc361c736 [DTensor][1/N] add DTensor constructor function: ones (#100933)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100933
Approved by: https://github.com/wanchaol
2023-05-16 16:50:54 +00:00
albanD
59dff01319 Add top level function to check if running with deploy (#101420)
Also not sure if this should be a public function or not. Leaving it private for now but let me know if you prefer for it to be public.

FYI @nikitaved this will logically conflict with your triton kernel PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101420
Approved by: https://github.com/malfet
2023-05-16 16:05:49 +00:00
Xuehai Pan
05f6250815 Add missing torch.distributed.ReduceOp.AVG in type stubs (#101534)
Add missing `AVG` to `torch.distributed.ReduceOp` enum for type annotation.

Ref:

88b6a4577b/torch/csrc/distributed/c10d/Types.hpp (L35-L47)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101534
Approved by: https://github.com/Skylion007
2023-05-16 15:51:21 +00:00
fduwjj
9d858642af [PTD] Make input contiguous for _ReduceScatter (#101373)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101373
Approved by: https://github.com/wz337
2023-05-15 22:08:21 +00:00
Shen Li
af841f38bd [SPMD] Allow Override.replacement to have a global view (#101427)
It's easier for users to implement one Override that takes care of
all target submodules of different types, instead of specifying one
mapping pair for each FQN/type. For example, when calculating
sharding for sparse layers, the decision needs to be make globally.
In this, case it's helpful to allow user Override to get access to
all submodules and make replacement decisions accordingly.

Differential Revision: [D45879732](https://our.internmc.facebook.com/intern/diff/D45879732)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101427
Approved by: https://github.com/fegin
2023-05-15 21:27:41 +00:00
Aaron Gokaslan
dfe484a3b3 [BE]: Bugfix functorch and some generic typing improvements (#101337)
Fixes some typing bugs found with newer versions of mypy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101337
Approved by: https://github.com/ezyang
2023-05-14 14:20:56 +00:00
Iris
568db1b464 [dtensor] Relax condition for _split_tensor() (#101218)
When tensor.size(self.dim) < num_chunks, we will fill empty chunk with empty tensor (https://github.com/pytorch/pytorch/pull/98722). Therefore, we no longer needs this assert.

For example, when sharding a tensor with 1 element on 2 ranks along dim 0, results would be as follows:
```
rank:0, dtensor:DTensor(local_tensor=tensor([0.4963], device='cuda:0'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)])
rank:1, dtensor:DTensor(local_tensor=tensor([], device='cuda:1'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101218
Approved by: https://github.com/wanchaol
2023-05-14 07:39:27 +00:00
Yanli Zhao
5ac48eb353 [FSDP]Skip unshard call during checkpointing for NO_SHARD sharding strategy (#101095)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101095
Approved by: https://github.com/fegin
2023-05-12 18:19:18 +00:00
Wanchao Liang
3ae612ba7f [dtensor] remove assertions about submesh checks (#101229)
This PR removes assertions from submesh checks to directly return local
tensor, this is so that all the other APIs can work with submesh
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101229
Approved by: https://github.com/fduwjj
2023-05-12 04:20:35 +00:00
Aaron Gokaslan
738ba13b35 [BE]: enable PLE error codes in ruff and fix bugs (#101079)
Enables PyLint error codes implemented in ruff. These are un-opinionated static analysis checks on Python code that finds common bugs. After running all the PLE error codes that are implemented in ruff, I fixed the bugs, added a few ignores for malformed Python code that is part of our JIT test script, and finally added a few ignores for a false positive on PLE0605 and submitted an issue upstream to fix in ruff https://github.com/charliermarsh/ruff/issues/4345 .

Common bugs found here include analysis for malformed logging format calls, bad string format calls, invalid escape sequences, and more.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101079
Approved by: https://github.com/malfet
2023-05-11 23:57:25 +00:00
Wanchao Liang
599ae95d1a [dtensor] use stack to manage mesh resources (#101202)
This PR changes the context manager behavior of device mesh, now we use
a mesh env to track the current mesh and save the mesh to a stack so
that we can allow nested context manager
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101202
Approved by: https://github.com/wz337
2023-05-11 23:48:36 +00:00
Chien-Chin Huang
49c8a0cad0 [SPMD][BE] Remove the legacy tracing code (#100858)
Remove the legacy tracing code as it cause several test and benchmark issues.

Differential Revision: [D45649123](https://our.internmc.facebook.com/intern/diff/D45649123/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100858
Approved by: https://github.com/wanchaol
2023-05-11 23:08:27 +00:00
Ke Wen
daed3bf8f9 Implement coalesced all_gather_into_tensor (#101157)
This PR adds support for the following use cases:
- Sync style:
```
with dist._coalescing_manager():
     for i in range(num_coll):
         dist.all_gather_into_tensor(output_tensors[i], input_tensors[i])
```
- Async style:
```
with dist._coalescing_manager(async_ops=True) as cm:
     for i in range(num_coll):
         dist.all_gather_into_tensor(output_tensors[i], input_tensors[i])

# do a bunch of other things
cm.wait()
# do things that depend on the all-gather's
```
Each `all_gather_into_tensor` would be independent in terms of data and their buffer location. But could be executed in parallel by supported backends (like NCCL).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101157
Approved by: https://github.com/kumpera, https://github.com/wanchaol
2023-05-11 20:58:47 +00:00
Wanchao Liang
a1aa32e204 [dtensor] tensor ops to use strategy based sharding prop (#100607)
This is the first series of PR that adopts operator impls to use a
strategy based approach, each op utilizes OpStrategy and PlacementStrategy
to generate their own strategy. By utilizing the strategy based
approach along with the op graph, we could enable more advanced op
implementation (decomp is possible), and turn the sharding prop to be
more like a contraint satisfication problem.

This PR alone only adds some basic tensor op strategies, and it directly
works on the op graph that was used for metadata propagation. The tensor ops
added in this PR mainly follows one of the arg strategy. The next set of
PRs would add more op strategies to other ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100607
Approved by: https://github.com/XilunWu
2023-05-11 02:47:20 +00:00