Commit Graph

21 Commits

Author SHA1 Message Date
Wanchao Liang
f139aab2f4 [dynamo] add initial dynamo support for DTensor (#103146)
This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`

`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.

Captured graph:
```
    def forward(self, L_x_ : torch.Tensor):
        l_x_ = L_x_

        # File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
        prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False);  l_x_ = None

        # File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
        prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local);  prim_from_local = None
        to_local = prim_redistribute.to_local();  prim_redistribute = None
        add = to_local + 2;  to_local = None
        return (add,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym
2023-07-19 16:01:12 +00:00
Wanchao Liang
cb23373264 [dynamo] allow tensor subclass fakification in dynamo (#105308)
This PR adds necessary plumbing through torchdynamo to allow tensor
subclasses with certain contract (i.e. with `__tensor_flatten__` and
`__tensor_unflatten__`) to goes through the dynamo fakification pass by
fakifying the tensor subclass internal components.

Some of the tensor subclass contract logic mostly borrowed from
https://github.com/pytorch/pytorch/pull/97540

Added some tests to verify simply passing through a tensor subclass
(i.e. DTensor) through dynamo eager works as expected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105308
Approved by: https://github.com/ezyang
2023-07-18 17:28:04 +00:00
Wanchao Liang
bcb9ca4e5a [dtensor] canonicalize detach callsites and use view_as when appropriate (#105239)
This PR canonicalize the detach callsite to only call the detach
from `distribute_tensor`. Change other callsite to view_as and remove the
tensor constructor detach call

This is so that we don't detach local tensor for every op run when
rewrapping the DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105239
Approved by: https://github.com/albanD
2023-07-18 17:13:37 +00:00
Xilun Wu
a66107a30c [DTensor][Random] Introduce CudaRNGStateTracker to maintain parallel RNG state for DTensor (#103235)
# Change
This PR adds two classes to DTensor:

1. `CudaRNGStateTracker`:  `CudaRNGStateTracker` stores Random Number Generator (RNG) state (a `ByteTensor` object) in a `dict`, mapping from a corresponding tag to each state tensor. It also provides a set of convenient utility methods to help access/modify the state tensors. The most important interface is `_distribute_region` which will be used when DTensor executes a random op (an operator that calls RNG).

2. `OffsetBasedRNGTracker`: This subclass of `CudaRNGStateTracker` defines the default policy of how RNG states should be shared and synchronized among all ranks to respect the semantics of DTensor random operators.

# Warning

- With `Multi-threaded ProcessGroup`, the global variable `_rng_tracker` will be shared among threads(ranks) and cause issue. We need to figure out a compatible solution for that.

- The RNG state may be asynchronous outside of participating ranks. It is harmless in our current use case of submesh though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103235
Approved by: https://github.com/wanchaol
2023-06-27 19:00:25 +00:00
Wanchao Liang
4cc474dec4 [dtensor] support torch.save/load with DTensor (#103106)
This PR actually enables DTensor to be pickable and add tests to test
torch.save/load works correctly for DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103106
Approved by: https://github.com/kumpera
2023-06-09 04:11:15 +00:00
fduwjj
92923aca61 [TP] Use Stride inferred from local tensor in to_local bwd (#102630)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102630
Approved by: https://github.com/wanchaol
2023-06-01 04:30:24 +00:00
Wanchao Liang
c5d4ee2d73 [dtensor][simple] fix some comments (#102661)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102661
Approved by: https://github.com/fduwjj, https://github.com/XilunWu
2023-06-01 03:23:19 +00:00
Wanchao Liang
70eccdbf92 [dtensor] add necessary logging to APIs and components (#101994)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101994
Approved by: https://github.com/wz337
2023-05-23 18:17:54 +00:00
Wanchao Liang
599ae95d1a [dtensor] use stack to manage mesh resources (#101202)
This PR changes the context manager behavior of device mesh, now we use
a mesh env to track the current mesh and save the mesh to a stack so
that we can allow nested context manager
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101202
Approved by: https://github.com/wz337
2023-05-11 23:48:36 +00:00
Wanchao Liang
55a1dc7f88 [dtensor] redistributed by default take self mesh instead (#99060)
This PR switches redistribute to default use self mesh instead of
the global mesh, which is more user friendly
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99060
Approved by: https://github.com/mrshenli
2023-04-14 05:14:28 +00:00
Shen Li
02179827cb [Easy] Include SPMD and DTensor files in UFMT checks (#98148)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98148
Approved by: https://github.com/fegin
2023-04-02 15:34:49 +00:00
Wanchao Liang
e9c4904915 [dtensor] remove custom dispatch op (#95629)
Since we removed all custom dispatch ops, we can safely delete this
table as we won't use it for other purposes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95629
Approved by: https://github.com/XilunWu
2023-03-28 02:25:45 +00:00
Kazuaki Ishizaki
35fd5c548e Fix typos under torch/distributed directory (#95638)
This PR fixes typos in comments and messages of `.py` files under torch/distributed directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95638
Approved by: https://github.com/usamah1, https://github.com/H-Huang, https://github.com/kit1980
2023-03-27 21:13:44 +00:00
Wanchao Liang
2a1cb9640c [dtensor] support creating DTensor in submesh (#95458)
This PR supports creating DTensor in a submesh, if the rank is not
participating in the mesh, we assign the local tensor to be empty
tensor, and do nothing in the operator dispatch

Differential Revision: [D43643577](https://our.internmc.facebook.com/intern/diff/D43643577)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95458
Approved by: https://github.com/XilunWu
2023-02-28 17:54:26 +00:00
Wanchao Liang
bb9a05b116 [dtensor] use tracing for metadata prop (#95456)
This PR uses tracing for metadata prop, so that we can get correct
shape/stride metadata without manual calculation by ourselves.

The follow up PR on this would be adopt tracing for the sharding
prop itself

Differential Revision: [D43643578](https://our.internmc.facebook.com/intern/diff/D43643578)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95456
Approved by: https://github.com/XilunWu
2023-02-28 17:54:22 +00:00
Wanchao Liang
680fc84e7b [dtensor] group public APIs together (#94524)
This PR groups distribute_tensor/module to api.py

rename some to non-public (ToTensor/FromTensor)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94524
Approved by: https://github.com/XilunWu
2023-02-10 23:40:34 +00:00
Wanchao Liang
b072245178 [dtensor][4/N] refactor dispatching logic and add propagator (#90733)
This PR refactors the dispatching logic to make it more clean, and
isolate the sharding propagation logic out to a separate class.

This is so that we can implement more complicated propagation features
later.

Differential Revision: [D42876251](https://our.internmc.facebook.com/intern/diff/D42876251)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90733
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
2023-02-01 05:02:11 +00:00
fduwjj
77f336600a [PT-D] Enable Meta Tensor Support for DTensor (#92652)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92652
Approved by: https://github.com/XilunWu, https://github.com/wanchaol
2023-01-26 04:54:57 +00:00
Wanchao Liang
7afba50508 [dtensor] delete unused torch_function (#90449)
torch_function is not actually getting used yet today, deleting
it first and we can revisit once we really need it
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90449
Approved by: https://github.com/fduwjj
2022-12-10 01:29:02 +00:00
Wanchao Liang
bf23e0bdbd [dtensor] ufmt distributed._tensor (#89967)
cmd: `ufmt format torch/distributed/_tensor`

copy from Andrew:

Notes
For VSCode users,

Install ufmt: https://pypi.org/project/ufmt/
Install VSCode ufmt extension: https://marketplace.visualstudio.com/items?itemName=omnilib.ufmt
Include in settings.json:
```
{
    "[python]": {
        "editor.defaultFormatter": "omnilib.ufmt",
        "editor.formatOnSave": true,
    },
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89967
Approved by: https://github.com/fduwjj
2022-12-01 20:58:13 +00:00
Wanchao Liang
4b945967de [dtensor] PART 2: move DTensor abstraction and APIs to core distributed (#88176)
This PR moves the core DTensor abstraction and high level APIs to
torch.distributed._tensor folder, which includes the following:
1. DTensor class
2. high level APIs (distribute_tensor/module)
3. dispatching logic
4. redistribute logic

part of https://github.com/pytorch/pytorch/issues/88838
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88176
Approved by: https://github.com/fduwjj
2022-11-16 08:07:41 +00:00