Commit Graph

33 Commits

Author SHA1 Message Date
Xilun Wu
cc71ab86a6 [DTensor] raise error if the local_tensor argument passed to DTensor.from_local is a DTensor (#164496)
**Summary**
Raise error when the `local_tensor` argument passed to `DTensor.from_local` is
a DTensor, this prevents users from accidentally calling `from_local` over a DTensor
object.

The error message is organized in this way:
```
the local_tensor argument only accepts torch.Tensor but got <class 'torch.distributed.tensor.DTensor'> value.
```

**Test**
`pytest test/distributed/tensor/test_dtensor.py -k test_from_local`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164496
Approved by: https://github.com/ezyang
2025-10-02 21:25:01 +00:00
Sherlock Huang
60a4961ff4 [DTensor] Allow redistribute to Partial if src matches (#164253)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164253
Approved by: https://github.com/zpcore
2025-09-30 22:42:49 +00:00
Yuanyuan Chen
da003d7b95 [3/N] Import Callable from collections.abc in torch/distributed (#164104)
This is the result of applying the ruff `UP035` check.
`Callable` is imported from `collections.abc` instead of `typing`.
This PR is the follow-up of #164054.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164104
Approved by: https://github.com/Skylion007
2025-09-30 00:28:53 +00:00
Chien-Chin Huang
dfda239cce [DTensor] Raise an RuntimeError when checkpointing APIs are used with Partial placement (#163941)
A DTensor that contains partial placement shouldn't be checkpointed (DCP.save) -- the result is not correct and DCP doesn't know how to handle it.

There are several APIs that are only used by checkpointing, e.g.,`__create_write_items__`. These APIs should raise an exception if the DTensor, `self`, has Partial placement.

Ideally, we want to add the following test:

```
        with self.assertRaisesRegex(
            RuntimeError, "Any checkpointing related operations are not supported for"
        ):

            dcp.save({"dtensor": dtensor}, checkpoint_id=tempfile.gettempdir())
```

While we do see the RuntimeError is raised, the error was raised in another thread due to DTensor checkpoint APIs are called by DCP in a separate thread, which assertRaisesRegex cannot capture.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163941
Approved by: https://github.com/tianyu-l
2025-09-27 19:50:16 +00:00
Scott Wolchok
5599f487ef Fully native DTensor.__new__ (#162508)
Move the entirety of `__new__` into C++, saving a layer of disable_dynamo and making progress toward all-C++.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162508
Approved by: https://github.com/ezyang
ghstack dependencies: #161695
2025-09-21 18:36:05 +00:00
PyTorch MergeBot
7dd5f7b125 Revert "python fastpath for DTensor detach(), confirm that aliasing DTensorSpec is ok (#160580)"
This reverts commit 4b2d297eec.

Reverted https://github.com/pytorch/pytorch/pull/160580 on behalf of https://github.com/bdhirsh due to this broke shampoo, yanking ([comment](https://github.com/pytorch/pytorch/pull/160580#issuecomment-3287372891))
2025-09-13 02:04:36 +00:00
Brian Hirsh
4b2d297eec python fastpath for DTensor detach(), confirm that aliasing DTensorSpec is ok (#160580)
My goal right now is to try to make the "vanilla" AccumulateGrad path for DTensor (that just calls detach) fast. I'm doing this in two steps:

(1) [this PR]: hardcode aten.detach in DTensor to re-use the input tensor's DTensorSpec, instead of running "real" sharding prop.

(2) [assuming success of 1]: move the detach() call into C++, try adding a DTensor dispatch key, and avoid dispatching back to python entirely (except for some code that probably needs to allocate a pyobject for the output DTensor, from C++)

I'm pushing this PR first to confirm that I don't break anything with my detach fastpath. I did some manual local testing to confirm that for normal usages of detach, the input and output DTensor have equal DTensorSpec objects. Technically, we previously would allocate a fresh DTensorSpec, and with this change we are just re-using the input tensor's DTensorSpec. So I'm mostly hoping that DTensorSpecs don't generally get mutated

This by itself does seem to speed up `alias` by quite a bit (roughly 2.5x speedup, from ~336us -> 133us):

**aten.detach(plain_tensor)**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f8da2921790>
_ = x.detach()
  4.80 us
  1 measurement, 100000 runs , 1 thread
```

**aten.detach(DTensor) [before this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f47cd68e750>
_ = x_dt.detach()
  336.40 us
  1 measurement, 1000 runs , 1 thread
```

**aten.detach(DTensor) [after this PR]**
```
<torch.utils.benchmark.utils.common.Measurement object at 0x7f0a34c05520>
_ = x_dt.detach()
  Median: 133.45 us
  2 measurements, 1000 runs per measurement, 1 thread
```

benchmark script:
```
import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.utils.benchmark as benchmark

fake_store = FakeStore()
dist.init_process_group("fake", store=fake_store, rank=0, world_size=2)

mesh = torch.distributed.device_mesh.init_device_mesh('cuda', (2,))
x = torch.randn(4, 4, requires_grad=True)
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)

t0 = benchmark.Timer(
    stmt='_ = x_dt.detach()',
    globals={'x_dt': x_dt},
)
print(t0.blocked_autorange())

dist.destroy_process_group()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160580
Approved by: https://github.com/ezyang
2025-09-09 18:04:56 +00:00
Scott Wolchok
88d94d17e8 Add torch.Tensor._make_dtensor to accelerate DTensor.__new__ further (#161590)
This seems to be a (very very roughly) ~8% improvement on DTensor benchmark very similar to the benchmark from #160580 (120ish usec -> 110ish usec)

Differential Revision: [D81530105](https://our.internmc.facebook.com/intern/diff/D81530105)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161590
Approved by: https://github.com/albanD
ghstack dependencies: #161466, #161586
2025-09-05 18:43:41 +00:00
Xuehai Pan
3f8e2e91ad [BE][15/16] fix typos in torch/ (torch/distributed/tensor/) (#156605)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156605
Approved by: https://github.com/wanchaol, https://github.com/albanD
2025-07-17 12:08:33 +00:00
Edward Z. Yang
e2f64eedaf Fix DTensor handling of conjugate bit. (#158030)
Fixes https://github.com/pytorch/pytorch/issues/130646 specifically for DTensor

Fixes https://github.com/pytorch/torchtitan/issues/267

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158030
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2025-07-10 18:28:12 +00:00
Aaron Orenstein
e95e8eed0a mypy 1.16.0 (#155821)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155821
Approved by: https://github.com/ezyang, https://github.com/zou3519
2025-06-14 18:18:43 +00:00
zpcore
50d8168c8b [DTensor] Support in gradient placement for local_map() (#155181)
Support `in_grad_placements` argument in torch.distributed.tensor.experimental.local_map().  The argument helps enforce placement of gradient of the input Dtensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155181
Approved by: https://github.com/wanchaol
2025-06-12 17:07:04 +00:00
Yuanhao Ji
0a7eef140b Add torch.Tensor._make_wrapper_subclass to torch/_C/__init__.pyi (#154022)
Fixes #153790

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154022
Approved by: https://github.com/Skylion007
2025-05-27 14:10:00 +00:00
Xilun Wu
cbb03e6971 [BE][DTensor] move torch.distributed._tensor import to torch.distributed.tensor in test files (#153225)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153225
Approved by: https://github.com/kwen2501, https://github.com/fegin
2025-05-09 20:40:54 +00:00
Ruisi Zhang
1c5619ef9c [DTensor] Add DTensor redistribute fwd/bwd datatype conversion to enable SimpleFSDP mixed precision training (#150740)
As titled, this pr adds additional `forward_dtype` and `backward_dtype` conversion in DTensor `redistribute` API to enable SimpleFSDP's mixed precision training.

In this forward pass, the DTensor can be configured to be cast to `forward_dtype`; in the backward pass, the DTensor can be configured to be cast to `backward_dtype`.

1. **Correctness**: The end-to-end SimpleFSDP mixed precision training integration has been proved to work properly in the PR from this fork: https://github.com/tianyu-l/pytorch_intern24/pull/20. We are now migrating the code to official PyTorch DTensor.

2. **Example Usage**: There is an example in TorchTian's SimpleFSDP implementation: https://github.com/pytorch/torchtitan/pull/1060.

In the example below, a DTensor `x` is all-gather'ed along the `self.compute_placements`, with datatype cast to `self.param_dtype`. In the backward pass, additionally, the computed gradients are reduce-scatter'ed along the `self.grad_placements`, with datatype cast to `self.reduce_dtype`.

```python
output = x.redistribute(
        placements=self.compute_placements,
        forward_dtype=self.param_dtype,
        backward_dtype=self.reduce_dtype,
).to_local(grad_placements=self.grad_placements)
```

Under the hood, in `class Redistribute(torch.autograd.Function):`, the `forward` function first takes `x`'s local tensor, convert it to `forward_dtype`, before all-gather `x`.

The `backward` function take `grad_output` and convert it to `backward_dtype`, before reduce-scatter `grad_output`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150740
Approved by: https://github.com/tianyu-l
2025-04-13 05:49:03 +00:00
Tugsbayasgalan Manlaibaatar
6b1b95ad2a Support subclass constructor capturing in export (#147014)
Notable TODOs:
1. Need to implement AutogradHOP to get rid of subclasses before serializing
2. Need to implement mechanism to figure out what subclasses will be used in export when they are not expressed in the inputs

Differential Revision: [D69640673](https://our.internmc.facebook.com/intern/diff/D69640673)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147014
Approved by: https://github.com/bdhirsh
2025-03-16 18:19:19 +00:00
Xuehai Pan
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
Xilun Wu
ef61c290e1 [DTensor][random] defer DTensor RNG state sync until first random op call or manual_seed call; support more flexible OffsetBasedRNGTracker init (#147025)
Resolves https://github.com/pytorch/pytorch/issues/146767.

May also resolve https://github.com/pytorch/pytorch/issues/147584.

### Summary
This PR removes the RNG tracker init from the `distribute_tensor` call for the following reasons:

1. if the user does not use random ops on DTensor, there's no need to init DTensor RNG which currently requires CUDA device to be present.
2. this complies with the 0-communication semantic of `src_data_rank=None` shard distribution.

Besides, `OffsetBasedRNGTracker` only accepts `DeviceMesh` argument to its constructor method.

### Consequence

DTensor RNG initialization is delayed till the first DTensor random ops call or `torch.distributed.tensor.random.manual_seed`.

### Test
`pytest test/distributed/tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`
`pytest test/distributed/tensor/parallel/test_tp_style.py`

Differential Revision: [D70201856](https://our.internmc.facebook.com/intern/diff/D70201856)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147025
Approved by: https://github.com/kwen2501
2025-02-26 17:33:22 +00:00
Ke Wen
4879f8f919 [TP] Add warning when module is distributed twice (#147006)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147006
Approved by: https://github.com/XilunWu
2025-02-13 06:49:17 +00:00
Aaron Orenstein
c95efc37ba PEP585 update - torch/distributed/tensor (#145141)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145141
Approved by: https://github.com/bobrenjc93
2025-01-18 20:01:59 +00:00
bobrenjc93
08be9ec312 Migrate from Tuple -> tuple in torch/distributed (#144258)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144258
Approved by: https://github.com/aorenste
2025-01-10 08:34:54 +00:00
Wanchao Liang
b1c2c3967a [dtensor] deprecate _shard_tensor to use src_data_rank=None (#144171)
as titled, we can achieve no comm sharding for the inference case with
src_data_rank=None, so deprecate the private APi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144171
Approved by: https://github.com/awgu
2025-01-09 22:26:45 +00:00
Wanchao Liang
eb7a303d21 [dtensor] expose the __create_chunk_list__ in the doc (#144100)
as titled, this PR expose this dunder method as a public API in the doc,
so that different checkpoint implementations can leverage this protocol,
instead of exposing a separate API

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144100
Approved by: https://github.com/awgu
ghstack dependencies: #144099
2025-01-03 20:06:23 +00:00
Wanchao Liang
48a05ee773 [dtensor] improve doc of the DTensor class (#144099)
as titled: explicitly list all public members to make sure the public
API stays consistent, also use groupwise as the member order to make doc
look better

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144099
Approved by: https://github.com/awgu
2025-01-03 05:35:44 +00:00
Wanchao Liang
f242dbb76f [dtensor] add src_data_rank to distribute_tensor API (#143883)
As titled, this PR add a kwarg src_data_rank to the distribute_tensor
API, to allow user specify a specific rank as the full tensor source
data. Previously we by default specify group_rank=0 as the source of
truth for single device semantic, this new option:

* gives advanced user flexiblity to choose the source data rank
* allow user to specify None explicity, which means we will skip the
  communications needed (scatter/broadcast) for the cases that does not
care about single device semantic (i.e. loading from a checkpoint)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143883
Approved by: https://github.com/XilunWu, https://github.com/tianyu-l
2025-01-02 05:35:52 +00:00
Ke Wen
a58d2f14e8 [DTensor] Add a private util for sharding tensor (#142288)
Locally shards a full tensor based on indicated sharding arrangement, and returns a DTensor containing the local shard.

warning: This is a private API purposed to skip the communication otherwise required by `distribute_tensor`. It is only applicable to a case where all ranks have the same `full_tensor`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142288
Approved by: https://github.com/wz337
2024-12-07 05:30:18 +00:00
IvanKobzarev
781c68c865 [aotd] coerce_same_metadata_as_tangent with expected_type for e.g.AsyncCollectiveTensor (#139095)
Based on discussion here: https://github.com/pytorch/pytorch/pull/138731

Introducing ability for subclass implement type convertion to expected_type.
```
    def __coerce_same_metadata_as_tangent__(
        self, expected_metadata: Any, expected_type: Optional[Type] = None
    ):
```
Here if `expected_type=None` means `SubclassClass` is expected.

E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case
`expected_type=Tensor` will be called during runtime

Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139095
Approved by: https://github.com/bdhirsh
2024-11-07 16:24:48 +00:00
zeshengzong
e374d6850a [distributed][test] Remove unused variable and fix doc typo (#136943)
Refactor distributed test code:
- Fix TODO: Remove unused variable
- Fix doc typo
- Migrate deprecated method call `load_state_dict` and `save_state_dict`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136943
Approved by: https://github.com/H-Huang
2024-10-02 08:31:53 +00:00
Aaron Gokaslan
31715be72a [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-16 19:44:11 +00:00
PyTorch MergeBot
3117f2cf67 Revert "[BE]: Update mypy to 1.11.2 (#133816)"
This reverts commit 55299cfc22.

Reverted https://github.com/pytorch/pytorch/pull/133816 on behalf of https://github.com/jeanschmidt due to seems to have broken https://github.com/pytorch/pytorch/actions/runs/10865710499/job/30155699792 on main ([comment](https://github.com/pytorch/pytorch/pull/133816#issuecomment-2352377684))
2024-09-16 09:11:16 +00:00
Aaron Gokaslan
55299cfc22 [BE]: Update mypy to 1.11.2 (#133816)
Updates mypy to 1.11.1 to improve type inference

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133816
Approved by: https://github.com/ezyang
2024-09-14 21:40:36 +00:00
Xilun Wu
de8a8653c0 [dtensor][BE] replace compute_local_shape with compute_local_shape_and_global_offset (#135554)
**Summary**
1. This PR removes the public API `compute_local_shape` and replace its use with the more general API `compute_local_shape_and_global_offset`.
2. To keep `compute_local_shape_and_global_offset` consistent with `compute_local_shape` on empty shards, it now returns local tensor shape `(0,)` for empty shards which is more aligned with DTensor's semantics on non-participating ranks.

**Test**
`pytest test/distributed/_tensor/test_dtensor.py`
`pytest test/distributed/_tensor/test_init.py`
`pytest test/distributed/_tensor/test_tensor_ops.py`

Differential Revision: [D62415591](https://our.internmc.facebook.com/intern/diff/D62415591)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135554
Approved by: https://github.com/tianyu-l, https://github.com/wz337
2024-09-12 06:30:09 +00:00
Wanchao Liang
cfc227ad43 [reland][dtensor] move DTensor to public namespace (#134203)
reland of https://github.com/pytorch/pytorch/pull/133113

I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(

----

Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
2024-09-08 17:08:40 +00:00