Fixes#142058
## Summary
DTensor `convolution_backward` op throws exception when the input Tensor has `requires_grad=False` which happens if the conv layer is the first layer in the model.
ATEN convolution_backward op Usually returns 3 Tensors (grad_input, grad_weight, grad_bias) and the `grad_input` is actually an Optional[Tensor] which can be `None` in the case mentioned above.
However, the DTensor sharding propagation rule and corresponding TP conv backward implementation both assume that the `grad_input` would be existent.
## Fix
allow the `grad_input` to be `None` for `convolution_backward` op.
## Test
`pytest test/distributed/tensor/test_convolution_ops.py`
## Follow-up
The current implementation of DTensor conv op also ignores `output_mask` and this may need further care.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142278
Approved by: https://github.com/bdhirsh
**Description**
This is an example of how FlexAttention can be used in a context parallel fashion. Right now it's only a flex_attention call with collectives added and has no load balancer, but we're about to add the missing parts step by step:
1. backward pass
2. static load balancing for causal masking
3. dynamic load balancing for other general maskings
4. automatic collective insertion solution
5. non-intrusive context parallel APIs
**Test**
`torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/tensor/examples/flex_attention_cp.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145896
Approved by: https://github.com/fegin, https://github.com/Skylion007
This is done by copying the one for a regular mm, and enforcing that the scales have the same sharding scheme as their respective operands. This works because scales are 2-d tensors that must "broadcast" to the operands. This broadcasting is trivial when scales have dimensions of 1 or N, which is the only options we currently support.
Note, however, that after this PR scales will be allowed to have the mesh's world size as a dimension (in certain cases). This works because, when mapped to the local shard, it becomes a dimension of 1, which can be handled by the operator. Note that when using row-wise _scaled_mm for tensor (sequence) parallelism, this situation arises naturally!
Because of these specificities, the test is rather complex, as it specifically tests all these behaviors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143760
Approved by: https://github.com/tianyu-l
as titled, this PR expose this dunder method as a public API in the doc,
so that different checkpoint implementations can leverage this protocol,
instead of exposing a separate API
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144100
Approved by: https://github.com/awgu
ghstack dependencies: #144099
as titled, this PR propagates the src_data_rank in the TP API, so that
module level APIs could leverage the flexibility to choose
src_data_rank, and avoid the communication if it does not need to
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144005
Approved by: https://github.com/tianyu-l
ghstack dependencies: #143883
As titled, this PR add a kwarg src_data_rank to the distribute_tensor
API, to allow user specify a specific rank as the full tensor source
data. Previously we by default specify group_rank=0 as the source of
truth for single device semantic, this new option:
* gives advanced user flexiblity to choose the source data rank
* allow user to specify None explicity, which means we will skip the
communications needed (scatter/broadcast) for the cases that does not
care about single device semantic (i.e. loading from a checkpoint)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143883
Approved by: https://github.com/XilunWu, https://github.com/tianyu-l
Changes:
1. Bump `ruff` from 0.7.4 to 0.8.4
2. Change `%`-formatted strings to f-string
3. Change arguments with the `__`-prefix to positional-only arguments with the `/` separator in function signature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143753
Approved by: https://github.com/Skylion007
**Summary**
This PR adds a new experimental API `set_rotate_method` for Context Parallel. This API allows user to choose the desired communication method (between all-to-all and all-gather) for shards rotation.
**Test**
`pytest test/distributed/_tensor/test_attention.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142093
Approved by: https://github.com/fegin
Locally shards a full tensor based on indicated sharding arrangement, and returns a DTensor containing the local shard.
warning: This is a private API purposed to skip the communication otherwise required by `distribute_tensor`. It is only applicable to a case where all ranks have the same `full_tensor`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142288
Approved by: https://github.com/wz337
**Summary**
DTensor RNG code raises error if the seed passed in is beyong `torch.int64` range (e.g. `torch.tensor([2**64-1])` raises error). The solution is to specify the `dtype=torch.uint64` in the `torch.tensor()` call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141532
Approved by: https://github.com/wconstab
ghstack dependencies: #141731, #141220, #141223
**Summary**
This PR proposes 4 changes to DTensor RNG management:
1. DTensor allows users to eagerly initialize the RNG tracker by calling `torch.distributed.tensor._random.manual_seed`.
2. DTensor `manual_seed` no longer checks the integrity of the `seed` argument. Users are responsible for setting the same seed on all ranks within an SPMD group, but if there are multiple separate SPMD groups (e.g. across pipeline stages), users should set a _different_ seed for each SPMD group. For cases like Pipeline Parallel, users can set different initial seed for pipelining stages by calling
```
world_mesh = init_device_mesh(
device_type="cuda",
mesh_shape=(2, 2, 2),
mesh_dim_names=("pp", "dp", "tp"),
)
pp_mesh = world_mesh["pp"]
pp_rank = pp_mesh.get_local_rank()
spmd_mesh = world_mesh["dp", "tp"]._flatten("spmd") # this flattening is only needed if you need to call collective over this mesh
torch.distributed.tensor._random.manual_seed(123+pp_rank, spmd_mesh)
```
In other word, if users want to call `torch.distributed.tensor._random.manual_seed`, they will be responsible for passing in the right value and DTensor won't perform any checks on it. If the current rank is not a part of the mesh, it will use the current device RNG state to initialize.
3. `OffsetBasedRNGTracker` still performs RNG state synchronization by broadcasting the RNG state on rank 0 to `WORLD`. However, calling `torch.distributed.tensor._random.manual_seed` is an exception. In this case, no broadcast will happen.
4. Enforce that the `manual_seed` call only accept "full mesh" i.e. the DTensor RNG state on every rank must be set through the call. This makes sure that no rank has its RNG state left uninitialized and the SPMD ranks have their RNG state synchronous.
**Motivation**
tl;dr
1. Lazily initializing DTensor RNG tracker causes hang in non-SPMD code such as Pipeline Parallel.
2. Users may want to set different seed on ranks in one device mesh.
3. We want to keep the old behavior if users prefer not curating the RNG state and want to have DTensor take care of it.
see detail in https://github.com/pytorch/pytorch/issues/140301
**Test**
`pytest test/distributed/_tensor/test_random_ops.py`
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141223
Approved by: https://github.com/wconstab
ghstack dependencies: #141731, #141220
**Summary**
The ad-hoc DTensor RNG tracker was used to mimic Megatron DDP+TP RNG behavior but it turns out not compatible with PyTorch Distributed FSDP2+TP so we decide to deprecate it and use `OffsetBasedRNGTracker` to replace, which follows the SPMD semantics (replicas get the same random sampling result, shards get different results).
**Motivation**
`TensorParallelRNGTracker` was designed for DDP+TP where the random operators produce the same result along the data parallel mesh dimension and different results along the tensor parallel dimension. However this does not apply to the new FSDP+TP composable combination where the model weights are sharded along data parallel mesh dimension as well. Therefore we decide to remove this outdated RNG tracker type for now. If users have demands for exact match between PyTorch Distributed and Megatron on Random Number generation result, feel free to file an issue.
**Impact**
`TensorParallelRNGTracker` was only used when Tensor Parallel is used (i.e. calling `parallelize_module`).
For non-FSDP users, the "replicas get the same random numbers and shards get different ones" remains unchanged. Unlike `TensorParallelRNGTracker` which sets different seeds (`base_seed + 2718 + TP_rank`) within the TP group, DTensor now sets the same seed (default value is 1234 but users can call `torch.distributed.tensor._random.manual_seed` to modify) on all ranks but choose the right RNG offset based on DTensor placements to enforce the "replicas get the same random numbers and shards get different ones" invariant.
For FSDP2 users, improvement should be observed in a way that DTensor sharded within DP group now gets different random number sampling which `TensorParallelRNGTracker` failed to do, though we're not sure how much this change will improve the eventual training loss convergence.
**Test**
1-d model weight meta init:
`pytest test/distributed/_tensor/test_random_ops.py -s -k test_tp_model_meta_init`
2-d model weight meta init:
`pytest test/distributed/_tensor/test_random_ops.py -s -k test_fsdp_tp_model_meta_init`
TP model weight init test:
`pytest test/distributed/tensor/parallel/test_tp_random_state.py`
FSDP+TP model weight init test:
`pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141220
Approved by: https://github.com/wconstab
ghstack dependencies: #141731
**Summary**
Added tests for model meta init on 1-d mesh (TP) and 2-d mesh (FSDP+TP). This exploits the issue where DTensor RNG failed to initialize weights differently across FSDP ranks.
**Test**
`pytest test/distributed/_tensor/test_random_ops.py -s -k meta_init`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141731
Approved by: https://github.com/wconstab
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
Based on discussion here: https://github.com/pytorch/pytorch/pull/138731
Introducing ability for subclass implement type convertion to expected_type.
```
def __coerce_same_metadata_as_tangent__(
self, expected_metadata: Any, expected_type: Optional[Type] = None
):
```
Here if `expected_type=None` means `SubclassClass` is expected.
E.g. for `DTensor` we may find tangent `AsyncCollectiveTensor` where we expected `Tensor` - in this case
`expected_type=Tensor` will be called during runtime
Adding implementation to AsyncCollectiveTensor, that just triggers `wait()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139095
Approved by: https://github.com/bdhirsh
Fixes#138742. In the issue, the matrix multiplication with DTensor failed when the size of one of mesh dimension is 1 when the mesh is > 1D. We are missing tests for covering this corner case where mesh_shape is (n, 1) or (1, n). The DTensor mm op is correct when the 1D mesh is of shape (self.world_size, ) or 2D mesh with none of the mesh_dimension has a size of 1.
In this PR, we fixed the corner case by updating `gen_einsum_strategies` in `_einsum_strategy.py`. Specifically, we cannot skip generating `mesh_dim_strategies` when `mesh_dim <= 1`, as this is not valid for nD mesh with one of the mesh dimension sizes being 1.
Without the fix, the OpStrategy generated for 2D mesh with mesh_shape of (1,n) or (n,1) is wrong, as the OpStrategy generated is 1D.
```
all_mesh_dim_strategies=[[[Replicate(), Replicate(), Replicate()], [Partial(sum), Shard(dim=1), Shard(dim=0)], [Shard(dim=0), Shard(dim=0), Replicate()], [Shard(dim=1), Replicate(), Shard(dim=1)]]]
OpStrategy(all_strategies)::: [(R, R) -> R, (S(1), S(0)) -> P, (S(0), R) -> S(0), (R, S(1)) -> S(1)] @ mesh: (4, 1)[(R, R) -> R, (S(1), S(0)) -> P, (S(0), R) -> S(0), (R, S(1)) -> S(1)] @ mesh: (4, 1)
```
After the fix, we can see the OpStrategy generated is correct with 2D strategy.
```
all_mesh_dim_strategies=[[[Replicate(), Replicate(), Replicate()], [Partial(sum), Shard(dim=1), Shard(dim=0)], [Shard(dim=0), Shard(dim=0), Replicate()], [Shard(dim=1), Replicate(), Shard(dim=1)]]][[[Replicate(), Replicate(), Replicate()], [Partial(sum), Shard(dim=1), Shard(dim=0)], [Shard(dim=0), Shard(dim=0), Replicate()], [Shard(dim=1), Replicate(), Shard(dim=1)]]]
OpStrategy(all_strategies) = [(RR, RR) -> RR, (RS(1), RS(0)) -> RP, (RS(0), RR) -> RS(0), (RR, RS(1)) -> RS(1), (S(1)R, S(0)R) -> PR, (S(1)S(1), S(0)S(0)) -> PP, (S(1)S(0), S(0)R) -> PS(0), (S(1)R, S(0)S(1)) -> PS(1), (S(0)R, RR) -> S(0)R, (S(0)S(1), RS(0)) -> S(0)P, (S(0)S(0), RR) -> S(0)S(0), (S(0)R, RS(1)) -> S(0)S(1), (RR, S(1)R) -> S(1)R, (RS(1), S(1)S(0)) -> S(1)P, (RS(0), S(1)R) -> S(1)S(0), (RR, S(1)S(1)) -> S(1)S(1)] @ mesh: (4, 1)
```
*******
As a follow up, we should add more test coverage for DTensor op with 2D mesh and 2D mesh with one of the size of mesh dimension being 1.
*******
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139134
Approved by: https://github.com/fegin
Summary:
This implementation does not utilize the benefit that after allgather we can directly perform the SDPA without doing the ring-based SDPA, but we can overlap the communication with the first sharded kv computation. This implementation shows some performance benefit and memory saving compared to the original alltoall implementation in certain cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132820
Approved by: https://github.com/XilunWu