Commit Graph

92 Commits

Author SHA1 Message Date
Wanchao Liang
a1aa32e204 [dtensor] tensor ops to use strategy based sharding prop (#100607)
This is the first series of PR that adopts operator impls to use a
strategy based approach, each op utilizes OpStrategy and PlacementStrategy
to generate their own strategy. By utilizing the strategy based
approach along with the op graph, we could enable more advanced op
implementation (decomp is possible), and turn the sharding prop to be
more like a contraint satisfication problem.

This PR alone only adds some basic tensor op strategies, and it directly
works on the op graph that was used for metadata propagation. The tensor ops
added in this PR mainly follows one of the arg strategy. The next set of
PRs would add more op strategies to other ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100607
Approved by: https://github.com/XilunWu
2023-05-11 02:47:20 +00:00
Shen Li
9439cb0e11 Avoid using einsum for torch.cat DTensor propogation (#100251)
DTensor was reusing `einop_rule` to propagate sharding for torch.cat.
However, einsum only supports up to 52 subscripts (i.e., input tensors).
We have encountered use cases where one cat operator has more than 60
input tensors. Therefore, this commit reimplements sharding prop
rule for cat without using einsum.

Differential Revision: [D45435232](https://our.internmc.facebook.com/intern/diff/D45435232)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100251
Approved by: https://github.com/wanchaol
2023-05-03 01:56:18 +00:00
Wanchao Liang
123be4b694 [dtensor] add debug tool to track op coverage (#100124)
This PR adds a debug tool to track the op coverage needed in DTensor.

Note that we specifically target ops after decomp table in inductor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100124
Approved by: https://github.com/XilunWu
2023-05-02 01:45:55 +00:00
Chien-Chin Huang
b94a0ba5bb [SPMD] Add embedding dense backward prop rule for postional embedding (#100038)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100038
Approved by: https://github.com/mrshenli
2023-04-27 16:31:51 +00:00
Wanchao Liang
ad882c5210 [spmd] Use TupleStrategy and enable replicate fused_adam (#99374)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99374
Approved by: https://github.com/mrshenli
2023-04-25 19:30:53 +00:00
Wanchao Liang
c1e2fa8189 [dtensor] add StrategyType and TupleStrategy (#99435)
This PR refactors the current StrategyList. It introduces a
StrategyType, which is the base class of Strategy, and it have
two sub strategies:

1. Refactor the previous StrategyList to OpStrategy
2. Add TupleStrategy, the new strategy added to deal with tuple cases where
it could return multiple different OpStrategy for an op.

This would help support a more complicated op and unblocks compile mode
FSDP
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99435
Approved by: https://github.com/mrshenli
2023-04-22 05:39:20 +00:00
Xilun Wu
ce60997376 [BE][DTensor] validate the mesh argument in DeviceMesh construction (#99094)
## What's in this PR
DeviceMesh's __init__ function now requires all calling ranks to pass the same `mesh` argument.

## Why
We want to enforce SPMD style of programs using DTensor. Before this PR, 2-D Parallel API (e.g. _create_1d_device_mesh) defines different DeviceMesh on different ranks. After this PR, it defines each sub-meshes and simply perform communications on the one that it is associated with.

Differential Revision: [D45165511](https://our.internmc.facebook.com/intern/diff/D45165511)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99094
Approved by: https://github.com/wanchaol
2023-04-21 23:47:51 +00:00
Iris
0d2b55c459 [DTensor] Change Sharding algorithm to be in line with `torch.chunk()` (#98722)
As functional collective being updated, using tensor_split() as the underlying sharding algorithm would require padding and unpadding on multiple ranks. Therefore, we are changing the sharding algorithm to be in line with ``torch.chunk()`` to allow padding on the last two ranks in most of the scenarios.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98722
Approved by: https://github.com/wanchaol
2023-04-21 02:05:22 +00:00
Xilun Wu
964c7e3e85 [BE][DTensor] fix DTensor equal op (#99014)
## What problem this PR solves?
#97170 fixed `equal` operator return type (old: Tensor, now: bool) by giving it the correct sharding propagation. This is consistent with the `aten::equal` op. However, the correctness only stays at the local result level:
* `equal` op returns True if the local copy of dtensor A equals to the the local copy of dtensor B

This is not the correct semantic of `equal` which should return True if all local copies of A are equal to the corresponding local copies of B.

## What is this PR?

1. For non-participating ranks, if the return type is scalar, `local_results` is set to `None` which means the default value is a reduced result of participating ranks only.
2. For all ranks, if the return type is scalar and the `op_call` is `aten::equal`(because `aten::equal` is the only function that returns scalar value and needs communication), all gather the `local_results` within the `default pg` and reduce on them with `operator.and_`. The result will be the new `local_result`.

## Result/Impact
For non-participating ranks and the return type is scalar:

1. op is `aten::equal`, the return value is same with all other ranks
2. op is not `aten::equal`, the return value is None. Before this PR, this will raise "NotImplementedError" but has not been tested.

For participating ranks and the return type is scalar:

1. op is `aten::equal`, the return value is the equality of two dtensor operands - True if all copies are equal, False otherwise.
2. op is not `aten::equal`, simply the local computation result.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99014
Approved by: https://github.com/wanchaol
2023-04-18 03:22:44 +00:00
Wanchao Liang
55a1dc7f88 [dtensor] redistributed by default take self mesh instead (#99060)
This PR switches redistribute to default use self mesh instead of
the global mesh, which is more user friendly
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99060
Approved by: https://github.com/mrshenli
2023-04-14 05:14:28 +00:00
Wanchao Liang
15686950b7 [spmd] quick fix on batch input view issue (#98813)
This is a quick fix/hack to get around with the issue that some
"global" tensor view operation is invalid, but somehow it get
triggered by some models as mini-batch input itself won't have this
issue.

Since ultimately we should remove the dtensor expand and use the new
expansion, this hack is only temporary to unblock
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98813
Approved by: https://github.com/yifuwang, https://github.com/mrshenli
2023-04-11 14:27:01 +00:00
Xilun Wu
7ecbce374e [DTensor][3/N] enable aten.native_dropout (#98577)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98577
Approved by: https://github.com/wanchaol
2023-04-10 23:57:04 +00:00
Xilun Wu
e686a1e1b3 [DTensor][2/N] add Philox offset adjustment logic in operator_dispatch (#98199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98199
Approved by: https://github.com/wanchaol
2023-04-10 23:57:04 +00:00
Xilun Wu
67963c32bd [DTensor][1/N] add DTensor RNG state APIs (#98198)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98198
Approved by: https://github.com/wanchaol
2023-04-10 23:57:00 +00:00
Shen Li
1be3549a27 Enable replicated embedding in SPMD for NLP models (#98686)
For models like NanoGPT, embeddings are replicated and input ids
are sharded. In this case, output lookups should be sharded to
match ids.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98686
Approved by: https://github.com/yifuwang
2023-04-09 02:13:10 +00:00
Shen Li
11b0a84f3e Enable LogSoftmax for SPMD tracing (#98380)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98380
Approved by: https://github.com/wanchaol
2023-04-06 04:41:37 +00:00
Shen Li
d0eafed7fb [Easy] Fix minor errors in DTensor examples (#98430)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98430
Approved by: https://github.com/wanchaol
2023-04-05 21:44:01 +00:00
Wanchao Liang
dcec2100b1 [dtensor] add placement strategy and einsum strategy (#98227)
This adds placement strategy to the op schema and implement einsum
strategy. It's the basic building piece for compile mode expansion
and new op implementation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98227
Approved by: https://github.com/XilunWu
2023-04-05 17:09:32 +00:00
Yifu Wang
0830808dde [spmd expansion] speed up expansion by ~5x (#98389)
According to profiling, the top two expensive operations in spmd expansion are propagate_op_sharding and make_fx (for every dispatcher op node). This PR makes the following changes to speed up spmd expansion:
- We are unneccessarily doing propagate_op_sharding twice for every op. Remove one.
- When no tensor redistribution is required, we only need to update non-tensor args of the node according to op_schema and avoid building a GraphModule just for the node.

On a DDP use cases + foreach Adam, this change speeds up spmd expansion by ~5x (~10 min -> ~2 min).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98389
Approved by: https://github.com/mrshenli
2023-04-05 16:31:40 +00:00
Kazuaki Ishizaki
6514d71add Fix typos under torch/distributed directory (#98225)
This PR fixes typos in comments and messages of `.py` files under `torch/distributed` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98225
Approved by: https://github.com/soulitzer, https://github.com/kit1980
2023-04-05 00:21:33 +00:00
Rodrigo Kumpera
9ad66dd588 Switch reduce_scatter and all_gather in DeviceMesh to use functional collectives (#96226)
Among the changes is the introduction of gather_dim and scatter_dim in DeviceMesh collectives to simplify user code.

The current plan is to keep padding and gather/scatter dim support in DeviceMesh while we explore  optimization opportunities in Inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96226
Approved by: https://github.com/wanchaol
2023-04-04 00:58:33 +00:00
Shen Li
96403cfcec [Easy] Fix lint error on DTensor math_ops.py (#98170)
This lint error is caused by conflicts betwee #97996 and #98148

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98170
Approved by: https://github.com/yifuwang
2023-04-02 19:11:05 +00:00
Shen Li
02179827cb [Easy] Include SPMD and DTensor files in UFMT checks (#98148)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98148
Approved by: https://github.com/fegin
2023-04-02 15:34:49 +00:00
Wanchao Liang
7fcff01b50 [reland] switch mean to use reduction linear (#97996)
mean is actually a reduction linear formula if the final reduction
is partial sum (which currently is), so switching to use that instead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97996
Approved by: https://github.com/XilunWu, https://github.com/yifuwang
2023-04-02 03:19:56 +00:00
Shen Li
e8d39606eb [SPMD] Enable fused Adam in full train step tracing (#98113)
Differential Revision: [](https://our.internmc.facebook.com/intern/diff/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98113
Approved by: https://github.com/yifuwang, https://github.com/fegin
2023-04-01 15:54:13 +00:00
Shen Li
bccf2ef0ce Format DTensor dispatch.py and _meta_registrations.py (#98114)
Format-only changes with black and lintrunner to prepare for the commit on top.

Differential Revision: [D44603809](https://our.internmc.facebook.com/intern/diff/D44603809)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98114
Approved by: https://github.com/yifuwang, https://github.com/fegin
2023-04-01 15:54:13 +00:00
PyTorch MergeBot
9e3b34775b Revert "[dtensor] switch mean to use reduction linear (#97996)"
This reverts commit 1b323b313c.

Reverted https://github.com/pytorch/pytorch/pull/97996 on behalf of https://github.com/huydhn due to Sorry for reverting your PR, but it fails a test on CPU 1b323b313c
2023-03-31 16:44:01 +00:00
Wanchao Liang
1b323b313c [dtensor] switch mean to use reduction linear (#97996)
mean is actually a reduction linear formula if the final reduction
is partial sum (which currently is), so switching to use that instead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97996
Approved by: https://github.com/XilunWu, https://github.com/yifuwang
2023-03-30 22:48:16 +00:00
Wanchao Liang
47ce41e732 [dtensor] remove DeviceMesh typing hack guard type imports (#97889)
This PR relands https://github.com/pytorch/pytorch/pull/94526
and tries to guard the type import for older version numpy
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97889
Approved by: https://github.com/fegin
2023-03-29 23:29:41 +00:00
PyTorch MergeBot
4114c1ea02 Revert "[dtensor] remove typing hack of DeviceMesh (#94526)"
This reverts commit 70b063db0e.

Reverted https://github.com/pytorch/pytorch/pull/94526 on behalf of https://github.com/atalman due to breaking internal builds
2023-03-29 17:33:58 +00:00
Wanchao Liang
70b063db0e [dtensor] remove typing hack of DeviceMesh (#94526)
This removes the typing hack, part of https://github.com/pytorch/pytorch/pull/92931
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94526
Approved by: https://github.com/XilunWu
2023-03-29 00:23:47 +00:00
Wanchao Liang
08c1d1a871 [dtensor] set cuda device automatically, and refactor error handling (#97583)
This PR would detect if device_type is cuda, if cuda passed in,
we would set the current cuda device each process/thread automatically
(This assumption is based on homogenous devices).

Also refactored error handling code
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97583
Approved by: https://github.com/wz337, https://github.com/XilunWu
2023-03-28 02:25:45 +00:00
Wanchao Liang
e9c4904915 [dtensor] remove custom dispatch op (#95629)
Since we removed all custom dispatch ops, we can safely delete this
table as we won't use it for other purposes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95629
Approved by: https://github.com/XilunWu
2023-03-28 02:25:45 +00:00
Kazuaki Ishizaki
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
Chien-Chin Huang
f3cf3d7620 [DTensor] Fix the default PG condition for DeviceMesh (#97384)
The current conditin to use the default PG is `len(unique_mesh_values) == WORLD_SIZE - 1`. The `- 1` is not correct and seems to be an incorrect fix from https://github.com/pytorch/pytorch/pull/96861.

Differential Revision: [D44314317](https://our.internmc.facebook.com/intern/diff/D44314317/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97384
Approved by: https://github.com/wanchaol
2023-03-24 00:04:34 +00:00
Xilun Wu
c2d7508276 [DTensor] default value for DTensor ops on non-participating devices (#95852)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95852
Approved by: https://github.com/wanchaol
2023-03-23 19:30:02 +00:00
Xilun Wu
103f4c99f0 [DTensor] implement aten.equal sharding prop (#97170)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97170
Approved by: https://github.com/wanchaol
2023-03-23 19:30:02 +00:00
Wanchao Liang
16e7e5a24b [dtensor] lazy init process groups in device mesh (#96700)
This PR adds a private flag to allow process grou lazy initialization, this is
replacing the previous `dim_groups` arg, as no one is using that now

This could help avoid creating process groups when not necessary

Differential Revision: [D44044664](https://our.internmc.facebook.com/intern/diff/D44044664)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96700
Approved by: https://github.com/fduwjj, https://github.com/XilunWu
2023-03-20 17:50:04 +00:00
Shintaro Iwasaki
95575f0a5f [DTensor] Fix _get_or_create_default_group() (#96961)
Summary:
This PR fixes `_get_or_create_default_group()` of `DeviceMesh`. When `mesh` of the first created `DeviceMesh` is not `[0, 1, 2, ... WORLD_SIZE - 1]` and `is_initialized() == False`, it wrongly asserts. This PR fixes this issue by removing these assertions.

 ---

More specifically, `_get_or_create_default_group()` has 4 checks:

1. `DeviceMesh must include every process in WORLD`
2. `DeviceMesh cannot have duplicate values`
3. `DeviceMesh ranks must start from 0`
4. `DeviceMesh should have all ranks of WORLD`

1, 3, and 4 are not satisfied when `self.mesh` is not `[0, 1, 2, ... WORLD_SIZE - 1]`.

2 is a valid check, but it is also checked in `__init__()`, so we don't need to check it again in this function.

Test Plan: CI

Reviewed By: wanchaol

Differential Revision: D44098849

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96961
Approved by: https://github.com/wanchaol
2023-03-17 15:52:19 +00:00
Shintaro Iwasaki
397fb2762e [DTensor] Fix DeviceMesh (#96861)
Summary: This Diff fixes some DeviceMesh issues, which blocks internal DTensor integration.  Specifically, when `self.mesh = [2, 3]` while `world_size = 4`, because `unique_mesh_values[-1] == 3`, it takes the first short-cut branch and uses `default_pg`. Let's check the length instead of the last value of `unique_mesh_values`.

Test Plan: CI

Reviewed By: wanchaol

Differential Revision: D44079872

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96861
Approved by: https://github.com/wanchaol
2023-03-16 16:40:38 +00:00
fduwjj
3405ac8a08 [TP][DTensor Op] Enable Embedding op for DTensor (#96702)
We enabled col-wise embedding for TP users.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96702
Approved by: https://github.com/wanchaol
2023-03-16 05:18:07 +00:00
Wanchao Liang
789fc4c292 [dtensor] refactor shape/offset calculation (#95923)
Shape offset calculation is commonly used and extract them into a separate util

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95923
Approved by: https://github.com/fduwjj
2023-03-05 06:33:32 +00:00
Ke Sang
6c061e5145 [DTensor][Shampoo] add _tenso.zero function (#95863)
Summary:
implement zeros function inside DTensor API
- user specify the zeros tensor shape, and the function will create local zero tensor given the placement information

Test Plan:
{F889157756} - unit test for util function for compute_local_tensor_size
- unit test for _tensor.zeros

Reviewed By: wanchaol

Differential Revision: D43630718

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95863
Approved by: https://github.com/wanchaol
2023-03-03 19:36:44 +00:00
Rodrigo Kumpera
7ebd816aab Switch DTensor to use funcol::all_reduce. (#95804)
This is relanding the troubling part of #95009 that caused a regression.

BC: This changes the signature and semantics of DeviceMesh::all_reduce.

DeviceMesh::all_reduce now uses a functional collective under the hood which makes it more easily traceable.
You no longer need to use CommTensor to get a trace.

all_reduce now is async only and uses AsyncCollectiveTensor to ensure proper stream synchronization.

Signature changed: removed async_op param and changes return type from Optional[Work] to torch.Tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95804
Approved by: https://github.com/fegin
2023-03-02 17:55:01 +00:00
Kazuaki Ishizaki
b3d8fae042 Fix typos in documents under torch directory (#95709)
This PR fixes typo in `.md` files under `torch` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95709
Approved by: https://github.com/Skylion007, https://github.com/kit1980
2023-03-01 23:43:35 +00:00
Wanchao Liang
7a772bfff9 [dtensor] add submesh example to checkpoint_example (#95655)
This PR adds a submesh example for checkpoing purposes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95655
Approved by: https://github.com/XilunWu
2023-03-01 08:19:27 +00:00
Wanchao Liang
2a1cb9640c [dtensor] support creating DTensor in submesh (#95458)
This PR supports creating DTensor in a submesh, if the rank is not
participating in the mesh, we assign the local tensor to be empty
tensor, and do nothing in the operator dispatch

Differential Revision: [D43643577](https://our.internmc.facebook.com/intern/diff/D43643577)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95458
Approved by: https://github.com/XilunWu
2023-02-28 17:54:26 +00:00
Wanchao Liang
261eb46ddd [dtensor] refactor get_coordiniate (#95457)
This refactor get_coordinate to return a optional[list] instead of
directly the coordinate on dim, this is so that we can check if the
rank is inside the mesh easily

Differential Revision: [D43643579](https://our.internmc.facebook.com/intern/diff/D43643579)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95457
Approved by: https://github.com/XilunWu
2023-02-28 17:54:26 +00:00
Wanchao Liang
bb9a05b116 [dtensor] use tracing for metadata prop (#95456)
This PR uses tracing for metadata prop, so that we can get correct
shape/stride metadata without manual calculation by ourselves.

The follow up PR on this would be adopt tracing for the sharding
prop itself

Differential Revision: [D43643578](https://our.internmc.facebook.com/intern/diff/D43643578)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95456
Approved by: https://github.com/XilunWu
2023-02-28 17:54:22 +00:00
PyTorch MergeBot
d950f45577 Revert "[Functional Collectives] Migrate DeviceMesh::all_reduce to use functional all_reduce. (#95009)"
This reverts commit 0765dbc25e.

Reverted https://github.com/pytorch/pytorch/pull/95009 on behalf of https://github.com/jeanschmidt due to this PR is causing internal breakages. Check https://fburl.com/diff/me41urq8
2023-02-27 19:21:58 +00:00