This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`
`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.
Captured graph:
```
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False); l_x_ = None
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local); prim_from_local = None
to_local = prim_redistribute.to_local(); prim_redistribute = None
add = to_local + 2; to_local = None
return (add,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym
This PR adds necessary plumbing through torchdynamo to allow tensor
subclasses with certain contract (i.e. with `__tensor_flatten__` and
`__tensor_unflatten__`) to goes through the dynamo fakification pass by
fakifying the tensor subclass internal components.
Some of the tensor subclass contract logic mostly borrowed from
https://github.com/pytorch/pytorch/pull/97540
Added some tests to verify simply passing through a tensor subclass
(i.e. DTensor) through dynamo eager works as expected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105308
Approved by: https://github.com/ezyang
This PR canonicalize the detach callsite to only call the detach
from `distribute_tensor`. Change other callsite to view_as and remove the
tensor constructor detach call
This is so that we don't detach local tensor for every op run when
rewrapping the DTensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105239
Approved by: https://github.com/albanD
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
When creating DeviceMesh, _init_process_group() would validate that all calling ranks pass in the same `mesh` argument. In FSDP, we are currently creating the DeviceMesh based on the pg of the root state so the mesh will always be valid. Adding the flag to DeviceMesh, so we can skip the all_gather_tensor of the validation during construction time.
_validate_mesh is default to True, but we manually flip it to False when initializing device mesh in FSDP's _runtime_utils.py.
Will modify skipping pg creation if existed for both 1D and 2D cases and then delete _init_process_groups flag in a follow up PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104807
Approved by: https://github.com/wanchaol
Not sure, how it worked before, but if arguments must be annotated is optional if they are defaulted to None
Towards enabling mypy-1.4.1 in lintrunner
<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 5e1b9f4</samp>
> _We annotate the arguments of doom_
> _To show the `None` values of gloom_
> _We improve the type checking and readability_
> _With `Optional` annotations of metal-ity_
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105022
Approved by: https://github.com/izaitsevfb, https://github.com/huydhn, https://github.com/Skylion007
Originally, we didn't enable BWD for colwise embedding because we thought it was just for inference, but it turns out that we do need it for training. So, let's enable it for now and unit test is also added.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104820
Approved by: https://github.com/fegin
# Change
This PR adds two classes to DTensor:
1. `CudaRNGStateTracker`: `CudaRNGStateTracker` stores Random Number Generator (RNG) state (a `ByteTensor` object) in a `dict`, mapping from a corresponding tag to each state tensor. It also provides a set of convenient utility methods to help access/modify the state tensors. The most important interface is `_distribute_region` which will be used when DTensor executes a random op (an operator that calls RNG).
2. `OffsetBasedRNGTracker`: This subclass of `CudaRNGStateTracker` defines the default policy of how RNG states should be shared and synchronized among all ranks to respect the semantics of DTensor random operators.
# Warning
- With `Multi-threaded ProcessGroup`, the global variable `_rng_tracker` will be shared among threads(ranks) and cause issue. We need to figure out a compatible solution for that.
- The RNG state may be asynchronous outside of participating ranks. It is harmless in our current use case of submesh though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103235
Approved by: https://github.com/wanchaol
This PR get rids of the dim_groups attribute from DeviceMesh, the main
motivation behind this is that we should let c10d store the process
groups during its creation instead of DeviceMesh, DeviceMesh should just
handle ranks correctly.
This could enable DTensor becomes picklable! (torch.save/load could be
possible), which I will give it a try in the next PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103105
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
Allow DTensor support cuda-like device, fix https://github.com/pytorch/pytorch/issues/102442
Currently, DTensor supports cuda and cpu. There are other efforts to make DTensor support third-party devices, for example https://github.com/pytorch/pytorch/pull/101914 and https://github.com/pytorch/pytorch/issues/101911. However, this support only extends a portion of third-party devices and is no good support for third-party cuda-like devices. Therefore, we would like to extend DTensor to support cuda-like devices, after all, cuda is so popular!
1. Similar to what is done here, we need to initialize the communication backend for the device set by DeviceMesh. So `_default_backend_for_device` is added to `Backend`. It is worth noting that when we register a new backend for a device other than cpu and cuda, we also need to add a new default backend for this device.
2. Adding `_device_handle` to `DeviceMesh` for cuda-like devices, similar to what is set in FSDP. When `_device_handle` is not None, the device has similar behavior to `cuda`. In this way, functions like `torch.cuda.device_count()` need to be modified to `device_mesh._device_handle.device_count()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102468
Approved by: https://github.com/wanchaol
This PR switches DeviceMesh to use dispatchable process group instead,
this could enable easier backend integration as user only need to
integrate with c10d process group custom backend, without needing to
change DeviceMesh to plug in the backend
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102336
Approved by: https://github.com/fduwjj
When tensor.size(self.dim) < num_chunks, we will fill empty chunk with empty tensor (https://github.com/pytorch/pytorch/pull/98722). Therefore, we no longer needs this assert.
For example, when sharding a tensor with 1 element on 2 ranks along dim 0, results would be as follows:
```
rank:0, dtensor:DTensor(local_tensor=tensor([0.4963], device='cuda:0'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)])
rank:1, dtensor:DTensor(local_tensor=tensor([], device='cuda:1'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)])
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101218
Approved by: https://github.com/wanchaol
This PR changes the context manager behavior of device mesh, now we use
a mesh env to track the current mesh and save the mesh to a stack so
that we can allow nested context manager
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101202
Approved by: https://github.com/wz337
This is the first series of PR that adopts operator impls to use a
strategy based approach, each op utilizes OpStrategy and PlacementStrategy
to generate their own strategy. By utilizing the strategy based
approach along with the op graph, we could enable more advanced op
implementation (decomp is possible), and turn the sharding prop to be
more like a contraint satisfication problem.
This PR alone only adds some basic tensor op strategies, and it directly
works on the op graph that was used for metadata propagation. The tensor ops
added in this PR mainly follows one of the arg strategy. The next set of
PRs would add more op strategies to other ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100607
Approved by: https://github.com/XilunWu
DTensor was reusing `einop_rule` to propagate sharding for torch.cat.
However, einsum only supports up to 52 subscripts (i.e., input tensors).
We have encountered use cases where one cat operator has more than 60
input tensors. Therefore, this commit reimplements sharding prop
rule for cat without using einsum.
Differential Revision: [D45435232](https://our.internmc.facebook.com/intern/diff/D45435232)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100251
Approved by: https://github.com/wanchaol
This PR refactors the current StrategyList. It introduces a
StrategyType, which is the base class of Strategy, and it have
two sub strategies:
1. Refactor the previous StrategyList to OpStrategy
2. Add TupleStrategy, the new strategy added to deal with tuple cases where
it could return multiple different OpStrategy for an op.
This would help support a more complicated op and unblocks compile mode
FSDP
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99435
Approved by: https://github.com/mrshenli
## What's in this PR
DeviceMesh's __init__ function now requires all calling ranks to pass the same `mesh` argument.
## Why
We want to enforce SPMD style of programs using DTensor. Before this PR, 2-D Parallel API (e.g. _create_1d_device_mesh) defines different DeviceMesh on different ranks. After this PR, it defines each sub-meshes and simply perform communications on the one that it is associated with.
Differential Revision: [D45165511](https://our.internmc.facebook.com/intern/diff/D45165511)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99094
Approved by: https://github.com/wanchaol
As functional collective being updated, using tensor_split() as the underlying sharding algorithm would require padding and unpadding on multiple ranks. Therefore, we are changing the sharding algorithm to be in line with ``torch.chunk()`` to allow padding on the last two ranks in most of the scenarios.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98722
Approved by: https://github.com/wanchaol