Reviewed GPT5 summary:
**Summary / Goal**
Add a utility to compute expected local tensor sizes and strides under *even sharding* in dtensor.
**Details**
- New function in `torch/distributed/tensor/_utils.py`.
- Computes local sizes/strides given global shape, mesh, and placements.
- Enforces divisibility of global dimension by mesh size (strict even sharding).
- Complements `compute_global_tensor_info`.
**Motivation**
Ensures correctness for stride/layout computations in distributed tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164296
Approved by: https://github.com/ezyang
Previous uneven `_StridedShard` in https://github.com/pytorch/pytorch/pull/150490 seems failing cases like sharding `tensor = torch.arange(6)` with FSDP 2, TP 2.
This PR attempts to reinvent `_StridedShard`.
I didn't test nested `_StridedShard`, because there shouldn't be any use cases. I think it will become quite messy when it comes to **nested uneven** `_StridedShard`. We are probably going to deprecate it anyway after @zpcore 's work https://github.com/pytorch/pytorch/pull/160266 on ordered sharding, so IMO not worth it to make it too general.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163843
Approved by: https://github.com/ezyang
There is only one substantive change: the branch on
`global_offset[shard_dim] <= local_offset[shard_dim]`
is removed because it is unnecessary: you can always treat the
first shard uniformly with the rest of the shards, because your
global offset is guaranteed to be zero in this case anyway.
I also switch the shard_size case to sym_ite, to make it possible
for LocalTensor to deal with the MPMD-ness here, but it's equivalent
to the old if-then-else.
I tried to rewrite the comments to be more clear what is going on
algorithmically here.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163344
Approved by: https://github.com/albanD, https://github.com/zpcore, https://github.com/tianyu-l
### Summary
Recreating #151990 to mitigate easyCLA failure
compute_global_tensor_shape util function takes in local tensor shape, device mesh
and placements. We all gather the shapes from the shards and according to the placement
type we construct the global shape.
Note: currenty only implemented for placement type Shard and Replicate, TODO for StridedShared
### Test
`pytest test/distributed/tensor/test_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152751
Approved by: https://github.com/XilunWu
`compute_local_shape_and_global_offset` util computes the local shape of
a particular shard of a DTensor, and the global offset (which describes
how the shard fits into the global tensor).
When the tensor dim does not evenly divide into the mesh dim, uneven
sharding occurs. In some cases, uneven sharding results in an empty
shard.
e.g.
tensor dim size: 4096
mesh dim size: 30
ranks 0..27 have local size 18
rank 28 has local size 8
rank 29 has local size 0 <--- empty shard
The global offset for an empty shard was previously undefined and
returned values that were computed based on logic that assumes no empty
shards. This caused DCP to fail to save a checkpoint, becuase
deduplication logic could 'throw away' real (non-empty) shards thinking
they were duplicates of zero-sized shards with the same offset.
Now, we define the global offset of an empty shard to be the dim-size,
which is out of bounds of the tensor and can't overlap with any
non-empty shards.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150862
Approved by: https://github.com/teja-rao, https://github.com/XilunWu
This enables using FSDP+TP on parameters with dimensions that aren't
evenly divisible by the DP/TP mesh sizes.
- this may not support all possible combinations of strided shardings
and shardings, but the support before this PR is not complete anyway
This contains several fixes for different aspects of DTensor behavior
relating to uneven strided sharding:
- original creation of the strided tensor requires fixes in
StridedShard._split_tensor
- full_tensor() reconstruction requries fixes in
StridedShard._to_replicate_tensor to correctly reshuffle the data into
the original pre-sharded order
- Distributed Checkpointing support requires correct computation of the
compute_local_shape_and_global_offset util so it knows how a local
shard maps to the global tensor, for reconstruction during
load/reshard.
This PR also adds a util `_explicit_order_placements` which converts a list of
placements with StridedSharding into a list of placements with only
regular sharding, with the order shuffled such that it is equivalent.
Builds on and completes the work started in https://github.com/pytorch/pytorch/pull/148894
Uneven Sharding Example
-------
(copied from _StridedShard._to_replicate_tensor docstring)
mesh = (DP=2, TP=2)
original = torch.arange(5)
**Applying Sharding**
Step 1 - Apply TP sharding
`tp = distribute_tensor(x, world_mesh['tp'], [Shard(0)])`
local_tensors:
rank0: [0,1,2] rank1: [3,4]
rank1: [0,1,2] rank3: [3,4]
Step 2 - Apply FSDP sharding
`dp_tp = ...` (the process of creating a strided-shard tensor is skipped over as it is hacky and complicated)
dp_tp has placement (_StridedShard(0, split_factor=2), Shard(0))
local_tensors:
rank0: [0,1] rank1: [3]
rank1: [2] rank3: [4]
**Reconstructing the Full Tensor**
Now, say someone wants to reconstruct dp_tp's full tensor. This will invoke 'redistribute' to replicate.
redistribute will first replicate the "Shard(0)" placement on the rightmost mesh dim, then replicate the
StridedShard placement second, which is implemented by this function.
So our starting point (`local_tensor` arg) is the result of replicating the Shard(0) placement across the
TP dim, which looks like this.
Note the discrepancy with the 'tp sharded tensor' line above! We'll fix it by locally shuffling data.
local_tensors:
rank0: [0,1,3] rank1: [0,1,3]
rank1: [2,4] rank3: [2,4]
Step 1: replicate over the DP dimension. Afterwards, each rank can locally sort the values.
note: we need padding to do this allgather, and we'll need to keep track of the padding amount for later
local_tensors:
rank0: [0,1,3,2,4] rank1: [0,1,3,2,4]
rank1: [0,1,3,2,4] rank3: [0,1,3,2,4]
Step 2: chunk and shuffle values around to account for the wrong order of operations above
and get the original tensor content back
01324# <- our allgather includes padding, if padding was applied in step 1
01324 <- Remove the padding
013, 24 <- chunk once, 'undoing' the DP allgather
01, 3, 2, 4 <- chunk each chunk, 'undoing' the initial (wrong) TP allgather performed by Shard(0)->Replicate()
012, 34 <- interleave with stride=TP mesh dim size
01234 <- concatenate
Co-authored-by: Luca Wehrstedt <lw@meta.com>
Co-authored-by: Will Constable <whc@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150490
Approved by: https://github.com/wanchaol, https://github.com/XilunWu
The util converts a list of placements in the traditional DTensor format
(e.g. [_StridedShard(0), Shard(0)], where list position is mesh_dim and sharding
is always applied left-to-right (from dim 0 to higher dims))
to a more explicitly ordered format, also replacing '_StridedShard' with
simple 'Shard' placements in the process.
(e.g. the above becomes [(1, Shard(0)), (0, Shard(0)] where the first
item in the tuple is the mesh_dim and the ordering of the tuples is the
sharding order.
This is useful so far as a helper for fixing local shape computation for
strided sharding in the uneven shape case, in the following PR- but may
also be useful more broadly if we can use explicit orderings to simplify
other parts of DTensor logic.
This skips implementing some combinations of _StridedSharding that are
not currently used in the wild today, but could be supported easily.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150493
Approved by: https://github.com/wanchaol, https://github.com/XilunWu
**Summary**
1. This PR removes the public API `compute_local_shape` and replace its use with the more general API `compute_local_shape_and_global_offset`.
2. To keep `compute_local_shape_and_global_offset` consistent with `compute_local_shape` on empty shards, it now returns local tensor shape `(0,)` for empty shards which is more aligned with DTensor's semantics on non-participating ranks.
**Test**
`pytest test/distributed/_tensor/test_dtensor.py`
`pytest test/distributed/_tensor/test_init.py`
`pytest test/distributed/_tensor/test_tensor_ops.py`
Differential Revision: [D62415591](https://our.internmc.facebook.com/intern/diff/D62415591)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135554
Approved by: https://github.com/tianyu-l, https://github.com/wz337
reland of https://github.com/pytorch/pytorch/pull/133113
I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(
----
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306