Previously, the new tensor out of the "new factory" all become replicated.
With this PR, if the new tensor has the same shape as the old tensor **and** the shape can be evenly sharded, then the old spec is inherited and preferred.
To accommodate this when the old tensor has sharded placements, the input args for local computation (size, stride) need to be adjusted.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122995
Approved by: https://github.com/wanchaol
as titled, for meta tensor ops, we should avoid calling the RNGTracker,
which could potentially alter the current RNG state. Meta tensor ops
should be no-op and post `to_empty` init would really alter the RNG
state
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125693
Approved by: https://github.com/XilunWu
This PR use str for reduce_op directly instead of the c10d enum. Since
our functional collective already uses str, there's no reason that we
need the c10d enum anymore as that requires a conversion
Also the str hash + eq performance is actually significantly faster than
the c10d type, so this would somewhat improves the CPU overhead too
Some local cpu benchmarks on `1000000` hash operations:
```
Hash performance for string type: 0.039897 seconds
Hash performance for integer type: 0.304665 seconds
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125172
Approved by: https://github.com/awgu, https://github.com/XilunWu, https://github.com/tianyu-l
**Summary**
This PR is attempt to land an experimental feature designed in #103686 . `local_map` is designed to allow users to apply to `DTensor` objects a function that was written to apply to `torch.Tensor`.
As a function, `local_map` takes in 2 required arguments (`func` and `out_placements`) and 3 optional arguments (`device_mesh`, `in_placements`, `redistribute_inputs`). `func` is the function to be applied to each local shard of input `DTensor`. `out_placements` is the sharding specification of output `DTensor`.
`local_map` returns a new function that does the following:
1. Infer `device_mesh` and `in_placements` from `DTensor` input if they're not provided. If `device_mesh` is provided, it must be identical to the device mesh of every `DTensor` input. If `in_placements` is provided, it serves as the required sharding specification of corresponding `DTensor` input before feeding its local shard into `func`. In case it is different from `DTensor`'s sharding specification, if `redistribute_inputs=False` an exception will be raised, otherwise perform a resharding to the required sharding.
2. Call `func` with the arguments passed in along with `device_mesh` except `DTensor`s. For `DTensor`, pass in its local shard. This `func` may include collectives.
3. For each output of `func` that has validate (i.e. not `None) sharding specification in `out_placements`, construct a new `DTensor` using the output and the specification. Use this `DTensor` as the output.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123676
Approved by: https://github.com/wanchaol
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR.
This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way.
Misc changes:
* Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test
* Adds compile support to the ring attention implementations (required some tweaks to process groups)
Test plan:
```
pytest test/distributed/_tensor/test_attention.py
pytest test/distributed/test_functional_api.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124215
Approved by: https://github.com/wanchaol
**Summary**
We wrap DTensor's local tensor in `LocalShardsWrapper` for torchrec's table-wise sharding. The exception is on non-participating ranks: for non-participating ranks, the local tensor is an empty torch.Tensor object. The reason of this design is to avoid complexity on supporting empty tensor case on `LocalShardsWrapper`.
**Test**
`torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e table-wise`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122853
Approved by: https://github.com/wz337
ghstack dependencies: #120265, #121392, #122843
**Summary**
Always wrap local tensor into a `LocalShardsWrapper`. This is for uniformity and it leads to easiness on adoption of DTensor as a wrapper for local shard(s) representation. To support more tensor ops over `LocalShardsWrapper`, users need to extend its `__torch_dispatch__`.
**Test**
`torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e row-wise-even`
**Result**
```
Row-wise even sharding example in DTensor
Col 0-15
------- ----------
Row 0-1 cuda:0
Row 2-3 cuda:1
Row 4-5 cuda:2
Row 6-7 cuda:3
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122843
Approved by: https://github.com/wz337
ghstack dependencies: #120265, #121392
**Summary**
This PR serves as a start of this effort by adding an example test that represents TorchRec's `ShardingType.TABLE_WISE` using DTensor.
**Test**
`torchrun --standalone --nnodes=1 --nproc-per-node=4 torch/distributed/_tensor/examples/torchrec_sharding_example.py -e table-wise`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120265
Approved by: https://github.com/wanchaol
Fixes https://github.com/pytorch/pytorch/issues/122459, https://github.com/pytorch/torchtrain/issues/61
Even with the previous PR ("support DTensor/subclass constructors directly in the graph"), I still see some errors when running the repro above that start some logs showing that dynamo is inlining `__new__`.
I noticed that putting `@torch._dynamo.disable` on DTensor's `__new__` makes the entire repro pass.
Why does having dynamo try to inline `Subclass.__new__` run into problems? Morally, dynamo probably shouldn't be inlining __new__ ("creating a subclass" is a blackbox operation that AOTAutograd can trace through anyway). But concretely, we can end up with a node in the dynamo FX graph that has a "partially initialized tensor subclass" as its example value, because the subclass has been created but its fields have not been assigned to yet.
This breaks a bunch of invariants throughout dynamo: there are many places where if we have a tensor subclass node, we want to look at its inner tensors, to see if they are FakeTensors, what their FakeTensorMode is, and if they have dynamic shapes.
One option is to decide that "uninitialized subclass" is a first-class thing that anyone looking at the FX node examples values on the dynamo graph needs to handle, but this seems like a lot of work when in reality we don't need dynamo to trace the __new__ at all. Hence the `torch._dynamo.disable`.
I still wasn't very satisfied, since it was unclear to me **why** dynamo was inlining the `__new__` call, instead of interposing on the `DTensor()` constructor directly. After a long chat with @anijain2305, he explained that with code like this:
```
@torch._dynamo.disable(recursive=False)
def f(x):
out = SubclassConstructor(x)
```
Dynamo will never get the chance to interpose on the subclass constructor. Instead, what will happen is:
(1) Dynamo hands back control to cpython to run `f()`, since we disabled that frame
(2) `SubclassConstructor(x)` is run in eager mode
(3) `SubclassConstructor(x)` eventually calls `SubclassConstructor__new__`
(4) this is a new frame, that cpython then allows dynamo to intercept and start compiling
So it looks like we are basically forced to handle the situation where dynamo might directly start compiling `Subclass.__new__`
All of the above does not explain the story for `__torch_dispatch__` though. Empirically, I have a repro in torchtrain where looking at the dynamo logs, we see dynamo try to inline `__torch_dispatch__`.
```
[rank0]:DEBUG: Skipping frame because no content in function call _prepare_output_fn /data/users/hirsheybar/b/pytorch/torch/distributed/tensor/parallel/style.py 318
[rank0]:DEBUG: torchdynamo start compiling __torch_dispatch__ /data/users/hirsheybar/b/pytorch/torch/distributed/_tensor/api.py:297, stack (elided 5 frames):
```
I haven't been able to create a smaller repro of the problem (even using `_dynamo.disable(recursive=False)`), although in theory, if there is a `torch.*` op that you were to inline (where one of the inputs is a subclass), the next frame would likely be `__torch_dispatch__`. Dynamo always treats `torch.*` operations as not-inlinable though, so in theory we shouldn't ever see dynamo inline `__torch_dispatch__`, but a `_dynamo.disable()` fixes the problem.
I asked Animesh if we can have dynamo automatically apply this behavior to subclasses instead of needing it to be added explicitly. He pointed out that for `disable(recursive=False)`, we can't really do this within dynamo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123347
Approved by: https://github.com/zou3519
ghstack dependencies: #122502, #122751, #123348
Automatic fixes that replaces certain list comprehensions with generator ones where appropriate so that they are immediately consumed. This is preview functionality in ruff for rule C419 and it was automatically applied.
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123960
Approved by: https://github.com/malfet
Ring attention support for _scaled_dot_product_flash_attention with DTensor.
This assumes the query and key/value are sharded along the sequence length dimension. See the tests for example usage with PT Transformer as well as direct usage with _scaled_dot_product_flash_attention.
## Notable caveats
* Numerical accuracy: The backwards pass doesn't match numerically with the non-chunked version but the forwards pass does. I assume this is due to accumulated errors. I've added a chunked version that uses autograd to verify that the distributed version matches the chunked version.
* nn.Linear has incorrect behavior when running on a sharded tensor of size (bs, heads, seq_len, dim) with `Shard(2)` and does an unnecessary accumulate which requires `Replicate()` on QKV when using `nn.MultiHeadedAttention` to work around the issue.
* If enabled, it forces sequence parallelism and doesn't interop with tensor parallelism.
## SDPA usage
```py
with attention_context_parallel(), sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
dquery = distribute_tensor(query, device_mesh, [Shard(2)])
dkey = distribute_tensor(key, device_mesh, [Shard(2)])
dvalue = distribute_tensor(value, device_mesh, [Shard(2)])
dout: DTensor = torch.nn.functional.scaled_dot_product_attention(
dquery, dkey, dvalue, is_causal=is_causal
)
out = dout.to_local()
```
## Transformer usage
```py
with attention_context_parallel(), sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
encoder_layer = nn.TransformerEncoderLayer(
d_model=dim,
nhead=nheads,
dim_feedforward=dim,
batch_first=True,
).to(dtype)
encoder_layer = parallelize_module(
module=encoder_layer,
device_mesh=device_mesh,
parallelize_plan={
"self_attn": ContextParallel(),
},
)
model = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
```
## Test plan
```
pytest test/distributed/_tensor/test_attention.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122460
Approved by: https://github.com/drisspg, https://github.com/wanchaol
as titled, previously we could possibly return the expected input spec
that shared by multiple args, this is not ok since different args might
have different tensor metas, why it was working before is because
redistribute in these cases become a no-op.
This PR fixes it by making each expected input spec to shallow clone the
corresponding input metadata
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122949
Approved by: https://github.com/tianyu-l
ghstack dependencies: #122929
This PR refactors the schema_suggestions in OuputSharding to be a single
OpSchema instead of list of schemas, which in practice we only have one,
for the multiple resharding case we also moved to OpStrategy so there's
no case that needs it to be a list
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122929
Approved by: https://github.com/tianyu-l
This PR is enough to fix https://github.com/pytorch/pytorch/issues/118600.
More description of the problem is in the issue, but the high-level problem is similar to the "tangents might be non-contiguous" problem that we handle today, via forcing all tangents to be contiguous. There, the problem was something like:
"We guessed the tangent strides incorrectly, because strides on the runtime tangents were different from strides on the forward outputs, which we used to generate tangents"
Here, the problem is similar:
"We guessed the tangent tensor subclass's metadata incorrectly, because the runtime tangent was a subclass with different metadata than the forward output subclass".
This happened in an internal DTensor issue, where the metadata in question was the `placements` (shard vs. replicate vs. Partial).
One option is to solve this problem via backward guards. This is needed to unblock internal though, so I figured handling this similarly to how we handle non-contiguous tangents would be reasonable. I did this by:
(1) Assert that the metadata on subclass tangents is the same as what we guessed, and if not raise a loud error
(2) In the error message, provide the name of an optional method that the subclass must implement to handle this case:
`def __force_same_metadata__(self, metadata_tensor):`: If the forward output had a `Replicate()` placement, but the runtime tangent had a `Shard(1)` placement, this method allows a subclass to take the tangent and "convert" it to one with a `Replicate()` placement.
`__force_standard_metadata__(self)`: One issue is that there is another placement called `_Partial`, and its semantics are such that DTensor is **unable** to convert a DTensor with some placement type into another DTensor with a `_Partial` placement.
`__force_standard_metadata__` is now called on all (fake) subclass forward outs at trace-time to generate tangents, and gives subclasses a chance to "fix" any outputs with metadata that they cannot convert to later. Morally, this is similar to the fact that we force a `contiguous()` call on all tangents at trace-time.
I'm interested in thoughts/feedback! Two new dunder methods on traceable subclasses is definitely a contentious change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118670
Approved by: https://github.com/ezyang
We would like to improve consistency for nn_module_stack metadata in torch.export.
This PR ensures that all tests in test/export/test_export.py has the following constraints:
- Remove nn_module_stack for all placeholder & output nodes, for all modules and submodules
- Ensure nn_module_stack is present for all other node types for the top-level module (there is still an issue with torch.cond submodules having empty fields)
- Add these checks to _export() in _trace.py (we would add this in the Verifier, but downstream apps construct ExportedPrograms separate from _export(), and metadata may not be maintained there)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120661
Approved by: https://github.com/avikchaudhuri
Enable ASGD foreach optimizer and add DTensor optimizer unit test for ASGD.
Note that we need to investigate why when using ASGD we need higher atol and rtol when comparing model parameters. Listing it as a TODO now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121942
Approved by: https://github.com/wanchaol
This PR adds support for 2D `clip_grad_norm_` (`foreach=True`).
- This PR changes `OpSchema.args_spec` to use pytree if the runtime schema info specifies it.
- This PR includes a unit test for 2D FSDP2 + SP with `clip_grad_norm_` enabled, which serves as a complete numerics test for 2D.
Note: With this PR patched, 2-way SP + 4-way FSDP matches 8-way FSDP numerics on Llama-7B (doubling local batch size for the 2-way SP run).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121945
Approved by: https://github.com/wanchaol
ghstack dependencies: #121747, #121869
This PR rewrite the stack strategy to be more generalized, basically
stack/cat like strategy follow pattern need to be smarter, i.e. it
should be able to identify:
1. PR, PP, RP -> follow PP
2. RR, SR, RS -> follow SS
So this PR refactors how the follow strategy should work, and make sure
we start following the strategy that incurred lowest cost. i.e. for
multiple PR, RP placements, we should be able to further delay the
pending sum reductions
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121869
Approved by: https://github.com/awgu