We are refactoring parallel style to solve the following things:
1. To further simplifying code logic to make more readable for users.
2. To remove tuple check so that we can work with dynamo for now. Ideally dynamo needs to support this case and we will fix it in parallel.
3. Add tests for newly added parallel style in UT and torch compile test so that we can capture regression due to code change.
4. Move placements early return check into DTensor since it is by passed by dynamo.
5. Remove PairwiseParallelStyle from unit tests to use the new Col/Rowwise parallel style.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111625
Approved by: https://github.com/wanchaol
This PR updates DTensor to support torch.compile
Cool stuff: there are some new tests in `test_dtensor.py` that show both the forward and backward graphs that we can send to inductor, when running a matmul with DTensor's. In particular, for this user code:
```
def fn(x, y):
dt = DTensor.from_local(x.reshape(2, 4), mesh, [Shard(0)], run_check=False)
dt2 = DTensor.from_local(y.reshape(4, 2), mesh, [Shard(1)], run_check=False)
dt_out = torch.matmul(dt, dt2)
dt_out_redistribute = dt_out.redistribute(mesh, [Replicate()])
return dt_out.to_local()
```
We generate the following fw and backward graphs.
Forward graph:
```
def forward(self, primals_1, primals_2):
view = torch.ops.aten.view.default(primals_1, [2, 4]); primals_1 = None
_to_copy = torch.ops.aten._to_copy.default(view, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0)); view = None
detach = torch.ops.aten.detach.default(_to_copy); _to_copy = None
detach_1 = torch.ops.aten.detach.default(detach); detach = None
view_1 = torch.ops.aten.view.default(primals_2, [4, 2]); primals_2 = None
_to_copy_1 = torch.ops.aten._to_copy.default(view_1, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0)); view_1 = None
detach_2 = torch.ops.aten.detach.default(_to_copy_1); _to_copy_1 = None
detach_3 = torch.ops.aten.detach.default(detach_2); detach_2 = None
detach_4 = torch.ops.aten.detach.default(detach_1)
all_gather_into_tensor = torch.ops.c10d_functional.all_gather_into_tensor.default(detach_3, 'ptd:0', [0, 1], 2)
wait_tensor = torch.ops.c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None
split = torch.ops.aten.split.Tensor(wait_tensor, 4); wait_tensor = None
getitem = split[0]
getitem_1 = split[1]; split = None
cat = torch.ops.aten.cat.default([getitem, getitem_1], 1); getitem = getitem_1 = None
detach_5 = torch.ops.aten.detach.default(cat); cat = None
mm = torch.ops.aten.mm.default(detach_4, detach_5); detach_4 = detach_5 = None
detach_6 = torch.ops.aten.detach.default(mm); mm = None
detach_9 = torch.ops.aten.detach.default(detach_6); detach_6 = None
detach_10 = torch.ops.aten.detach.default(detach_9); detach_9 = None
t = torch.ops.aten.t.default(detach_1); detach_1 = None
detach_13 = torch.ops.aten.detach.default(t); t = None
t_1 = torch.ops.aten.t.default(detach_3); detach_3 = None
detach_15 = torch.ops.aten.detach.default(t_1); t_1 = None
clone = torch.ops.aten.clone.default(detach_15, memory_format = torch.contiguous_format); detach_15 = None
return [detach_10, detach_13, clone]
```
Backward graph:
```
def forward(self, detach_13, clone, tangents_1):
detach_11 = torch.ops.aten.detach.default(tangents_1); tangents_1 = None
detach_12 = torch.ops.aten.detach.default(detach_11); detach_11 = None
mm_1 = torch.ops.aten.mm.default(detach_13, detach_12); detach_13 = None
detach_14 = torch.ops.aten.detach.default(mm_1); mm_1 = None
detach_16 = torch.ops.aten.detach.default(detach_12); detach_12 = None
all_gather_into_tensor_2 = torch.ops.c10d_functional.all_gather_into_tensor.default(clone, 'ptd:0', [0, 1], 2); clone = None
wait_tensor_2 = torch.ops.c10d_functional.wait_tensor.default(all_gather_into_tensor_2);
detach_17 = torch.ops.aten.detach.default(wait_tensor_2); wait_tensor_2 = None
mm_2 = torch.ops.aten.mm.default(detach_16, detach_17); detach_16 = detach_17 = None
detach_18 = torch.ops.aten.detach.default(mm_2); mm_2 = None
split_1 = torch.ops.aten.split.Tensor(detach_14, 2, 1); detach_14 = None
getitem_2 = split_1[0]
getitem_3 = split_1[1]; split_1 = None
cat_1 = torch.ops.aten.cat.default([getitem_2, getitem_3]); getitem_2 = getitem_3 = None
reduce_scatter_tensor = torch.ops.c10d_functional.reduce_scatter_tensor.default(cat_1, 'SUM', 'ptd:0', [0, 1], 2); cat_1 = None
wait_tensor_3 = torch.ops.c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None
detach_19 = torch.ops.aten.detach.default(wait_tensor_3); wait_tensor_3 = None
detach_20 = torch.ops.aten.detach.default(detach_19); detach_19 = None
detach_21 = torch.ops.aten.detach.default(detach_20); detach_20 = None
detach_22 = torch.ops.aten.detach.default(detach_21); detach_21 = None
_to_copy_2 = torch.ops.aten._to_copy.default(detach_22, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); detach_22 = None
view_2 = torch.ops.aten.view.default(_to_copy_2, [8]); _to_copy_2 = None
detach_23 = torch.ops.aten.detach.default(detach_18); detach_18 = None
detach_24 = torch.ops.aten.detach.default(detach_23); detach_23 = None
_to_copy_3 = torch.ops.aten._to_copy.default(detach_24, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); detach_24 = None
view_3 = torch.ops.aten.view.default(_to_copy_3, [8]); _to_copy_3 = None
return [view_3, view_2]
```
Some of the stuff in this graph looks kinda of silly though (e.g. an unnecessary split() + cat(), and all the extra detach() calls).
Stuff that's broken:
- functionalization is pretty horribly broken. In particular, the original strategy I used in this stack was to have functionalization run **above** subclass desugaring. But that doesn't play well with with the way we want to compile DTensor. DTensor has a few API's like `.redistribute()`, `.to_local()`, and the `DTensor()` constructor, that we want to put directly into the graph so that we can compile them (e.g. redistribute() will desugar into collective ops). Doing this requires functionalization to run **underneath** the subclass though. I hacked around this for now, by forcing these functions to run functionalization first if they need to.
- the backward test that I have is... wrong. The backward graph that we trace out looks kind of reasonable, but it gives incorrect gradients on one of the two inputs. This needs further debugging (presumably we should be able to stare at the graph and identify which part of it is wrong?).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105236
Approved by: https://github.com/wanchaol
skip tensor.to in from_local and distribute_tensor when device_type of
device mesh matches tensor.device type, since from_local on the critial
path of TP, this might also reduce some overhead
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110774
Approved by: https://github.com/fduwjj
When we convert to local tensor, dtensor can't track autograd or
gradient layout of the local tensor anymore, if user do sth not expected, there
needs to be a way for user to hint about the gradient layout of the
local tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110629
Approved by: https://github.com/zdevito
resolves https://github.com/pytorch/pytorch/issues/109101
The problem is essentially because we were hashing all the arguments, including
the scalar too (i.e. aten.div(tensor, scalar)), in the optimizer, the scalar might
change everytime we call the op, thus cache miss everytime we call the op
This PR improves the sharding cache behavior by introducing a
RuntimeSchemaInfo, used to record some runtime necessary hashing
information during op registration time. This enable us to:
* only hash arguments that are tensor or have static_argnum, this is to
enable many cases like aten.div.Tensor(tensor, 0.23231) hit the cache.
as we currently hashing all args which exclude those cases
* with the correct cache behavior, optimizers will hit the cache again
and resolve the high cpu overhead issue.
simple MLP shows all cache hit and for a single addmm -> 0.319ms (from 0.341ms), shows some hashing improvements:
<img width="1172" alt="Screenshot 2023-09-14 at 11 06 07 AM" src="https://github.com/pytorch/pytorch/assets/9443650/3406d673-dd8d-4ad9-9b80-9d4721c430e3">
Adam optimizer shows aten.div hit sharding cache again
<img width="1016" alt="Screenshot 2023-09-14 at 11 02 10 AM" src="https://github.com/pytorch/pytorch/assets/9443650/4280e8e3-af44-4fc2-8360-ea80b768f1d9">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109306
Approved by: https://github.com/fduwjj
This PR switches the usage of fx's shape prop TensorMetadata to
dtensor's own dedicated defined TensorMeta, this is because DTensor
only cares three fields: shape/stride/dtype, all other fields are not
necessary and can be inferred from local_tensor directly. This would
help significantly simplify how we deal with the tensor metadata by not
caring other fields.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108261
Approved by: https://github.com/fduwjj
ghstack dependencies: #107306
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:
(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests
(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.
(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415
Approved by: https://github.com/ezyang
This PR fixes the requires_grad set when calling distribute_tensor, we
should set the requires_grad of the local tensor after the detach call
to make sure we create the leaf correctly, otherwise it would raise
warnings
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107606
Approved by: https://github.com/fduwjj
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.
This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms

after (with this change), aten.addmm latency: 0.341ms

overall one layer of mlp time reduced from 13.535 -> 9.665ms
Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107305
Approved by: https://github.com/fduwjj
Move the remaining collectives to a separate file to prepare device mesh
to become a public distributed API
For those remaining utils, we need to upstream them to functional
collectives with proper implementation, added TODO there for a follow up
PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107012
Approved by: https://github.com/fduwjj
This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`
`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.
Captured graph:
```
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False); l_x_ = None
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local); prim_from_local = None
to_local = prim_redistribute.to_local(); prim_redistribute = None
add = to_local + 2; to_local = None
return (add,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym
This PR adds necessary plumbing through torchdynamo to allow tensor
subclasses with certain contract (i.e. with `__tensor_flatten__` and
`__tensor_unflatten__`) to goes through the dynamo fakification pass by
fakifying the tensor subclass internal components.
Some of the tensor subclass contract logic mostly borrowed from
https://github.com/pytorch/pytorch/pull/97540
Added some tests to verify simply passing through a tensor subclass
(i.e. DTensor) through dynamo eager works as expected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105308
Approved by: https://github.com/ezyang
This PR canonicalize the detach callsite to only call the detach
from `distribute_tensor`. Change other callsite to view_as and remove the
tensor constructor detach call
This is so that we don't detach local tensor for every op run when
rewrapping the DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105239
Approved by: https://github.com/albanD
# Change
This PR adds two classes to DTensor:
1. `CudaRNGStateTracker`: `CudaRNGStateTracker` stores Random Number Generator (RNG) state (a `ByteTensor` object) in a `dict`, mapping from a corresponding tag to each state tensor. It also provides a set of convenient utility methods to help access/modify the state tensors. The most important interface is `_distribute_region` which will be used when DTensor executes a random op (an operator that calls RNG).
2. `OffsetBasedRNGTracker`: This subclass of `CudaRNGStateTracker` defines the default policy of how RNG states should be shared and synchronized among all ranks to respect the semantics of DTensor random operators.
# Warning
- With `Multi-threaded ProcessGroup`, the global variable `_rng_tracker` will be shared among threads(ranks) and cause issue. We need to figure out a compatible solution for that.
- The RNG state may be asynchronous outside of participating ranks. It is harmless in our current use case of submesh though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103235
Approved by: https://github.com/wanchaol
This PR changes the context manager behavior of device mesh, now we use
a mesh env to track the current mesh and save the mesh to a stack so
that we can allow nested context manager
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101202
Approved by: https://github.com/wz337