Commit Graph

17 Commits

Author SHA1 Message Date
Simon Fan
cbaa07e438 [dtensor] add util to compute expected local sizes/strides for even sharding (#164296)
Reviewed GPT5 summary:

**Summary / Goal**
Add a utility to compute expected local tensor sizes and strides under *even sharding* in dtensor.

**Details**
- New function in `torch/distributed/tensor/_utils.py`.
- Computes local sizes/strides given global shape, mesh, and placements.
- Enforces divisibility of global dimension by mesh size (strict even sharding).
- Complements `compute_global_tensor_info`.

**Motivation**
Ensures correctness for stride/layout computations in distributed tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164296
Approved by: https://github.com/ezyang
2025-10-10 02:34:27 +00:00
Tianyu Liu
92f7361e27 [DTensor] fix uneven _StridedShard (#163843)
Previous uneven `_StridedShard` in https://github.com/pytorch/pytorch/pull/150490 seems failing cases like sharding `tensor = torch.arange(6)` with FSDP 2, TP 2.

This PR attempts to reinvent `_StridedShard`.

I didn't test nested `_StridedShard`, because there shouldn't be any use cases. I think it will become quite messy when it comes to **nested uneven** `_StridedShard`. We are probably going to deprecate it anyway after @zpcore 's work https://github.com/pytorch/pytorch/pull/160266 on ordered sharding, so IMO not worth it to make it too general.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163843
Approved by: https://github.com/ezyang
2025-09-25 22:12:29 +00:00
Edward Z. Yang
c261c71f3e Simplify _compute_local_shape_and_global_offset and make it SPMD. (#163344)
There is only one substantive change: the branch on
`global_offset[shard_dim] <= local_offset[shard_dim]`
is removed because it is unnecessary: you can always treat the
first shard uniformly with the rest of the shards, because your
global offset is guaranteed to be zero in this case anyway.

I also switch the shard_size case to sym_ite, to make it possible
for LocalTensor to deal with the MPMD-ness here, but it's equivalent
to the old if-then-else.

I tried to rewrite the comments to be more clear what is going on
algorithmically here.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163344
Approved by: https://github.com/albanD, https://github.com/zpcore, https://github.com/tianyu-l
2025-09-24 02:24:09 +00:00
Xilun Wu
7376111d59 [BE] fix compute_global_tensor_shape test (#161441)
Fixes #161154

**Test**
`pytest  test/distributed/tensor/test_utils.py -s -k test_compute_global_tensor_shape_1D`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161441
Approved by: https://github.com/kwen2501
2025-08-26 03:22:29 +00:00
Xuehai Pan
3f8e2e91ad [BE][15/16] fix typos in torch/ (torch/distributed/tensor/) (#156605)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156605
Approved by: https://github.com/wanchaol, https://github.com/albanD
2025-07-17 12:08:33 +00:00
Dharak Kharod
a78eec88b8 Implement util function compute_global_tensor_shape for 1D device mesh (#152751)
### Summary

Recreating #151990 to mitigate easyCLA failure

compute_global_tensor_shape util function takes in local tensor shape, device mesh
and placements. We all gather the shapes from the shards and according to the placement
type we construct the global shape.

Note: currenty only implemented for placement type Shard and Replicate, TODO for StridedShared

### Test

`pytest test/distributed/tensor/test_utils.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152751
Approved by: https://github.com/XilunWu
2025-05-05 02:44:31 +00:00
Will Constable
c3bc6b3542 [DTensor] Fix empty shard global-offset calculation (#150862)
`compute_local_shape_and_global_offset` util computes the local shape of
a particular shard of a DTensor, and the global offset (which describes
how the shard fits into the global tensor).

When the tensor dim does not evenly divide into the mesh dim, uneven
sharding occurs.  In some cases, uneven sharding results in an empty
shard.

e.g.
   tensor dim size: 4096
   mesh dim size: 30
   ranks 0..27 have local size 18
   rank 28 has local size 8
   rank 29 has local size 0 <--- empty shard

The global offset for an empty shard was previously undefined and
returned values that were computed based on logic that assumes no empty
shards.  This caused DCP to fail to save a checkpoint, becuase
deduplication logic could 'throw away' real (non-empty) shards thinking
they were duplicates of zero-sized shards with the same offset.

Now, we define the global offset of an empty shard to be the dim-size,
which is out of bounds of the tensor and can't overlap with any
non-empty shards.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150862
Approved by: https://github.com/teja-rao, https://github.com/XilunWu
2025-04-11 22:25:57 +00:00
Will Constable
a8b48ff14c [DTensor] clean up _local_shard_size_and_offset (#150650)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150650
Approved by: https://github.com/wanchaol, https://github.com/XilunWu
ghstack dependencies: #150490
2025-04-09 22:07:48 +00:00
Will Constable
3532dd4f1e [DTensor] StridedShard support uneven sharding (#150490)
This enables using FSDP+TP on parameters with dimensions that aren't
evenly divisible by the DP/TP mesh sizes.

- this may not support all possible combinations of strided shardings
  and shardings, but the support before this PR is not complete anyway

This contains several fixes for different aspects of DTensor behavior
relating to uneven strided sharding:
- original creation of the strided tensor requires fixes in
  StridedShard._split_tensor
- full_tensor() reconstruction requries fixes in
  StridedShard._to_replicate_tensor to correctly reshuffle the data into
  the original pre-sharded order
- Distributed Checkpointing support requires correct computation of the
  compute_local_shape_and_global_offset util so it knows how a local
  shard maps to the global tensor, for reconstruction during
  load/reshard.

This PR also adds a util `_explicit_order_placements` which converts a list of
placements with StridedSharding into a list of placements with only
regular sharding, with the order shuffled such that it is equivalent.

Builds on and completes the work started in https://github.com/pytorch/pytorch/pull/148894

Uneven Sharding Example
-------
(copied from _StridedShard._to_replicate_tensor docstring)

mesh = (DP=2, TP=2)
original = torch.arange(5)

**Applying Sharding**

Step 1 - Apply TP sharding
`tp = distribute_tensor(x, world_mesh['tp'], [Shard(0)])`

local_tensors:
rank0: [0,1,2]    rank1: [3,4]
rank1: [0,1,2]    rank3: [3,4]

Step 2 - Apply FSDP sharding
`dp_tp = ...` (the process of creating a strided-shard tensor is skipped over as it is hacky and complicated)
dp_tp has placement (_StridedShard(0, split_factor=2), Shard(0))
local_tensors:
rank0: [0,1]  rank1: [3]
rank1: [2]    rank3: [4]

**Reconstructing the Full Tensor**
Now, say someone wants to reconstruct dp_tp's full tensor. This will invoke 'redistribute' to replicate.
redistribute will first replicate the "Shard(0)" placement on the rightmost mesh dim, then replicate the
StridedShard placement second, which is implemented by this function.
So our starting point (`local_tensor` arg) is the result of replicating the Shard(0) placement across the
TP dim, which looks like this.

Note the discrepancy with the 'tp sharded tensor' line above!  We'll fix it by locally shuffling data.

local_tensors:
rank0: [0,1,3]  rank1: [0,1,3]
rank1: [2,4]    rank3: [2,4]

Step 1: replicate over the DP dimension.  Afterwards, each rank can locally sort the values.
  note: we need padding to do this allgather, and we'll need to keep track of the padding amount for later
	local_tensors:
rank0: [0,1,3,2,4]    rank1: [0,1,3,2,4]
rank1: [0,1,3,2,4]    rank3: [0,1,3,2,4]

Step 2: chunk and shuffle values around to account for the wrong order of operations above
and get the original tensor content back

01324#       <- our allgather includes padding, if padding was applied in step 1
01324        <- Remove the padding
013, 24      <- chunk once, 'undoing' the DP allgather
01, 3, 2, 4  <- chunk each chunk, 'undoing' the initial (wrong) TP allgather performed by Shard(0)->Replicate()
012, 34      <- interleave with stride=TP mesh dim size
01234        <- concatenate

Co-authored-by: Luca Wehrstedt <lw@meta.com>
Co-authored-by: Will Constable <whc@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150490
Approved by: https://github.com/wanchaol, https://github.com/XilunWu
2025-04-09 22:07:48 +00:00
Will Constable
c59aaa03ff [DTensor] add _explicit_order_placements util (#150493)
The util converts a list of placements in the traditional DTensor format
(e.g. [_StridedShard(0), Shard(0)], where list position is mesh_dim and sharding
is always applied left-to-right (from dim 0 to higher dims))

to a more explicitly ordered format, also replacing '_StridedShard' with
simple 'Shard' placements in the process.
(e.g. the above becomes [(1, Shard(0)), (0, Shard(0)] where the first
item in the tuple is the mesh_dim and the ordering of the tuples is the
sharding order.

This is useful so far as a helper for fixing local shape computation for
strided sharding in the uneven shape case, in the following PR- but may
also be useful more broadly if we can use explicit orderings to simplify
other parts of DTensor logic.

This skips implementing some combinations of _StridedSharding that are
not currently used in the wild today, but could be supported easily.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150493
Approved by: https://github.com/wanchaol, https://github.com/XilunWu
2025-04-09 16:55:24 +00:00
Xuehai Pan
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
Aaron Orenstein
c95efc37ba PEP585 update - torch/distributed/tensor (#145141)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145141
Approved by: https://github.com/bobrenjc93
2025-01-18 20:01:59 +00:00
bobrenjc93
08be9ec312 Migrate from Tuple -> tuple in torch/distributed (#144258)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144258
Approved by: https://github.com/aorenste
2025-01-10 08:34:54 +00:00
Xilun Wu
de8a8653c0 [dtensor][BE] replace compute_local_shape with compute_local_shape_and_global_offset (#135554)
**Summary**
1. This PR removes the public API `compute_local_shape` and replace its use with the more general API `compute_local_shape_and_global_offset`.
2. To keep `compute_local_shape_and_global_offset` consistent with `compute_local_shape` on empty shards, it now returns local tensor shape `(0,)` for empty shards which is more aligned with DTensor's semantics on non-participating ranks.

**Test**
`pytest test/distributed/_tensor/test_dtensor.py`
`pytest test/distributed/_tensor/test_init.py`
`pytest test/distributed/_tensor/test_tensor_ops.py`

Differential Revision: [D62415591](https://our.internmc.facebook.com/intern/diff/D62415591)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135554
Approved by: https://github.com/tianyu-l, https://github.com/wz337
2024-09-12 06:30:09 +00:00
Wanchao Liang
cfc227ad43 [reland][dtensor] move DTensor to public namespace (#134203)
reland of https://github.com/pytorch/pytorch/pull/133113

I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(

----

Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
2024-09-08 17:08:40 +00:00
PyTorch MergeBot
35f36363ec Revert "[dtensor] move DTensor to public namespace (#133113)"
This reverts commit 2ee6b97464.

Reverted https://github.com/pytorch/pytorch/pull/133113 on behalf of https://github.com/wanchaol due to looks like it break some internal type imports ([comment](https://github.com/pytorch/pytorch/pull/133113#issuecomment-2295670911))
2024-08-19 05:00:19 +00:00
Wanchao Liang
2ee6b97464 [dtensor] move DTensor to public namespace (#133113)
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
  PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
  I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
2024-08-17 05:09:52 +00:00