Commit Graph

29 Commits

Author SHA1 Message Date
Wanchao Liang
09f3e08bcc [dtensor][3/n] use dedicated TensorMeta instead of the fx one (#108261)
This PR switches the usage of fx's shape prop TensorMetadata to
dtensor's own dedicated defined TensorMeta, this is because DTensor
only cares three fields: shape/stride/dtype, all other fields are not
necessary and can be inferred from local_tensor directly. This would
help significantly simplify how we deal with the tensor metadata by not
caring other fields.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108261
Approved by: https://github.com/fduwjj
ghstack dependencies: #107306
2023-09-13 04:08:02 +00:00
Brian Hirsh
5efd63b1b8 better support for fakeifying and dynamoing through torch_dispatch subclasses (with dynamic shapes) (#107415)
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:

(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests

(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.

(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415
Approved by: https://github.com/ezyang
2023-08-29 02:36:48 +00:00
Wanchao Liang
945fa7e8a8 [dtensor] fix requires_grad in distribute_tensor (#107606)
This PR fixes the requires_grad set when calling distribute_tensor, we
should set the requires_grad of the local tensor after the detach call
to make sure we create the leaf correctly, otherwise it would raise
warnings
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107606
Approved by: https://github.com/fduwjj
2023-08-22 23:08:13 +00:00
Wanchao Liang
d8f2ef10a6 [dtensor][1/n] refactor op dispatch logic to reduce overhead (#107305)
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms
![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76)

after (with this change), aten.addmm latency: 0.341ms
![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f)

overall one layer of mlp time reduced from 13.535 -> 9.665ms

Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
2023-08-18 18:30:46 +00:00
fduwjj
4a6ca4cc05 [TP][DTensor Perf] Some perf improvement to reduce DTensor CPU overhead (#106524)
By inspecting a small TP benchmark, we found couple things we can optimize:
1. We call deep_copy so many times when we initialize DTensor.
2. Some shading_prop is not cached successfully.
3. We are still calling redistribute when not necessary.

![image](https://github.com/pytorch/pytorch/assets/6937752/b847d110-eea1-45df-9298-066d0ba07dd7)

![image](https://github.com/pytorch/pytorch/assets/6937752/fc08f564-caed-496b-80d7-275c1dba3806)

![image](https://github.com/pytorch/pytorch/assets/6937752/fdc06cc4-a4ba-48e8-a118-c041bbd04f5e)

So we want to:
1. Remove the deep_copy, and we now make placements a tuple so we are sure it's immutable.
2. Somehow the op_schema gets changed during sharding_op propogation, so we store a hash version of it before passing it to sharding_prop. Ideally we want to figure out why `op_schema` gets changed, but looks like in both index and detach/view op, all get changed, it might take more time to debug.
3. Also when we do hashing of op_schema, we want to hash the entire args_schema not just the args_spec which only contains the DTensorSpec from args which are Dtensors.
4. It turns out that sometimes, DTensor has mem_format to be None (not contiguous) and this will lead to redistribute get triggered, so that we only need to compare type/shape and stride in the metadata.

Also we need to ensure _Partial and Shard have different hash value in the DTensorSpec.

![image](https://github.com/pytorch/pytorch/assets/6937752/321e6890-1ab6-4975-adc9-524c6ef9a76b)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106524
Approved by: https://github.com/wanchaol
2023-08-14 20:03:19 +00:00
Wanchao Liang
c9cbcb2449 [device_mesh] move remaining collectives to a separate file (#107012)
Move the remaining collectives to a separate file to prepare device mesh
to become a public distributed API

For those remaining utils, we need to upstream them to functional
collectives with proper implementation, added TODO there for a follow up
PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107012
Approved by: https://github.com/fduwjj
2023-08-11 23:49:27 +00:00
alanhe151220037
1afbc985fe Make RNGStateTracker support cuda-like device (#106771)
replace  `CudaRNGStateTracker` with `RNGStateTracker` by rewriting some Cuda-binding code with `device_handle`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106771
Approved by: https://github.com/wanchaol
2023-08-10 19:14:33 +00:00
Iris
0cba33e176 [DTensor]Minor Docstring Update (#106250)
Fix docstring to reflect change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106250
Approved by: https://github.com/wanchaol
2023-08-02 00:27:29 +00:00
Wanchao Liang
f139aab2f4 [dynamo] add initial dynamo support for DTensor (#103146)
This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`

`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.

Captured graph:
```
    def forward(self, L_x_ : torch.Tensor):
        l_x_ = L_x_

        # File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
        prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False);  l_x_ = None

        # File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
        prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local);  prim_from_local = None
        to_local = prim_redistribute.to_local();  prim_redistribute = None
        add = to_local + 2;  to_local = None
        return (add,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym
2023-07-19 16:01:12 +00:00
Wanchao Liang
cb23373264 [dynamo] allow tensor subclass fakification in dynamo (#105308)
This PR adds necessary plumbing through torchdynamo to allow tensor
subclasses with certain contract (i.e. with `__tensor_flatten__` and
`__tensor_unflatten__`) to goes through the dynamo fakification pass by
fakifying the tensor subclass internal components.

Some of the tensor subclass contract logic mostly borrowed from
https://github.com/pytorch/pytorch/pull/97540

Added some tests to verify simply passing through a tensor subclass
(i.e. DTensor) through dynamo eager works as expected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105308
Approved by: https://github.com/ezyang
2023-07-18 17:28:04 +00:00
Wanchao Liang
bcb9ca4e5a [dtensor] canonicalize detach callsites and use view_as when appropriate (#105239)
This PR canonicalize the detach callsite to only call the detach
from `distribute_tensor`. Change other callsite to view_as and remove the
tensor constructor detach call

This is so that we don't detach local tensor for every op run when
rewrapping the DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105239
Approved by: https://github.com/albanD
2023-07-18 17:13:37 +00:00
Xilun Wu
a66107a30c [DTensor][Random] Introduce CudaRNGStateTracker to maintain parallel RNG state for DTensor (#103235)
# Change
This PR adds two classes to DTensor:

1. `CudaRNGStateTracker`:  `CudaRNGStateTracker` stores Random Number Generator (RNG) state (a `ByteTensor` object) in a `dict`, mapping from a corresponding tag to each state tensor. It also provides a set of convenient utility methods to help access/modify the state tensors. The most important interface is `_distribute_region` which will be used when DTensor executes a random op (an operator that calls RNG).

2. `OffsetBasedRNGTracker`: This subclass of `CudaRNGStateTracker` defines the default policy of how RNG states should be shared and synchronized among all ranks to respect the semantics of DTensor random operators.

# Warning

- With `Multi-threaded ProcessGroup`, the global variable `_rng_tracker` will be shared among threads(ranks) and cause issue. We need to figure out a compatible solution for that.

- The RNG state may be asynchronous outside of participating ranks. It is harmless in our current use case of submesh though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103235
Approved by: https://github.com/wanchaol
2023-06-27 19:00:25 +00:00
Wanchao Liang
4cc474dec4 [dtensor] support torch.save/load with DTensor (#103106)
This PR actually enables DTensor to be pickable and add tests to test
torch.save/load works correctly for DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103106
Approved by: https://github.com/kumpera
2023-06-09 04:11:15 +00:00
fduwjj
92923aca61 [TP] Use Stride inferred from local tensor in to_local bwd (#102630)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102630
Approved by: https://github.com/wanchaol
2023-06-01 04:30:24 +00:00
Wanchao Liang
c5d4ee2d73 [dtensor][simple] fix some comments (#102661)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102661
Approved by: https://github.com/fduwjj, https://github.com/XilunWu
2023-06-01 03:23:19 +00:00
Wanchao Liang
70eccdbf92 [dtensor] add necessary logging to APIs and components (#101994)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101994
Approved by: https://github.com/wz337
2023-05-23 18:17:54 +00:00
Wanchao Liang
599ae95d1a [dtensor] use stack to manage mesh resources (#101202)
This PR changes the context manager behavior of device mesh, now we use
a mesh env to track the current mesh and save the mesh to a stack so
that we can allow nested context manager
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101202
Approved by: https://github.com/wz337
2023-05-11 23:48:36 +00:00
Wanchao Liang
55a1dc7f88 [dtensor] redistributed by default take self mesh instead (#99060)
This PR switches redistribute to default use self mesh instead of
the global mesh, which is more user friendly
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99060
Approved by: https://github.com/mrshenli
2023-04-14 05:14:28 +00:00
Shen Li
02179827cb [Easy] Include SPMD and DTensor files in UFMT checks (#98148)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98148
Approved by: https://github.com/fegin
2023-04-02 15:34:49 +00:00
Wanchao Liang
e9c4904915 [dtensor] remove custom dispatch op (#95629)
Since we removed all custom dispatch ops, we can safely delete this
table as we won't use it for other purposes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95629
Approved by: https://github.com/XilunWu
2023-03-28 02:25:45 +00:00
Kazuaki Ishizaki
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
Wanchao Liang
2a1cb9640c [dtensor] support creating DTensor in submesh (#95458)
This PR supports creating DTensor in a submesh, if the rank is not
participating in the mesh, we assign the local tensor to be empty
tensor, and do nothing in the operator dispatch

Differential Revision: [D43643577](https://our.internmc.facebook.com/intern/diff/D43643577)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95458
Approved by: https://github.com/XilunWu
2023-02-28 17:54:26 +00:00
Wanchao Liang
bb9a05b116 [dtensor] use tracing for metadata prop (#95456)
This PR uses tracing for metadata prop, so that we can get correct
shape/stride metadata without manual calculation by ourselves.

The follow up PR on this would be adopt tracing for the sharding
prop itself

Differential Revision: [D43643578](https://our.internmc.facebook.com/intern/diff/D43643578)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95456
Approved by: https://github.com/XilunWu
2023-02-28 17:54:22 +00:00
Wanchao Liang
680fc84e7b [dtensor] group public APIs together (#94524)
This PR groups distribute_tensor/module to api.py

rename some to non-public (ToTensor/FromTensor)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94524
Approved by: https://github.com/XilunWu
2023-02-10 23:40:34 +00:00
Wanchao Liang
b072245178 [dtensor][4/N] refactor dispatching logic and add propagator (#90733)
This PR refactors the dispatching logic to make it more clean, and
isolate the sharding propagation logic out to a separate class.

This is so that we can implement more complicated propagation features
later.

Differential Revision: [D42876251](https://our.internmc.facebook.com/intern/diff/D42876251)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90733
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
2023-02-01 05:02:11 +00:00
fduwjj
77f336600a [PT-D] Enable Meta Tensor Support for DTensor (#92652)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92652
Approved by: https://github.com/XilunWu, https://github.com/wanchaol
2023-01-26 04:54:57 +00:00
Wanchao Liang
7afba50508 [dtensor] delete unused torch_function (#90449)
torch_function is not actually getting used yet today, deleting
it first and we can revisit once we really need it
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90449
Approved by: https://github.com/fduwjj
2022-12-10 01:29:02 +00:00
Wanchao Liang
bf23e0bdbd [dtensor] ufmt distributed._tensor (#89967)
cmd: `ufmt format torch/distributed/_tensor`

copy from Andrew:

Notes
For VSCode users,

Install ufmt: https://pypi.org/project/ufmt/
Install VSCode ufmt extension: https://marketplace.visualstudio.com/items?itemName=omnilib.ufmt
Include in settings.json:
```
{
    "[python]": {
        "editor.defaultFormatter": "omnilib.ufmt",
        "editor.formatOnSave": true,
    },
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89967
Approved by: https://github.com/fduwjj
2022-12-01 20:58:13 +00:00
Wanchao Liang
4b945967de [dtensor] PART 2: move DTensor abstraction and APIs to core distributed (#88176)
This PR moves the core DTensor abstraction and high level APIs to
torch.distributed._tensor folder, which includes the following:
1. DTensor class
2. high level APIs (distribute_tensor/module)
3. dispatching logic
4. redistribute logic

part of https://github.com/pytorch/pytorch/issues/88838
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88176
Approved by: https://github.com/fduwjj
2022-11-16 08:07:41 +00:00