A LocalTensor is a tensor subclass which simulates a tensor that is
distributed across SPMD ranks. A LocalTensor might be size N, but in fact
there are world_size shards/replicas of it stored internally. When you do a
plain PyTorch operation on it, we apply the operation to each shard; when you
do a collective, we do the mathematically equivalent operation on the local
shards. A LocalTensor is associated with a list of ranks which specify
which ranks it holds local tensors for.
NB, this is NOT a DataParallel like abstraction where you can run operations
on multiple different GPUs. It is intended purely for *debugging* purposes,
the overhead is almost certainly too high to keep eight GPUs (even the C++
autograd needs multithreading to keep up!) (It might potentially be possible
to trace through this with torch.compile and then compile it with CUDA graphs
but this is currently a non-goal.)
In order to handle MPMD, we provide a helper decorator that allows you to
run a function with no side effects for each LocalTensor shard and combine
results back into LocalTensor or LocalIntNode.
Note: This PR convert all DTensor ops and some DTensor tests to illustrate
intended usage and ensure conrrectness. In subsequent PR more tests will be
converted. DUring test conversion we aim to share as much as possible of
test logic between multi-process / multi-threaded and local tensor tests.
We would like to developers to be able to run both flavors of the tests.
Note: This work is based on the original proposal
by @ezyang (WIP PR https://github.com/pytorch/pytorch/pull/162753).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164537
Approved by: https://github.com/ezyang
Instead of collecting local results using all_gather_object followed by local reduction, with this change we switch to using a single all_reduce with MIN reduction operation to compute the final equals result.
This change is needed to enable LocalTensor work (all_gather_object introduces challenges in for DTensor and LocalTensor integration).
topic: not user facing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164999
Approved by: https://github.com/ezyang
`vmap(F.embedding)(DTensor, DTensor)` was failing because F.embedding's
batching rule generates a new tensor via at::arange, at::arange
generates a regular tensor, and DTensor rightfully errors on mixed
DTensor-regular Tensor operations.
This PR fixes the problem by activating DTensor implicit replication on
just the at::arange and the subsequent add operation.
In order to accomplish this I move the DTensor implicit replication flag
to C++ (most batching rules are in C++).
Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162117
Approved by: https://github.com/bdhirsh
Mainly, this helps tell the user more info about the operator that
failed to run if it fails during sharding propagation.
Previously, only this exception would be raised:
```
RuntimeError: ('Attempted to flatten sharded dimension 1, ', 'but only the leftmost dim of a Flatten can be sharded.')
```
Now you get both the above exception as well as
```
The above exception was the direct cause of the following exception:
RuntimeError: Sharding propagation failed for Op(op=aten.view.default, args_schema=Spec((Replicate(), Shard(dim=0), Shard(dim=1), Shard(dim=2)) on (8, 8, 4)), [64, 4] @ mesh: (1, 2, 2, 2))
```
<stacktrace omitted>
<details><summary>detailed error</summary>
```
======================================================================
ERROR: test_linear (__main__.TestDTensor)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 668, in wrapper
self._join_processes(fn)
File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 932, in _join_processes
self._check_return_codes(fn, elapsed_time)
File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 972, in _check_return_codes
raise RuntimeError(error)
RuntimeError: Process 4 exited with error code 10 and exception:
Traceback (most recent call last):
File "/data/users/whc/pytorch/torch/distributed/tensor/_dispatch.py", line 150, in dispatch
self.sharding_propagator.propagate(op_info)
File "/data/users/whc/pytorch/torch/distributed/tensor/_sharding_prop.py", line 309, in propagate
OutputSharding, self.propagate_op_sharding(op_info.schema)
File "/data/users/whc/pytorch/torch/distributed/tensor/_sharding_prop.py", line 45, in __call__
return self.cache(*args, **kwargs)
File "/data/users/whc/pytorch/torch/distributed/tensor/_sharding_prop.py", line 329, in propagate_op_sharding_non_cached
op_strategy = self.op_strategy_funcs[op_schema.op](strategy_schema)
File "/data/users/whc/pytorch/torch/distributed/tensor/_ops/_view_ops.py", line 673, in reshape_strategy
input_tgt_placements, output_placements = propagate_shape_and_sharding(
File "/data/users/whc/pytorch/torch/distributed/tensor/_ops/_view_ops.py", line 601, in propagate_shape_and_sharding
in_dim = get_in_dim_to_shard(cmd)
File "/data/users/whc/pytorch/torch/distributed/tensor/_ops/_view_ops.py", line 537, in get_in_dim_to_shard
raise RuntimeError(
RuntimeError: ('Attempted to flatten sharded dimension 1, ', 'but only the leftmost dim of a Flatten can be sharded.')
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 816, in run_test
getattr(self, test_name)()
File "/data/users/whc/pytorch/torch/testing/_internal/common_distributed.py", line 670, in wrapper
fn()
File "/data/users/whc/pytorch/torch/testing/_internal/common_utils.py", line 3224, in wrapper
method(*args, **kwargs)
File "/data/users/whc/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 490, in wrapper
raise e
File "/data/users/whc/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 487, in wrapper
func(self, *args, **kwargs) # type: ignore[misc]
File "/data/users/whc/pytorch/test.py", line 60, in test_linear
print("results: ", distributed_linear(distributed_input))
File "/data/users/whc/pytorch/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/users/whc/pytorch/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/data/users/whc/pytorch/torch/nn/modules/linear.py", line 134, in forward
return F.linear(input, self.weight, self.bias)
File "/data/users/whc/pytorch/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
File "/data/users/whc/pytorch/torch/_dynamo/eval_frame.py", line 1005, in _fn
return fn(*args, **kwargs)
File "/data/users/whc/pytorch/torch/distributed/tensor/_api.py", line 358, in __torch_dispatch__
return DTensor._op_dispatcher.dispatch(
File "/data/users/whc/pytorch/torch/distributed/tensor/_dispatch.py", line 163, in dispatch
raise RuntimeError(
RuntimeError: Sharding propagation failed for Op(op=aten.view.default, args_schema=Spec((Replicate(), Shard(dim=0), Shard(dim=1), Shard(dim=2)) on (8, 8, 4)), [64, 4] @ mesh: (1, 2, 2, 2))
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161574
Approved by: https://github.com/zpcore, https://github.com/XilunWu
**Summary**
This PR enables in-place op `aten.squeeze_.dim` on DTensor with a change to
DTensor dispatch logic: when processing in-place operator, we should assign
`output_sharding.output_spec` back to the first argument. This is because
the in-place op_call on `arg._local_tensor` could also shift the tensor meta.
**Test**
`pytest test/distributed/tensor/test_view_ops.py -s -k test_squeeze_`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159532
Approved by: https://github.com/zpcore
as titled. It's sometimes confusing to use PlacementStrategy as a name,
as we also have OpStrategy and TupleStrategy, the latter two contain
the former, so it is better to make the naming clearer.
Renaming PlacementStrategy -> OpSpec as it is an operator spec that
contains output_spec + input_specs.
Also found some utils can be merged to OpSchema so included together in
this PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155592
Approved by: https://github.com/awgu
Today, if you run DTensor (or any tensor subclass) under __torch_dispatch__, you will start seeing `CompositeImplicitAutograd` ops show up in the torch_dispatch.
"handling" these ops is trivial: you can just tell them to decompose into their constituent ops. Normally this decomposing happens in autograd, above DTensor, but inference_mode turns autograd off, forcing the subclass to handle the op directly.
It looks like previously we manually added a few CompositeImplicitAutograd entries to DTensor (e.g. linear), but this PR tries to support these ops a bit more generically.
The main difference is that DTensor now needs to check if a given op is `CompositeImplicitAutograd` before attempting to run sharding prop. I ran a quick microbenchmark for the below code with `timeit`, which gave me overhead on the order of ~1us, which is hopefully not too bad for eager mode:
```
def fast_function():
return torch._C._dispatch_has_kernel_for_dispatch_key(op_call.name(), torch._C.DispatchKey.CompositeImplicitAutograd)
import timeit
time_taken = timeit.timeit(fast_function, number=1000)
# printed 0.12..., aka 1.2us
print(f'func={str(op_call)}, time={str(time_taken)}')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149514
Approved by: https://github.com/kwen2501, https://github.com/albanD, https://github.com/wanchaol
as titled, this PR moves the same mesh check from the sharding propagation level to each individual operator level.
This is to allow more flexibility for each individual operator to check the operator can be run on the same mesh or not. For example, before this PR if user have two DTensor params that lives on different DeviceMesh, and want to run `for_each` operator on them individually, it would error out with cross mesh error. But for foreach computation there could be DTensors that live on different meshes, as long as the the mesh are the same in a "zipped way".
This should also fix https://github.com/pytorch/pytorch/issues/134212
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147869
Approved by: https://github.com/tianyu-l
Resolves https://github.com/pytorch/pytorch/issues/146767.
May also resolve https://github.com/pytorch/pytorch/issues/147584.
### Summary
This PR removes the RNG tracker init from the `distribute_tensor` call for the following reasons:
1. if the user does not use random ops on DTensor, there's no need to init DTensor RNG which currently requires CUDA device to be present.
2. this complies with the 0-communication semantic of `src_data_rank=None` shard distribution.
Besides, `OffsetBasedRNGTracker` only accepts `DeviceMesh` argument to its constructor method.
### Consequence
DTensor RNG initialization is delayed till the first DTensor random ops call or `torch.distributed.tensor.random.manual_seed`.
### Test
`pytest test/distributed/tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`
`pytest test/distributed/tensor/parallel/test_tp_style.py`
Differential Revision: [D70201856](https://our.internmc.facebook.com/intern/diff/D70201856)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147025
Approved by: https://github.com/kwen2501
**Summary**
Added tests for model meta init on 1-d mesh (TP) and 2-d mesh (FSDP+TP). This exploits the issue where DTensor RNG failed to initialize weights differently across FSDP ranks.
**Test**
`pytest test/distributed/_tensor/test_random_ops.py -s -k meta_init`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141731
Approved by: https://github.com/wconstab
reland of https://github.com/pytorch/pytorch/pull/133113
I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(
----
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306