### Summary
@LucasLLC recently implemented `broadcast` in funcol. This is not yet available in the native funcol ops. This PR adds support for broadcast for native funcol.
- Added `_c10d_functional::broadcast` and `_c10d_functional::broadcast_`
- Integrated with python functol broadcast and `AsyncCollectiveTensor`
- Implemented Inductor lowering. Verified correctness and buffer reuse behavior
- Validated dynamo traceability
- Validated AOTInductor compile-ability
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119229
Approved by: https://github.com/wanchaol
ghstack dependencies: #119104
### Summary
Run the relevant tests in `test/distributed/_tensor/test_dtensor_compile.py` and `test/distributed/test_device_mesh.py` with native funcol enabled, in addition to with them being disabled.
All tests excepts `test_tp_compile_comm_reordering` pass. This is expected because the native funcols have slightly different IRs, so the reordering pass needs to be adjusted. This test is disabled for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118437
Approved by: https://github.com/LucasLLC
ghstack dependencies: #118910, #118911
### Summary
- Added `group_name` as the third field in `dim_group_infos`.
- `DeviceMeshTest` now runs both w/ and w/0 `_USE_NATIVE_C10D_FUNCTIONAL=1` in CI.
### Other fixes
- Convert `reduceOp` to lower case before passing it into c10d_functional ops.
- Added a finalizer to handle unwaited collectives (this mirrors the treatment for Python functional collective ops).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118423
Approved by: https://github.com/wanchaol, https://github.com/LucasLLC, https://github.com/wconstab
This diff introduces an env var `_USE_NATIVE_C10D_FUNCTIONAL` that tells `_functional_collective` to use native `c10d_functional` ops. The Python version and the native version will co-exist until we completely switch to the native version after more testing and verification.
NOTE: `DeviceMesh` support for native `c10d_functional` will be added in a subsequent PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113057
Approved by: https://github.com/LucasLLC, https://github.com/wconstab, https://github.com/wanchaol
Summary:
- Ported `all_to_all_single` to native c10d_functional
- Added Inductor support for the native `all_to_all_single` via the new collective IR's `create_out_of_place()`
- Since the new collective IR derives from `FallbackKernel` which implements a generic `free_unbacked_symbols`, no additional unbacked symbol handling for all_to_all_single is required
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113438
Approved by: https://github.com/yf225, https://github.com/ezyang
This diff aims to directly import DeviceMesh from torch.distributed.device_mesh instead of importing it from dist._tensor. This is done to avoid a circular dependency issue. The code changes in each file of the diff are as follows:
- torch/distributed/_functional_collectives.py: import DeviceMesh from torch.distributed instead of dist._tensor.
Overall, this diff aims to improve the code by avoiding circular dependencies and improving the import statements.
==
The above summary is generated by LLM with minor manual fixes. The following summary is by me.
The original import causes some issues when compiling DDP with compiled_autograd. The root cause of compilation failure is not identified but it is good to fix the lazy initialization, which indirectly fixes the compilation issues for DDP.
Differential Revision: [D51857246](https://our.internmc.facebook.com/intern/diff/D51857246/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115649
Approved by: https://github.com/wconstab, https://github.com/wz337
ghstack dependencies: #115523, #115302, #115648
Apply a few optimizations to funcol:
- allgather on non-0 dim, the resulting tensor already needs to access
data in order to do torch.cat, so we sync wait here so that we don;t
need to go through ACT dispatch for chunk + cat alltogether
- have a fast return logic to aten.view as it's a commonly hit op for
view related ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113324
Approved by: https://github.com/XilunWu
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).
Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
* Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
* Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
* Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
* Signatures now:
```python
# attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
# ctx is anything useful for rebuilding the class we want to guard on
attrs, ctx = x.__tensor_flatten__()
...
# inner_tensors is a dict of {attr -> tensor}
# ctx is taken unmodified from flattening and (eventually) guarded on
# outer_size is the expected size of the output; possibly symbolic
# outer_stride is the expected strides of the output; possibly symbolic
y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
# at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
# the assert simplifies symbols when there are relationships between outer and inner symbols
```
* Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
* Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
* Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
Apply a few optimizations to funcol:
- allgather on non-0 dim, the resulting tensor already needs to access
data in order to do torch.cat, so we sync wait here so that we don;t
need to go through ACT dispatch for chunk + cat alltogether
- have a fast return logic to aten.view as it's a commonly hit op for
view related ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113324
Approved by: https://github.com/XilunWu
ghstack dependencies: #113323
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.
Before:
```text
torch
├── utils
│ ├── _pytree.py
│ ├── _cxx_pytree.py
│ ...
...
```
After:
```text
torch
├── utils
│ ├── _pytree
│ │ ├── __init__.py
│ │ └── api
│ │ ├── __init__.py
│ │ ├── cxx.py
│ │ └── python.py
│ ...
...
```
The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
This PR adds Inductor support for [native c10d_functional ops](https://github.com/pytorch/pytorch/pull/110570).
The Inductor IRs introduced in this PR will replace the existing `CollectiveKernel` IR hierarchy. Compared to the existing collective IRs, the new IRs:
- Are target language agnostic and support AOTInductor.
- Express the constraints solely with read/write deps. This maximizes the potential for buffer reuse.
- Address an issue where out-of-place collective's input buffers could be mutated while being volatilely read.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112439
Approved by: https://github.com/Chillee
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.
Before:
```text
torch
├── utils
│ ├── _pytree.py
│ ├── _cxx_pytree.py
│ ...
...
```
After:
```text
torch
├── utils
│ ├── _pytree
│ │ ├── __init__.py
│ │ └── api
│ │ ├── __init__.py
│ │ ├── cxx.py
│ │ └── python.py
│ ...
...
```
The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
This PR introduces a native version of c10d_functional ops. The main goal is to add collective support in AOTInductor and allow collective ops to work in multi-threaded native runtimes.
The native version also incorporated API improvements we wished to implement in Python c10d_functional:
- Removed `ranks` and `group_size` from collective op signatures which were proven to be redundant.
- Use tensor storage as opposed to `void*` to resolve in-flight work.
The native process group registration/resolution mechansim is only used for native c10d_functional in the PR. It will become the single source of truth in upcoming PRs.
The upcoming PRs will implement Inductor/AOTInductor support for c10d_functional, after which native c10d_functional will replace Python c10d_functional.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110570
Approved by: https://github.com/wanchaol
This PR updates DTensor to support torch.compile
Cool stuff: there are some new tests in `test_dtensor.py` that show both the forward and backward graphs that we can send to inductor, when running a matmul with DTensor's. In particular, for this user code:
```
def fn(x, y):
dt = DTensor.from_local(x.reshape(2, 4), mesh, [Shard(0)], run_check=False)
dt2 = DTensor.from_local(y.reshape(4, 2), mesh, [Shard(1)], run_check=False)
dt_out = torch.matmul(dt, dt2)
dt_out_redistribute = dt_out.redistribute(mesh, [Replicate()])
return dt_out.to_local()
```
We generate the following fw and backward graphs.
Forward graph:
```
def forward(self, primals_1, primals_2):
view = torch.ops.aten.view.default(primals_1, [2, 4]); primals_1 = None
_to_copy = torch.ops.aten._to_copy.default(view, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0)); view = None
detach = torch.ops.aten.detach.default(_to_copy); _to_copy = None
detach_1 = torch.ops.aten.detach.default(detach); detach = None
view_1 = torch.ops.aten.view.default(primals_2, [4, 2]); primals_2 = None
_to_copy_1 = torch.ops.aten._to_copy.default(view_1, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0)); view_1 = None
detach_2 = torch.ops.aten.detach.default(_to_copy_1); _to_copy_1 = None
detach_3 = torch.ops.aten.detach.default(detach_2); detach_2 = None
detach_4 = torch.ops.aten.detach.default(detach_1)
all_gather_into_tensor = torch.ops.c10d_functional.all_gather_into_tensor.default(detach_3, 'ptd:0', [0, 1], 2)
wait_tensor = torch.ops.c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None
split = torch.ops.aten.split.Tensor(wait_tensor, 4); wait_tensor = None
getitem = split[0]
getitem_1 = split[1]; split = None
cat = torch.ops.aten.cat.default([getitem, getitem_1], 1); getitem = getitem_1 = None
detach_5 = torch.ops.aten.detach.default(cat); cat = None
mm = torch.ops.aten.mm.default(detach_4, detach_5); detach_4 = detach_5 = None
detach_6 = torch.ops.aten.detach.default(mm); mm = None
detach_9 = torch.ops.aten.detach.default(detach_6); detach_6 = None
detach_10 = torch.ops.aten.detach.default(detach_9); detach_9 = None
t = torch.ops.aten.t.default(detach_1); detach_1 = None
detach_13 = torch.ops.aten.detach.default(t); t = None
t_1 = torch.ops.aten.t.default(detach_3); detach_3 = None
detach_15 = torch.ops.aten.detach.default(t_1); t_1 = None
clone = torch.ops.aten.clone.default(detach_15, memory_format = torch.contiguous_format); detach_15 = None
return [detach_10, detach_13, clone]
```
Backward graph:
```
def forward(self, detach_13, clone, tangents_1):
detach_11 = torch.ops.aten.detach.default(tangents_1); tangents_1 = None
detach_12 = torch.ops.aten.detach.default(detach_11); detach_11 = None
mm_1 = torch.ops.aten.mm.default(detach_13, detach_12); detach_13 = None
detach_14 = torch.ops.aten.detach.default(mm_1); mm_1 = None
detach_16 = torch.ops.aten.detach.default(detach_12); detach_12 = None
all_gather_into_tensor_2 = torch.ops.c10d_functional.all_gather_into_tensor.default(clone, 'ptd:0', [0, 1], 2); clone = None
wait_tensor_2 = torch.ops.c10d_functional.wait_tensor.default(all_gather_into_tensor_2);
detach_17 = torch.ops.aten.detach.default(wait_tensor_2); wait_tensor_2 = None
mm_2 = torch.ops.aten.mm.default(detach_16, detach_17); detach_16 = detach_17 = None
detach_18 = torch.ops.aten.detach.default(mm_2); mm_2 = None
split_1 = torch.ops.aten.split.Tensor(detach_14, 2, 1); detach_14 = None
getitem_2 = split_1[0]
getitem_3 = split_1[1]; split_1 = None
cat_1 = torch.ops.aten.cat.default([getitem_2, getitem_3]); getitem_2 = getitem_3 = None
reduce_scatter_tensor = torch.ops.c10d_functional.reduce_scatter_tensor.default(cat_1, 'SUM', 'ptd:0', [0, 1], 2); cat_1 = None
wait_tensor_3 = torch.ops.c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None
detach_19 = torch.ops.aten.detach.default(wait_tensor_3); wait_tensor_3 = None
detach_20 = torch.ops.aten.detach.default(detach_19); detach_19 = None
detach_21 = torch.ops.aten.detach.default(detach_20); detach_20 = None
detach_22 = torch.ops.aten.detach.default(detach_21); detach_21 = None
_to_copy_2 = torch.ops.aten._to_copy.default(detach_22, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); detach_22 = None
view_2 = torch.ops.aten.view.default(_to_copy_2, [8]); _to_copy_2 = None
detach_23 = torch.ops.aten.detach.default(detach_18); detach_18 = None
detach_24 = torch.ops.aten.detach.default(detach_23); detach_23 = None
_to_copy_3 = torch.ops.aten._to_copy.default(detach_24, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); detach_24 = None
view_3 = torch.ops.aten.view.default(_to_copy_3, [8]); _to_copy_3 = None
return [view_3, view_2]
```
Some of the stuff in this graph looks kinda of silly though (e.g. an unnecessary split() + cat(), and all the extra detach() calls).
Stuff that's broken:
- functionalization is pretty horribly broken. In particular, the original strategy I used in this stack was to have functionalization run **above** subclass desugaring. But that doesn't play well with with the way we want to compile DTensor. DTensor has a few API's like `.redistribute()`, `.to_local()`, and the `DTensor()` constructor, that we want to put directly into the graph so that we can compile them (e.g. redistribute() will desugar into collective ops). Doing this requires functionalization to run **underneath** the subclass though. I hacked around this for now, by forcing these functions to run functionalization first if they need to.
- the backward test that I have is... wrong. The backward graph that we trace out looks kind of reasonable, but it gives incorrect gradients on one of the two inputs. This needs further debugging (presumably we should be able to stare at the graph and identify which part of it is wrong?).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105236
Approved by: https://github.com/wanchaol
optree recently landed and provide quite good perf, conditionally import
new optree if optree is installed
Some numbers testing mlp layer with TP + func collective:
before this PR: 10.390ms
after this PR: 9.189ms
so around e2e 10% CPU overhead reduction
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110670
Approved by: https://github.com/fegin
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:
(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests
(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.
(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415
Approved by: https://github.com/ezyang
As the title says, I was trying to test the functional collectives, and, when printing the resulting tensors, sometimes they wouldn't have finished the Async operation yet. According to the comments in the file, "AsyncTensor wrapper applied to returned tensor, which issues wait_tensor() at the time of first use". This is true in most cases, but not when print() is your first use. This PR fixes that.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107808
Approved by: https://github.com/fduwjj
We cannot use inner tensors for finalizers as they are uncollective until waited.
This PR adds a bunch of tests for the observable behavior we want, including the
necessary scafold for us to test code for their waitiness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107250
Approved by: https://github.com/wconstab
AsyncCollectiveTensor is a tensor subclass that is meant to "delay synchronization" when you call into the functional collectives API's. It does this (if I understand correctly) by internally holding an "unsynchronized" version of the tensor, which is the result of the communication op, and internally calling `.wait()` to synchronize the data the next time it is used.
Previously, these wait() calls would happen immediately, because `AsyncCollectiveTensor` gets wrapped by `DTensor()`, which calls `.detach()` on its inner tensor, immediately causing the sync (code: 1518d5eec4/torch/distributed/_tensor/api.py (L207))
AsyncCollectiveTensor shouldn't need to do a synchronization if you try to detach() it though - in fact, it should be fine to avoid synchronizing if you perform any view ops on it (which just require viewing metadata, but not actual data). This PR tries to update `AsyncCollectiveTensor` to delay `wait()` calls whenever the subclass encounters a view op.
Added some light testing, that just runs some DTensor compute followed by view ops, and confirms that the output is still an `AsyncCollectiveTensor` when we call `.to_local()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105240
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/wconstab
This PR get rids of the dim_groups attribute from DeviceMesh, the main
motivation behind this is that we should let c10d store the process
groups during its creation instead of DeviceMesh, DeviceMesh should just
handle ranks correctly.
This could enable DTensor becomes picklable! (torch.save/load could be
possible), which I will give it a try in the next PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103105
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
Also not sure if this should be a public function or not. Leaving it private for now but let me know if you prefer for it to be public.
FYI @nikitaved this will logically conflict with your triton kernel PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101420
Approved by: https://github.com/malfet
Summary:
Currently there are build configs where the torchdynamo import trips over a
strange SystemError related to some module's __dict__.items() returning NULL,
while torchdynamo tries to iterate all torch modules and process them for
its allowed functions list.
While this is hard to repro, we should be able to work around it and then fix
it properly.
Test Plan: Rely on others to test this, assuming CI passes.
Reviewed By: anijain2305
Differential Revision: D45663313
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100901
Approved by: https://github.com/yanboliang, https://github.com/malfet
We do it by making it possible to register multiple tensors for the same
worker and coordinate waiting/cleanup among them.
This ensures waiting on any number the output tensors will result in a
single stream sync. This simplifies codegen by inductor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99763
Approved by: https://github.com/wanchaol
Summary:
This diff is reverting D45387167
D45387167: Basic dynamo support for traceable collectives (#94440) by wconstab has been identified to be causing the following test or build failures (internal)
If you believe this diff has been generated in error you may Commandeer and Abandon it.
Test Plan: NA
Reviewed By: s4ayub
Differential Revision: D45448312
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100424
Approved by: https://github.com/rohan-varma, https://github.com/kumpera
Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.
Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.
Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.
Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
in eager vs compiled. In eager, there will be work-obj registration and
a wrapper subclass will insert a 'wait' call at the appropriate time.
In compile/trace mode, wait will be immetiately called, and work obj
registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
api, such as '_expand_group' which is essentially a constant transformation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94440
Approved by: https://github.com/kumpera