Commit Graph

2557 Commits

Author SHA1 Message Date
Carlos Mocholí
9df4ee8d38 Fix ColwiseParallel typo (#116151)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116151
Approved by: https://github.com/wanchaol
2023-12-20 06:40:32 +00:00
voznesenskym
77d5f60740 [fsdp][torch.compile] FSDP changes (#115497)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115497
Approved by: https://github.com/albanD
2023-12-19 18:44:36 +00:00
Lucas Pasqualin
d749b4a152 Implements permute_tensor in functional collectives (#115078)
Implementation of `permute_tensor` as per @yifuwang 's suggestion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115078
Approved by: https://github.com/wanchaol, https://github.com/yifuwang
2023-12-19 18:33:28 +00:00
Carlos Mocholí
a31effa15f Update device_mesh.py docs imports (#116074)
These are not importable from `torch.distributed`, at least today.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116074
Approved by: https://github.com/wz337, https://github.com/fegin
2023-12-19 09:44:55 +00:00
wz337
b48abbc020 [DeviceMesh] Fix DeviceMesh docstring (#116053)
1. remove outdated comments
2. fix examples in docstring

Doc after fix:
<img width="706" alt="image" src="https://github.com/pytorch/pytorch/assets/31293777/19f4f03c-0fd7-4e88-bca1-1a6ce693fbb7">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116053
Approved by: https://github.com/wanchaol
2023-12-19 04:05:49 +00:00
Yue Dong
270ed13e87 [DTensor] Make DTensor from_local backward partial() to replicate() pass through (#115967)
Summary:
This change makes the `DTensor.from_local()` placements in backward pass from `Partial()` to `Replicate()` as pass through for following reasons:
1. When we run backward pass of DTensor.from_local, if the target placement is partial() (i.e. from user manual overwrite code instead of torch_dispatch) we keep the grad as replicate. This is because converting the gradients back to `Partial()` is meaningless.
2. The current div logic will lead to wrong numerical value in the above case.

Test Plan:
**CI**:
CI Tests

**Unit test**:
`buck2 test mode/dev-nosan //caffe2/test/distributed/_tensor:redistribute`
- Passed

**With model training**:
```
# We tested the case where input tensor is manually overwrite as Partial() and
# output tensor manually overwrite to Shard() then to local.

# Before the change: numerical value not correct
Forward pass:
    collective: ReduceScatter
backward pass:
    collective: AllGather + div by process group size

# After the change: div is removed as expected.
Forward pass:
    collective: ReduceScatter
Backward pas:
    collective: AllGather
```

Differential Revision: D52175709

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115967
Approved by: https://github.com/wanchaol
2023-12-19 00:16:10 +00:00
Lucas Pasqualin
8452f41305 Adds allreduce to inductor remap (#115950)
Fixes #115728

Implements a rewrite path for allreduce

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115950
Approved by: https://github.com/wconstab
2023-12-18 22:00:22 +00:00
Tianyu Liu
2a5659a797 add length assertion to PrepareModuleInput and PrepareModuleOutput (#115957)
## summary

`zip(inputs, self.input_layouts, self.desired_input_layouts)` is used in `_prepare_input_fn`; similar for `_prepare_output_fn`. Without assertion, unmatched dimension in inputs/outputs will be lost, potentially causing unexpected behabiors.

## test plan
`python test/distributed/tensor/parallel/test_tp_style.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115957
Approved by: https://github.com/wanchaol
2023-12-18 21:50:18 +00:00
Wanchao Liang
a1a0b290d2 [tp] further fix the docs (#115974)
some typo result in the note section not rendered properly, can't see
this from the last PR directly as the last PR only show the first commit
documentation :(

Also make the parallelize_module doc example more concrete

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115974
Approved by: https://github.com/wz337
2023-12-18 20:41:53 +00:00
Yue Dong
87ea6fb844 Make input contiguous for DTensor reduce scatter to fix the incorrect numerical values (#115847)
Summary:
This change is to make the input tensor contiguous for DTensor reduce scatter in the case no padding is needed.

There's no exception thrown during training, but we ran into numerical value correctness issue without the change.

Test Plan:
**CI**
CI test

**WHEN model test**:
- Verified loss for each iteration within the expected range.
- Verified NE on-par with this change with 4B training data.

Differential Revision: D52170822

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115847
Approved by: https://github.com/wanchaol
2023-12-17 01:35:09 +00:00
Wanchao Liang
61abacf829 [tp] improve documentation (#115880)
Improve the TP documentation in terms of format and descriptions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115880
Approved by: https://github.com/XilunWu
2023-12-15 18:44:22 +00:00
Chien-Chin Huang
8c57fde21f Let all_reduce_coalesced accept one tensor as well (#115650)
This diff introduces a change to the `all_reduce_coalesced` function in `distributed_c10d.py`. The function now accepts a single tensor as well as a list of tensors. This allows for more flexibility in the use of the function.

This is just a syntax sugar for the compiler to use `all_reduce_coalesced` without worrying  about converting the input to a list.

Differential Revision: [D51433236](https://our.internmc.facebook.com/intern/diff/D51433236/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115650
Approved by: https://github.com/wconstab
ghstack dependencies: #115523, #115302, #115648, #115649
2023-12-13 21:32:01 +00:00
Chien-Chin Huang
54d552e991 [funcol] Directly import DeviceMesh to avoid circular dependency (#115649)
This diff aims to directly import DeviceMesh from torch.distributed.device_mesh instead of importing it from dist._tensor. This is done to avoid a circular dependency issue. The code changes in each file of the diff are as follows:

- torch/distributed/_functional_collectives.py: import DeviceMesh from torch.distributed instead of dist._tensor.

Overall, this diff aims to improve the code by avoiding circular dependencies and improving the import statements.

==
The above summary is generated by LLM with minor manual fixes. The following summary is by me.

The original import causes some issues when compiling DDP with compiled_autograd. The root cause of compilation failure is not identified but it is good to fix the lazy initialization, which indirectly fixes the compilation issues for DDP.

Differential Revision: [D51857246](https://our.internmc.facebook.com/intern/diff/D51857246/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115649
Approved by: https://github.com/wconstab, https://github.com/wz337
ghstack dependencies: #115523, #115302, #115648
2023-12-13 20:44:58 +00:00
Will Constable
c90fdb9ac0 Fix torch.distributed.breakpoint (#115705)
Switches from calling breakpoint() internally to using a subclass of
Pdb.

Fixes #115685

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115705
Approved by: https://github.com/wanchaol, https://github.com/fegin
2023-12-13 20:33:56 +00:00
Pavan Balaji
afa62d6237 [nccl-pg] Pass group global rank information to NCCL PG (#114736)
We were only passing a subset of the group creation information to the
NCCL PG.  We are specifically missing the information on which global
ranks belong to a particular PG.

This allows the NCCL PG to use this additional information for things
like better trace logging.

Test Plan:

OSS CI

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114736
Approved by: https://github.com/kwen2501
2023-12-13 18:02:51 +00:00
voznesenskym
310f6ab11a [fsdp] Replace acc_grad hooking with register_post_accumulate_grad_hook on flat_param (#112184)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112184
Approved by: https://github.com/albanD
ghstack dependencies: #115315
2023-12-13 16:24:44 +00:00
Chien-Chin Huang
50db2aa70a [funcol][BE] Apply ufmt to _functional_collectives.py and turn on lintrunner for functional_collective (#115648)
No logic change, just formatting.

Differential Revision: [D51857236](https://our.internmc.facebook.com/intern/diff/D51857236/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115648
Approved by: https://github.com/wconstab, https://github.com/wz337
ghstack dependencies: #115523, #115302
2023-12-13 11:19:29 +00:00
Chien-Chin Huang
db8d409d08 [DCP][BE] Apply ufmt to DCP and turn on lintrunner for DCP (#115302)
No logic change. Just typing and ufmt.

Differential Revision: [D51914982](https://our.internmc.facebook.com/intern/diff/D51914982/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115302
Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/LucasLLC
ghstack dependencies: #115523
2023-12-13 10:32:36 +00:00
Chien-Chin Huang
cc28f61fa3 [DCP][BE] Move DCP._state_dict_utils out from DCP (#115523)
DCP._state_dict_utils is also used by FSDP. This can cause circular import sometimes. Move it out from DCP to avoid circular import.

Differential Revision: [D52022440](https://our.internmc.facebook.com/intern/diff/D52022440/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115523
Approved by: https://github.com/wz337
2023-12-13 08:59:48 +00:00
Lucas Pasqualin
ffb2a28a67 Fixes expected behavior when no_dist=True in state_dict_loader.load (#115660)
Fixes expected behavior when `no_dist=True` in `state_dict_loader.load`

Fixes #115591

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115660
Approved by: https://github.com/wz337, https://github.com/fegin
2023-12-12 22:21:16 +00:00
fduwjj
40ce9a4cfb [c10d] Create a python c10d API _set_pg_timeout to set timeout (#115453)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115453
Approved by: https://github.com/wconstab, https://github.com/H-Huang
2023-12-12 20:52:43 +00:00
Chien-Chin Huang
d954ef208f [DCP][state_dict] DCP state_dict cannot correctly find FQN when the leaf module is wrapped by FSDP (#115592)
Summary: The original logic has an incorrect assumption that there is at one object name left when traversing the module tree. This is not correct when the leaf module is wrapped by FSDP.

Test Plan: CI

Differential Revision: D52049293

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115592
Approved by: https://github.com/wz337
2023-12-12 19:22:23 +00:00
Iris Z
1eca63c6ac [DeviceMesh] Move helper function 'get_mesh_dim_by_name' to MeshEnv class (#115572)
Move helper function `get_mesh_dim_by_name ` outside of the DeviceMesh class to keep the public class cleaner.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115572
Approved by: https://github.com/XilunWu, https://github.com/wanchaol
2023-12-12 06:29:46 +00:00
Kumar Ashutosh
405a0040cf Adds tool to visualize sharding (#114307)
This pull request adds a tool to visualize sharding. It uses the device_mesh and placement details to construct a visualization of the split of a torch dtensor.

Things to fix:

- [x] This implementation only uses the first element of the placement tuple, when can there be more than one elements?
- [x] The calculation of the split is happening here but maybe it is already done somewhere internally in Shard class and can we directly call that here?

Fixes #108746

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114307
Approved by: https://github.com/wanchaol
2023-12-12 06:18:03 +00:00
Wanchao Liang
fbb744fd49 [dtensor] enable radam foreach optimizer (#115566)
As titled, test both non-foreach and foreach optim

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115566
Approved by: https://github.com/XilunWu
ghstack dependencies: #115297, #115564, #115565
2023-12-12 03:57:00 +00:00
Wanchao Liang
4bd661c472 [dtensor] enable adadelta foreach optimizer (#115564)
as titled

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115564
Approved by: https://github.com/XilunWu
ghstack dependencies: #115297
2023-12-12 03:56:55 +00:00
Wanchao Liang
8a27352d6b [dtensor] add a implicit replication flag (#115297)
This PR adds a experimental implicit replication support for DTensor to
inter-op with torch.Tensor, basically under this context manager DTensor
could work together with torch.Tensor by assuming the torch.Tensor
sharding layout is replicated.

Note that this is risky for DTensor so we don't turn it on by default,
but for certain cases where it is for sure replicated, user can use this
to allow DTensor and Tensor computation work together

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115297
Approved by: https://github.com/awgu
2023-12-12 03:56:48 +00:00
wz337
c70f995b5c [DeviceMesh] Add mesh_dim_names to DeviceMesh __repr__ if it exists (#115579)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115579
Approved by: https://github.com/wanchaol
2023-12-12 02:18:34 +00:00
Wanchao Liang
0692240b90 [dtensor] account for empty list when turning to OpStrategy (#115298)
Trying to fix https://github.com/pytorch/pytorch/issues/115065

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115298
Approved by: https://github.com/XilunWu
2023-12-12 00:11:16 +00:00
Howard Huang
99f06c0cc2 [BE] update errors to be more descriptive (#115443)
we call `_check_single_tensor` and `_check_tensor_list` as validation but don't print out the param types that were invalid

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115443
Approved by: https://github.com/XilunWu
2023-12-11 21:21:10 +00:00
Chip Turner
937d616e82 Re-enable type checking for distributed_c10d.py (#115223)
Re-enable type checking for distributed_c10d.py

Type checking for distributed_c10d.py was inadvertently turned off in issues that have accumulated since.

Note: the backwards compatibility linter does not like some of these changes.  But they were incorrect before.  This needs human verification, however.

#suppress-api-compatibility-check

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115223
Approved by: https://github.com/wconstab
2023-12-09 11:07:54 +00:00
Yue Dong
485ea9a70a [DTensor] Add DTensor experimental op for LayerNorm backward sharding rule propogation (#115398)
Summary: This diff is only for prototype to unblock the TP work. PyTorch distributed team is working on a more generic backward op for `aten.layer_norm`. Will remove this op from the experimental file once it is ready.

Test Plan:
**Local Test**:
Accuracy:
- Dtensor + Checkpoint: first run loss: P884569822 (on-par with baseline: P884213363)
- 2nd by loading saved checkpoint: P884583429 (on-par with baseline: P884271869)

Trace:
- Collective functions are inserted automatically.
- Example: https://fburl.com/perfdoctor/l567ww1x

**MAST Test**:
With: trainer = 128, batch_size=512
- NE on-par:
(see: 4441_ep_bs512_2fsdp_tp_sp_dtensor)
 {F1155318138}

Differential Revision: D51490868

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115398
Approved by: https://github.com/wanchaol
2023-12-09 09:38:56 +00:00
Wanchao Liang
1215f2ffe2 [dtensor] readme typo (#115383)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115383
Approved by: https://github.com/awgu
ghstack dependencies: #115365
2023-12-08 21:40:40 +00:00
Iris Zhang (PyTorch)
23fa9621e4 [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099) (#115193)
Summary:

Rename _device_mesh.py to device_mesh.py, update all callsites, add documentation.
We created stubs for public class and methods in torch.distributed.device_mesh so that torch.distributed.device_mesh can be imported with or without distributed is available().

Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/115099
Prior to landing, CI signals are all passed. Shipit added the "ci/trunk" label to the PR and DID NOT wait for it and went ahead committing. More context can be found in the reverted PR above.

Test Plan: CI.

Differential Revision: D51861018

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115193
Approved by: https://github.com/fegin
2023-12-08 08:44:32 +00:00
Lucas Pasqualin
5432088098 Adds Checkpointer Wrapper for DCP [3/N] (#114603)
Adds a useful high level wrapper for calling `dist.save/load` with the correct storage readers and writers.

Instead of doing:

```
DCP.save(
    state_dict={...},
    storage_writer=StorageWriter(...)
)

DCP.load(
    state_dict={...},
    storage_reader=StorageReader(...)
)
```

We can now do:

```
checkpointer = Checkpointer(...)

checkpointer.save(state_dict={...})
checkpointer.load(state_dict={...})
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114603
Approved by: https://github.com/fegin, https://github.com/wz337
2023-12-08 01:03:21 +00:00
wz337
dacf5d6e92 [DTensor] Remove assert to allow tensor sharding dimension < Shard(x).ndim (#115114)
Consolidated by changes made by @yoyoyocmu. https://www.internalfb.com/diff/D51821717
Remove assert to allow tensor dimension < Shard(x).ndim. With the current padding, we do support this already.

Follow up: we will still need to fix the size mismatch and `full_tensor()` hang when tensor is uneven-sharded.
Created issue here: https://github.com/pytorch/pytorch/issues/115310

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115114
Approved by: https://github.com/yoyoyocmu, https://github.com/wanchaol
2023-12-07 21:57:30 +00:00
Wanchao Liang
6a6a1e3ef7 [dtensor] update README to make all example runnable (#115365)
as titled, also add torchrun commands

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115365
Approved by: https://github.com/fegin
2023-12-07 20:23:37 +00:00
Tobias Ringwald
43f42bf3cb Updated docs for deprecated torch.set_default_tensor_type (#115041)
Added deprecation note for torch.set_default_tensor_type. Updated docs that referenced this method.

Fixes #113646.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115041
Approved by: https://github.com/janeyx99
2023-12-07 16:17:36 +00:00
Chip Turner
78b945484b [c10d] Extend NCCL communicator splitting to more use cases (#114916)
Previously we could only use `ncclCommSplit` when we knew all backends were connected on all shards (due to the need to perform a NOCOLOR split), which in practice meant we could only use it for subgroups that were copies of the entire world.

This change allows for specifying a bound device id to `init_process_group` which tells the pg and its backends that the specified device, and the specified device only, will be associated with this rank.

This guarantee lets us do an early connect (which we could not previously do due to how ProcessGroupNCCL infers devices based on tensors and not the rank number).  And by doing the early connect, we have the guarantee ranks are connected and can perform nocolor splits when needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114916
Approved by: https://github.com/kwen2501
2023-12-07 15:13:01 +00:00
Xilun Wu
5e3631db31 [DTensor] force re-compute sharding when normalized_shape differs in fwd layer norm (#115250)
**Summary**:
#114174 did not test the case where `elementwise_affine=False` (i.e. `weight` and `bias` are `None`) and this test would fail due to cached sharding propagation. The difference on sharding prop between these cases is, when `weight` and `bias` are None, the forward layer norm op will be recognized as a "static shape op" and `propagate_op_sharding` will be applied rather than `propagate_op_sharding_non_cached`. A fix is to force re-compute sharding when `normalized_shape` changes by setting op schema's `RuntimeSchemaInfo`'s `static_argnum` to include `normalized_shape` (i.e. 1)

**Test**:
pytest test/distributed/_tensor/test_math_ops.py -s -k layer_norm

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115250
Approved by: https://github.com/wanchaol
2023-12-07 07:44:06 +00:00
Wanchao Liang
b6de337d16 [funcol] a few optimizations to funcol (#113324)
Apply a few optimizations to funcol:

- allgather on non-0 dim, the resulting tensor already needs to access
data in order to do torch.cat, so we sync wait here so that we don;t
need to go through ACT dispatch for chunk + cat alltogether
- have a fast return logic to aten.view as it's a commonly hit op for
view related ops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113324
Approved by: https://github.com/XilunWu
2023-12-06 19:25:35 +00:00
Yue Dong
ab120e65fb Fix FSDP + TP state dict in param unflattening (#115105)
Summary:
This diff fix the param unflattening when using FSDP together with TP. Currently we hardcode the `reshape_size` to be multiplied by 2, which instead should be the size of the process group.

Before the fix, example exception: `shape '[257, 514]' is invalid for input of size 264196`, where the process group size is 4 instead of 2.

Test Plan:
**CI**:
CI test

**Unit test**:
`buck2 test mode/dev-nosan //caffe2/test/distributed/tensor/parallel:fsdp_2d_parallel`
- Passed

**Test model with WHEN**:
- Verified that checkpoint can be saved and resumed successfully;
- Verified the accuracy with window_ne, which is on-par with baseline.
https://pxl.cl/3Wp8w

Differential Revision: D51826120

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115105
Approved by: https://github.com/fegin
2023-12-05 21:19:56 +00:00
Joel Schlosser
22704426c3 Expand dynamic dims support for traceable subclasses (#114311)
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).

Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
    * Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
    * Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
    * Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
    * Signatures now:
    ```python
    # attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
    # ctx is anything useful for rebuilding the class we want to guard on
    attrs, ctx = x.__tensor_flatten__()
    ...
    # inner_tensors is a dict of {attr -> tensor}
    # ctx is taken unmodified from flattening and (eventually) guarded on
    # outer_size is the expected size of the output; possibly symbolic
    # outer_stride is the expected strides of the output; possibly symbolic
    y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)

    # at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
    # the assert simplifies symbols when there are relationships between outer and inner symbols
    ```
    * Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
    * Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
    * Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
2023-12-05 21:09:25 +00:00
Xilun Wu
4c5fe66880 [DTensor][BE] fix bug in OpStrategy for Tuple output (#115161)
**Summary**:
DTensor sharding propagation returns an `OpStrategy` object in case of a
Tuple of multiple DTensors of the same `placements` and this object will later
be expanded to a tuple of `DTensorSpec`s. However, the expansion was done
as copying the object's reference instead of copying/creating new objects and
this leads to wrong overriding issue in Tensor Meta propagation logic.

**Test**:
pytest test/distributed/_tensor/test_math_ops.py
pytest test/distributed/_tensor/test_dtensor_ops.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115161
Approved by: https://github.com/wanchaol
2023-12-05 18:28:40 +00:00
Nikita Shulga
a827ac71f2 Revert "[DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099)"
This reverts commit eaa64339d6.
2023-12-05 08:59:36 -08:00
Iris Zhang (PyTorch)
eaa64339d6 [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099)
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.

Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/114991
It was failing because failing a public module binding tests in MacOS, and this is due to the change in import order for torch/distributed/fsdp/_common_utils.py. Since this original import would still work, we remove the changes in this file.

Test Plan: CI.

Differential Revision: D51825114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115099
Approved by: https://github.com/wanchaol, https://github.com/fegin
2023-12-05 05:44:52 +00:00
Wanchao Liang
288b1acaa9 [dtensor] fix empty shape init for dtensor constructors (#115091)
As titled, this PR fixes the empty shape init case, where if we pass in
things like `torch.dtensor.zeros([])`, it should call `torch.zeros([])`
under the hood not `torch.empty(0)`, this makes dtensor constructor and
torch constructor aligns

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115091
Approved by: https://github.com/XilunWu
2023-12-05 00:51:29 +00:00
Lucas Pasqualin
753c07bbe0 All gather keys before processing Stateful objects in save/load [2/N] (#114304)
Accounts for the case where `state_dict` keys may present in different orders. Since users may be calling collectives in `state_dict` and `load_state_dict` call, different ordered keys could cause a deadlock. This is mostly a defensive move, meant to match the feature in TSS.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114304
Approved by: https://github.com/fegin, https://github.com/wz337
2023-12-04 18:31:14 +00:00
PyTorch MergeBot
3a2e2044cd Revert "[DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#114710) (#114991)"
This reverts commit 729ac7317a.

Reverted https://github.com/pytorch/pytorch/pull/114991 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/114991#issuecomment-1837214567))
2023-12-02 17:55:51 +00:00
Wanchao Liang
28925902fa [TP] fully rewrite Tensor Parallel APIs (#114732)
This PR rewrites Tensor Parallel implementation. Tensor Parallel APIs
supposed to be a very thin-wrapper to DTensor APIs, but the current
implementation got too messy and buggy. It's really hard to debug what
went wrong when using it. It's crucially important for advanced users or
developers to understand the API and its implementation easily without
going through all different types of functions and utils, so that
they could trust what happen under the hood.

In particular this PR:

* Make ParallelStyle to be a real contract API for parallelize_module to
  take, each concrete ParallelStyle only needs to implement `apply` to
apply the sharding to nn.Module, remove all non-necessary fields. This
also enable easier ParallelStyle authoring going forward.
* Keep the ColwiseParallel and RowwiseParallel public interface, but
  refactor them in a way that makes the parameter sharding, inputs and
outputs handling lives within the style itself, so that it's easy to
understand how Linear/Embedding layers are sharded and how the inputs/outputs
transformations are performed.
* remove all those private _prepare_input/_prepare_output_fn fields for
  both ColwiseParallel/RowwiseParallel. Since we throw deprecation
messages in nightly for a while and TP is on prototype release, the
fields are also private, it should be safe to remove them
* Refactor the recently landed PrepareModuleInput/Output style, change
  output_layouts to desired_input/output_layouts, group
  the function inside the style itself, no default arguments for these
two styles and user need to specify them to think about the sharding
layouts. Fixed bugs about not handling
`use_local_output` flag.
* Make default arguments be None instead of Placement object, this is
  standard python practice to not have custom object instance as default
argument
* Remove all dead APIs (i.e. PairwiseParallel and SequenceParallel
  style, all prepare input/output functions) as we throw deprecation
 msgs for a while, and in the progress of removing all of them from the tests.
* throw deprecation warning for `tp_mesh_dim` as we recomemnd use device
  mesh slice/indexing instead of manually specify mesh dim
* Rewrite all documentations for every ParallelStyle and make the
  documentation more clear about what each style is doing

TODOs:
* Rewrite TP tests to adjust for the changes we have in this PR
* add more tests to guard the bug fixes

Differential Revision: [D51761183](https://our.internmc.facebook.com/intern/diff/D51761183)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114732
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-12-02 08:18:12 +00:00