Commit Graph

134 Commits

Author SHA1 Message Date
Wanchao Liang
f026b32008 [device_mesh][BE] reduce_scatter fallback to funcol and remove from DM (#105642)
For the reason similar to https://github.com/pytorch/pytorch/pull/105605
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105642
Approved by: https://github.com/kumpera, https://github.com/wz337, https://github.com/fduwjj
2023-07-27 01:33:05 +00:00
Wanchao Liang
2fa063e1e0 [device_mesh][BE] remove allgather from DM (#105614)
For the reason similar to https://github.com/pytorch/pytorch/pull/105605
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105614
Approved by: https://github.com/rohan-varma, https://github.com/wz337, https://github.com/fduwjj
2023-07-27 01:33:05 +00:00
Wanchao Liang
4a49f1f46e [device mesh][BE] remove allreduce from DM (#105605)
This PR removes allreduce from DM and use functional collective instead,
the rationle is that we don't want to maintain yet another set of
collective apis, and since the DM's collective is now a thin wrapper to functional collective so we
don't really need these collective to live in DM
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105605
Approved by: https://github.com/kumpera, https://github.com/wz337, https://github.com/fduwjj
2023-07-27 01:33:02 +00:00
fduwjj
0003d5135d [TP] Enable partial tensor add without redistribute (#105939)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105939
Approved by: https://github.com/wanchaol
2023-07-26 03:12:39 +00:00
Wanchao Liang
e3539a0e54 [dtensor] forward fix for dynamo import with deploy (#105760)
Summary: forward fix to avoid revert

Differential Revision: D47679598

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105760
Approved by: https://github.com/atalman
2023-07-23 07:13:38 +00:00
Mo Mo
7b56238551 fix typo (#105507)
Differential Revision: D47568928

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105507
Approved by: https://github.com/awgu, https://github.com/fduwjj
2023-07-19 20:34:43 +00:00
Wanchao Liang
f139aab2f4 [dynamo] add initial dynamo support for DTensor (#103146)
This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`

`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.

Captured graph:
```
    def forward(self, L_x_ : torch.Tensor):
        l_x_ = L_x_

        # File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
        prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False);  l_x_ = None

        # File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
        prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local);  prim_from_local = None
        to_local = prim_redistribute.to_local();  prim_redistribute = None
        add = to_local + 2;  to_local = None
        return (add,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym
2023-07-19 16:01:12 +00:00
Justin Chu
232b96b6e2 [BE] Enable ruff's UP rules and autoformat distributed/ (#105433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105433
Approved by: https://github.com/albanD
2023-07-19 14:27:11 +00:00
Wanchao Liang
cb23373264 [dynamo] allow tensor subclass fakification in dynamo (#105308)
This PR adds necessary plumbing through torchdynamo to allow tensor
subclasses with certain contract (i.e. with `__tensor_flatten__` and
`__tensor_unflatten__`) to goes through the dynamo fakification pass by
fakifying the tensor subclass internal components.

Some of the tensor subclass contract logic mostly borrowed from
https://github.com/pytorch/pytorch/pull/97540

Added some tests to verify simply passing through a tensor subclass
(i.e. DTensor) through dynamo eager works as expected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105308
Approved by: https://github.com/ezyang
2023-07-18 17:28:04 +00:00
Wanchao Liang
bcb9ca4e5a [dtensor] canonicalize detach callsites and use view_as when appropriate (#105239)
This PR canonicalize the detach callsite to only call the detach
from `distribute_tensor`. Change other callsite to view_as and remove the
tensor constructor detach call

This is so that we don't detach local tensor for every op run when
rewrapping the DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105239
Approved by: https://github.com/albanD
2023-07-18 17:13:37 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
PyTorch MergeBot
b4d91b1c5b Revert "[Typing] Fix PEP 484 Violation (#105022)"
This reverts commit 4148b7bada.

Reverted https://github.com/pytorch/pytorch/pull/105022 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/105022#issuecomment-1635967734))
2023-07-14 14:45:09 +00:00
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
Iris
4f8ba6f8f6 [DeviceMesh]Add validate mesh flag to DeviceMesh (#104807)
When creating DeviceMesh, _init_process_group() would validate that all calling ranks pass in the same `mesh` argument. In FSDP, we are currently creating the DeviceMesh based on the pg of the root state so the mesh will always be valid. Adding the flag to DeviceMesh, so we can skip the all_gather_tensor of the validation during construction time.

_validate_mesh is default to True, but we manually flip it to False when initializing device mesh in FSDP's  _runtime_utils.py.

Will modify skipping pg creation if existed for both 1D and 2D cases and then delete _init_process_groups flag in a follow up PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104807
Approved by: https://github.com/wanchaol
2023-07-12 23:42:13 +00:00
Nikita Shulga
4148b7bada [Typing] Fix PEP 484 Violation (#105022)
Not sure, how it worked before, but if arguments must be annotated is optional if they are defaulted to None

Towards enabling mypy-1.4.1 in lintrunner

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 5e1b9f4</samp>

> _We annotate the arguments of doom_
> _To show the `None` values of gloom_
> _We improve the type checking and readability_
> _With `Optional` annotations of metal-ity_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105022
Approved by: https://github.com/izaitsevfb, https://github.com/huydhn, https://github.com/Skylion007
2023-07-12 10:20:48 +00:00
fduwjj
aa84078c6c [PTD][TP] Add BWD support for colwise embedding sharding (#104820)
Originally, we didn't enable BWD for colwise embedding because we thought it was just for inference, but it turns out that we do need it for training. So, let's enable it for now and unit test is also added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104820
Approved by: https://github.com/fegin
2023-07-10 22:33:20 +00:00
Xilun Wu
e799f565eb [DTensor][TP][Random] Introduce TensorParallelRNGTracker to integrate parallel RNG state with Tensor Parallel (#103910)
This PR enables the automatic use of `TensorParallelRNGTracker` in Tensor Parallel api. Some unit tests are going to be added to cover.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103910
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-06-30 08:06:41 +00:00
Wanchao Liang
da06920f47 Replace all_gather in device mesh with functional collective equivalent (#104056)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104056
Approved by: https://github.com/kumpera, https://github.com/wanchaol
2023-06-30 05:30:02 +00:00
Xilun Wu
a66107a30c [DTensor][Random] Introduce CudaRNGStateTracker to maintain parallel RNG state for DTensor (#103235)
# Change
This PR adds two classes to DTensor:

1. `CudaRNGStateTracker`:  `CudaRNGStateTracker` stores Random Number Generator (RNG) state (a `ByteTensor` object) in a `dict`, mapping from a corresponding tag to each state tensor. It also provides a set of convenient utility methods to help access/modify the state tensors. The most important interface is `_distribute_region` which will be used when DTensor executes a random op (an operator that calls RNG).

2. `OffsetBasedRNGTracker`: This subclass of `CudaRNGStateTracker` defines the default policy of how RNG states should be shared and synchronized among all ranks to respect the semantics of DTensor random operators.

# Warning

- With `Multi-threaded ProcessGroup`, the global variable `_rng_tracker` will be shared among threads(ranks) and cause issue. We need to figure out a compatible solution for that.

- The RNG state may be asynchronous outside of participating ranks. It is harmless in our current use case of submesh though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103235
Approved by: https://github.com/wanchaol
2023-06-27 19:00:25 +00:00
Wanchao Liang
4cc474dec4 [dtensor] support torch.save/load with DTensor (#103106)
This PR actually enables DTensor to be pickable and add tests to test
torch.save/load works correctly for DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103106
Approved by: https://github.com/kumpera
2023-06-09 04:11:15 +00:00
Wanchao Liang
d31707a257 Get rid of dim_groups attribute from DeviceMesh (#103105)
This PR get rids of the dim_groups attribute from DeviceMesh, the main
motivation behind this is that we should let c10d store the process
groups during its creation instead of DeviceMesh, DeviceMesh should just
handle ranks correctly.

This could enable DTensor becomes picklable! (torch.save/load could be
possible), which I will give it a try in the next PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103105
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
2023-06-09 04:11:15 +00:00
Xilun Wu
675f2597fa [reland][DTensor][3/N] add DTensor constructor function: full (#101436) (#103165)
This is a reland attempt of reverted PR #101436 .

Differential Revision: [D46537531](https://our.internmc.facebook.com/intern/diff/D46537531)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103165
Approved by: https://github.com/wanchaol
2023-06-08 16:18:33 +00:00
Wanchao Liang
8585784a34 [dtensor] fix allgather unpadding logic (#103219)
This PR fixes allgather unpadding logic so that we only need to unpad
the full tensor instead of first chunking it to small tensors and unpad
individually, as we know how our padding algorithm works
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103219
Approved by: https://github.com/wz337, https://github.com/fduwjj
2023-06-08 03:31:24 +00:00
shaoyf42
17737f9d0e [DTensor] Allow DTensor support cuda-like device (#102468)
Allow DTensor support cuda-like device, fix https://github.com/pytorch/pytorch/issues/102442

Currently, DTensor supports cuda and cpu. There are other efforts to make DTensor support third-party devices, for example https://github.com/pytorch/pytorch/pull/101914 and https://github.com/pytorch/pytorch/issues/101911. However, this support only extends a portion of third-party devices and is no good support for third-party cuda-like devices. Therefore, we would like to extend DTensor to support cuda-like devices, after all, cuda is so popular!

1. Similar to what is done here, we need to initialize the communication backend for the device set by DeviceMesh. So `_default_backend_for_device` is added to `Backend`. It is worth noting that when we register a new backend for a device other than cpu and cuda, we also need to add a new default backend for this device.
2. Adding `_device_handle` to `DeviceMesh` for cuda-like devices, similar to what is set in FSDP. When `_device_handle` is not None, the device has similar behavior to `cuda`. In this way, functions like `torch.cuda.device_count()` need to be modified to `device_mesh._device_handle.device_count()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102468
Approved by: https://github.com/wanchaol
2023-06-07 23:13:53 +00:00
PyTorch MergeBot
0f672e8c67 Revert "[DTensor][3/N] add DTensor constructor function: full (#101436)"
This reverts commit 2ca75d49a8.

Reverted https://github.com/pytorch/pytorch/pull/101436 on behalf of https://github.com/malfet due to Caused internal SEV ([comment](https://github.com/pytorch/pytorch/pull/101436#issuecomment-1575076672))
2023-06-03 17:09:08 +00:00
shaoyf42
fc218a8a13 Fix typos in README of DTensor (#102813)
Fix typos in README of DTensor. But there is still a problem to be fixed. I reported an error when I tried to use distribute_module with  shard_params. I show the specific error message in issue https://github.com/pytorch/pytorch/issues/102812.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102813
Approved by: https://github.com/wanchaol
2023-06-02 19:27:23 +00:00
fduwjj
92923aca61 [TP] Use Stride inferred from local tensor in to_local bwd (#102630)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102630
Approved by: https://github.com/wanchaol
2023-06-01 04:30:24 +00:00
Wanchao Liang
c5d4ee2d73 [dtensor][simple] fix some comments (#102661)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102661
Approved by: https://github.com/fduwjj, https://github.com/XilunWu
2023-06-01 03:23:19 +00:00
Wanchao Liang
ff58d19c89 DeviceMesh use dispatchable PG to support custom backend (#102336)
This PR switches DeviceMesh to use dispatchable process group instead,
this could enable easier backend integration as user only need to
integrate with c10d process group custom backend, without needing to
change DeviceMesh to plug in the backend
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102336
Approved by: https://github.com/fduwjj
2023-05-30 19:22:37 +00:00
Wanchao Liang
6e0c741105 [dtensor] hide mesh validation check under init_process_group flag (#101996)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101996
Approved by: https://github.com/wz337
2023-05-23 18:17:54 +00:00
Wanchao Liang
70eccdbf92 [dtensor] add necessary logging to APIs and components (#101994)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101994
Approved by: https://github.com/wz337
2023-05-23 18:17:54 +00:00
Xilun Wu
2ca75d49a8 [DTensor][3/N] add DTensor constructor function: full (#101436)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101436
Approved by: https://github.com/wanchaol
2023-05-23 06:05:40 +00:00
Wanchao Liang
38a29324b0 [dtensor][2/N] more tensor ops to use strategy propagation (#101203)
As titled, this PR adapts a few more tensor ops to use strategy based
sharding prop
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101203
Approved by: https://github.com/XilunWu
2023-05-22 17:16:14 +00:00
Xilun Wu
010763be9a [DTensor][2/N] add DTensor constructor function: empty (#101022)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101022
Approved by: https://github.com/wanchaol
2023-05-16 16:50:54 +00:00
Xilun Wu
5cc361c736 [DTensor][1/N] add DTensor constructor function: ones (#100933)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100933
Approved by: https://github.com/wanchaol
2023-05-16 16:50:54 +00:00
Xuehai Pan
05f6250815 Add missing torch.distributed.ReduceOp.AVG in type stubs (#101534)
Add missing `AVG` to `torch.distributed.ReduceOp` enum for type annotation.

Ref:

88b6a4577b/torch/csrc/distributed/c10d/Types.hpp (L35-L47)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101534
Approved by: https://github.com/Skylion007
2023-05-16 15:51:21 +00:00
Iris
568db1b464 [dtensor] Relax condition for _split_tensor() (#101218)
When tensor.size(self.dim) < num_chunks, we will fill empty chunk with empty tensor (https://github.com/pytorch/pytorch/pull/98722). Therefore, we no longer needs this assert.

For example, when sharding a tensor with 1 element on 2 ranks along dim 0, results would be as follows:
```
rank:0, dtensor:DTensor(local_tensor=tensor([0.4963], device='cuda:0'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)])
rank:1, dtensor:DTensor(local_tensor=tensor([], device='cuda:1'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101218
Approved by: https://github.com/wanchaol
2023-05-14 07:39:27 +00:00
Wanchao Liang
3ae612ba7f [dtensor] remove assertions about submesh checks (#101229)
This PR removes assertions from submesh checks to directly return local
tensor, this is so that all the other APIs can work with submesh
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101229
Approved by: https://github.com/fduwjj
2023-05-12 04:20:35 +00:00
Wanchao Liang
599ae95d1a [dtensor] use stack to manage mesh resources (#101202)
This PR changes the context manager behavior of device mesh, now we use
a mesh env to track the current mesh and save the mesh to a stack so
that we can allow nested context manager
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101202
Approved by: https://github.com/wz337
2023-05-11 23:48:36 +00:00
Wanchao Liang
a1aa32e204 [dtensor] tensor ops to use strategy based sharding prop (#100607)
This is the first series of PR that adopts operator impls to use a
strategy based approach, each op utilizes OpStrategy and PlacementStrategy
to generate their own strategy. By utilizing the strategy based
approach along with the op graph, we could enable more advanced op
implementation (decomp is possible), and turn the sharding prop to be
more like a contraint satisfication problem.

This PR alone only adds some basic tensor op strategies, and it directly
works on the op graph that was used for metadata propagation. The tensor ops
added in this PR mainly follows one of the arg strategy. The next set of
PRs would add more op strategies to other ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100607
Approved by: https://github.com/XilunWu
2023-05-11 02:47:20 +00:00
Shen Li
9439cb0e11 Avoid using einsum for torch.cat DTensor propogation (#100251)
DTensor was reusing `einop_rule` to propagate sharding for torch.cat.
However, einsum only supports up to 52 subscripts (i.e., input tensors).
We have encountered use cases where one cat operator has more than 60
input tensors. Therefore, this commit reimplements sharding prop
rule for cat without using einsum.

Differential Revision: [D45435232](https://our.internmc.facebook.com/intern/diff/D45435232)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100251
Approved by: https://github.com/wanchaol
2023-05-03 01:56:18 +00:00
Wanchao Liang
123be4b694 [dtensor] add debug tool to track op coverage (#100124)
This PR adds a debug tool to track the op coverage needed in DTensor.

Note that we specifically target ops after decomp table in inductor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100124
Approved by: https://github.com/XilunWu
2023-05-02 01:45:55 +00:00
Chien-Chin Huang
b94a0ba5bb [SPMD] Add embedding dense backward prop rule for postional embedding (#100038)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100038
Approved by: https://github.com/mrshenli
2023-04-27 16:31:51 +00:00
Wanchao Liang
ad882c5210 [spmd] Use TupleStrategy and enable replicate fused_adam (#99374)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99374
Approved by: https://github.com/mrshenli
2023-04-25 19:30:53 +00:00
Wanchao Liang
c1e2fa8189 [dtensor] add StrategyType and TupleStrategy (#99435)
This PR refactors the current StrategyList. It introduces a
StrategyType, which is the base class of Strategy, and it have
two sub strategies:

1. Refactor the previous StrategyList to OpStrategy
2. Add TupleStrategy, the new strategy added to deal with tuple cases where
it could return multiple different OpStrategy for an op.

This would help support a more complicated op and unblocks compile mode
FSDP
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99435
Approved by: https://github.com/mrshenli
2023-04-22 05:39:20 +00:00
Xilun Wu
ce60997376 [BE][DTensor] validate the mesh argument in DeviceMesh construction (#99094)
## What's in this PR
DeviceMesh's __init__ function now requires all calling ranks to pass the same `mesh` argument.

## Why
We want to enforce SPMD style of programs using DTensor. Before this PR, 2-D Parallel API (e.g. _create_1d_device_mesh) defines different DeviceMesh on different ranks. After this PR, it defines each sub-meshes and simply perform communications on the one that it is associated with.

Differential Revision: [D45165511](https://our.internmc.facebook.com/intern/diff/D45165511)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99094
Approved by: https://github.com/wanchaol
2023-04-21 23:47:51 +00:00
Iris
0d2b55c459 [DTensor] Change Sharding algorithm to be in line with `torch.chunk()` (#98722)
As functional collective being updated, using tensor_split() as the underlying sharding algorithm would require padding and unpadding on multiple ranks. Therefore, we are changing the sharding algorithm to be in line with ``torch.chunk()`` to allow padding on the last two ranks in most of the scenarios.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98722
Approved by: https://github.com/wanchaol
2023-04-21 02:05:22 +00:00