Commit Graph

2331 Commits

Author SHA1 Message Date
Chien-Chin Huang
1b3e5b53f3 [FSDP][optim_state_dict] Add device to _shard_utils.py to explicitly use the device from fsdp_state (#109631)
_get_pg_default_device does not always get the device we want. This PR let the user explicitly tell use the correct device.

Differential Revision: [D49425743](https://our.internmc.facebook.com/intern/diff/D49425743/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109631
Approved by: https://github.com/awgu, https://github.com/fduwjj, https://github.com/wz337
2023-09-20 01:59:38 +00:00
Wanchao Liang
9a95b4bc7b [dtensor] quick fix to #109306 (#109428)
Looks like the op argument schema type check is not reliable.. for
things like aten.div.Tensor(Tensor, Tensor), the second argument can still be
a float/scalar for some reason, switch to check with the instance type
directly
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109428
Approved by: https://github.com/awgu, https://github.com/fegin
2023-09-16 20:53:55 +00:00
wz337
0aedacb4f7 [2D][1/N] Add _enable_extension to fsdp state (#109242)
Add _enable_extension to fsdp state. We will use this to determine whether we should enable the extension or not.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109242
Approved by: https://github.com/fegin
2023-09-16 19:03:10 +00:00
wz337
322bf50dbe [2D][2/N][DeviceMesh] Add get_parent_mesh_dim() in _MeshEnv class (#109330)
Adding some additional APIs that are needed for 2D workflow.

Since each parallelism is only aware of its own mesh when we are constructing 2D state_dict. We need to know the mesh_dim of the child mesh in the parent mesh. So, we can use it to create DTensor that is 2D sound.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109330
Approved by: https://github.com/fegin, https://github.com/fduwjj, https://github.com/wanchaol
2023-09-16 19:03:04 +00:00
Edward Z. Yang
ec8b58f5ba Add support for tolist on AsyncCollectiveTensor (#109377)
This has to be done by hand because tolist isn't supported on tensor subclasses.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109377
Approved by: https://github.com/wconstab, https://github.com/fduwjj
2023-09-15 21:48:13 +00:00
Brian
806c52b4c9 Update chunk_sharding_spec.py (#108915)
Fixes #108869

Implements the first solution proposed in the issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108915
Approved by: https://github.com/wanchaol, https://github.com/wz337
2023-09-15 21:43:15 +00:00
Brian
ab99a95470 Update planner.py (#107998)
Fixes #107997
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107998
Approved by: https://github.com/wz337
2023-09-15 18:12:45 +00:00
Wanchao Liang
9456de937b [dtensor] Fix and improve the sharding cache behavior (#109306)
resolves https://github.com/pytorch/pytorch/issues/109101

The problem is essentially because we were hashing all the arguments, including
the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might
change everytime we call the op, thus cache miss everytime we call the op

This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:
* only hash arguments that are tensor or have static_argnum, this is to
enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
as we currently hashing all args which exclude those cases
* with the correct cache behavior, optimizers will hit the cache again
and resolve the high cpu overhead issue.

simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements:
<img width="1172" alt="Screenshot 2023-09-14 at 11 06 07 AM" src="https://github.com/pytorch/pytorch/assets/9443650/3406d673-dd8d-4ad9-9b80-9d4721c430e3">

Adam optimizer shows aten.div hit sharding cache again
<img width="1016" alt="Screenshot 2023-09-14 at 11 02 10 AM" src="https://github.com/pytorch/pytorch/assets/9443650/4280e8e3-af44-4fc2-8360-ea80b768f1d9">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109306
Approved by: https://github.com/fduwjj
2023-09-15 10:32:49 +00:00
Rodrigo Kumpera
2bca5f2af7 [C10D] Track pg name in c++. (#108813)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108813
Approved by: https://github.com/wconstab
2023-09-15 01:10:29 +00:00
Rohan Varma
25bf1a49c0 [FSDP][Wrap] ModuleWrapPolicy callable (#109117)
Makes ModuleWrapPolicy callable, in my case this is needed for
composition with or_policy. We should also make or_policy a public interface
IMO.

Differential Revision: [D49175112](https://our.internmc.facebook.com/intern/diff/D49175112/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109117
Approved by: https://github.com/fegin
ghstack dependencies: #109116
2023-09-14 07:14:18 +00:00
Rohan Varma
f558e86fa0 [FSDP] continue if param not exist in sharded load (#109116)
If I add a param and then wrap with FSDP + load state dict, when
strict=False don't hard error here.

Differential Revision: [D49170812](https://our.internmc.facebook.com/intern/diff/D49170812/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109116
Approved by: https://github.com/fegin
2023-09-14 07:14:18 +00:00
Andrew Gu
54dd65f93a [FSDP] Only check exec order if DETAIL (#109049)
The execution order check seems to have been causing more problems than it prevents. Motivated by an internal issue, we can move this check to only `DISTRIBUTED_DEBUG_LEVEL=DETAIL`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109049
Approved by: https://github.com/fegin
2023-09-13 20:40:38 +00:00
Randolf Scholz
32f50b7021 Improve type annotations for jit.script (#108782)
Fixes #108781

- [x] added `@overload` for `jit.script`
- [x] added typing unittest in `test/typing/pass/jit.py`
    - NOTE: unittest is not automatically checked by mypy when executing lintrunner currently. (how to fix?)
- [x] used `stubgen` to create [torch/jit/_script.pyi](https://github.com/pytorch/pytorch/pull/108782/files#diff-738e66abee2523a952b3ddbaecf95e187cce559473cf8c1b3da7c247ee5d1132) and added overloads there. (adding them inside `_script.py` itself interfered with JIT engine)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108782
Approved by: https://github.com/ezyang
2023-09-13 19:20:25 +00:00
Wanchao Liang
375d2ca6c9 [dtensor][4/n] don't use make_fx for strategy propagation (#108262)
We were using make_fx for strategy based propagation so that we can get
a graph and the shape related metadata, this becomes too much overkill
for the sharding propagation purpose. This change refactors the strategy
propagation to remove the graph based propagation, instead just use the
op to index to the strategy functions.

We also just use a fake shape prop instead of relying on fx tracing for
the shape/stride propagation.

for a future possible decomposed propagation, we will exercise different
codepath to enable that

NOTE that this would also greatly reduce the latency of:
1. first time dtensor operations when populating the cache, the first
iter would become faster again!
2. greatly reduce the test_dtensor_ops.py time again, right now the
whole test finished within 2-3 mins again.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108262
Approved by: https://github.com/fduwjj
ghstack dependencies: #107306, #108261
2023-09-13 04:08:02 +00:00
Wanchao Liang
09f3e08bcc [dtensor][3/n] use dedicated TensorMeta instead of the fx one (#108261)
This PR switches the usage of fx's shape prop TensorMetadata to
dtensor's own dedicated defined TensorMeta, this is because DTensor
only cares three fields: shape/stride/dtype, all other fields are not
necessary and can be inferred from local_tensor directly. This would
help significantly simplify how we deal with the tensor metadata by not
caring other fields.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108261
Approved by: https://github.com/fduwjj
ghstack dependencies: #107306
2023-09-13 04:08:02 +00:00
Wanchao Liang
fc1dcfb9ab [dtensor][2/n] use op overload instead of function schema (#107306)
function schema doesn't provide us anything as we can also get the schema from `op._schema`, include the op directly in op_schema makes easier for sharding prop to do fake execution, and in principle it should also make the hash comparison faster as we don't need to hash the function schema, instead we just hash the `id(op)` which is constant

This PR is just a refactor to include op to OpSchema instead of func schema, no other logic changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107306
Approved by: https://github.com/fduwjj
2023-09-13 04:08:02 +00:00
wz337
6dc56d3490 [DTensor] Remove compute_local_offset from _utils.py (#109096)
Separating internal changes with OSS changes. This PR contains removing the compute_local_offset from the OSS directory only.

This replaces https://github.com/pytorch/pytorch/pull/108965
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109096
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-09-12 21:55:15 +00:00
Iris
b6f9d4dbc4 [DCP] Enable nD device_mesh resharding DTensor in DCP and add associated tests (#106230)
This PR:
     1. Drop assert for 1D DeviceMesh check to allow DTensor with nD DeviceMesh when creating write_item.
     2. Add tests for both placement changes and mesh changes for both 1D and 2D scenarios.

cc. @kumpera  @wanchaol  @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106230
Approved by: https://github.com/kumpera
2023-09-12 00:47:58 +00:00
Brian Vaughan
bb14805bcd fix an incorrect indent in documentation (#108273)
doc for `torch.distributed.send(tensor, dst, group=None, tag=0)` was rendering incorrectly here: https://pytorch.org/docs/stable/distributed.html due to lack of indent (it was interpreting the continuation as a new argument).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108273
Approved by: https://github.com/awgu, https://github.com/kit1980
2023-09-11 21:27:52 +00:00
redwrasse
ba4782e3c0 cleanup typos; redundant parentheses (#109003)
- minor spelling fixes in `aten/src/ATen/core/TransformationHelper.h`
- remove redundant parentheses in control statements in `torch/distributed/algorithms/_quantization/quantization.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109003
Approved by: https://github.com/davidradl, https://github.com/H-Huang
2023-09-11 17:09:17 +00:00
lxg2015
e19a855b4d [HSDP] Fix Node 1 unable receive parameters from Node 0 (#108331)
When use hybrid_shard mode FSDP,
state.process_group means gpu_0,1,,,~,7 on node 0,so gpus on node 1 cannot receive parameters, setting process_group to default_group(global_group)can fix this issue

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108331
Approved by: https://github.com/awgu
2023-09-11 15:13:28 +00:00
redwrasse
f81eacd30c typo fix strategy_comb in basic_strategy.py (#108972)
Typo fix `startegy_comb` -> `strategy_comb` in `basic_strategy.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108972
Approved by: https://github.com/Skylion007
2023-09-10 15:58:15 +00:00
wz337
311fbe43e6 [DeviceMesh] Fix __getitem__ docstring typo (#108837)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108837
Approved by: https://github.com/wanchaol
2023-09-09 01:46:14 +00:00
Matthew Hoffman
e40d6ae0a7 Improve torch.cuda.amp type hints (#108630)
Fixes #108629

1. Add the following to their modules' `__all__` so that pyright considers them to be publicly exported:
* [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast)
* [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler)
* [`torch.cuda.amp.autocast`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast)
* [`torch.cuda.amp.custom_fwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd)
* [`torch.cuda.amp.custom_bwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_bwd)
2. Add `overload`s for `torch.cuda.amp.GradScaler.scale` to differentiate when a `torch.Tensor` is returned vs. an `Iterable[torch.Tensor]` is returned based on the type of the `outputs` parameter.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108630
Approved by: https://github.com/ezyang
2023-09-08 06:06:25 +00:00
Rodrigo Kumpera
b26af5d5ac [c10d] Add TCPSTore libuv backend support to c10d rendezvous. (#108284)
This enables libuv under env and tcp urls.

Under env either use the environment variable USE_LIBUV=1
or the url parameter use_lib=1.

Under tcp use the url parameter use_lib=1.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108284
Approved by: https://github.com/H-Huang, https://github.com/XilunWu
2023-09-07 21:39:58 +00:00
wz337
7bc25e38c0 [HSDP] Raise error when HSDP device_mesh has a parent_mesh (#108603)
As we don't currently support HSDP + TP yet, raises an error for HSDP initialization if a device_mesh passed in has a parent mesh.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108603
Approved by: https://github.com/awgu
2023-09-07 04:17:10 +00:00
wz337
ca2cdb3009 [DeviceMesh] Minor docstring update for init_device_mesh and rename variables (#108391)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108391
Approved by: https://github.com/wanchaol
2023-09-06 08:27:11 +00:00
wz337
49aa8d19dd [DTensor] Replace usage of compute_local_offset by compute_local_shape_and_global_offset (#108547)
This PR removes four usages of compute_local_offset() in PyTorch repo and replaces it with the new API compute_local_shape_and_global_offset().

We will be removing compute_local_offset() API in the next diff, as there are usages internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108547
Approved by: https://github.com/wanchaol
2023-09-06 04:53:44 +00:00
Rohan Varma
208fd1cb84 [RFC] Somewhat BC breaking: make checkpoint_wrapper default to NO_REENTRANT (#108435)
We should use no_reentrant. There are a lot of users of this API, but
it is in a prototype state so should be fine to change.

Differential Revision: [D48898148](https://our.internmc.facebook.com/intern/diff/D48898148/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108435
Approved by: https://github.com/awgu
ghstack dependencies: #108032, #108033
2023-09-05 21:43:41 +00:00
Rohan Varma
db6d09c086 [RFC][FSDP] Don't move ignored params / buffers to device (#108033)
Since these are ignored by FSDP, don't move them.

Differential Revision: [D48727044](https://our.internmc.facebook.com/intern/diff/D48727044/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108033
Approved by: https://github.com/awgu
ghstack dependencies: #108032
2023-09-05 21:43:41 +00:00
Rohan Varma
3334ec3a00 [RFC] Don't materialize ignored modules for FSDP (#108032)
Per title. This seems needed for cases where I have a large embedding
I want to separately manage, but FSDP would initialize it and thus consume the
memory.

Currently the interaction with torchdistX materialize_module is not tested,
this can be done as follow up work.

Differential Revision: [D48722046](https://our.internmc.facebook.com/intern/diff/D48722046/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108032
Approved by: https://github.com/awgu
2023-09-05 21:43:41 +00:00
Danielle Pintz
fee9fc1df0 [pytorch] Update docstring for FSDP.set_state_dict_type (#103864)
Summary: I noticed optim_state_dict_config was missing from the Args section

Test Plan: N/A

Reviewed By: rohan-varma

Differential Revision: D46670165

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103864
Approved by: https://github.com/rohan-varma, https://github.com/fegin, https://github.com/fduwjj
2023-09-05 21:43:31 +00:00
wz337
66af4f6ec7 [HSDP] Add device_mesh to FSDP kwarg and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-05 21:21:21 +00:00
Danielle Pintz
b2c6383f44 [pytorch] Small fix to docstring of FSDP.optim_state_dict_to_load (#108383)
Summary: Fix ordering of args in docstring

Test Plan: N/A

Differential Revision: D48889668

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108383
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wz337
2023-09-05 14:56:56 +00:00
Xilun Wu
a78b78cd76 [DTensor][random] add DTensor constructor: randn (#108285)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108285
Approved by: https://github.com/wanchaol
2023-09-01 20:28:41 +00:00
Chien-Chin Huang
591cb776af [FSDP][state_dict][optim_state_dict] Log slow optim and model state_dict paths (#108290)
This PR adds SimpleProfiler for FSDP state_dict/load_state_dict logging purpose. SimpleProfiler use class variables to record profiling results and it does everything in the Python which can be slow. So it is only suitable for logging slow actions such as initialization and state_dict/load_state_dict.

This PR uses SimpleProfiler to log some critical/slow paths of the model and optimizer state_dict/load_state_dict.

Differential Revision: [D48774406](https://our.internmc.facebook.com/intern/diff/D48774406/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108290
Approved by: https://github.com/wz337
2023-09-01 06:57:59 +00:00
wz337
90ef3b82d1 [DeviceMesh] Add unique mesh_dim_name check in init_device_mesh() (#108326)
Each mesh_dim_name in mesh_dim_names need to be unique. This PR adds check when calling init_device_mesh().
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108326
Approved by: https://github.com/wanchaol
2023-09-01 02:14:18 +00:00
PyTorch MergeBot
ab5b4c4419 Revert "[HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)"
This reverts commit cc220e45a8.

Reverted https://github.com/pytorch/pytorch/pull/107533 on behalf of https://github.com/huydhn due to Sorry for reverting this, but it is failing in trunk with the same failure on test_dynamo_distributed cc220e45a8 ([comment](https://github.com/pytorch/pytorch/pull/107533#issuecomment-1701983247))
2023-09-01 01:26:30 +00:00
Jirka Borovec
9178deedff removing some redundant str splits (#106089)
drop some redundant string splits, no factual changes, just cleaning the codebase

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106089
Approved by: https://github.com/albanD, https://github.com/malfet
2023-09-01 00:22:58 +00:00
wz337
cc220e45a8 [HSDP] Add device_mesh to FSDP and add dtensor state_dict support for HSDP (#107533)
This PR:
1) Add device_mesh kwarg to FSDP. Remove init_device_mesh() from _runtime_utils.py, as device_mesh would be passed in by user as an kwarg.
2) change use_dtensor flag for state_dict_config and optim_state_dict_config to be private. If device_mesh is used with sharded model/optim state dict, _use_dtensor flag would be set to True and model/optim state dict would return dtensor state_dict. Otherwise, _use_dtensor flag would be set to False and model/optim state dict would return sharded_tensor state_dict.
3) Update _optim_utils.py, _shard_utils.py, and _state_dict_utils.py to add support for HSDP to return 2D DTensor state_dict.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107533
Approved by: https://github.com/fegin, https://github.com/awgu, https://github.com/wanchaol
2023-09-01 00:15:00 +00:00
Wanchao Liang
a29b9101fa [dynamo] fix dynamo + DTensor to work with 2d (#108329)
pair debugged with @wconstab and we found some issue in both dynamo and
the TP's fsdp extension side. This PR fixes the dynamo + DTensor integration
so that the current graph break FSDP can work with tensor parallel by moving
the torch.compile after FSDP wrapping.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108329
Approved by: https://github.com/Skylion007, https://github.com/wconstab
2023-08-31 22:46:26 +00:00
Wanchao Liang
eafc05887f [dtensor] fix two more requires_grad callsite (#108358)
redistribute return a new DTensor and those returned DTensors should
follow the input DTensor requires_grad instead of the input tensor local
tensor's requires_grad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108358
Approved by: https://github.com/fduwjj
2023-08-31 22:25:40 +00:00
Wanchao Liang
74ff028839 [dtensor] fix new_empty_strided op (#107835)
This PR fixes the new_empty_strided op to become replicate from sharding
when necessary, this is a quick fix to resolve https://github.com/pytorch/pytorch/issues/107661

We'll need to think more about the behavior of this op when it comes to
sharding, one possibility is to follow the input sharding, but given the
output shape of this op might not be the same as the input, it's hard to
say we should follow the input sharding, further improvement needed once
we figure out the op syntax
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107835
Approved by: https://github.com/fduwjj
2023-08-31 18:27:35 +00:00
Pritam Damania
704b0b3c67 [RESUBMIT] Standardize on error types for distributed errors. (#108191)
We have a plethora of error types for various errors raised from c10d. These include `RuntimeError`, `TimeoutError`, `SocketError`, `DistBackendError` etc.

This results in messy code during error handling somewhat like this:
```
if "NCCL" in exception_str:
  ...
if "Timed out initializing process group in store based barrier on rank" in exception_str:
  ...
if "The client socket has timed out after" in exception_str:
  ...
if "Broken pipe" in exception_str:
  ...
if "Connection reset by peer" in exception_str:
  ...
```

To address this issue, in this PR I've ensured added these error types:

1. **DistError** - the base type of all distributed errors
2. **DistBackendError** - this already existed and referred to PG backend errors
3. **DistStoreError** - for errors originating from the store
4. **DistNetworkError** - for general network errors coming from the socket library

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108191
Approved by: https://github.com/H-Huang
2023-08-30 21:47:39 +00:00
wz337
13e4cce83c [DTensor] Add util API to compute_local_shape_and_global_offset for checkpointing purpose (#107996)
The compute_local_shape_and_global_offset API does the following:
1) Calculate both local_shape and global_offset in one API to replace two API calls (compute_local_size and compute_local_shape).
2) Generate the correct global_offset for checkpointing purposes. We are currently using compute_local_offset for downstream checkpoint components, which could lead to incorrect results. For checkpointing, we need global_offset instead of local_offset. In some cases, global_offset does not equal to local_offset, when a dimension is sharded multipe times on different mesh dimension (e.g. placements = [Shard(0), Shard(0)]).

Follow-up PRs:
1) Replace related downstream components to use compute_local_shape_and_global_offset instead of compute_local_size and compute_local_offset.
2) Audit existing code base to see if we can remove compute_local_size and compute_local_offset, since they are currently being used.

cc. @wanchaol
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107996
Approved by: https://github.com/wanchaol
2023-08-30 02:46:50 +00:00
Brian Hirsh
5efd63b1b8 better support for fakeifying and dynamoing through torch_dispatch subclasses (with dynamic shapes) (#107415)
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:

(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests

(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.

(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415
Approved by: https://github.com/ezyang
2023-08-29 02:36:48 +00:00
PyTorch MergeBot
d4ff06ec84 Revert "Standardize on error types for distributed errors. (#107651)"
This reverts commit 0e2317479b.

Reverted https://github.com/pytorch/pytorch/pull/107651 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it is failing inductor test in trunk for one of its model moco ([comment](https://github.com/pytorch/pytorch/pull/107651#issuecomment-1696578138))
2023-08-28 23:58:33 +00:00
Pritam Damania
0e2317479b Standardize on error types for distributed errors. (#107651)
We have a plethora of error types for various errors raised from c10d. These include `RuntimeError`, `TimeoutError`, `SocketError`, `DistBackendError` etc.

This results in messy code during error handling somewhat like this:
```
if "NCCL" in exception_str:
  ...
if "Timed out initializing process group in store based barrier on rank" in exception_str:
  ...
if "The client socket has timed out after" in exception_str:
  ...
if "Broken pipe" in exception_str:
  ...
if "Connection reset by peer" in exception_str:
  ...
```

To address this issue, in this PR I've ensured added these error types:

1. **DistError** - the base type of all distributed errors
2. **DistBackendError** - this already existed and referred to PG backend errors
3. **DistStoreError** - for errors originating from the store
4. **DistNetworkError** - for general network errors coming from the socket library
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107651
Approved by: https://github.com/H-Huang
2023-08-28 21:58:15 +00:00
wz337
264df88a2d [C10D][Logger]Add more info to c10d logger (#107331)
This PR adds pg_name and world_size to c10d logging.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107331
Approved by: https://github.com/kumpera
2023-08-28 15:10:56 +00:00
wz337
781b7ebe91 [DeviceMesh] Expose init_device_mesh (#107969)
As title.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107969
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-08-26 06:48:17 +00:00
Kunal Bhalla
af229ecd34 [RFC] Change --standalone to bind to a random port (#107734)
Given standalone generates args anyways, it seems like it would be more convenient if it explicitly used a random port by default instead of trying to use 29400.

That way users can directly go with `--standalone` instead of having to spell out `--rdzv-backend=c10d --rdzv-endpoint=localhost:0`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107734
Approved by: https://github.com/H-Huang
2023-08-25 22:13:44 +00:00
dilililiwhy
ff37f6018d Enable custom device support in fsdp checkpoint (#107289)
Fixes https://github.com/pytorch/pytorch/issues/104390
Enable custom device(privateuse1 backend) support in checkpointing by a dynamic abstract device module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107289
Approved by: https://github.com/wz337
2023-08-25 11:50:03 +00:00
weifengpy
ec10b17cfb [FSDP] verify backward_prefetch works correctly with unit test (#107058)
issue resolved: https://github.com/pytorch/pytorch/pull/105984

context:
* CI did not catch the commit that breaks backward_prefetch https://github.com/pytorch/pytorch/pull/105006
* we had an action item to add unit test to prevent similar cases: https://github.com/pytorch/pytorch/pull/105984

what's included in this unit test
* monkey patch
torch.distributed.fsdp._runtime_utils._get_handle_to_prefetch and check which handles are prefetched

for backward_prefetch = BackwardPrefetch.BACKWARD_PRE
* state._exec_order_data.handles_post_forward_order equals forward order: encoder 0...5 -> decoder 0...5 -> root
* pre-backward hook order: root -> decoder 5...0 -> encoder 5...0
* prefetch order: decoder 5...0 -> encoder 5...0 -> None
  * when current_handle=encoder 0, _get_handle_to_prefetch returns None

for backward_prefetch = BackwardPrefetch.BACKWARD_POST
* state._exec_order_data.handles_post_forward_order equals forward order: encoder 0...5 -> decoder 0...5 -> root
* post-backward hook (AccumulateGrad) order: decoder 5, 4...0 -> encoder 5...0 -> root
* prefetch order: decoder 4...0 -> encoder 5...0 -> None -> None
  * 1st None: when current_handle=encoder 0, _get_handle_to_prefetch returns None
  * 2nd None: when current_handle=root, we get decoder 5 inside _get_handle_to_prefetch but is not needed. so returns None
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107058
Approved by: https://github.com/awgu
2023-08-25 01:12:43 +00:00
wz337
d707724ac9 [DeviceMesh] init_device_mesh dosctring update to include one d mesh initialization (#107805)
As title.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107805
Approved by: https://github.com/fduwjj, https://github.com/wanchaol
2023-08-24 01:28:22 +00:00
fduwjj
3828cd4b79 [TP][EZ] Update doc for TP parallel style (#107819)
We need to update the doc for PairwiseParallel and SequenceParallel so that users don't get wrong impressions that these working for ``nn.Transformer``.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107819
Approved by: https://github.com/awgu, https://github.com/wanchaol
2023-08-24 00:13:52 +00:00
Antoni Viros i Martin
2c45a579ca Add wait_tensor so print always has a correct result for AsyncCollectiveTensor (#107808)
As the title says, I was trying to test the functional collectives, and, when printing the resulting tensors, sometimes they wouldn't have finished the Async operation yet. According to the comments in the file, "AsyncTensor wrapper applied to returned tensor, which issues wait_tensor() at the time of first use". This is true in most cases, but not when print() is your first use. This PR fixes that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107808
Approved by: https://github.com/fduwjj
2023-08-24 00:00:23 +00:00
Andrew Gu
2515ab93c4 [FSDP][Docs] Add note on NCCL_CROSS_NIC=1 for HSDP (#107784)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107784
Approved by: https://github.com/fegin
ghstack dependencies: #106068, #106080
2023-08-23 22:00:50 +00:00
wz337
cdd0821f00 [2/N][DeviceMesh] Overriding __getitem__ for DeviceMesh to support Mesh Slicing (#107730)
Add support for DeviceMesh slicing by overloading __getitem__ for DeviceMesh.

With this change, you can do:
```
mesh_shape = (2, 4)
mesh_dim_names = ("DP", "TP")
two_d_mesh = init_device_mesh(
    self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names
)
tp_mesh = two_d_mesh["TP"]
```

cc. @wanchaol, @fduwjj
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107730
Approved by: https://github.com/wanchaol
2023-08-23 20:35:30 +00:00
Andrew Gu
2b964d6efd [FSDP] Enable async all-reduce for HSDP (#106080)
**Overview**
This PR runs the HSDP all-reduce as async so that it can overlap with both all-gather and reduce-scatter, which can lead to slight end-to-end speedups when the sharding process group is fully intra-node. Previously, the all-reduce serializes with reduce-scatter, so it can only overlap with one all-gather.

For some clusters (e.g. our AWS cluster), `NCCL_CROSS_NIC=1` improves inter-node all-reduce times when overlapped with intra-node all-gather/reduce-scatter.

**Experiment**
<details>
<summary> Example 'before' trace </summary>
<img width="559" alt="hsdp_32gpus_old" src="https://github.com/pytorch/pytorch/assets/31054793/15222b6f-2b64-4e0b-a212-597335f05ba5">

</details>

<details>
<summary> Example 'after' trace </summary>
<img width="524" alt="hsdp_32gpus_new" src="https://github.com/pytorch/pytorch/assets/31054793/94f63a1d-4255-4035-9e6e-9e10733f4e44">

</details>

For the 6-encoder-layer, 6-decoder layer transformer with `d_model=8192`, `nhead=64` on 4 nodes / 32 40 GB A100s via AWS, the end-to-end iteration times are as follows (with AG == all-gather, RS == reduce-scatter, AR == all-reduce; bandwidth reported as algorithmic bandwidth):
- Reference FSDP:
    - **1160 ms / iteration**
    - ~23 ms / encoder AG/RS --> 24.46 GB/s bandwidth
    - ~40 ms / decoder AG/RS --> 26.5 GB/s bandwidth
    - 50 GB/s theoretical inter-node bandwidth
- Baseline 8-way HSDP (only overlap AR with AG) -- intra-node AG/RS, inter-node AR:
    - **665 ms / iteration**
    - ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
    - ~5 ms / decoder AG/RS --> 212 GB/s bandwidth
    - ~30 ms / encoder AR --> 2.34 GB/s bandwidth
    - ~55 ms / decoder AR --> 2.65 GB/s bandwidth
    - 300 GB/s theoretical intra-node bandwidth
- New 8-way HSDP (overlap AR with AG and RS) -- intra-node AG/RS, inter-node AR:
    - **597 ms / iteration**
    - ~3 ms / encoder AG/RS --> 187.5 GB/s bandwidth
    - ~6.2 ms / decoder AG/RS --> 170.97 GB/s bandwidth (slower)
    - ~23 ms / encoder AR (non-overlapped) --> 3.057 GB/s bandwidth (faster)
    - ~49 ms / decoder AR (non-overlapped) --> 2.70 GB/s bandwidth (faster)
    - ~100 ms / decoder AR (overlapped) --> 1.325 GB/s bandwidth (slower)
    - Overlapping with reduce-scatter reduces all-reduce bandwidth utilization even though the all-reduce is inter-node and reduce-scatter is intra-node!
- New 8-way HSDP (overlap AR with AG and RS) with `NCCL_CROSS_NIC=1`:
    - **556 ms / iteration**
    - Speedup comes from faster overlapped AR

Thus, for this particular workload, the async all-reduce enables 16% iteration-time speedup compared to the existing HSDP and 52% speedup compared to FSDP. These speedups are pronounced due to the workload being communication bound, so any communication time reduction translates directly to speedup.

**Unit Test**
This requires >= 4 GPUs:
```
python -m pytest test/distributed/fsdp/test_fsdp_hybrid_shard.py -k test_fsdp_hybrid_shard_parity
```

Differential Revision: [D47852456](https://our.internmc.facebook.com/intern/diff/D47852456)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106080
Approved by: https://github.com/ezyang
ghstack dependencies: #106068
2023-08-23 18:36:15 +00:00
Andrew Gu
50e1378680 [FSDP] Break up _post_backward_hook into smaller funcs (#106068)
The post-backward hook has some complexity due to the different paths: {no communication hook, communication hook} x {`NO_SHARD`, `FULL_SHARD`/`SHARD_GRAD_OP`, `HYBRID_SHARD`/`_HYBRID_SHARD_ZERO2`} plus some options like CPU offloading and `use_orig_params=True` (requiring using sharded gradient views).

The PR following this one that adds async all-reduce for HSDP further complicates this since the bottom-half after all-reduce must still be run in the separate all-reduce stream, making it more unwieldy to unify with the existing bottom-half.

Nonetheless, this PR breaks up the post-backward hook into smaller logical functions to hopefully help readability.

Differential Revision: [D47852461](https://our.internmc.facebook.com/intern/diff/D47852461)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106068
Approved by: https://github.com/ezyang, https://github.com/fegin
2023-08-23 18:36:15 +00:00
Codle
42738c56a0 Skip the extra copy operation in broadcast_object_list if tensor_list has only one element (#107509)
The `broadcast_object_list` function can easily broadcast the state_dict of models/optimizers. However, the `torch.cat` operation performed within `broadcast_object_list` consumes an additional double amount of memory space. This means that only objects with a maximum memory occupancy of half the device capacity can be broadcasted. This PR improves usability by skipping the `torch.cat` operation on object_lists with only a single element.

Before (30G tensor):
<img width="607" alt="image" src="https://github.com/pytorch/pytorch/assets/22362311/c0c67931-0851-4f27-81c1-0119c6cd2944">

After (46G tensor):
<img width="600" alt="image" src="https://github.com/pytorch/pytorch/assets/22362311/90cd1536-be7c-43f4-82ef-257234afcfa5">

Test Code:
```python
if __name__ == "__main__":
    dist.init_process_group(backend='nccl')
    torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())

    fake_tensor = torch.randn(30 * 1024 * 1024 * 1024 // 4)

    if dist.get_rank() == 0:
        state_dict = {"fake_tensor": fake_tensor}
    else:
        state_dict = {}
    object_list = [state_dict]
    dist.broadcast_object_list(object_list, src=0)
    print("Rank: ", dist.get_rank(), " Broadcasted Object: ", object_list[0].keys())
    dist.barrier()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107509
Approved by: https://github.com/awgu
2023-08-23 17:19:10 +00:00
Aaron Gokaslan
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
Wanchao Liang
979e706f8e [dtensor] update some comments (#107608)
This update some comments from the follow up of https://github.com/pytorch/pytorch/pull/107305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107608
Approved by: https://github.com/fduwjj
ghstack dependencies: #107606
2023-08-22 23:08:13 +00:00
Wanchao Liang
945fa7e8a8 [dtensor] fix requires_grad in distribute_tensor (#107606)
This PR fixes the requires_grad set when calling distribute_tensor, we
should set the requires_grad of the local tensor after the detach call
to make sure we create the leaf correctly, otherwise it would raise
warnings
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107606
Approved by: https://github.com/fduwjj
2023-08-22 23:08:13 +00:00
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Brian
3361fae89b Fix FP16Planner documentation (#107620)
Fixes #107619

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107620
Approved by: https://github.com/awgu
2023-08-22 02:05:27 +00:00
wz337
f5d1df3c2f [1/N] Introduce init_device_mesh() (#107254)
This PR introduces init_device_mesh() as an API to standardize UX device_mesh initialization.

The functionality of slicing out a submesh from a given mesh would come in later PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107254
Approved by: https://github.com/wanchaol
2023-08-21 21:13:47 +00:00
Wanchao Liang
da765995fb [2d] remove ShardedTensor from fsdp extension (#107472)
2D Parallel won't use ShardedTensor, and it causes headable for dynamo
to recoginize it, removing it from the runtime flatten/unflatten path
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107472
Approved by: https://github.com/fduwjj
2023-08-21 17:16:07 +00:00
Brian
24968383b5 Fix RenamePlanner documentation (#107535)
Fixes #107490

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107535
Approved by: https://github.com/awgu, https://github.com/fduwjj
2023-08-21 07:51:57 +00:00
Chien-Chin Huang
7ba513b6e4 [FSDP][state_dict] Expose optimizer state_dict config (#105949)
Optimizer state_dict config are not exposed. This PR exposes the 2 dataclass.

Differential Revision: [D47766024](https://our.internmc.facebook.com/intern/diff/D47766024/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105949
Approved by: https://github.com/rohan-varma
2023-08-21 07:29:49 +00:00
Xilun Wu
5ce88e7e71 remove unnecessary import introduced in PR 106535 (#107440)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107440
Approved by: https://github.com/fduwjj
ghstack dependencies: #106535
2023-08-21 05:29:31 +00:00
Aaron Gokaslan
b1e8e01e50 [BE]: Apply PYI autofixes to various types (#107521)
Applies some autofixes from the ruff PYI rules to improve the typing of PyTorch. I haven't enabled most of these ruff rules yet as they do not have autofixes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107521
Approved by: https://github.com/ezyang
2023-08-20 02:42:21 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
Wanchao Liang
d8f2ef10a6 [dtensor][1/n] refactor op dispatch logic to reduce overhead (#107305)
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms
![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76)

after (with this change), aten.addmm latency: 0.341ms
![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f)

overall one layer of mlp time reduced from 13.535 -> 9.665ms

Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
2023-08-18 18:30:46 +00:00
Xilun Wu
3699c6adaa [DTensor][random] add DTensor constructor: rand (#106535)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106535
Approved by: https://github.com/fduwjj, https://github.com/wanchaol
2023-08-18 07:39:34 +00:00
Rodrigo Kumpera
bbf03561a9 [functional collectives] Move back to registering finalizers on wrappers. (#107250)
We cannot use inner tensors for finalizers as they are uncollective until waited.

This PR adds a bunch of tests for the observable behavior we want, including the
necessary scafold for us to test code for their waitiness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107250
Approved by: https://github.com/wconstab
2023-08-17 21:08:28 +00:00
fduwjj
983fd5ba79 [2D][TP] Enable DDP TP integration with unit test (#106583)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106583
Approved by: https://github.com/kumpera, https://github.com/fegin, https://github.com/wanchaol
ghstack dependencies: #107313
2023-08-17 02:54:17 +00:00
fduwjj
f3b0d83fe3 [EZ][TP] Refactor FSDP 2D integration extension code so that it can re-used (#107313)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107313
Approved by: https://github.com/wz337
2023-08-16 22:01:17 +00:00
Chien-Chin Huang
f6a9c15421 [FSDP][state_dict] Make optim_state_dict_to_load work with use_orig_param=False + NO_SHARD (#107185)
Summary: As title

Test Plan: CI

Reviewed By: wz337

Differential Revision: D48329724

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107185
Approved by: https://github.com/fegin
2023-08-15 21:42:41 +00:00
Shen Li
45128ab67c [Reland] Add OnCompletion Hook to ProcessGroup (#106988) (#107233)
This allows infra/trainers to get detailed stats about communication
efficiencies without know anything about what model or distributed
training paradigms have been used. This is helpful as infra/trainer
package usually prefers to be as model/algorithm agnostic as possible.
Therefore, we cannot assume that infra/trainer can have access to all
collectives used by the model authors.

This commit adds an `OnCompletion` hook to `ProcessGroupNCCL` which
will be fired on every work completion event.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107233
Approved by: https://github.com/kumpera
2023-08-15 17:35:14 +00:00
PyTorch MergeBot
fd214aa8be Revert "Add OnCompletion Hook to ProcessGroup (#106988)"
This reverts commit ba1da47e8f.

Reverted https://github.com/pytorch/pytorch/pull/106988 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but it is failing Windows build with some linker error.  The Windows failures on PR looks legit ([comment](https://github.com/pytorch/pytorch/pull/106988#issuecomment-1678580899))
2023-08-15 08:24:33 +00:00
fduwjj
d6c120d7f9 [TP][DTensor Perf]Fix DTensor Spec hash (#107181)
https://github.com/pytorch/pytorch/pull/106524 gets merged so fast that we didn't figure out that we should hash both stride and dtype in DTensorSpec. This is a forward fix.

One analysis for why using just shape is not enough.
1. We use the hash value for sharding propogation cache. And the output sharding contains the stride, size of the output DTensor. If we don't consider stride, we will see errors.
2. One reason can be found below:
```
OpSchema(func_schema=aten::t(Tensor(a) self) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]), placements=(Shard(dim=0),), tensor_meta=TensorMetadata(shape=torch.Size([64, 128]), dtype=torch.float32, requires_grad=False, stride=(128, 1), memory_format=None, is_quantized=False, qparams={})),), kwargs_schema={})
```

```
OpSchema(func_schema=aten::t(Tensor(a) self) -> Tensor(a), args_schema=(DTensorSpec(mesh=DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]), placements=(Shard(dim=0),), tensor_meta=TensorMetadata(shape=torch.Size([64, 128]), dtype=torch.float32, requires_grad=False, stride=(1, 64), memory_format=None, is_quantized=False, qparams={})),), kwargs_schema={})
```

The only difference between two op_schame is the tensor stride:
<img width="151" alt="image" src="https://github.com/pytorch/pytorch/assets/6937752/161335df-bdfb-47c5-ba79-82616d070d15">

that makes the transpose op generates wrong result and leads to the add_/addmm_ op failing with errors:
```
Traceback (most recent call last):
  File "/data/users/fduwjj/pytorch/torch/multiprocessing/spawn.py", line 74, in _wrap
    fn(i, *args)
  File "/data/users/fduwjj/pytorch/benchmarks/distributed/tensor/tp_benchmark.py", line 210, in run_tp
    output.sum().backward()
  File "/data/users/fduwjj/pytorch/torch/_tensor.py", line 491, in backward
    torch.autograd.backward(
  File "/data/users/fduwjj/pytorch/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/api.py", line 252, in __torch_dispatch__
    return op_dispatch.operator_dispatch(
  File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/dispatch.py", line 116, in operator_dispatch
    out, _, _ = _operator_dispatch(op_call, args, kwargs, sharding_propagator)
  File "/data/users/fduwjj/pytorch/torch/distributed/_tensor/dispatch.py", line 246, in _operator_dispatch
    local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
  File "/data/users/fduwjj/pytorch/torch/_ops.py", line 435, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: The size of tensor a (64) must match the size of tensor b (8) at non-singleton dimension 1
```

Same thing with dtype, if we are using DTensor in the environment of mixed precision, we will run into situations like this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107181
Approved by: https://github.com/wanchaol
ghstack dependencies: #106524
2023-08-15 05:33:10 +00:00
Shen Li
ba1da47e8f Add OnCompletion Hook to ProcessGroup (#106988)
This allows infra/trainers to get detailed stats about communication
efficiencies without know anything about what model or distributed
training paradigms have been used. This is helpful as infra/trainer
package usually prefers to be as model/algorithm agnostic as possible.
Therefore, we cannot assume that infra/trainer can have access to all
collectives used by the model authors.

This commit adds an `OnCompletion` hook to `ProcessGroupNCCL` which
will be fired on every work completion event.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106988
Approved by: https://github.com/kumpera, https://github.com/H-Huang
ghstack dependencies: #107140, #107141, #107160
2023-08-15 04:32:23 +00:00
Bruce Jiang
2624da638d Support third-party devices to use the init_process_group method with… (#107113)
…out specifying the Backend

When init_process_group is not been done before, it will automatically apply  init_process_group within Devicemesh without specifying the backend. Thus, when a third-party device want to use Devicemesh without doing init_process_group before, there comes a problem. In this PR, add a default_device_backend_map for third-party device users to add their backends to this map when they register their backends to pytorch firstly. When doing init_process_group without parameter backend, it will init the backends in this map. Thus, a third-party user can use init_process_group method without specifying the Backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107113
Approved by: https://github.com/wanchaol
2023-08-15 03:46:07 +00:00
Rohan Varma
ddf36c82b8 [PT-D][FSDP] Handle corner case of load with multi-backend PG (#107172)
Summary:
When loading a CPU state_dict with a pg initialized with
cpu:gloo,cuda:nccl, we hit a gloo crash since dest tensor is on GPU and input
is on CPU.

As a workaround, just enforce that if local_tensor.is_cpu, the dest tensor is
also cpu.

Test Plan: CI

Differential Revision: D48324752

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107172
Approved by: https://github.com/fegin
2023-08-14 23:24:44 +00:00
Jirka
858b465d74 fix str splits in single line (#106005)
Simple formating improvement and two spell fixes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106005
Approved by: https://github.com/H-Huang
2023-08-14 23:07:38 +00:00
fduwjj
4a6ca4cc05 [TP][DTensor Perf] Some perf improvement to reduce DTensor CPU overhead (#106524)
By inspecting a small TP benchmark, we found couple things we can optimize:
1. We call deep_copy so many times when we initialize DTensor.
2. Some shading_prop is not cached successfully.
3. We are still calling redistribute when not necessary.

![image](https://github.com/pytorch/pytorch/assets/6937752/b847d110-eea1-45df-9298-066d0ba07dd7)

![image](https://github.com/pytorch/pytorch/assets/6937752/fc08f564-caed-496b-80d7-275c1dba3806)

![image](https://github.com/pytorch/pytorch/assets/6937752/fdc06cc4-a4ba-48e8-a118-c041bbd04f5e)

So we want to:
1. Remove the deep_copy, and we now make placements a tuple so we are sure it's immutable.
2. Somehow the op_schema gets changed during sharding_op propogation, so we store a hash version of it before passing it to sharding_prop. Ideally we want to figure out why `op_schema` gets changed, but looks like in both index and detach/view op, all get changed, it might take more time to debug.
3. Also when we do hashing of op_schema, we want to hash the entire args_schema not just the args_spec which only contains the DTensorSpec from args which are Dtensors.
4. It turns out that sometimes, DTensor has mem_format to be None (not contiguous) and this will lead to redistribute get triggered, so that we only need to compare type/shape and stride in the metadata.

Also we need to ensure _Partial and Shard have different hash value in the DTensorSpec.

![image](https://github.com/pytorch/pytorch/assets/6937752/321e6890-1ab6-4975-adc9-524c6ef9a76b)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106524
Approved by: https://github.com/wanchaol
2023-08-14 20:03:19 +00:00
Wanchao Liang
c9cbcb2449 [device_mesh] move remaining collectives to a separate file (#107012)
Move the remaining collectives to a separate file to prepare device mesh
to become a public distributed API

For those remaining utils, we need to upstream them to functional
collectives with proper implementation, added TODO there for a follow up
PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107012
Approved by: https://github.com/fduwjj
2023-08-11 23:49:27 +00:00
Michael Voznesensky
42660015b4 [Dynamo x FSDP][2/x] Small changes to distributed to make it dynamo friendly (#106886)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106886
Approved by: https://github.com/awgu, https://github.com/wconstab
ghstack dependencies: #106884
2023-08-11 22:35:50 +00:00
Wanchao Liang
5c48ff20b5 AsyncCollectiveTensor: dont sync on view ops (#105240)
AsyncCollectiveTensor is a tensor subclass that is meant to "delay synchronization" when you call into the functional collectives API's. It does this (if I understand correctly) by internally holding an "unsynchronized" version of the tensor, which is the result of the communication op, and internally calling `.wait()` to synchronize the data the next time it is used.

Previously, these wait() calls would happen immediately, because `AsyncCollectiveTensor` gets wrapped by `DTensor()`, which calls `.detach()` on its inner tensor, immediately causing the sync (code: 1518d5eec4/torch/distributed/_tensor/api.py (L207))

AsyncCollectiveTensor shouldn't need to do a synchronization if you try to detach() it though - in fact, it should be fine to avoid synchronizing if you perform any view ops on it (which just require viewing metadata, but not actual data). This PR tries to update `AsyncCollectiveTensor` to delay `wait()` calls whenever the subclass encounters a view op.

Added some light testing, that just runs some DTensor compute followed by view ops, and confirms that the output is still an `AsyncCollectiveTensor` when we call `.to_local()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105240
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/wconstab
2023-08-11 19:20:25 +00:00
Andrew Gu
7b94d93431 [FSDP] Fix train -> EMA -> eval with mixed precision (#106858)
This fixes a pretty vicious bug relating to `SHARD_GRAD_OP`, mixed precision, EMA, and eval.

**Bug Explanation**
The model has a main module and an EMA module, where the main module is used for training and the EMA module is used for eval. The model has FSDP's fp16 mixed precision enabled. The flow consists of (1) training forward/backward/optimizer -> (2) EMA update (copy main module to EMA module) -> eval forward in `torch.no_grad()`, where this repeats for many iterations.

Consider the _second_ iteration.
- From the first iteration's eval forward, the EMA module has the fp16 unsharded parameters in memory (not freed due to `SHARD_GRAD_OP`).
- In this second iteration's step (2), we perform the EMA update under the `summon_full_params()` context, where FSDP specially forces full precision.  This means that the EMA module now uses fp32 unsharded parameters, distinct from the fp16 unsharded parameters still in memory. The EMA update modifies those fp32 parameters, and upon exiting the context, FSDP correctly writes the modifications back to the fp32 sharded parameters.
- In the second iteration's step (3) (eval forward), FSDP checks whether it needs to run the unshard op (including all-gather) but sees it does not since the fp16 unsharded parameters are still in memory. Thus, FSDP uses those fp16 unsharded parameters directly without all-gather. However, these fp16 unsharded parameters are stale and do not include the EMA update!
- In other words, at this point, the fp32 sharded parameters are correct, the fp16 unsharded parameters are stale, and FSDP chooses _not_ to re-all-gather since the fp16 unsharded parameters are in memory.

**Fix Explanation**
This PR fixes this by freeing the fp16 unsharded parameters if they are still allocated when forcing full precision, i.e. using fp32 unsharded parameters in `summon_full_params()`. This ensures that any modifications written back to the fp32 sharded parameters will be persisted via the next all-gather.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106858
Approved by: https://github.com/kumpera
ghstack dependencies: #106857
2023-08-10 19:32:43 +00:00
alanhe151220037
1afbc985fe Make RNGStateTracker support cuda-like device (#106771)
replace  `CudaRNGStateTracker` with `RNGStateTracker` by rewriting some Cuda-binding code with `device_handle`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106771
Approved by: https://github.com/wanchaol
2023-08-10 19:14:33 +00:00
weifengpy
4bc846c101 [FSDP] Ignore buffer type casting in ignored modules (#106766)
issue resolved: https://github.com/pytorch/pytorch/issues/97791

before this PR, mixed_precision applies to buffers from ignored modules. see ```test_state_dict_with_ignored_modules(mixed_precision=True)``` for reproduce

after, we avoid applying mixed_precision semantics to buffers from ignored modules
* step 1 initialization: state._ignored_buffer_names contains all the buffers from ignored modules
* step 2 lazy init at runtime: skip ignored buffers in ```_get_buffers_and_dtypes_for_computation```
* step 3 skip upcasting in state_dict hook: avoid upcasting for ignored buffers in ```_get_buffers_and_dtypes_for_computation```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106766
Approved by: https://github.com/awgu
2023-08-09 23:09:43 +00:00
Andrew Gu
6f036c9637 [FSDP][Easy] zeros -> empty for immediately freed tensors (#106857)
Since we immediately free these tensors' storage (via `_free_storage()`), there is no reason to zero them after allocation:
92e5b124c8/torch/distributed/fsdp/flat_param.py (L1140-L1145)
92e5b124c8/torch/distributed/fsdp/flat_param.py (L1155-L1161)
92e5b124c8/torch/distributed/fsdp/flat_param.py (L1166-L1171)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106857
Approved by: https://github.com/Skylion007
2023-08-09 17:26:33 +00:00
Eddy Ogola Onyango
cbcd9083be [DCP] Modify tensor saving logic in DCP (#106415)
Currently, DCP treats tensors as duplicates and only saves them on rank0. This won't work for PiPPy as PiPPy does have unique tensors across different ranks. With the current setup, we would only be saving the tensors on rank0 (coordinator rank).

In this PR, we are changing to letting each rank create its own WriteItem for tensors. For the ones that does replicate across different ranks, we are handling it thru dedup_tensors(), which will dedup the replicate WriteItem so we only do the actual writing once.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106415
Approved by: https://github.com/wz337
2023-08-09 00:16:10 +00:00
Michael Voznesensky
d1a99a083f Reland Simplify handle indexing (#105006) (#106357)
This reverts commit a9a3c45649.

This PR changes the following:
- `_ExecOrderData.handle_to_handle_index` -> `FlatParamHandle._handle_index`
- `_ExecOrderData.handles_to_pre_forward_order_index` -> `FlatParamHandle._pre_forward_order_index`
- `_ExecOrderData.handles_to_post_forward_order_index` -> `FlatParamHandle._post_forward_index`
- `_FSDPState._needs_pre_forward_unshard` -> `FlatParamHandle._needs_pre_forward_unshard`
- `_FSDPState._needs_pre_backward_unshard` -> `FlatParamHandle._needs_pre_backward_unshard`
- `_FSDPState._handles_prefetched` -> `FlatParamHandle._prefetched`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106357
Approved by: https://github.com/awgu
2023-08-03 19:17:32 +00:00
fduwjj
578d9fee42 [DTensor][EZ] op schema comparison so that no redistribute is called (#106158)
When looking at traces of TP more carefully, I found that for cases when input reshard is not needed, we also call redistribute within sharding propogation. Upon carefully checking, looks like the way we compare different op_schema is not correct.

One example can be seen in the following trace:
<img width="1146" alt="image" src="https://github.com/pytorch/pytorch/assets/6937752/7322d26f-7029-41f9-8f8c-5f27a6bb98f9">

As you can see, no collectives are called, and this redistribute is not needed.

With this change:

<img width="1491" alt="image" src="https://github.com/pytorch/pytorch/assets/6937752/eb4a971f-44c1-4d83-8671-fce94cfa926c">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106158
Approved by: https://github.com/Skylion007, https://github.com/wanchaol
2023-08-03 19:17:10 +00:00
Andrew Gu
57fba6fd86 [FSDP][9/N] Introduce CustomPolicy (#104986)
This PR adds a new `CustomPolicy` that acts like the existing `lambda_auto_wrap_policy` except it (1) leverages the new auto wrapping infrastructure and (2) allows overriding FSDP kwargs for particular instances. (1) gives it access to the validation checks (like for frozen parameters), and (2) makes it as expressive as manual wrapping. This should allow us to effectively deprecate manual wrapping if desired.

The API is as follows:
```
def lambda_fn(module: nn.Module) -> Union[bool, Dict[str, Any]]:
    ...
policy = CustomPolicy(lambda_fn)
```
The `lambda_fn` can return:
- `False` or `{}` to indicate no wrapping
- `True` to indicate wrapping while inheriting the root's FSDP kwargs
- Non-empty `dict` to indicate wrapping while overriding the specified FSDP kwargs and inheriting the rest from the root

---

After this PR, the follow-up work items for auto wrapping are:
1. Add shared parameter validation
2. (Longer-term / exploratory) Add a policy that provides a reasonable auto wrapping with "minimal" user input

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104986
Approved by: https://github.com/ezyang
ghstack dependencies: #104427, #104967, #104999, #104969
2023-08-03 12:46:36 +00:00
Andrew Gu
15953fdf35 [FSDP][8/N] Replace _FSDPPolicy.policy with _Policy._run_policy (#104969)
This does some code organization improvement.
- It renames `_FSDPPolicy` to `_Policy` to show that it is not only for FSDP but for any module-level API.
- It formalizes the contract that such a policy should return something like `target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]]` that maps each module to wrap to its kwargs. It does so by requiring a `_run_policy` abstract method (this time private since users do not need to care about it). Then, our auto wrapping can just call `_run_policy()` to generate the dict and do any validation or post-processing.

This PR is technically BC-breaking because it removes the public `ModuleWrapPolicy.policy`. However, I do not think anyone was using that anyway, so this is a pretty safe breakage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104969
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967, #104999
2023-08-03 12:42:14 +00:00
Andrew Gu
640a96dfbb [FSDP][Easy] Allow ModuleWrapPolicy to take Iterable (#104999)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104999
Approved by: https://github.com/rohan-varma
ghstack dependencies: #104427, #104967
2023-08-02 22:03:03 +00:00