This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`
`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.
Captured graph:
```
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False); l_x_ = None
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local); prim_from_local = None
to_local = prim_redistribute.to_local(); prim_redistribute = None
add = to_local + 2; to_local = None
return (add,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym
This PR adds necessary plumbing through torchdynamo to allow tensor
subclasses with certain contract (i.e. with `__tensor_flatten__` and
`__tensor_unflatten__`) to goes through the dynamo fakification pass by
fakifying the tensor subclass internal components.
Some of the tensor subclass contract logic mostly borrowed from
https://github.com/pytorch/pytorch/pull/97540
Added some tests to verify simply passing through a tensor subclass
(i.e. DTensor) through dynamo eager works as expected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105308
Approved by: https://github.com/ezyang
This PR canonicalize the detach callsite to only call the detach
from `distribute_tensor`. Change other callsite to view_as and remove the
tensor constructor detach call
This is so that we don't detach local tensor for every op run when
rewrapping the DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105239
Approved by: https://github.com/albanD
# Change
This PR adds two classes to DTensor:
1. `CudaRNGStateTracker`: `CudaRNGStateTracker` stores Random Number Generator (RNG) state (a `ByteTensor` object) in a `dict`, mapping from a corresponding tag to each state tensor. It also provides a set of convenient utility methods to help access/modify the state tensors. The most important interface is `_distribute_region` which will be used when DTensor executes a random op (an operator that calls RNG).
2. `OffsetBasedRNGTracker`: This subclass of `CudaRNGStateTracker` defines the default policy of how RNG states should be shared and synchronized among all ranks to respect the semantics of DTensor random operators.
# Warning
- With `Multi-threaded ProcessGroup`, the global variable `_rng_tracker` will be shared among threads(ranks) and cause issue. We need to figure out a compatible solution for that.
- The RNG state may be asynchronous outside of participating ranks. It is harmless in our current use case of submesh though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103235
Approved by: https://github.com/wanchaol
This PR changes the context manager behavior of device mesh, now we use
a mesh env to track the current mesh and save the mesh to a stack so
that we can allow nested context manager
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101202
Approved by: https://github.com/wz337