Given standalone generates args anyways, it seems like it would be more convenient if it explicitly used a random port by default instead of trying to use 29400.
That way users can directly go with `--standalone` instead of having to spell out `--rdzv-backend=c10d --rdzv-endpoint=localhost:0`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107734
Approved by: https://github.com/H-Huang
As the title says, I was trying to test the functional collectives, and, when printing the resulting tensors, sometimes they wouldn't have finished the Async operation yet. According to the comments in the file, "AsyncTensor wrapper applied to returned tensor, which issues wait_tensor() at the time of first use". This is true in most cases, but not when print() is your first use. This PR fixes that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107808
Approved by: https://github.com/fduwjj
**Overview**
This PR runs the HSDP all-reduce as async so that it can overlap with both all-gather and reduce-scatter, which can lead to slight end-to-end speedups when the sharding process group is fully intra-node. Previously, the all-reduce serializes with reduce-scatter, so it can only overlap with one all-gather.
For some clusters (e.g. our AWS cluster), `NCCL_CROSS_NIC=1` improves inter-node all-reduce times when overlapped with intra-node all-gather/reduce-scatter.
**Experiment**
<details>
<summary> Example 'before' trace </summary>
<img width="559" alt="hsdp_32gpus_old" src="https://github.com/pytorch/pytorch/assets/31054793/15222b6f-2b64-4e0b-a212-597335f05ba5">
</details>
<details>
<summary> Example 'after' trace </summary>
<img width="524" alt="hsdp_32gpus_new" src="https://github.com/pytorch/pytorch/assets/31054793/94f63a1d-4255-4035-9e6e-9e10733f4e44">
</details>
For the 6-encoder-layer, 6-decoder layer transformer with `d_model=8192`, `nhead=64` on 4 nodes / 32 40 GB A100s via AWS, the end-to-end iteration times are as follows (with AG == all-gather, RS == reduce-scatter, AR == all-reduce; bandwidth reported as algorithmic bandwidth):
- Reference FSDP:
- **1160 ms / iteration**
- ~23 ms / encoder AG/RS --> 24.46 GB/s bandwidth
- ~40 ms / decoder AG/RS --> 26.5 GB/s bandwidth
- 50 GB/s theoretical inter-node bandwidth
- Baseline 8-way HSDP (only overlap AR with AG) -- intra-node AG/RS, inter-node AR:
- **665 ms / iteration**
- ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
- ~5 ms / decoder AG/RS --> 212 GB/s bandwidth
- ~30 ms / encoder AR --> 2.34 GB/s bandwidth
- ~55 ms / decoder AR --> 2.65 GB/s bandwidth
- 300 GB/s theoretical intra-node bandwidth
- New 8-way HSDP (overlap AR with AG and RS) -- intra-node AG/RS, inter-node AR:
- **597 ms / iteration**
- ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
- ~6.2 ms / decoder AG/RS --> 170.97 GB/s bandwidth (slower)
- ~23 ms / encoder AR (non-overlapped) --> 3.057 GB/s bandwidth (faster)
- ~49 ms / decoder AR (non-overlapped) --> 2.70 GB/s bandwidth (faster)
- ~100 ms / decoder AR (overlapped) --> 1.325 GB/s bandwidth (slower)
- Overlapping with reduce-scatter reduces all-reduce bandwidth utilization even though the all-reduce is inter-node and reduce-scatter is intra-node!
- New 8-way HSDP (overlap AR with AG and RS) with `NCCL_CROSS_NIC=1`:
- **556 ms / iteration**
- Speedup comes from faster overlapped AR
Thus, for this particular workload, the async all-reduce enables 16% iteration-time speedup compared to the existing HSDP and 52% speedup compared to FSDP. These speedups are pronounced due to the workload being communication bound, so any communication time reduction translates directly to speedup.
**Unit Test**
This requires >= 4 GPUs:
```
python -m pytest test/distributed/fsdp/test_fsdp_hybrid_shard.py -k test_fsdp_hybrid_shard_parity
```
Differential Revision: [D47852456](https://our.internmc.facebook.com/intern/diff/D47852456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106080
Approved by: https://github.com/ezyang
ghstack dependencies: #106068
The post-backward hook has some complexity due to the different paths: {no communication hook, communication hook} x {`NO_SHARD`, `FULL_SHARD`/`SHARD_GRAD_OP`, `HYBRID_SHARD`/`_HYBRID_SHARD_ZERO2`} plus some options like CPU offloading and `use_orig_params=True` (requiring using sharded gradient views).
The PR following this one that adds async all-reduce for HSDP further complicates this since the bottom-half after all-reduce must still be run in the separate all-reduce stream, making it more unwieldy to unify with the existing bottom-half.
Nonetheless, this PR breaks up the post-backward hook into smaller logical functions to hopefully help readability.
Differential Revision: [D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068
Approved by: https://github.com/ezyang, https://github.com/fegin
The `broadcast_object_list` function can easily broadcast the state_dict of models/optimizers. However, the `torch.cat` operation performed within `broadcast_object_list` consumes an additional double amount of memory space. This means that only objects with a maximum memory occupancy of half the device capacity can be broadcasted. This PR improves usability by skipping the `torch.cat` operation on object_lists with only a single element.
Before (30G tensor):
<img width="607" alt="image" src="https://github.com/pytorch/pytorch/assets/22362311/c0c67931-0851-4f27-81c1-0119c6cd2944">
After (46G tensor):
<img width="600" alt="image" src="https://github.com/pytorch/pytorch/assets/22362311/90cd1536-be7c-43f4-82ef-257234afcfa5">
Test Code:
```python
if __name__ == "__main__":
dist.init_process_group(backend='nccl')
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
fake_tensor = torch.randn(30 * 1024 * 1024 * 1024 // 4)
if dist.get_rank() == 0:
state_dict = {"fake_tensor": fake_tensor}
else:
state_dict = {}
object_list = [state_dict]
dist.broadcast_object_list(object_list, src=0)
print("Rank: ", dist.get_rank(), " Broadcasted Object: ", object_list[0].keys())
dist.barrier()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107509
Approved by: https://github.com/awgu
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
This PR fixes the requires_grad set when calling distribute_tensor, we
should set the requires_grad of the local tensor after the detach call
to make sure we create the leaf correctly, otherwise it would raise
warnings
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107606
Approved by: https://github.com/fduwjj
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.
This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms

after (with this change), aten.addmm latency: 0.341ms

overall one layer of mlp time reduced from 13.535 -> 9.665ms
Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
We cannot use inner tensors for finalizers as they are uncollective until waited.
This PR adds a bunch of tests for the observable behavior we want, including the
necessary scafold for us to test code for their waitiness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107250
Approved by: https://github.com/wconstab
This allows infra/trainers to get detailed stats about communication
efficiencies without know anything about what model or distributed
training paradigms have been used. This is helpful as infra/trainer
package usually prefers to be as model/algorithm agnostic as possible.
Therefore, we cannot assume that infra/trainer can have access to all
collectives used by the model authors.
This commit adds an `OnCompletion` hook to `ProcessGroupNCCL` which
will be fired on every work completion event.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107233
Approved by: https://github.com/kumpera
https://github.com/pytorch/pytorch/pull/106524 gets merged so fast that we didn't figure out that we should hash both stride and dtype in DTensorSpec. This is a forward fix.
One analysis for why using just shape is not enough.
1. We use the hash value for sharding propogation cache. And the output sharding contains the stride, size of the output DTensor. If we don't consider stride, we will see errors.
2. One reason can be found below:
```
OpSchema(func_schema=aten::t(Tensor(a) self) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]), placements=(Shard(dim=0),), tensor_meta=TensorMetadata(shape=torch.Size([64, 128]), dtype=torch.float32, requires_grad=False, stride=(128, 1), memory_format=None, is_quantized=False, qparams={})),), kwargs_schema={})
```
```
OpSchema(func_schema=aten::t(Tensor(a) self) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]), placements=(Shard(dim=0),), tensor_meta=TensorMetadata(shape=torch.Size([64, 128]), dtype=torch.float32, requires_grad=False, stride=(1, 64), memory_format=None, is_quantized=False, qparams={})),), kwargs_schema={})
```
The only difference between two op_schame is the tensor stride:
<img width="151" alt="image" src="https://github.com/pytorch/pytorch/assets/6937752/161335df-bdfb-47c5-ba79-82616d070d15">
that makes the transpose op generates wrong result and leads to the add_/addmm_ op failing with errors:
```
Traceback (most recent call last):
File "/data/users/fduwjj/pytorch/torch/multiprocessing/spawn.py", line 74, in _wrap
fn(i, *args)
File "/data/users/fduwjj/pytorch/benchmarks/distributed/tensor/tp_benchmark.py", line 210, in run_tp
output.sum().backward()
File "/data/users/fduwjj/pytorch/torch/_tensor.py", line 491, in backward
torch.autograd.backward(
File "/data/users/fduwjj/pytorch/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/api.py", line 252, in __torch_dispatch__
return op_dispatch.operator_dispatch(
File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/dispatch.py", line 116, in operator_dispatch
out, _, _ = _operator_dispatch(op_call, args, kwargs, sharding_propagator)
File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/dispatch.py", line 246, in _operator_dispatch
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
File "/data/users/fduwjj/pytorch/torch/_ops.py", line 435, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: The size of tensor a (64) must match the size of tensor b (8) at non-singleton dimension 1
```
Same thing with dtype, if we are using DTensor in the environment of mixed precision, we will run into situations like this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107181
Approved by: https://github.com/wanchaol
ghstack dependencies: #106524
This allows infra/trainers to get detailed stats about communication
efficiencies without know anything about what model or distributed
training paradigms have been used. This is helpful as infra/trainer
package usually prefers to be as model/algorithm agnostic as possible.
Therefore, we cannot assume that infra/trainer can have access to all
collectives used by the model authors.
This commit adds an `OnCompletion` hook to `ProcessGroupNCCL` which
will be fired on every work completion event.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106988
Approved by: https://github.com/kumpera, https://github.com/H-Huang
ghstack dependencies: #107140, #107141, #107160
…out specifying the Backend
When init_process_group is not been done before, it will automatically apply init_process_group within Devicemesh without specifying the backend. Thus, when a third-party device want to use Devicemesh without doing init_process_group before, there comes a problem. In this PR, add a default_device_backend_map for third-party device users to add their backends to this map when they register their backends to pytorch firstly. When doing init_process_group without parameter backend, it will init the backends in this map. Thus, a third-party user can use init_process_group method without specifying the Backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107113
Approved by: https://github.com/wanchaol
Summary:
When loading a CPU state_dict with a pg initialized with
cpu:gloo,cuda:nccl, we hit a gloo crash since dest tensor is on GPU and input
is on CPU.
As a workaround, just enforce that if local_tensor.is_cpu, the dest tensor is
also cpu.
Test Plan: CI
Differential Revision: D48324752
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107172
Approved by: https://github.com/fegin
Move the remaining collectives to a separate file to prepare device mesh
to become a public distributed API
For those remaining utils, we need to upstream them to functional
collectives with proper implementation, added TODO there for a follow up
PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107012
Approved by: https://github.com/fduwjj
AsyncCollectiveTensor is a tensor subclass that is meant to "delay synchronization" when you call into the functional collectives API's. It does this (if I understand correctly) by internally holding an "unsynchronized" version of the tensor, which is the result of the communication op, and internally calling `.wait()` to synchronize the data the next time it is used.
Previously, these wait() calls would happen immediately, because `AsyncCollectiveTensor` gets wrapped by `DTensor()`, which calls `.detach()` on its inner tensor, immediately causing the sync (code: 1518d5eec4/torch/distributed/_tensor/api.py (L207))
AsyncCollectiveTensor shouldn't need to do a synchronization if you try to detach() it though - in fact, it should be fine to avoid synchronizing if you perform any view ops on it (which just require viewing metadata, but not actual data). This PR tries to update `AsyncCollectiveTensor` to delay `wait()` calls whenever the subclass encounters a view op.
Added some light testing, that just runs some DTensor compute followed by view ops, and confirms that the output is still an `AsyncCollectiveTensor` when we call `.to_local()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105240
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/wconstab
This fixes a pretty vicious bug relating to `SHARD_GRAD_OP`, mixed precision, EMA, and eval.
**Bug Explanation**
The model has a main module and an EMA module, where the main module is used for training and the EMA module is used for eval. The model has FSDP's fp16 mixed precision enabled. The flow consists of (1) training forward/backward/optimizer -> (2) EMA update (copy main module to EMA module) -> eval forward in `torch.no_grad()`, where this repeats for many iterations.
Consider the _second_ iteration.
- From the first iteration's eval forward, the EMA module has the fp16 unsharded parameters in memory (not freed due to `SHARD_GRAD_OP`).
- In this second iteration's step (2), we perform the EMA update under the `summon_full_params()` context, where FSDP specially forces full precision. This means that the EMA module now uses fp32 unsharded parameters, distinct from the fp16 unsharded parameters still in memory. The EMA update modifies those fp32 parameters, and upon exiting the context, FSDP correctly writes the modifications back to the fp32 sharded parameters.
- In the second iteration's step (3) (eval forward), FSDP checks whether it needs to run the unshard op (including all-gather) but sees it does not since the fp16 unsharded parameters are still in memory. Thus, FSDP uses those fp16 unsharded parameters directly without all-gather. However, these fp16 unsharded parameters are stale and do not include the EMA update!
- In other words, at this point, the fp32 sharded parameters are correct, the fp16 unsharded parameters are stale, and FSDP chooses _not_ to re-all-gather since the fp16 unsharded parameters are in memory.
**Fix Explanation**
This PR fixes this by freeing the fp16 unsharded parameters if they are still allocated when forcing full precision, i.e. using fp32 unsharded parameters in `summon_full_params()`. This ensures that any modifications written back to the fp32 sharded parameters will be persisted via the next all-gather.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106858
Approved by: https://github.com/kumpera
ghstack dependencies: #106857
issue resolved: https://github.com/pytorch/pytorch/issues/97791
before this PR, mixed_precision applies to buffers from ignored modules. see ```test_state_dict_with_ignored_modules(mixed_precision=True)``` for reproduce
after, we avoid applying mixed_precision semantics to buffers from ignored modules
* step 1 initialization: state._ignored_buffer_names contains all the buffers from ignored modules
* step 2 lazy init at runtime: skip ignored buffers in ```_get_buffers_and_dtypes_for_computation```
* step 3 skip upcasting in state_dict hook: avoid upcasting for ignored buffers in ```_get_buffers_and_dtypes_for_computation```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106766
Approved by: https://github.com/awgu
Currently, DCP treats tensors as duplicates and only saves them on rank0. This won't work for PiPPy as PiPPy does have unique tensors across different ranks. With the current setup, we would only be saving the tensors on rank0 (coordinator rank).
In this PR, we are changing to letting each rank create its own WriteItem for tensors. For the ones that does replicate across different ranks, we are handling it thru dedup_tensors(), which will dedup the replicate WriteItem so we only do the actual writing once.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106415
Approved by: https://github.com/wz337
This PR adds a new `CustomPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.
The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
...
policy = CustomPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root
---
After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104986
Approved by: https://github.com/ezyang
ghstack dependencies: #104427, #104967, #104999, #104969
This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.
This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999
This PR adds improved error/warning messaging when auto wrapping with `ModuleWrapPolicy` in the presence of frozen parameters.
- For `use_orig_params=False`, FSDP requires uniform `requires_grad` for each FSDP instance. This PR adds a `ValueError` at wrapping time with a message that mentions the violating module and the frozen/non-frozen parameter names.
- For `use_orig_params=True`, FSDP allows non-uniform `requires_grad` for each FSDP instance. However, it will result in higher-than-expected gradient memory usage. This PR adds a `UserWarning` at wrapping time with a message that mentions the violating module, how much extra gradient memory will be used (in units of numel), and the frozen/non-frozen parameter names.
- There is a possibility that this warning will be spammy/verbose, but my current thinking is that it is okay for now unless users complain.
<details>
<summary> Why DFS via named_children() vs. Using named_modules()</summary>
```
LoraModel(
(embed_tokens): Embedding(100, 32)
(layers): ModuleList(
(0-3): 4 x LoraDecoder(
(attn): LoraAttention(
(q_proj): Linear(in_features=32, out_features=32, bias=False)
(lora_A): Linear(in_features=32, out_features=8, bias=False)
(lora_B): Linear(in_features=8, out_features=32, bias=False)
(k_proj): Linear(in_features=32, out_features=32, bias=False)
(v_proj): Linear(in_features=32, out_features=32, bias=False)
(o_proj): Linear(in_features=32, out_features=32, bias=False)
)
(mlp): LoraMLP(
(proj1): Linear(in_features=32, out_features=128, bias=False)
(proj2): Linear(in_features=128, out_features=32, bias=False)
)
(inp_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
(post_attn_layernorm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)
)
(norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
)
```
Reverse topological order with stack-based DFS via `named_children()`:
```
[
'embed_tokens',
'layers.0.attn.q_proj', 'layers.0.attn.lora_A', 'layers.0.attn.lora_B', 'layers.0.attn.k_proj', 'layers.0.attn.v_proj', 'layers.0.attn.o_proj', 'layers.0.attn', 'layers.0.mlp.proj1', 'layers.0.mlp.proj2', 'layers.0.mlp', 'layers.0.inp_layernorm', 'layers.0.post_attn_layernorm', 'layers.0',
'layers.1.attn.q_proj', 'layers.1.attn.lora_A', 'layers.1.attn.lora_B', 'layers.1.attn.k_proj', 'layers.1.attn.v_proj', 'layers.1.attn.o_proj', 'layers.1.attn', 'layers.1.mlp.proj1', 'layers.1.mlp.proj2', 'layers.1.mlp', 'layers.1.inp_layernorm', 'layers.1.post_attn_layernorm', 'layers.1',
'layers.2.attn.q_proj', 'layers.2.attn.lora_A', 'layers.2.attn.lora_B', 'layers.2.attn.k_proj', 'layers.2.attn.v_proj', 'layers.2.attn.o_proj', 'layers.2.attn', 'layers.2.mlp.proj1', 'layers.2.mlp.proj2', 'layers.2.mlp', 'layers.2.inp_layernorm', 'layers.2.post_attn_layernorm', 'layers.2',
'layers.3.attn.q_proj', 'layers.3.attn.lora_A', 'layers.3.attn.lora_B', 'layers.3.attn.k_proj', 'layers.3.attn.v_proj', 'layers.3.attn.o_proj', 'layers.3.attn', 'layers.3.mlp.proj1', 'layers.3.mlp.proj2', 'layers.3.mlp', 'layers.3.inp_layernorm', 'layers.3.post_attn_layernorm', 'layers.3',
'layers', 'norm', ''
]
```
Reverse topological order with `named_modules()`:
```
[
'norm',
'layers.3.post_attn_layernorm', 'layers.3.inp_layernorm', 'layers.3.mlp.proj2', 'layers.3.mlp.proj1', 'layers.3.mlp', 'layers.3.attn.o_proj', 'layers.3.attn.v_proj', 'layers.3.attn.k_proj', 'layers.3.attn.lora_B', 'layers.3.attn.lora_A', 'layers.3.attn.q_proj', 'layers.3.attn', 'layers.3',
'layers.2.post_attn_layernorm', 'layers.2.inp_layernorm', 'layers.2.mlp.proj2', 'layers.2.mlp.proj1', 'layers.2.mlp', 'layers.2.attn.o_proj', 'layers.2.attn.v_proj', 'layers.2.attn.k_proj', 'layers.2.attn.lora_B', 'layers.2.attn.lora_A', 'layers.2.attn.q_proj', 'layers.2.attn', 'layers.2',
'layers.1.post_attn_layernorm', 'layers.1.inp_layernorm', 'layers.1.mlp.proj2', 'layers.1.mlp.proj1', 'layers.1.mlp', 'layers.1.attn.o_proj', 'layers.1.attn.v_proj', 'layers.1.attn.k_proj', 'layers.1.attn.lora_B', 'layers.1.attn.lora_A', 'layers.1.attn.q_proj', 'layers.1.attn', 'layers.1', 'layers.0.post_attn_layernorm', 'layers.0.inp_layernorm', 'layers.0.mlp.proj2', 'layers.0.mlp.proj1', 'layers.0.mlp', 'layers.0.attn.o_proj', 'layers.0.attn.v_proj', 'layers.0.attn.k_proj', 'layers.0.attn.lora_B', 'layers.0.attn.lora_A', 'layers.0.attn.q_proj', 'layers.0.attn', 'layers.0',
'layers', 'embed_tokens', ''
]
```
With the stack-based DFS via `named_children()`, reversing the topological order gives us each level in the module tree in the registered order, wheres with `named_modules()`, reversing the topological order gives us each level in reverse. Both are valid orders, but we prefer the former since it allows us to error/warn on the _first-registered_ module that violates the frozen/non-frozen condition.
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104427
Approved by: https://github.com/ezyang
This PR should not make any functional difference. It:
- adds clearer documentation
- clarifies a type
- revises minor typos
- swaps a .keys for a .items call on a dictionary
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106069
Approved by: https://github.com/awgu
### Background: Gradient Pre-Divide
Consider $N$ data parallel workers. Define $g_i$ to be the $i$ th worker's local unsharded gradient. Data parallel gradient reduction computes $\overline g = \frac{1}{N} \sum_{i \in [N]} g_i$.
$\sum_{i \in [N]} g_i$ increases the magnitude by a factor of $N$, which may overflow for fp16. However, if we pre-divide and compute $\sum_{i \in [N]} \frac{g_i}{N}$, then the $\frac{g_i}{N}$ may underflow. The current solution from Myle for FSDP is to pre-divide by $\sqrt{N}$ and post-divide by $\sqrt{N}$:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{i \in [N]} \frac{g_i}{\sqrt{N}}.$$
Now, consider HSDP with $N = S \cdot R$ data parallel workers, sharding over $S$ workers and replicating over $R$ workers. Define $g_{i,j}$ to be the $i \cdot S + j$ th worker's local unsharded gradient (so sharding indexes with $i$ and replication indexes with $j$). The existing implementation computes
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}},$$
where the $\frac{1}{\sqrt{R}} \frac{1}{\sqrt{S}}$ involves two separate `aten::div_` kernels.
### Revisiting Pre-Divide for HSDP
A minor optimization that we can do is with this intermediate `div_`. There are two options:
1. Compute $\overline{g}$ in the same way as FSDP:
$$\overline{g} = \frac{1}{\sqrt{N}} \sum_{j \in [R]} \sum_{i \in [S]} \frac{g_{i,j}}{\sqrt{N}}.$$
2. Compute $\overline{g}$ still with an intermediate division for rescaling but coalescing the two `divs_` into one:
$$\overline{g} = \frac{1}{\sqrt{R}} \sum_{j \in [R]} \textcolor{red}{ \frac{1}{\sqrt{N}} } \sum_{i \in [S]} \frac{g_i}{\sqrt{S}}$$
This PR goes with the 1st approach prioritizing performance because (1) it matches the existing FSDP behavior and (2) it avoids a memor-bandwidth bound `div_` kernel that blocks all-reduce launch.
### Implementation Details
In order to accommodate this, we need to refactor the communication hook logic that baked the gradient pre/post-division into the default hook.
- We raise an error if registering a communication hook for HSDP since the current implementation would only apply the hook to the reduce-scatter, not the all-reduce, which may be unexpected.
- We change it so that `state._comm_hook is not None` iff a communication hook is registered. This makes the collectives and the pre/post-division in the default no-communication-hook path more visible in the code.
Differential Revision: [D47852459](https://our.internmc.facebook.com/intern/diff/D47852459)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106034
Approved by: https://github.com/rohan-varma
With distributed checkpointing in PyTorch/XLA SPMD, the WriteItem index hints should not be modified when creating the global plan. In order to reuse the default planner logic for checkpoint metadata creation, we need to make the behavior of rewriting index hints optional.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105861
Approved by: https://github.com/kumpera
- This PR rewords the `BackwardPrefetch` docs to make the tradeoffs clear in the first sentence of each with more technical details after.
- The only supported `_FSDPPolicy` is `ModuleWrapPolicy` at the time of writing this PR. We may add others in the future such as in my other PR stack. This PR removes `_FSDPPolicy` from the public docs.
- This provides some more details around `MixedPrecision` such as explaining that layer norm and batch norm accumulate in fp32.
Follow-ups:
- Why do we force batch norm modules to have FSDP applied separately? (E.g. was this because before batch norm kernels did not support fp16/bf16?) Like layer norm, this just means that the affine parameters are in fp32. Both already accumulate in fp32 even with fp16/bf16 inputs.
- Check the `param_init_fn` + `sync_module_states=True` usage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105847
Approved by: https://github.com/rohan-varma
This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`
`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.
Captured graph:
```
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False); l_x_ = None
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local); prim_from_local = None
to_local = prim_redistribute.to_local(); prim_redistribute = None
add = to_local + 2; to_local = None
return (add,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym
This fixes https://github.com/pytorch/pytorch/issues/104504.
- When not using full-precision eval, the relevant fix is to force `_use_sharded_views()` calls if needed in `SUMMON_FULL_PARAMS` training state.
- When using full-precision in eval, the relevant fix is tracking what was the unsharded flat parameter from which the unsharded views were computed and using that instead of determining the unsharded flat parameter from the calling context via `_get_padded_unsharded_flat_param()`.
This also fixes https://github.com/pytorch/pytorch/issues/104770.
<details>
<summary> Print output showing parity </summary>
```
Key: 0
Model 1: [-1.5, 6.40625, -0.9453125, -0.3828125, 0.16015625, -1.5078125]
Model 2: [-1.5, 6.40625, -0.9453125, -0.3828125, 0.16015625, -1.5078125]
Key: 1
Model 1: [0.0157470703125, -0.8828125, 5.65625, 1.1328125, 0.275390625, 0.11181640625]
Model 2: [0.0157470703125, -0.8828125, 5.65625, 1.1328125, 0.275390625, 0.11181640625]
Key: 2
Model 1: [0.1689453125, -0.00567626953125, -0.09375, 7.34375, -0.18359375, -0.09521484375]
Model 2: [0.1689453125, -0.00567626953125, -0.09375, 7.34375, -0.18359375, -0.09521484375]
Key: 3
Model 1: [0.546875, -0.8984375, 0.228515625, 0.7578125, 6.0625, 0.435546875]
Model 2: [0.546875, -0.8984375, 0.228515625, 0.7578125, 6.0625, 0.435546875]
Key: 4
Model 1: [-0.66796875, -0.88671875, 0.30078125, 0.06494140625, 0.412109375, 6.9375]
Model 2: [-0.66796875, -0.88671875, 0.30078125, 0.06494140625, 0.412109375, 6.9375]
Key: 5
Model 1: [0.07763671875, 0.8671875, -0.43359375, 0.5703125, 0.76171875, -0.0089111328125]
Model 2: [0.07763671875, 0.8671875, -0.43359375, 0.5703125, 0.76171875, -0.0089111328125]
Key: 6
Model 1: [-0.283203125, -0.361328125, 0.474609375, 0.10205078125, 1.125, -0.0859375]
Model 2: [-0.283203125, -0.361328125, 0.474609375, 0.10205078125, 1.125, -0.0859375]
Key: 7
Model 1: [1.140625, 0.62890625, -0.07568359375, -1.0390625, -0.2578125, -0.053955078125]
Model 2: [1.140625, 0.62890625, -0.07568359375, -1.0390625, -0.2578125, -0.053955078125]
Key: 8
Model 1: [0.68359375, -1.09375, 0.59375, 1.0, -0.23828125, 0.578125]
Model 2: [0.68359375, -1.09375, 0.59375, 1.0, -0.23828125, 0.578125]
Key: 9
Model 1: [0.515625, 0.296875, -0.1826171875, -0.12890625, -0.51953125, -0.3359375]
Model 2: [0.515625, 0.296875, -0.1826171875, -0.12890625, -0.51953125, -0.3359375]
```
</details>
Follow-ups:
- I suspect that for `SHARD_GRAD_OP`, train forward -> eval forward when using full-precision in eval will not free the low-precision unsharded parameters from the train forward, resulting in 1.5x unsharded parameter memory.
Differential Revision: [D47527597](https://our.internmc.facebook.com/intern/diff/D47527597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105346
Approved by: https://github.com/fegin, https://github.com/rohan-varma
This PR adds necessary plumbing through torchdynamo to allow tensor
subclasses with certain contract (i.e. with `__tensor_flatten__` and
`__tensor_unflatten__`) to goes through the dynamo fakification pass by
fakifying the tensor subclass internal components.
Some of the tensor subclass contract logic mostly borrowed from
https://github.com/pytorch/pytorch/pull/97540
Added some tests to verify simply passing through a tensor subclass
(i.e. DTensor) through dynamo eager works as expected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105308
Approved by: https://github.com/ezyang
This PR canonicalize the detach callsite to only call the detach
from `distribute_tensor`. Change other callsite to view_as and remove the
tensor constructor detach call
This is so that we don't detach local tensor for every op run when
rewrapping the DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105239
Approved by: https://github.com/albanD
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
constraints:
1. No support for gradient accumulation
2. CPU offload runs step() on CPU. In future PRs ideally we'd run this on GPU.
3. When CPU offload + optimizer overlap, we have to copy the flat_param grad to CPU with non_blocking=False, otherwise step() might run on invalid data.
4. Step is waited on in post backward final cb, when in theory it can wait until the next forward.
Differential Revision: [D44809582](https://our.internmc.facebook.com/intern/diff/D44809582/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98667
Approved by: https://github.com/awgu, https://github.com/fegin
Purely out of preference, this PR renames the streams to `_unshard_stream` instead of `_streams_unshard` etc. since the former reads more naturally. The PR also removes some duplicated comments and adds back a unit test that streams are shared.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104966
Approved by: https://github.com/rohan-varma
When creating DeviceMesh, _init_process_group() would validate that all calling ranks pass in the same `mesh` argument. In FSDP, we are currently creating the DeviceMesh based on the pg of the root state so the mesh will always be valid. Adding the flag to DeviceMesh, so we can skip the all_gather_tensor of the validation during construction time.
_validate_mesh is default to True, but we manually flip it to False when initializing device mesh in FSDP's _runtime_utils.py.
Will modify skipping pg creation if existed for both 1D and 2D cases and then delete _init_process_groups flag in a follow up PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104807
Approved by: https://github.com/wanchaol
Not sure, how it worked before, but if arguments must be annotated is optional if they are defaulted to None
Towards enabling mypy-1.4.1 in lintrunner
<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 5e1b9f4</samp>
> _We annotate the arguments of doom_
> _To show the `None` values of gloom_
> _We improve the type checking and readability_
> _With `Optional` annotations of metal-ity_
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105022
Approved by: https://github.com/izaitsevfb, https://github.com/huydhn, https://github.com/Skylion007
Originally, we didn't enable BWD for colwise embedding because we thought it was just for inference, but it turns out that we do need it for training. So, let's enable it for now and unit test is also added.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104820
Approved by: https://github.com/fegin
Summary:
This diff does the following:
1. re-enable single_file_per_rank for FsspecWriter, as the issue of file slicing error is resolved because of [https://github.com/pytorch/pytorch/pull/99167]
2. remove sync_files from FsspecWriter as there is no fsspec equivalence.
3. remove the internal implementation of FsspecWriter/Reader, as it has been upstreamed to PyTorch OSS
4. keep the internal test for manifold inside internal as we can only test it in fb environment
5. consolidate test to remove duplicates
6. remove unnecessary TARGETS
Test Plan:
```
buck test @//mode/dev-nosan //caffe2/test/distributed/checkpoint/fb:test_fsspec_filesystem -- --print-passing-details
----------------------------------------------------------------------
Ran 1 test in 54.894s
OK
/usr/local/fbcode/platform010/lib/python3.8/tempfile.py:818: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmpzomokvh6'>
_warnings.warn(warn_message, ResourceWarning)
Buck UI: https://www.internalfb.com/buck2/4cb722a2-3ee7-48f2-a9ef-55ee6fb1a498
Test UI: https://www.internalfb.com/intern/testinfra/testrun/8725724447995201
Network: Up: 8.8 MiB Down: 1.5 GiB (reSessionID-04c29f56-ae94-4187-8a1a-c812f432674d)
Jobs completed: 209847. Time elapsed: 1:56.5s.
Cache hits: 100%. Commands: 85687 (cached: 85687, remote: 0, local: 0)
Tests finished: Pass 3. Fail 0. Fatal 0. Skip 0. Build failure 0
```
Differential Revision: D47266068
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104724
Approved by: https://github.com/fegin, https://github.com/fduwjj
When using KeyedOptimizer.init_state(), some optimizers initializes the states even if the param is empty (size() == 0) while some optimizer avoid initializing the states. There is no way FSDP can tell. Instead, FSDP should look up `optim.state`. Fortunatelly, `optim.state` does not rely on FQNs which some internal users change the FQNs.
Differential Revision: [D47285562](https://our.internmc.facebook.com/intern/diff/D47285562/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104765
Approved by: https://github.com/fduwjj
The "for now" is because we still have the issue that when using the parameter `ignored_states` path, we do not recover the ignored modules, so FSDP still wraps those as empty shells (no managed parameters), which is not ideal. This is not a blocking issue as far as I know.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104418
Approved by: https://github.com/rohan-varma
This moves `fully_shard` to use `_auto_wrap()` just like `FullyShardedDataParallel`. This means that `fully_shard` goes through the `_init_param_handle_from_module()` path (i.e. 1 `fully_shard` per "wrap"), removing the need for `_init_param_handles_from_module()` (which was 1 `fully_shard` for all "wraps" of a given policy). `_auto_wrap()` simply calls `fully_shard` on target submodules.
This includes several important fixes:
- We should register the pre/post-forward hooks on the module regardless of it has managed parameters.
- We can permit `_module_handles` to return `[]` in the composable path (for when the module has no managed parameters).
- We should unify the paths for `_get_buffers_and_dtypes_for_computation()` (previously, composable path was buggy in some cases).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104408
Approved by: https://github.com/rohan-varma
This PR is the first in refactoring the auto wrapping, only affecting `ModuleWrapPolicy` for wrapper `FullyShardedDataParallel`. The end goal is to improve the auto wrapping infra to support:
- Checking valid frozen parameters (uniform frozenness per FSDP)
- Checking valid shared parameters (shared parameters assigned to their lowest-common-ancestor module or higher)
- Writing auto wrapping policies that may take multiple passes over the module tree
- Specifying different FSDP kwargs per FSDP instance (instead of enforcing the same for all FSDP instances constructed via an auto wrap policy)
The way I envision achieving this is that, we decouple the actual "wrapping" (which is `_post_order_apply()` in this PR) from constructing the wrapping targets and kwargs (which is `target_module_to_kwargs` in this PR). In that way, a policy reduces to just constructing that latter `target_module_to_kwargs` mapping.
I do not personally recommend the size-based policy, but if we wanted to implement that under this new organization, the tracking of wrapped/nonwrapped numel should be done in the pass over the module tree prior to the actual "wrapping". This modularization keeps the actual "wrapping" part simple.
The change to how `old_dtype` is handled is mainly to avoid keeping a reference to `_override_module_mixed_precision()` function closure in each hook and to allow the function to take in all module clases at once to return which ones actually got overridden for the downstream error message. (We can directly store the global state as a mapping.)
To-do in follow-ups (not in order):
- Add frozen parameter check before `_post_order_apply()`
- Add shared parameter check before `_post_order_apply()`
- Expose wrapping policy that allows per module / per module class kwarg customization (where any unspecified kwarg adopts the root's kwarg)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104346
Approved by: https://github.com/rohan-varma, https://github.com/fegin
In https://github.com/pytorch/pytorch/pull/97645 and some follow up diffs, we made FSDP run in full precision in eval mode, even if mixed precision was specified.
However, this is probably not the best idea and we should provide a flag for users to have control over this a bit more. Adding an env var FSDP_FULL_PREC_IN_EVAL and defaulting it to off, users who want to run eval in fp32 can toggle this before wrapping model in FSDP:
os.environ["FSDP_FULL_PREC_IN_EVAL"] = "1"
Verified that unittests, APS workflow, TNT workloads can run eval appropriately with this change.
Differential Revision: [D47246556](https://our.internmc.facebook.com/intern/diff/D47246556/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104682
Approved by: https://github.com/awgu
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.
Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
This allows us use use_dtensor=True for ShardedStateDictConfig() before calling model.load_state_dict(). It only works for offload_to_cpu=False now.
Next PR will make use_dtensor=True work with offload_to_cpu=True for load_state_dict().
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104087
Approved by: https://github.com/fegin
This addresses https://github.com/pytorch/pytorch/issues/104187.
After this PR, the contract with the user is that:
- If passing `param_init_fn=None`, each `nn.Module.reset_parameters()` should only initialize its own parameters/buffers (like `parameters(recurse=False)`/`buffers(recurse=False)`).
- If passing `param_init_fn` not equal to `None`, then similarly, one call to `param_init_fn(module)` should only initialize `module`'s own parameters/buffers.
With this contract and this PR's changes, meta device initialization through either `reset_parameters()` or `param_init_fn` should be correct. Those functions will run on the original parameter/buffer shapes allowing for correct shape-dependent computations like for fan-in/fan-out, and there will not be any re-initialization of any module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104189
Approved by: https://github.com/rohan-varma
Since we do not call `_FSDPState.__init__()` and only use it for typing, it is not possible for these attributes to be `None`. The purpose of these `assert`s is to make sure that these attributes are set by `_init_process_group_state_for_hybrid_shard()`. If we care to make that explicit, I would posit that we should be using `hasattr` checks, not `is not None` checks, because if indeed `_init_process_group_state_for_hybrid_shard()` did not set these attributes, then even checking that it is not `None` would lead to an `AttributeError`. I do not include these `hasattr` checks for now since `_init_process_group_state_for_hybrid_shard()` is short enough that we can quickly tell by inspection that it sets the desired attributes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104274
Approved by: https://github.com/rohan-varma
This checks that `ignored_modules` and `ignored_states` have the expected type and provides a reasonable error message if not. Otherwise, if someone passes a mix of modules and parameters to `ignored_states` for example, then our code may be silently incorrect.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104273
Approved by: https://github.com/rohan-varma
This fixes https://github.com/pytorch/pytorch/issues/104148 (unfreezing parameters after `n` steps).
- This fixes a bug where we did not delete the post-backward hook state properly for the `requires_grad=False` case.
- This makes the `already_resharded` correct for `SHARD_GRAD_OP`.
- This generalizes `_clear_grads_if_needed()` to `_reset_flat_param_grad_info_if_needed()` to additionally include propagating the original parameters' `requires_grad` to the flat parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104186
Approved by: https://github.com/rohan-varma, https://github.com/fegin
# Change
This PR adds two classes to DTensor:
1. `CudaRNGStateTracker`: `CudaRNGStateTracker` stores Random Number Generator (RNG) state (a `ByteTensor` object) in a `dict`, mapping from a corresponding tag to each state tensor. It also provides a set of convenient utility methods to help access/modify the state tensors. The most important interface is `_distribute_region` which will be used when DTensor executes a random op (an operator that calls RNG).
2. `OffsetBasedRNGTracker`: This subclass of `CudaRNGStateTracker` defines the default policy of how RNG states should be shared and synchronized among all ranks to respect the semantics of DTensor random operators.
# Warning
- With `Multi-threaded ProcessGroup`, the global variable `_rng_tracker` will be shared among threads(ranks) and cause issue. We need to figure out a compatible solution for that.
- The RNG state may be asynchronous outside of participating ranks. It is harmless in our current use case of submesh though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103235
Approved by: https://github.com/wanchaol
Summary:
Details in T133020932
First commit of collective utils library. Ported over from model store, removed scuba logging, error_trait and all dependencies on modelstore.
Test Plan: In the following diffs.
Differential Revision: D45545970
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101037
Approved by: https://github.com/H-Huang