resolves https://github.com/pytorch/pytorch/issues/109101
The problem is essentially because we were hashing all the arguments, including
the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might
change everytime we call the op, thus cache miss everytime we call the op
This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:
* only hash arguments that are tensor or have static_argnum, this is to
enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
as we currently hashing all args which exclude those cases
* with the correct cache behavior, optimizers will hit the cache again
and resolve the high cpu overhead issue.
simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements:
<img width="1172" alt="Screenshot 2023-09-14 at 11 06 07 AM" src="https://github.com/pytorch/pytorch/assets/9443650/3406d673-dd8d-4ad9-9b80-9d4721c430e3">
Adam optimizer shows aten.div hit sharding cache again
<img width="1016" alt="Screenshot 2023-09-14 at 11 02 10 AM" src="https://github.com/pytorch/pytorch/assets/9443650/4280e8e3-af44-4fc2-8360-ea80b768f1d9">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109306
Approved by: https://github.com/fduwjj
The execution order check seems to have been causing more problems than it prevents. Motivated by an internal issue, we can move this check to only `DISTRIBUTED_DEBUG_LEVEL=DETAIL`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109049
Approved by: https://github.com/fegin
We were using make_fx for strategy based propagation so that we can get
a graph and the shape related metadata, this becomes too much overkill
for the sharding propagation purpose. This change refactors the strategy
propagation to remove the graph based propagation, instead just use the
op to index to the strategy functions.
We also just use a fake shape prop instead of relying on fx tracing for
the shape/stride propagation.
for a future possible decomposed propagation, we will exercise different
codepath to enable that
NOTE that this would also greatly reduce the latency of:
1. first time dtensor operations when populating the cache, the first
iter would become faster again!
2. greatly reduce the test_dtensor_ops.py time again, right now the
whole test finished within 2-3 mins again.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108262
Approved by: https://github.com/fduwjj
ghstack dependencies: #107306, #108261
This PR switches the usage of fx's shape prop TensorMetadata to
dtensor's own dedicated defined TensorMeta, this is because DTensor
only cares three fields: shape/stride/dtype, all other fields are not
necessary and can be inferred from local_tensor directly. This would
help significantly simplify how we deal with the tensor metadata by not
caring other fields.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108261
Approved by: https://github.com/fduwjj
ghstack dependencies: #107306
function schema doesn't provide us anything as we can also get the schema from `op._schema`, include the op directly in op_schema makes easier for sharding prop to do fake execution, and in principle it should also make the hash comparison faster as we don't need to hash the function schema, instead we just hash the `id(op)` which is constant
This PR is just a refactor to include op to OpSchema instead of func schema, no other logic changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107306
Approved by: https://github.com/fduwjj
This PR:
1. Drop assert for 1D DeviceMesh check to allow DTensor with nD DeviceMesh when creating write_item.
2. Add tests for both placement changes and mesh changes for both 1D and 2D scenarios.
cc. @kumpera @wanchaol @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106230
Approved by: https://github.com/kumpera
When use hybrid_shard mode FSDP,
state.process_group means gpu_0,1,,,~,7 on node 0,so gpus on node 1 cannot receive parameters, setting process_group to default_group(global_group)can fix this issue
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108331
Approved by: https://github.com/awgu
This PR removes four usages of compute_local_offset() in PyTorch repo and replaces it with the new API compute_local_shape_and_global_offset().
We will be removing compute_local_offset() API in the next diff, as there are usages internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108547
Approved by: https://github.com/wanchaol
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
This PR adds SimpleProfiler for FSDP state_dict/load_state_dict logging purpose. SimpleProfiler use class variables to record profiling results and it does everything in the Python which can be slow. So it is only suitable for logging slow actions such as initialization and state_dict/load_state_dict.
This PR uses SimpleProfiler to log some critical/slow paths of the model and optimizer state_dict/load_state_dict.
Differential Revision: [D48774406](https://our.internmc.facebook.com/intern/diff/D48774406/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108290
Approved by: https://github.com/wz337
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
This PR fixes the new_empty_strided op to become replicate from sharding
when necessary, this is a quick fix to resolve https://github.com/pytorch/pytorch/issues/107661
We'll need to think more about the behavior of this op when it comes to
sharding, one possibility is to follow the input sharding, but given the
output shape of this op might not be the same as the input, it's hard to
say we should follow the input sharding, further improvement needed once
we figure out the op syntax
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107835
Approved by: https://github.com/fduwjj
We have a plethora of error types for various errors raised from c10d. These include `RuntimeError`, `TimeoutError`, `SocketError`, `DistBackendError` etc.
This results in messy code during error handling somewhat like this:
```
if "NCCL" in exception_str:
...
if "Timed out initializing process group in store based barrier on rank" in exception_str:
...
if "The client socket has timed out after" in exception_str:
...
if "Broken pipe" in exception_str:
...
if "Connection reset by peer" in exception_str:
...
```
To address this issue, in this PR I've ensured added these error types:
1. **DistError** - the base type of all distributed errors
2. **DistBackendError** - this already existed and referred to PG backend errors
3. **DistStoreError** - for errors originating from the store
4. **DistNetworkError** - for general network errors coming from the socket library
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108191
Approved by: https://github.com/H-Huang
The compute_local_shape_and_global_offset API does the following:
1) Calculate both local_shape and global_offset in one API to replace two API calls (compute_local_size and compute_local_shape).
2) Generate the correct global_offset for checkpointing purposes. We are currently using compute_local_offset for downstream checkpoint components, which could lead to incorrect results. For checkpointing, we need global_offset instead of local_offset. In some cases, global_offset does not equal to local_offset, when a dimension is sharded multipe times on different mesh dimension (e.g. placements = [Shard(0), Shard(0)]).
Follow-up PRs:
1) Replace related downstream components to use compute_local_shape_and_global_offset instead of compute_local_size and compute_local_offset.
2) Audit existing code base to see if we can remove compute_local_size and compute_local_offset, since they are currently being used.
cc. @wanchaol
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107996
Approved by: https://github.com/wanchaol
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:
(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests
(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.
(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415
Approved by: https://github.com/ezyang
We have a plethora of error types for various errors raised from c10d. These include `RuntimeError`, `TimeoutError`, `SocketError`, `DistBackendError` etc.
This results in messy code during error handling somewhat like this:
```
if "NCCL" in exception_str:
...
if "Timed out initializing process group in store based barrier on rank" in exception_str:
...
if "The client socket has timed out after" in exception_str:
...
if "Broken pipe" in exception_str:
...
if "Connection reset by peer" in exception_str:
...
```
To address this issue, in this PR I've ensured added these error types:
1. **DistError** - the base type of all distributed errors
2. **DistBackendError** - this already existed and referred to PG backend errors
3. **DistStoreError** - for errors originating from the store
4. **DistNetworkError** - for general network errors coming from the socket library
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107651
Approved by: https://github.com/H-Huang
Given standalone generates args anyways, it seems like it would be more convenient if it explicitly used a random port by default instead of trying to use 29400.
That way users can directly go with `--standalone` instead of having to spell out `--rdzv-backend=c10d --rdzv-endpoint=localhost:0`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107734
Approved by: https://github.com/H-Huang
As the title says, I was trying to test the functional collectives, and, when printing the resulting tensors, sometimes they wouldn't have finished the Async operation yet. According to the comments in the file, "AsyncTensor wrapper applied to returned tensor, which issues wait_tensor() at the time of first use". This is true in most cases, but not when print() is your first use. This PR fixes that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107808
Approved by: https://github.com/fduwjj
**Overview**
This PR runs the HSDP all-reduce as async so that it can overlap with both all-gather and reduce-scatter, which can lead to slight end-to-end speedups when the sharding process group is fully intra-node. Previously, the all-reduce serializes with reduce-scatter, so it can only overlap with one all-gather.
For some clusters (e.g. our AWS cluster), `NCCL_CROSS_NIC=1` improves inter-node all-reduce times when overlapped with intra-node all-gather/reduce-scatter.
**Experiment**
<details>
<summary> Example 'before' trace </summary>
<img width="559" alt="hsdp_32gpus_old" src="https://github.com/pytorch/pytorch/assets/31054793/15222b6f-2b64-4e0b-a212-597335f05ba5">
</details>
<details>
<summary> Example 'after' trace </summary>
<img width="524" alt="hsdp_32gpus_new" src="https://github.com/pytorch/pytorch/assets/31054793/94f63a1d-4255-4035-9e6e-9e10733f4e44">
</details>
For the 6-encoder-layer, 6-decoder layer transformer with `d_model=8192`, `nhead=64` on 4 nodes / 32 40 GB A100s via AWS, the end-to-end iteration times are as follows (with AG == all-gather, RS == reduce-scatter, AR == all-reduce; bandwidth reported as algorithmic bandwidth):
- Reference FSDP:
- **1160 ms / iteration**
- ~23 ms / encoder AG/RS --> 24.46 GB/s bandwidth
- ~40 ms / decoder AG/RS --> 26.5 GB/s bandwidth
- 50 GB/s theoretical inter-node bandwidth
- Baseline 8-way HSDP (only overlap AR with AG) -- intra-node AG/RS, inter-node AR:
- **665 ms / iteration**
- ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
- ~5 ms / decoder AG/RS --> 212 GB/s bandwidth
- ~30 ms / encoder AR --> 2.34 GB/s bandwidth
- ~55 ms / decoder AR --> 2.65 GB/s bandwidth
- 300 GB/s theoretical intra-node bandwidth
- New 8-way HSDP (overlap AR with AG and RS) -- intra-node AG/RS, inter-node AR:
- **597 ms / iteration**
- ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
- ~6.2 ms / decoder AG/RS --> 170.97 GB/s bandwidth (slower)
- ~23 ms / encoder AR (non-overlapped) --> 3.057 GB/s bandwidth (faster)
- ~49 ms / decoder AR (non-overlapped) --> 2.70 GB/s bandwidth (faster)
- ~100 ms / decoder AR (overlapped) --> 1.325 GB/s bandwidth (slower)
- Overlapping with reduce-scatter reduces all-reduce bandwidth utilization even though the all-reduce is inter-node and reduce-scatter is intra-node!
- New 8-way HSDP (overlap AR with AG and RS) with `NCCL_CROSS_NIC=1`:
- **556 ms / iteration**
- Speedup comes from faster overlapped AR
Thus, for this particular workload, the async all-reduce enables 16% iteration-time speedup compared to the existing HSDP and 52% speedup compared to FSDP. These speedups are pronounced due to the workload being communication bound, so any communication time reduction translates directly to speedup.
**Unit Test**
This requires >= 4 GPUs:
```
python -m pytest test/distributed/fsdp/test_fsdp_hybrid_shard.py -k test_fsdp_hybrid_shard_parity
```
Differential Revision: [D47852456](https://our.internmc.facebook.com/intern/diff/D47852456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106080
Approved by: https://github.com/ezyang
ghstack dependencies: #106068
The post-backward hook has some complexity due to the different paths: {no communication hook, communication hook} x {`NO_SHARD`, `FULL_SHARD`/`SHARD_GRAD_OP`, `HYBRID_SHARD`/`_HYBRID_SHARD_ZERO2`} plus some options like CPU offloading and `use_orig_params=True` (requiring using sharded gradient views).
The PR following this one that adds async all-reduce for HSDP further complicates this since the bottom-half after all-reduce must still be run in the separate all-reduce stream, making it more unwieldy to unify with the existing bottom-half.
Nonetheless, this PR breaks up the post-backward hook into smaller logical functions to hopefully help readability.
Differential Revision: [D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068
Approved by: https://github.com/ezyang, https://github.com/fegin
The `broadcast_object_list` function can easily broadcast the state_dict of models/optimizers. However, the `torch.cat` operation performed within `broadcast_object_list` consumes an additional double amount of memory space. This means that only objects with a maximum memory occupancy of half the device capacity can be broadcasted. This PR improves usability by skipping the `torch.cat` operation on object_lists with only a single element.
Before (30G tensor):
<img width="607" alt="image" src="https://github.com/pytorch/pytorch/assets/22362311/c0c67931-0851-4f27-81c1-0119c6cd2944">
After (46G tensor):
<img width="600" alt="image" src="https://github.com/pytorch/pytorch/assets/22362311/90cd1536-be7c-43f4-82ef-257234afcfa5">
Test Code:
```python
if __name__ == "__main__":
dist.init_process_group(backend='nccl')
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
fake_tensor = torch.randn(30 * 1024 * 1024 * 1024 // 4)
if dist.get_rank() == 0:
state_dict = {"fake_tensor": fake_tensor}
else:
state_dict = {}
object_list = [state_dict]
dist.broadcast_object_list(object_list, src=0)
print("Rank: ", dist.get_rank(), " Broadcasted Object: ", object_list[0].keys())
dist.barrier()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107509
Approved by: https://github.com/awgu
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
This PR fixes the requires_grad set when calling distribute_tensor, we
should set the requires_grad of the local tensor after the detach call
to make sure we create the leaf correctly, otherwise it would raise
warnings
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107606
Approved by: https://github.com/fduwjj
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.
I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.
This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms

after (with this change), aten.addmm latency: 0.341ms

overall one layer of mlp time reduced from 13.535 -> 9.665ms
Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
We cannot use inner tensors for finalizers as they are uncollective until waited.
This PR adds a bunch of tests for the observable behavior we want, including the
necessary scafold for us to test code for their waitiness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107250
Approved by: https://github.com/wconstab
This allows infra/trainers to get detailed stats about communication
efficiencies without know anything about what model or distributed
training paradigms have been used. This is helpful as infra/trainer
package usually prefers to be as model/algorithm agnostic as possible.
Therefore, we cannot assume that infra/trainer can have access to all
collectives used by the model authors.
This commit adds an `OnCompletion` hook to `ProcessGroupNCCL` which
will be fired on every work completion event.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107233
Approved by: https://github.com/kumpera
https://github.com/pytorch/pytorch/pull/106524 gets merged so fast that we didn't figure out that we should hash both stride and dtype in DTensorSpec. This is a forward fix.
One analysis for why using just shape is not enough.
1. We use the hash value for sharding propogation cache. And the output sharding contains the stride, size of the output DTensor. If we don't consider stride, we will see errors.
2. One reason can be found below:
```
OpSchema(func_schema=aten::t(Tensor(a) self) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]), placements=(Shard(dim=0),), tensor_meta=TensorMetadata(shape=torch.Size([64, 128]), dtype=torch.float32, requires_grad=False, stride=(128, 1), memory_format=None, is_quantized=False, qparams={})),), kwargs_schema={})
```
```
OpSchema(func_schema=aten::t(Tensor(a) self) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]), placements=(Shard(dim=0),), tensor_meta=TensorMetadata(shape=torch.Size([64, 128]), dtype=torch.float32, requires_grad=False, stride=(1, 64), memory_format=None, is_quantized=False, qparams={})),), kwargs_schema={})
```
The only difference between two op_schame is the tensor stride:
<img width="151" alt="image" src="https://github.com/pytorch/pytorch/assets/6937752/161335df-bdfb-47c5-ba79-82616d070d15">
that makes the transpose op generates wrong result and leads to the add_/addmm_ op failing with errors:
```
Traceback (most recent call last):
File "/data/users/fduwjj/pytorch/torch/multiprocessing/spawn.py", line 74, in _wrap
fn(i, *args)
File "/data/users/fduwjj/pytorch/benchmarks/distributed/tensor/tp_benchmark.py", line 210, in run_tp
output.sum().backward()
File "/data/users/fduwjj/pytorch/torch/_tensor.py", line 491, in backward
torch.autograd.backward(
File "/data/users/fduwjj/pytorch/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/api.py", line 252, in __torch_dispatch__
return op_dispatch.operator_dispatch(
File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/dispatch.py", line 116, in operator_dispatch
out, _, _ = _operator_dispatch(op_call, args, kwargs, sharding_propagator)
File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/dispatch.py", line 246, in _operator_dispatch
local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
File "/data/users/fduwjj/pytorch/torch/_ops.py", line 435, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: The size of tensor a (64) must match the size of tensor b (8) at non-singleton dimension 1
```
Same thing with dtype, if we are using DTensor in the environment of mixed precision, we will run into situations like this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107181
Approved by: https://github.com/wanchaol
ghstack dependencies: #106524
This allows infra/trainers to get detailed stats about communication
efficiencies without know anything about what model or distributed
training paradigms have been used. This is helpful as infra/trainer
package usually prefers to be as model/algorithm agnostic as possible.
Therefore, we cannot assume that infra/trainer can have access to all
collectives used by the model authors.
This commit adds an `OnCompletion` hook to `ProcessGroupNCCL` which
will be fired on every work completion event.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106988
Approved by: https://github.com/kumpera, https://github.com/H-Huang
ghstack dependencies: #107140, #107141, #107160
…out specifying the Backend
When init_process_group is not been done before, it will automatically apply init_process_group within Devicemesh without specifying the backend. Thus, when a third-party device want to use Devicemesh without doing init_process_group before, there comes a problem. In this PR, add a default_device_backend_map for third-party device users to add their backends to this map when they register their backends to pytorch firstly. When doing init_process_group without parameter backend, it will init the backends in this map. Thus, a third-party user can use init_process_group method without specifying the Backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107113
Approved by: https://github.com/wanchaol
Summary:
When loading a CPU state_dict with a pg initialized with
cpu:gloo,cuda:nccl, we hit a gloo crash since dest tensor is on GPU and input
is on CPU.
As a workaround, just enforce that if local_tensor.is_cpu, the dest tensor is
also cpu.
Test Plan: CI
Differential Revision: D48324752
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107172
Approved by: https://github.com/fegin
Move the remaining collectives to a separate file to prepare device mesh
to become a public distributed API
For those remaining utils, we need to upstream them to functional
collectives with proper implementation, added TODO there for a follow up
PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107012
Approved by: https://github.com/fduwjj
AsyncCollectiveTensor is a tensor subclass that is meant to "delay synchronization" when you call into the functional collectives API's. It does this (if I understand correctly) by internally holding an "unsynchronized" version of the tensor, which is the result of the communication op, and internally calling `.wait()` to synchronize the data the next time it is used.
Previously, these wait() calls would happen immediately, because `AsyncCollectiveTensor` gets wrapped by `DTensor()`, which calls `.detach()` on its inner tensor, immediately causing the sync (code: 1518d5eec4/torch/distributed/_tensor/api.py (L207))
AsyncCollectiveTensor shouldn't need to do a synchronization if you try to detach() it though - in fact, it should be fine to avoid synchronizing if you perform any view ops on it (which just require viewing metadata, but not actual data). This PR tries to update `AsyncCollectiveTensor` to delay `wait()` calls whenever the subclass encounters a view op.
Added some light testing, that just runs some DTensor compute followed by view ops, and confirms that the output is still an `AsyncCollectiveTensor` when we call `.to_local()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105240
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/wconstab
This fixes a pretty vicious bug relating to `SHARD_GRAD_OP`, mixed precision, EMA, and eval.
**Bug Explanation**
The model has a main module and an EMA module, where the main module is used for training and the EMA module is used for eval. The model has FSDP's fp16 mixed precision enabled. The flow consists of (1) training forward/backward/optimizer -> (2) EMA update (copy main module to EMA module) -> eval forward in `torch.no_grad()`, where this repeats for many iterations.
Consider the _second_ iteration.
- From the first iteration's eval forward, the EMA module has the fp16 unsharded parameters in memory (not freed due to `SHARD_GRAD_OP`).
- In this second iteration's step (2), we perform the EMA update under the `summon_full_params()` context, where FSDP specially forces full precision. This means that the EMA module now uses fp32 unsharded parameters, distinct from the fp16 unsharded parameters still in memory. The EMA update modifies those fp32 parameters, and upon exiting the context, FSDP correctly writes the modifications back to the fp32 sharded parameters.
- In the second iteration's step (3) (eval forward), FSDP checks whether it needs to run the unshard op (including all-gather) but sees it does not since the fp16 unsharded parameters are still in memory. Thus, FSDP uses those fp16 unsharded parameters directly without all-gather. However, these fp16 unsharded parameters are stale and do not include the EMA update!
- In other words, at this point, the fp32 sharded parameters are correct, the fp16 unsharded parameters are stale, and FSDP chooses _not_ to re-all-gather since the fp16 unsharded parameters are in memory.
**Fix Explanation**
This PR fixes this by freeing the fp16 unsharded parameters if they are still allocated when forcing full precision, i.e. using fp32 unsharded parameters in `summon_full_params()`. This ensures that any modifications written back to the fp32 sharded parameters will be persisted via the next all-gather.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106858
Approved by: https://github.com/kumpera
ghstack dependencies: #106857
issue resolved: https://github.com/pytorch/pytorch/issues/97791
before this PR, mixed_precision applies to buffers from ignored modules. see ```test_state_dict_with_ignored_modules(mixed_precision=True)``` for reproduce
after, we avoid applying mixed_precision semantics to buffers from ignored modules
* step 1 initialization: state._ignored_buffer_names contains all the buffers from ignored modules
* step 2 lazy init at runtime: skip ignored buffers in ```_get_buffers_and_dtypes_for_computation```
* step 3 skip upcasting in state_dict hook: avoid upcasting for ignored buffers in ```_get_buffers_and_dtypes_for_computation```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106766
Approved by: https://github.com/awgu
Currently, DCP treats tensors as duplicates and only saves them on rank0. This won't work for PiPPy as PiPPy does have unique tensors across different ranks. With the current setup, we would only be saving the tensors on rank0 (coordinator rank).
In this PR, we are changing to letting each rank create its own WriteItem for tensors. For the ones that does replicate across different ranks, we are handling it thru dedup_tensors(), which will dedup the replicate WriteItem so we only do the actual writing once.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106415
Approved by: https://github.com/wz337
This PR adds a new `CustomPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.
The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
...
policy = CustomPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root
---
After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104986
Approved by: https://github.com/ezyang
ghstack dependencies: #104427, #104967, #104999, #104969
This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.
This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999