Commit Graph

79 Commits

Author SHA1 Message Date
Xuehai Pan
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
PyTorch MergeBot
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
Wanchao Liang
b0ef363972 [dtensor] rename _Partial -> Partial for all imports (#127420)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127420
Approved by: https://github.com/awgu
2024-05-29 21:42:40 +00:00
Xuehai Pan
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
Wanchao Liang
a60b06bd2b [dtensor] update public API docs (#127340)
This PR updates the API documentations for the public facing APIs

needs more example for each API but plan to add them in a separate PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127340
Approved by: https://github.com/wz337
ghstack dependencies: #127338, #127339
2024-05-29 05:18:47 +00:00
Wanchao Liang
2c9a420da3 [dtensor] move some modules to private namespace (#127339)
as titled, moving some modules that are mainly for DTensor private usage
to be a private module.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127339
Approved by: https://github.com/awgu
ghstack dependencies: #127338
2024-05-29 05:18:47 +00:00
Wanchao Liang
daf1eb44bc try to fix the warning in distribute_tensor (#125476)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125476
Approved by: https://github.com/albanD, https://github.com/awgu
ghstack dependencies: #125475
2024-05-06 18:59:47 +00:00
PyTorch MergeBot
084d818e71 Revert "try to fix the warning in distribute_tensor (#125476)"
This reverts commit 2b41e1d6fc.

Reverted https://github.com/pytorch/pytorch/pull/125476 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but there are real failures on the PR that sneak in during the log classifier outage ([comment](https://github.com/pytorch/pytorch/pull/125476#issuecomment-2094468740))
2024-05-04 22:25:32 +00:00
Wanchao Liang
2b41e1d6fc try to fix the warning in distribute_tensor (#125476)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125476
Approved by: https://github.com/albanD, https://github.com/awgu
ghstack dependencies: #125475
2024-05-04 05:25:13 +00:00
Brian Hirsh
599a2e25f1 Reland "make sure dynamo doesn't inline DTensor __new__ or __torch_dispatch__ (#123347)" (#125288)
Re-land of https://github.com/pytorch/pytorch/pull/123347.

The original PR broke internal because of a circular import due to importing dynamo in the DTensor code. The new version uses `torch._dynamo_disable` to work around

This reverts commit 9d88339b53.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125288
Approved by: https://github.com/ezyang, https://github.com/yanboliang, https://github.com/yoyoyocmu, https://github.com/anijain2305, https://github.com/fegin
ghstack dependencies: #124398, #124399, #124400
2024-05-01 21:56:01 +00:00
PyTorch MergeBot
9d88339b53 Revert "make sure dynamo doesn't inline DTensor __new__ or __torch_dispatch__ (#123347)"
This reverts commit 63dcb5b0f2.

Reverted https://github.com/pytorch/pytorch/pull/123347 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/123347#issuecomment-2059994989))
2024-04-16 22:08:24 +00:00
Brian Hirsh
63dcb5b0f2 make sure dynamo doesn't inline DTensor __new__ or __torch_dispatch__ (#123347)
Fixes https://github.com/pytorch/pytorch/issues/122459, https://github.com/pytorch/torchtrain/issues/61

Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.

I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.

Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.

This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.

One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.

I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
    out = SubclassConstructor(x)
```

Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling

So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`

All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn                     /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```

I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.

I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123347
Approved by: https://github.com/zou3519
ghstack dependencies: #122502, #122751, #123348
2024-04-15 17:23:20 +00:00
Wanchao Liang
afee5bea92 [dtensor] refactor schema suggestions in output sharding (#122929)
This PR refactors the schema_suggestions in OuputSharding to be a single
OpSchema instead of list of schemas, which in practice we only have one,
for the multiple resharding case we also moved to OpStrategy so there's
no case that needs it to be a list

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122929
Approved by: https://github.com/tianyu-l
2024-04-01 17:39:39 +00:00
Brian Hirsh
e7fa3f7812 AOTDispatch: allow subclasses to correct when we guess metadata of tangents incorrectly (#118670)
This PR is enough to fix https://github.com/pytorch/pytorch/issues/118600.

More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like:

"We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents"

Here, the problem is similar:

"We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass".

This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial).

One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by:

(1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error

(2) In the error message, provide the name of an optional method that the subclass must implement to handle this case:

`def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement.

`__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement.

`__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time.

I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118670
Approved by: https://github.com/ezyang
2024-03-22 23:16:08 +00:00
Wanchao Liang
a26480a4d1 [dtensor] move early return check into redistribute autograd function (#121653)
This PR fixed the bug of redistribute to move early return check into the
redistribute autograd function, so that even though we redistribute the
same placement, the grad_placements from the `to_local` call might be
different, the redistribute backward still need to happen

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121653
Approved by: https://github.com/awgu
2024-03-12 17:37:30 +00:00
Wanchao Liang
242e03ba86 [dtensor] add async_op option to redistribute and some refactor (#121477)
async output option was only available in `full_tensor()` call, but I think it's
generally good to make this option available in the `redistribute` call directly
so that user can control it

This PR adds async_op option to redistribute call, to allow user control
whether to perform tensor redistribution asynchronously or not.

By default we set this to False, this is to follow the semantics of the c10d
collectives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121477
Approved by: https://github.com/wz337
2024-03-09 06:17:23 +00:00
albanD
6791b0c09e Change default torch_function behavior to be disabled when torch_dispatch is defined (take 2) (#120632)
This does not introduce a new test but is tested by checking that all the classes we already have still behave as before now that they don't explicitly disable torch_function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120632
Approved by: https://github.com/ezyang
2024-03-09 01:08:37 +00:00
Wanchao Liang
bc02fca358 [dtensor] to_local backward grad placement passthrough (#121474)
to_local accepts a `grad_placements` if user choose to pass, previously
we enforce the grad_out to be the "same" placement as the current
DTensor for safety.

But I realized that we DO NOT need to enforce this constraint. Why?
backward placement does not need to be the same as fwd tensor placement, this
is already the case for param vs param.grad (i.e. param can be replicate
and grad can be partial), so we should not restrict this to activation
vs activation grad too

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121474
Approved by: https://github.com/awgu, https://github.com/yoyoyocmu, https://github.com/yifuwang
2024-03-08 23:11:49 +00:00
Yeounoh Chung
f7ec984b1b [DTensor][XLA] support XLA backend in distirbute_module API (#121355)
Addresses #92909  cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121355
Approved by: https://github.com/wanchaol
2024-03-08 15:47:33 +00:00
Yeounoh Chung
4f9d4e1ab0 [DTensor][XLA] refactor DTensor _xla API (#113214)
In response to the change pytorch/xla#5776 and #92909

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113214
Approved by: https://github.com/wanchaol
2024-03-07 06:18:05 +00:00
Wanchao Liang
2e50566722 [dtensor] change distribute_module input/output_fn to accept module (#120895)
This is a BC breaking change to distribute_module. The underlying rationle
for this change is that sometimes in the input_fn/output_fn, user would want
to access to the current module for some attributes. This might not be
common enough, but in some cases it's worth to access to the module.

An outstanding use case we want to support is float8, if we want to make
float8 works with the TP API, the input_fn/output_fn of TP parallel
styles would need to get access to the module, where the module might
encapsulates `dynamic_linear.emulate` attribute, that is useful for
input/output casting

Since this is needed for fp8 and DTensor still under prototype release,
I feel it's worth the change and it's better we make the change as
early.

Right now making it a soft BC breaking, which means we maintain BC still
but throw deprecation messages.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120895
Approved by: https://github.com/tianyu-l
2024-03-04 07:22:32 +00:00
Andrew Gu
87fb8b6218 [DTensor] Relaxed to_local requires_grad warning (#118186)
The existing warning in `DTensor.__new__()` checks `if requires_grad != local_tensor.requires_grad:` and warns with:

> To construct DTensor from `torch.Tensor`, it's recommended to use `local_tensor.detach()` and make `requires_grad` consistent.

Calling `local_tensor.detach()` will have the returned `Tensor` have `requires_grad=False`, so the error message refers to the case where `local_tensor.requires_grad is True` but the user passed `requires_grad=False` to `to_local()`.

However, there is the converse case, where `local_tensor.requires_grad is False` but the user passed `requires_grad=True`. In this case, the original `if requires_grad != local_tensor.requires_grad:` check succeeds, and the warning is emitted. However, the warning message does not apply in that case.

This can happen via `_prepare_output_fn` -> `redistribute` -> `Redistribute.forward()`, where `output.requires_grad is False` but it passes `requires_grad=input.requires_grad` which can be `True`.

We should not warn in this case since `Redistribute.forward()` is our own framework code, so I was proposing to relax the warning.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118186
Approved by: https://github.com/XilunWu, https://github.com/wanchaol
ghstack dependencies: #117994
2024-01-25 15:49:32 +00:00
Wanchao Liang
c170fbd309 [dtensor] refactor redistribute and fix uneven sharding redistribution (#115525)
This PR:
- refactors the redistribute implementation logic to make it more
sound, by figuring out the transform informations first and then apply
transformation step by step, we also cache the decisions so that it
could be reuse again
- for uneven sharding, refactor uneven sharding logic, and use a logical
  shape concept for each transform information to fix the uneven sharding
  multi-mesh redistribute bug

fixes https://github.com/pytorch/pytorch/issues/115310

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115525
Approved by: https://github.com/XilunWu
2024-01-22 18:57:44 +00:00
Yue Dong
270ed13e87 [DTensor] Make DTensor from_local backward partial() to replicate() pass through (#115967)
Summary:
This change makes the `DTensor.from_local()` placements in backward pass from `Partial()` to `Replicate()` as pass through for following reasons:
1. When we run backward pass of DTensor.from_local, if the target placement is partial() (i.e. from user manual overwrite code instead of torch_dispatch) we keep the grad as replicate. This is because converting the gradients back to `Partial()` is meaningless.
2. The current div logic will lead to wrong numerical value in the above case.

Test Plan:
**CI**:
CI Tests

**Unit test**:
`buck2 test mode/dev-nosan //caffe2/test/distributed/_tensor:redistribute`
- Passed

**With model training**:
```
# We tested the case where input tensor is manually overwrite as Partial() and
# output tensor manually overwrite to Shard() then to local.

# Before the change: numerical value not correct
Forward pass:
    collective: ReduceScatter
backward pass:
    collective: AllGather + div by process group size

# After the change: div is removed as expected.
Forward pass:
    collective: ReduceScatter
Backward pas:
    collective: AllGather
```

Differential Revision: D52175709

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115967
Approved by: https://github.com/wanchaol
2023-12-19 00:16:10 +00:00
Iris Zhang (PyTorch)
23fa9621e4 [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099) (#115193)
Summary:

Rename _device_mesh.py to device_mesh.py, update all callsites, add documentation.
We created stubs for public class and methods in torch.distributed.device_mesh so that torch.distributed.device_mesh can be imported with or without distributed is available().

Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/115099
Prior to landing, CI signals are all passed. Shipit added the "ci/trunk" label to the PR and DID NOT wait for it and went ahead committing. More context can be found in the reverted PR above.

Test Plan: CI.

Differential Revision: D51861018

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115193
Approved by: https://github.com/fegin
2023-12-08 08:44:32 +00:00
wz337
dacf5d6e92 [DTensor] Remove assert to allow tensor sharding dimension < Shard(x).ndim (#115114)
Consolidated by changes made by @yoyoyocmu. https://www.internalfb.com/diff/D51821717
Remove assert to allow tensor dimension < Shard(x).ndim. With the current padding, we do support this already.

Follow up: we will still need to fix the size mismatch and `full_tensor()` hang when tensor is uneven-sharded.
Created issue here: https://github.com/pytorch/pytorch/issues/115310

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115114
Approved by: https://github.com/yoyoyocmu, https://github.com/wanchaol
2023-12-07 21:57:30 +00:00
Joel Schlosser
22704426c3 Expand dynamic dims support for traceable subclasses (#114311)
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).

Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
    * Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
    * Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
    * Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
    * Signatures now:
    ```python
    # attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
    # ctx is anything useful for rebuilding the class we want to guard on
    attrs, ctx = x.__tensor_flatten__()
    ...
    # inner_tensors is a dict of {attr -> tensor}
    # ctx is taken unmodified from flattening and (eventually) guarded on
    # outer_size is the expected size of the output; possibly symbolic
    # outer_stride is the expected strides of the output; possibly symbolic
    y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)

    # at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
    # the assert simplifies symbols when there are relationships between outer and inner symbols
    ```
    * Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
    * Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
    * Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
2023-12-05 21:09:25 +00:00
Nikita Shulga
a827ac71f2 Revert "[DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099)"
This reverts commit eaa64339d6.
2023-12-05 08:59:36 -08:00
Iris Zhang (PyTorch)
eaa64339d6 [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#115099)
Summary:
Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.

Original diff reverted: D51629761
Original PR reverted: https://github.com/pytorch/pytorch/pull/114991
It was failing because failing a public module binding tests in MacOS, and this is due to the change in import order for torch/distributed/fsdp/_common_utils.py. Since this original import would still work, we remove the changes in this file.

Test Plan: CI.

Differential Revision: D51825114

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115099
Approved by: https://github.com/wanchaol, https://github.com/fegin
2023-12-05 05:44:52 +00:00
PyTorch MergeBot
3a2e2044cd Revert "[DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#114710) (#114991)"
This reverts commit 729ac7317a.

Reverted https://github.com/pytorch/pytorch/pull/114991 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/114991#issuecomment-1837214567))
2023-12-02 17:55:51 +00:00
Iris Zhang (PyTorch)
729ac7317a [DeviceMesh] Rename _device_mesh.py to device_mesh.py to prepare for beta (#114710) (#114991)
Summary:

Same content of changes as https://github.com/pytorch/pytorch/pull/114710

Rename _device_mesh.py to device_mesh.py, update all callsites, adds documentation.
ghstack-source-id: 208980207
exported-using-ghexport

Test Plan: CI.

Reviewed By: wanchaol

Differential Revision: D51629761

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114991
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/fegin
2023-12-02 04:39:41 +00:00
Andrew Gu
c39c69953f [DTensor] Used new placements for neg dim in distribute_tensor (#113930)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113930
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925
2023-11-20 22:32:58 +00:00
Andrew Gu
e2095a04ae [DTensor] Ensured grad_placements was tuple (#113925)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113925
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134
2023-11-20 22:32:58 +00:00
Andrew Gu
f4ffd46c08 [DTensor] Used new placements for neg dim in from_local (#114134)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114134
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924
2023-11-20 22:32:51 +00:00
Andrew Gu
b41ad7d695 [DTensor] Used new placements for neg dim in redistribute (#113924)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113924
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919
2023-11-20 22:30:16 +00:00
Wanchao Liang
b16e3b5373 [funcol] add two APIs: wait() and numpy() (#113323)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113323
Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/wconstab
2023-11-14 09:27:45 +00:00
Wanchao Liang
6ed20af10e [dtensor] refactor op dispatch and fix is_same_size/equal (#112927)
torch.equal/is_same_size currently skips sharding prop and directly do
local tensor compute, this is wrong. for these two ops:

- torch.equal: should not skip sharding prop, need to have two DTensor
have the SAME sharding before compare local shard values
- torch.is_same_size: need to completely skip both sharding prop and
local compute

This PR refactors the existing op_dispatch to make it a class instance
so that we can do custom op handling, then fixes both torch.equal and
torch.is_same_size

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112927
Approved by: https://github.com/fduwjj, https://github.com/XilunWu
2023-11-13 22:46:31 +00:00
Wanchao Liang
9834fb7fd0 [dtensor] full_tensor to return synchronously (#113322)
full_tensor API should return synchronously instead of
AsyncCollectiveTensor and if the return is that, we do the wait
directly, this makes the full_tensor API be more percise
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113322
Approved by: https://github.com/wz337
2023-11-09 18:02:40 +00:00
Iris Zhang
9af3f98faf [DTensor] Fix DTensor.from_local() returns DTensor with wrong size for uneven sharded tensor (#110781)
Fixes #110762

This PR:
fixes issue described in #110762 by adding kwarg for shape and stride when creating DTensor using `DTensor.from_local()`. When `shape` and `stride` are provided, we skip calcualtion for `tensor_shape` and `tensor_stride` using `compute_global_tensor_info()`, as `compute_global_tensor_info()` always assume even sharding.

Test plan:
```
python3 test/distributed/_tensor/test_dtensor.py -k test_from_local_uneven_sharding
python3 test/distributed/_tensor/test_dtensor.py -k test_from_local_uneven_sharding_raise_error
```

cc. @wanchaol
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110781
Approved by: https://github.com/wanchaol
2023-11-04 11:21:10 +00:00
Wanchao Liang
2f09da3a21 [dtensor] Introduce full_tensor API to DTensor (#112224)
This PR introduces a `full_tensor` API to DTensor, there were so many
callsites that exercises the `redistribute(replicate)` path and I feel
it deserves a separate API, mostly just a syntactic sugar
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112224
Approved by: https://github.com/wz337
2023-10-31 00:44:09 +00:00
Iris Zhang
12c1465d76 [DeviceMesh] Make mesh_resources private (#112294)
This is to prepare moving DeviceMesh as a standalone distributed package.

`_mesh_resources` should only be used in torch.distributed package.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112294
Approved by: https://github.com/fegin
2023-10-28 17:28:46 +00:00
Wanchao Liang
61461f39d1 [dtensor] handle negative dim and fix TP regression (#111750)
TP style still have some regression due to negative dim specifications,
fix it by allow DTensor API to handle negative dims and normalize them.

i.e. TP uses `Shard(-1)`, and then try to redistribute `Shard(1) -> Shard(-1)`, this should ideally be no-op but current it runs a decompose sharding phrase and it would turn this transformation to `Shard(1) -> Replicate -> Shard(-1)`, which is wrong and triggers unnecessary allgathers
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111750
Approved by: https://github.com/rohan-varma
2023-10-22 04:25:45 +00:00
Wanchao Liang
1d291e1f19 [dtensor] hide xla imports to avoid warning (#111751)
xla imports throw warnings about xla not imported and we should only
import xla when needed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111751
Approved by: https://github.com/rohan-varma
2023-10-22 04:09:10 +00:00
Yeounoh Chung
8376079b97 [DTensor][XLA] Support Xla backend in distribute_tensor API (#110275)
This addresses #92909 , and enable XLA backend support for `distribute_tensor` API.

Test plan: added a unit test case & tested with CloudTPU. The CI should skip this unless it's a XLA workflow.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110275
Approved by: https://github.com/wanchaol, https://github.com/alanwaketan, https://github.com/JackCaoG
2023-10-21 01:17:15 +00:00
fduwjj
fdc29f58c6 [TP] Refactor style to make it work with torch.compile (#111625)
We are refactoring parallel style to solve the following things:
1. To further simplifying code logic to make more readable for users.
2. To remove tuple check so that we can work with dynamo for now. Ideally dynamo needs to support this case and we will fix it in parallel.
3. Add tests for newly added parallel style in UT and torch compile test so that we can capture regression due to code change.
4. Move placements early return check into DTensor since it is by passed by dynamo.
5. Remove PairwiseParallelStyle from unit tests to use the new Col/Rowwise parallel style.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111625
Approved by: https://github.com/wanchaol
2023-10-20 19:20:43 +00:00
Brian Hirsh
4d29b40299 torch.compile DTensor E2E (#105236)
This PR updates DTensor to support torch.compile

Cool stuff: there are some new tests in `test_dtensor.py` that show both the forward and backward graphs that we can send to inductor, when running a matmul with DTensor's. In particular, for this user code:
```
        def fn(x, y):
            dt = DTensor.from_local(x.reshape(2, 4), mesh, [Shard(0)], run_check=False)
            dt2 = DTensor.from_local(y.reshape(4, 2), mesh, [Shard(1)], run_check=False)
            dt_out = torch.matmul(dt, dt2)
            dt_out_redistribute = dt_out.redistribute(mesh, [Replicate()])
            return dt_out.to_local()
```

We generate the following fw and backward graphs.

Forward graph:
```
def forward(self, primals_1, primals_2):
    view = torch.ops.aten.view.default(primals_1, [2, 4]);  primals_1 = None
    _to_copy = torch.ops.aten._to_copy.default(view, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0));  view = None
    detach = torch.ops.aten.detach.default(_to_copy);  _to_copy = None
    detach_1 = torch.ops.aten.detach.default(detach);  detach = None
    view_1 = torch.ops.aten.view.default(primals_2, [4, 2]);  primals_2 = None
    _to_copy_1 = torch.ops.aten._to_copy.default(view_1, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0));  view_1 = None
    detach_2 = torch.ops.aten.detach.default(_to_copy_1);  _to_copy_1 = None
    detach_3 = torch.ops.aten.detach.default(detach_2);  detach_2 = None
    detach_4 = torch.ops.aten.detach.default(detach_1)
    all_gather_into_tensor = torch.ops.c10d_functional.all_gather_into_tensor.default(detach_3, 'ptd:0', [0, 1], 2)
    wait_tensor = torch.ops.c10d_functional.wait_tensor.default(all_gather_into_tensor);  all_gather_into_tensor = None
    split = torch.ops.aten.split.Tensor(wait_tensor, 4);  wait_tensor = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    cat = torch.ops.aten.cat.default([getitem, getitem_1], 1);  getitem = getitem_1 = None
    detach_5 = torch.ops.aten.detach.default(cat);  cat = None
    mm = torch.ops.aten.mm.default(detach_4, detach_5);  detach_4 = detach_5 = None
    detach_6 = torch.ops.aten.detach.default(mm);  mm = None
    detach_9 = torch.ops.aten.detach.default(detach_6);  detach_6 = None
    detach_10 = torch.ops.aten.detach.default(detach_9);  detach_9 = None
    t = torch.ops.aten.t.default(detach_1);  detach_1 = None
    detach_13 = torch.ops.aten.detach.default(t);  t = None
    t_1 = torch.ops.aten.t.default(detach_3);  detach_3 = None
    detach_15 = torch.ops.aten.detach.default(t_1);  t_1 = None
    clone = torch.ops.aten.clone.default(detach_15, memory_format = torch.contiguous_format);  detach_15 = None
    return [detach_10, detach_13, clone]
```

Backward graph:
```
def forward(self, detach_13, clone, tangents_1):
    detach_11 = torch.ops.aten.detach.default(tangents_1);  tangents_1 = None
    detach_12 = torch.ops.aten.detach.default(detach_11);  detach_11 = None
    mm_1 = torch.ops.aten.mm.default(detach_13, detach_12);  detach_13 = None
    detach_14 = torch.ops.aten.detach.default(mm_1);  mm_1 = None
    detach_16 = torch.ops.aten.detach.default(detach_12);  detach_12 = None
    all_gather_into_tensor_2 = torch.ops.c10d_functional.all_gather_into_tensor.default(clone, 'ptd:0', [0, 1], 2);  clone = None
    wait_tensor_2 = torch.ops.c10d_functional.wait_tensor.default(all_gather_into_tensor_2);
    detach_17 = torch.ops.aten.detach.default(wait_tensor_2);  wait_tensor_2 = None
    mm_2 = torch.ops.aten.mm.default(detach_16, detach_17);  detach_16 = detach_17 = None
    detach_18 = torch.ops.aten.detach.default(mm_2);  mm_2 = None
    split_1 = torch.ops.aten.split.Tensor(detach_14, 2, 1);  detach_14 = None
    getitem_2 = split_1[0]
    getitem_3 = split_1[1];  split_1 = None
    cat_1 = torch.ops.aten.cat.default([getitem_2, getitem_3]);  getitem_2 = getitem_3 = None
    reduce_scatter_tensor = torch.ops.c10d_functional.reduce_scatter_tensor.default(cat_1, 'SUM', 'ptd:0', [0, 1], 2);  cat_1 = None
    wait_tensor_3 = torch.ops.c10d_functional.wait_tensor.default(reduce_scatter_tensor);  reduce_scatter_tensor = None
    detach_19 = torch.ops.aten.detach.default(wait_tensor_3);  wait_tensor_3 = None
    detach_20 = torch.ops.aten.detach.default(detach_19);  detach_19 = None
    detach_21 = torch.ops.aten.detach.default(detach_20);  detach_20 = None
    detach_22 = torch.ops.aten.detach.default(detach_21);  detach_21 = None
    _to_copy_2 = torch.ops.aten._to_copy.default(detach_22, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  detach_22 = None
    view_2 = torch.ops.aten.view.default(_to_copy_2, [8]);  _to_copy_2 = None
    detach_23 = torch.ops.aten.detach.default(detach_18);  detach_18 = None
    detach_24 = torch.ops.aten.detach.default(detach_23);  detach_23 = None
    _to_copy_3 = torch.ops.aten._to_copy.default(detach_24, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  detach_24 = None
    view_3 = torch.ops.aten.view.default(_to_copy_3, [8]);  _to_copy_3 = None
    return [view_3, view_2]
```

Some of the stuff in this graph looks kinda of silly though (e.g. an unnecessary split() + cat(), and all the extra detach() calls).

Stuff that's broken:
- functionalization is pretty horribly broken. In particular, the original strategy I used in this stack was to have functionalization run **above** subclass desugaring. But that doesn't play well with with the way we want to compile DTensor. DTensor has a few API's like `.redistribute()`, `.to_local()`, and the `DTensor()` constructor, that we want to put directly into the graph so that we can compile them (e.g. redistribute() will desugar into collective ops). Doing this requires functionalization to run **underneath** the subclass though. I hacked around this for now, by forcing these functions to run functionalization first if they need to.
- the backward test that I have is... wrong. The backward graph that we trace out looks kind of reasonable, but it gives incorrect gradients on one of the two inputs. This needs further debugging (presumably we should be able to stare at the graph and identify which part of it is wrong?).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105236
Approved by: https://github.com/wanchaol
2023-10-11 21:55:27 +00:00
Wanchao Liang
2a76c7f018 [dtensor] skip move to device when device_type match (#110774)
skip tensor.to in from_local and distribute_tensor when device_type of
device mesh matches tensor.device type, since from_local on the critial
path of TP, this might also reduce some overhead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110774
Approved by: https://github.com/fduwjj
2023-10-09 19:39:11 +00:00
fduwjj
2dc5e166a5 [TP][Inference] Enable DTensor TP inference (#110751)
In https://github.com/pytorch/pytorch/pull/109977, we observed that during inference mode, aten.Linear does not get decomposed. So instead of enabling sharding propagation for linear op, we use func.decompose so that it gets decomposed to matmul and mm.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110751
Approved by: https://github.com/bdhirsh, https://github.com/wanchaol
2023-10-07 18:57:27 +00:00
Wanchao Liang
c95cf4b4c9 [dtensor] add grad placements kwarg to to_local API (#110629)
When we convert to local tensor, dtensor can't track autograd or
gradient layout of the local tensor anymore, if user do sth not expected, there
needs to be a way for user to hint about the gradient layout of the
local tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110629
Approved by: https://github.com/zdevito
2023-10-05 21:34:01 +00:00
Wanchao Liang
9456de937b [dtensor] Fix and improve the sharding cache behavior (#109306)
resolves https://github.com/pytorch/pytorch/issues/109101

The problem is essentially because we were hashing all the arguments, including
the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might
change everytime we call the op, thus cache miss everytime we call the op

This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:
* only hash arguments that are tensor or have static_argnum, this is to
enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
as we currently hashing all args which exclude those cases
* with the correct cache behavior, optimizers will hit the cache again
and resolve the high cpu overhead issue.

simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements:
<img width="1172" alt="Screenshot 2023-09-14 at 11 06 07 AM" src="https://github.com/pytorch/pytorch/assets/9443650/3406d673-dd8d-4ad9-9b80-9d4721c430e3">

Adam optimizer shows aten.div hit sharding cache again
<img width="1016" alt="Screenshot 2023-09-14 at 11 02 10 AM" src="https://github.com/pytorch/pytorch/assets/9443650/4280e8e3-af44-4fc2-8360-ea80b768f1d9">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109306
Approved by: https://github.com/fduwjj
2023-09-15 10:32:49 +00:00