Previously, the new tensor out of the "new factory" all become replicated.
With this PR, if the new tensor has the same shape as the old tensor **and** the shape can be evenly sharded, then the old spec is inherited and preferred.
To accommodate this when the old tensor has sharded placements, the input args for local computation (size, stride) need to be adjusted.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122995
Approved by: https://github.com/wanchaol
This PR use str for reduce_op directly instead of the c10d enum. Since
our functional collective already uses str, there's no reason that we
need the c10d enum anymore as that requires a conversion
Also the str hash + eq performance is actually significantly faster than
the c10d type, so this would somewhat improves the CPU overhead too
Some local cpu benchmarks on `1000000` hash operations:
```
Hash performance for string type: 0.039897 seconds
Hash performance for integer type: 0.304665 seconds
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125172
Approved by: https://github.com/awgu, https://github.com/XilunWu, https://github.com/tianyu-l
Ring attention support for _scaled_dot_product_flash_attention with DTensor.
This assumes the query and key/value are sharded along the sequence length dimension. See the tests for example usage with PT Transformer as well as direct usage with _scaled_dot_product_flash_attention.
## Notable caveats
* Numerical accuracy: The backwards pass doesn't match numerically with the non-chunked version but the forwards pass does. I assume this is due to accumulated errors. I've added a chunked version that uses autograd to verify that the distributed version matches the chunked version.
* nn.Linear has incorrect behavior when running on a sharded tensor of size (bs, heads, seq_len, dim) with `Shard(2)` and does an unnecessary accumulate which requires `Replicate()` on QKV when using `nn.MultiHeadedAttention` to work around the issue.
* If enabled, it forces sequence parallelism and doesn't interop with tensor parallelism.
## SDPA usage
```py
with attention_context_parallel(), sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
dquery = distribute_tensor(query, device_mesh, [Shard(2)])
dkey = distribute_tensor(key, device_mesh, [Shard(2)])
dvalue = distribute_tensor(value, device_mesh, [Shard(2)])
dout: DTensor = torch.nn.functional.scaled_dot_product_attention(
dquery, dkey, dvalue, is_causal=is_causal
)
out = dout.to_local()
```
## Transformer usage
```py
with attention_context_parallel(), sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=nheads,
dim_feedforward=dim,
batch_first=True,
).to(dtype)
encoder_layer = parallelize_module(
module=encoder_layer,
device_mesh=device_mesh,
parallelize_plan={
"self_attn": ContextParallel(),
},
)
model = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
```
## Test plan
```
pytest test/distributed/_tensor/test_attention.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122460
Approved by: https://github.com/drisspg, https://github.com/wanchaol
This PR refactors the schema_suggestions in OuputSharding to be a single
OpSchema instead of list of schemas, which in practice we only have one,
for the multiple resharding case we also moved to OpStrategy so there's
no case that needs it to be a list
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122929
Approved by: https://github.com/tianyu-l
Enable ASGD foreach optimizer and add DTensor optimizer unit test for ASGD.
Note that we need to investigate why when using ASGD we need higher atol and rtol when comparing model parameters. Listing it as a TODO now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121942
Approved by: https://github.com/wanchaol
This PR adds support for 2D `clip_grad_norm_` (`foreach=True`).
- This PR changes `OpSchema.args_spec` to use pytree if the runtime schema info specifies it.
- This PR includes a unit test for 2D FSDP2 + SP with `clip_grad_norm_` enabled, which serves as a complete numerics test for 2D.
Note: With this PR patched, 2-way SP + 4-way FSDP matches 8-way FSDP numerics on Llama-7B (doubling local batch size for the 2-way SP run).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121945
Approved by: https://github.com/wanchaol
ghstack dependencies: #121747, #121869
This PR rewrite the stack strategy to be more generalized, basically
stack/cat like strategy follow pattern need to be smarter, i.e. it
should be able to identify:
1. PR, PP, RP -> follow PP
2. RR, SR, RS -> follow SS
So this PR refactors how the follow strategy should work, and make sure
we start following the strategy that incurred lowest cost. i.e. for
multiple PR, RP placements, we should be able to further delay the
pending sum reductions
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121869
Approved by: https://github.com/awgu
This PR adds support for `clip_grad_norm_(foreach=True)` by implementing `aten._foreach_norm.Scalar` and `aten._foreach_mul_.Tensor`. `foreach=True` is required to get competitive performance with `DTensor`.
`foreach=True` reduces CPU overhead for Llama-7B from 388 ms to 63 ms. Existing flat-parameter FSDP's `clip_grad_norm_` takes 3 ms on CPU 😢 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120910
Approved by: https://github.com/wanchaol, https://github.com/janeyx99
ghstack dependencies: #120238
This PR adds `DTensor` support for `aten.linalg_vector_norm.default` and `aten.stack.default` so that we can run `clip_grad_norm_` (with `foreach=False`).
To implement `linalg_vector_norm`, we introduce a `_NormPartial` placement since the reduction op for norm is the norm itself.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120238
Approved by: https://github.com/wanchaol
This PR refactors the tuple strategy handling logic, and allow
TupleStrategy to have both input/output specs for each OpStrategy child,
so that we could further enable operators like foreach norm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120695
Approved by: https://github.com/awgu
This fixes an internal DTensor enablement bug (I don't have an OSS issue for it)
I finally root-caused this as follows:
(1) we were fakefying a DTensor graph input, that was an autograd non-leaf (it had a grad_fn)
(2) that caused it do go through this `clone()` call during fakeification: https://github.com/pytorch/pytorch/blob/main/torch/_subclasses/meta_utils.py#L549
(3) `clone(torch.preserve_format)` is supposed to return another DTensor with the same strides as the input, but I noticed we were returning a DTensor with contiguous strides incorrectly.
(4) It turns out that DTensor was hashing on the sharding strategy for `aten.clone`, regardless of the `memory_format` kwarg that was passed in.
I could have manually updated the `clone` sharding strategy registration to take `memory_format` into account. But instead, I figured that every aten op with a sharding strategy needs to handle the memory_format kwarg specially - so I tried to generically force DTensor to consider all ATen ops that take a `memory_format` kwarg during hashing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118667
Approved by: https://github.com/wanchaol
ghstack dependencies: #117667, #117666, #118209, #118191
As titled. This is a followup to PR #118917 on nll_loss_forward. It also fixes an issue in it: the forward function produces two return values, the loss `result` and the `total_weight`. The previous PR didn't explicitly deal with the `total_weight` part.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119256
Approved by: https://github.com/wanchaol
This is part of the work to support cross entropy in dtensor.
This PR doesn't support nll_loss computation with input sharded on the channel dimension yet. In that case, redistribution to Replicate is needed in sharding propagation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118917
Approved by: https://github.com/wanchaol
Adding an `OpInfo` test for `split_with_sizes_copy` so we can use it to test [CUDA fast path for split_with_sizes_copy.out](https://github.com/pytorch/pytorch/pull/117203). Since the `OpInfo` test doesn't exist yet and introducing it requires modifications to the `CompositeExplicitAutograd` impl, adding the `OpInfo` test in a separate PR to establish a healthy baseline.
Changes made:
- Registered a batching rule for `split_with_sizes_copy`.
- Registered a decomposition for `split_with_sizes_copy`.
- Registered a DTensor prop rule for `split_with_sizes_copy`.
- Added required dtype and device checks to the composite impl.
- Added output resize to the composite impl.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118512
Approved by: https://github.com/albanD
This PR add support for rowwise sharded embedding by adding a
MaskPartial placement that inherits from the default partial placement,
and override the Partial constracts to construct the mask and release
the mask after the reduction
The MaskPartial placement have the potential to support other ops
sharding computation that requires a mask for semantic correctness.
currently make it live in the embedding ops but we can move it to a
common place if needed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118080
Approved by: https://github.com/tianyu-l
ghstack dependencies: #118079
This PR rewrites sharded embedding rule to use OpStrategy instead of the
rule, one step further to get rid of rules and consolidate the embedding
operator implementation, to prepare for rowwise embedding
implementation, which will come in next PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118079
Approved by: https://github.com/tianyu-l
This PR add support for rowwise sharded embedding by adding a
MaskPartial placement that inherits from the default partial placement,
and override the Partial constracts to construct the mask and release
the mask after the reduction
The MaskPartial placement have the potential to support other ops
sharding computation that requires a mask for semantic correctness.
currently make it live in the embedding ops but we can move it to a
common place if needed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118080
Approved by: https://github.com/tianyu-l
ghstack dependencies: #118079
This PR rewrites sharded embedding rule to use OpStrategy instead of the
rule, one step further to get rid of rules and consolidate the embedding
operator implementation, to prepare for rowwise embedding
implementation, which will come in next PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118079
Approved by: https://github.com/tianyu-l
**Summary**
Previously DTensor sharding plans filter (i.e. `is_tensor_shardable()`) cannot correctly handle the case where the input `DTensor` has 0 dimension. This filter should return `True` if the sharding placement on 0 dimension is `Replicate` even if `tensor dim < num of shards` on that dimension in which case `tensor dim == 0` and `num of shards == 1`.
In this PR we also noticed a behavior discrepancy of `torch.addmm`. See #118131
**Test Plan**
```
pytest test/distributed/_tensor/test_dtensor_ops.py -s -k addmm
pytest test/distributed/_tensor/test_dtensor_ops.py -s -k mm_cpu_float32
CUDA_VISIBLE_DEVICES="" pytest test/distributed/_tensor/test_matrix_ops.py -s -k empty_operand
pytest test/distributed/_tensor/test_matrix_ops.py -s -k empty_operand
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117726
Approved by: https://github.com/wanchaol
**Summary**
This PR switches the softmax and log_softmax ops to use OpStrategy instead of rules. This PR also adds support when the softmax dimension is sharded -- a replication is performed before computation.
**Test**
`python test/distributed/_tensor/test_math_ops.py -k test_softmax_fwd`
`python test/distributed/_tensor/test_math_ops.py -k test_softmax_with_bwd`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117723
Approved by: https://github.com/XilunWu
**Summary**:
Ops like `native_layer_norm_backward` return a tuple of optional torch.Tensor.
This PR allows to use OpStrategy to represent `native_layer_norm_backward`'s
return value sharding.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115682
Approved by: https://github.com/wanchaol
Summary: This diff is only for prototype to unblock the TP work. PyTorch distributed team is working on a more generic backward op for `aten.layer_norm`. Will remove this op from the experimental file once it is ready.
Test Plan:
**Local Test**:
Accuracy:
- Dtensor + Checkpoint: first run loss: P884569822 (on-par with baseline: P884213363)
- 2nd by loading saved checkpoint: P884583429 (on-par with baseline: P884271869)
Trace:
- Collective functions are inserted automatically.
- Example: https://fburl.com/perfdoctor/l567ww1x
**MAST Test**:
With: trainer = 128, batch_size=512
- NE on-par:
(see: 4441_ep_bs512_2fsdp_tp_sp_dtensor)
{F1155318138}
Differential Revision: D51490868
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115398
Approved by: https://github.com/wanchaol