Commit Graph

77 Commits

Author SHA1 Message Date
Yifu Wang
4ac857f94e Support broadcast in native funcol (#119229)
### Summary

@LucasLLC recently implemented `broadcast` in funcol. This is not yet available in the native funcol ops. This PR adds support for broadcast for native funcol.

- Added `_c10d_functional::broadcast` and `_c10d_functional::broadcast_`
- Integrated with python functol broadcast and `AsyncCollectiveTensor`
- Implemented Inductor lowering. Verified correctness and buffer reuse behavior
- Validated dynamo traceability
- Validated AOTInductor compile-ability

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119229
Approved by: https://github.com/wanchaol
ghstack dependencies: #119104
2024-02-16 21:01:34 +00:00
Yifu Wang
8f82a44a5b Run device mesh tests with native funcol enabled (#118437)
### Summary

Run the relevant tests in `test/distributed/_tensor/test_dtensor_compile.py` and `test/distributed/test_device_mesh.py` with native funcol enabled, in addition to with them being disabled.

All tests excepts `test_tp_compile_comm_reordering` pass. This is expected because the native funcols have slightly different IRs, so the reordering pass needs to be adjusted. This test is disabled for now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118437
Approved by: https://github.com/LucasLLC
ghstack dependencies: #118910, #118911
2024-02-04 04:11:11 +00:00
Yifu Wang
697ca4f292 Preliminary DeviceMesh + native c10d functional integration (#118423)
### Summary
- Added `group_name` as the third field in `dim_group_infos`.
- `DeviceMeshTest` now runs both w/ and w/0 `_USE_NATIVE_C10D_FUNCTIONAL=1` in CI.

### Other fixes
- Convert `reduceOp` to lower case before passing it into c10d_functional ops.
- Added a finalizer to handle unwaited collectives (this mirrors the treatment for Python functional collective ops).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118423
Approved by: https://github.com/wanchaol, https://github.com/LucasLLC, https://github.com/wconstab
2024-01-31 04:36:12 +00:00
atalman
15702a8027 Fix lnit after #118533 (#118633)
Fixes lint after https://github.com/pytorch/pytorch/pull/118533
Adds ignore ``possibly-undefined`` to more places

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118633
Approved by: https://github.com/DanilBaibak
2024-01-30 14:07:16 +00:00
Yifu Wang
b778f44e97 Allow using native c10d_functional via _functional_collectives (#113057)
This diff introduces an env var `_USE_NATIVE_C10D_FUNCTIONAL` that tells `_functional_collective` to use native `c10d_functional` ops. The Python version and the native version will co-exist until we completely switch to the native version after more testing and verification.

NOTE: `DeviceMesh` support for native `c10d_functional` will be added in a subsequent PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113057
Approved by: https://github.com/LucasLLC, https://github.com/wconstab, https://github.com/wanchaol
2024-01-30 02:34:25 +00:00
Roger Lam
2c5488d719 Match all_gather_into_tensor args names in remapping (#117224)
Fixes #114179

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117224
Approved by: https://github.com/wanchaol, https://github.com/wconstab
2024-01-17 03:50:29 +00:00
Yifu Wang
718b576e2c Port all_to_all_single to native c10d_functional (#113438)
Summary:
- Ported `all_to_all_single` to native c10d_functional
- Added Inductor support for the native `all_to_all_single` via the new collective IR's `create_out_of_place()`
- Since the new collective IR derives from `FallbackKernel` which implements a generic `free_unbacked_symbols`, no additional unbacked symbol handling for all_to_all_single is required

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113438
Approved by: https://github.com/yf225, https://github.com/ezyang
2023-12-22 08:12:13 +00:00
Lucas Pasqualin
d749b4a152 Implements permute_tensor in functional collectives (#115078)
Implementation of `permute_tensor` as per @yifuwang 's suggestion.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115078
Approved by: https://github.com/wanchaol, https://github.com/yifuwang
2023-12-19 18:33:28 +00:00
Lucas Pasqualin
8452f41305 Adds allreduce to inductor remap (#115950)
Fixes #115728

Implements a rewrite path for allreduce

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115950
Approved by: https://github.com/wconstab
2023-12-18 22:00:22 +00:00
Chien-Chin Huang
54d552e991 [funcol] Directly import DeviceMesh to avoid circular dependency (#115649)
This diff aims to directly import DeviceMesh from torch.distributed.device_mesh instead of importing it from dist._tensor. This is done to avoid a circular dependency issue. The code changes in each file of the diff are as follows:

- torch/distributed/_functional_collectives.py: import DeviceMesh from torch.distributed instead of dist._tensor.

Overall, this diff aims to improve the code by avoiding circular dependencies and improving the import statements.

==
The above summary is generated by LLM with minor manual fixes. The following summary is by me.

The original import causes some issues when compiling DDP with compiled_autograd. The root cause of compilation failure is not identified but it is good to fix the lazy initialization, which indirectly fixes the compilation issues for DDP.

Differential Revision: [D51857246](https://our.internmc.facebook.com/intern/diff/D51857246/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115649
Approved by: https://github.com/wconstab, https://github.com/wz337
ghstack dependencies: #115523, #115302, #115648
2023-12-13 20:44:58 +00:00
Chien-Chin Huang
50db2aa70a [funcol][BE] Apply ufmt to _functional_collectives.py and turn on lintrunner for functional_collective (#115648)
No logic change, just formatting.

Differential Revision: [D51857236](https://our.internmc.facebook.com/intern/diff/D51857236/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115648
Approved by: https://github.com/wconstab, https://github.com/wz337
ghstack dependencies: #115523, #115302
2023-12-13 11:19:29 +00:00
Wanchao Liang
b6de337d16 [funcol] a few optimizations to funcol (#113324)
Apply a few optimizations to funcol:

- allgather on non-0 dim, the resulting tensor already needs to access
data in order to do torch.cat, so we sync wait here so that we don;t
need to go through ACT dispatch for chunk + cat alltogether
- have a fast return logic to aten.view as it's a commonly hit op for
view related ops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113324
Approved by: https://github.com/XilunWu
2023-12-06 19:25:35 +00:00
Joel Schlosser
22704426c3 Expand dynamic dims support for traceable subclasses (#114311)
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).

Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
    * Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
    * Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
    * Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
    * Signatures now:
    ```python
    # attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
    # ctx is anything useful for rebuilding the class we want to guard on
    attrs, ctx = x.__tensor_flatten__()
    ...
    # inner_tensors is a dict of {attr -> tensor}
    # ctx is taken unmodified from flattening and (eventually) guarded on
    # outer_size is the expected size of the output; possibly symbolic
    # outer_stride is the expected strides of the output; possibly symbolic
    y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)

    # at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
    # the assert simplifies symbols when there are relationships between outer and inner symbols
    ```
    * Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
    * Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
    * Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
2023-12-05 21:09:25 +00:00
PyTorch MergeBot
4534cf102a Revert "[funcol] a few optimizations to funcol (#113324)"
This reverts commit 7117bffff9.

Reverted https://github.com/pytorch/pytorch/pull/113324 on behalf of https://github.com/huydhn due to Sorry for reverting your change here, but it is failing internal test ([comment](https://github.com/pytorch/pytorch/pull/113324#issuecomment-1813317913))
2023-11-15 21:53:23 +00:00
Wanchao Liang
7117bffff9 [funcol] a few optimizations to funcol (#113324)
Apply a few optimizations to funcol:

- allgather on non-0 dim, the resulting tensor already needs to access
data in order to do torch.cat, so we sync wait here so that we don;t
need to go through ACT dispatch for chunk + cat alltogether
- have a fast return logic to aten.view as it's a commonly hit op for
view related ops

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113324
Approved by: https://github.com/XilunWu
ghstack dependencies: #113323
2023-11-14 09:28:09 +00:00
Wanchao Liang
b16e3b5373 [funcol] add two APIs: wait() and numpy() (#113323)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113323
Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/wconstab
2023-11-14 09:27:45 +00:00
PyTorch MergeBot
23e0923c74 Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"
This reverts commit eeeb40b327.

Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to Reverting this pr as the one under it in the stack is causing regressions in torchrec ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1806044435))
2023-11-10 16:30:36 +00:00
Xuehai Pan
eeeb40b327 [pytree] reorganize submodule structure for C++ and Python pytree (#112278)
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.

Before:

```text
torch
├── utils
│   ├── _pytree.py
│   ├── _cxx_pytree.py
│   ...
...
```

After:

```text
torch
├── utils
│   ├── _pytree
│   │   ├── __init__.py
│   │   └── api
│   │       ├── __init__.py
│   │       ├── cxx.py
│   │       └── python.py
│   ...
...
```

The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
2023-11-10 05:41:32 +00:00
PyTorch MergeBot
bf452dcde6 Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"
This reverts commit fa895da968.

Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to in the bottom diff in the stack changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1804870560))
2023-11-10 00:12:52 +00:00
Lucas Pasqualin
1d56e7b5af Adds broadcast to functional collectives (#112668)
Adds `broadcast` to functional collectives, including inductor support.

Test with `python test_inductor_collectives.py -- TestCollectivesMultiProc.test_broadcast_inductor`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112668
Approved by: https://github.com/wanchaol, https://github.com/wconstab
2023-11-09 15:47:52 +00:00
Yifu Wang
625958d8bc Inductor support for native c10d_functional (#112439)
This PR adds Inductor support for [native c10d_functional ops](https://github.com/pytorch/pytorch/pull/110570).

The Inductor IRs introduced in this PR will replace the existing `CollectiveKernel` IR hierarchy. Compared to the existing collective IRs, the new IRs:
- Are target language agnostic and support AOTInductor.
- Express the constraints solely with read/write deps. This maximizes the potential for buffer reuse.
- Address an issue where out-of-place collective's input buffers could be mutated while being volatilely read.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112439
Approved by: https://github.com/Chillee
2023-11-08 23:40:21 +00:00
Xuehai Pan
fa895da968 [pytree] reorganize submodule structure for C++ and Python pytree (#112278)
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.

Before:

```text
torch
├── utils
│   ├── _pytree.py
│   ├── _cxx_pytree.py
│   ...
...
```

After:

```text
torch
├── utils
│   ├── _pytree
│   │   ├── __init__.py
│   │   └── api
│   │       ├── __init__.py
│   │       ├── cxx.py
│   │       └── python.py
│   ...
...
```

The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
2023-11-08 06:05:39 +00:00
rzou
a06832f911 Grandfather in c10d_functional ops to pt2_compliant (#113049)
This PR also adds the ability to specify Tags for more `m.def(`
overloads.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113049
Approved by: https://github.com/williamwen42
2023-11-07 12:55:05 +00:00
Yifu Wang
ec18ef62f4 Native c10d_functional ops (#110570)
This PR introduces a native version of c10d_functional ops. The main goal is to add collective support in AOTInductor and allow collective ops to work in multi-threaded native runtimes.

The native version also incorporated API improvements we wished to implement in Python c10d_functional:

- Removed `ranks` and `group_size` from collective op signatures which were proven to be redundant.
- Use tensor storage as opposed to `void*` to resolve in-flight work.

The native process group registration/resolution mechansim is only used for native c10d_functional in the PR. It will become the single source of truth in upcoming PRs.

The upcoming PRs will implement Inductor/AOTInductor support for c10d_functional, after which native c10d_functional will replace Python c10d_functional.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110570
Approved by: https://github.com/wanchaol
2023-10-25 22:56:06 +00:00
Brian Hirsh
4d29b40299 torch.compile DTensor E2E (#105236)
This PR updates DTensor to support torch.compile

Cool stuff: there are some new tests in `test_dtensor.py` that show both the forward and backward graphs that we can send to inductor, when running a matmul with DTensor's. In particular, for this user code:
```
        def fn(x, y):
            dt = DTensor.from_local(x.reshape(2, 4), mesh, [Shard(0)], run_check=False)
            dt2 = DTensor.from_local(y.reshape(4, 2), mesh, [Shard(1)], run_check=False)
            dt_out = torch.matmul(dt, dt2)
            dt_out_redistribute = dt_out.redistribute(mesh, [Replicate()])
            return dt_out.to_local()
```

We generate the following fw and backward graphs.

Forward graph:
```
def forward(self, primals_1, primals_2):
    view = torch.ops.aten.view.default(primals_1, [2, 4]);  primals_1 = None
    _to_copy = torch.ops.aten._to_copy.default(view, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0));  view = None
    detach = torch.ops.aten.detach.default(_to_copy);  _to_copy = None
    detach_1 = torch.ops.aten.detach.default(detach);  detach = None
    view_1 = torch.ops.aten.view.default(primals_2, [4, 2]);  primals_2 = None
    _to_copy_1 = torch.ops.aten._to_copy.default(view_1, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0));  view_1 = None
    detach_2 = torch.ops.aten.detach.default(_to_copy_1);  _to_copy_1 = None
    detach_3 = torch.ops.aten.detach.default(detach_2);  detach_2 = None
    detach_4 = torch.ops.aten.detach.default(detach_1)
    all_gather_into_tensor = torch.ops.c10d_functional.all_gather_into_tensor.default(detach_3, 'ptd:0', [0, 1], 2)
    wait_tensor = torch.ops.c10d_functional.wait_tensor.default(all_gather_into_tensor);  all_gather_into_tensor = None
    split = torch.ops.aten.split.Tensor(wait_tensor, 4);  wait_tensor = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    cat = torch.ops.aten.cat.default([getitem, getitem_1], 1);  getitem = getitem_1 = None
    detach_5 = torch.ops.aten.detach.default(cat);  cat = None
    mm = torch.ops.aten.mm.default(detach_4, detach_5);  detach_4 = detach_5 = None
    detach_6 = torch.ops.aten.detach.default(mm);  mm = None
    detach_9 = torch.ops.aten.detach.default(detach_6);  detach_6 = None
    detach_10 = torch.ops.aten.detach.default(detach_9);  detach_9 = None
    t = torch.ops.aten.t.default(detach_1);  detach_1 = None
    detach_13 = torch.ops.aten.detach.default(t);  t = None
    t_1 = torch.ops.aten.t.default(detach_3);  detach_3 = None
    detach_15 = torch.ops.aten.detach.default(t_1);  t_1 = None
    clone = torch.ops.aten.clone.default(detach_15, memory_format = torch.contiguous_format);  detach_15 = None
    return [detach_10, detach_13, clone]
```

Backward graph:
```
def forward(self, detach_13, clone, tangents_1):
    detach_11 = torch.ops.aten.detach.default(tangents_1);  tangents_1 = None
    detach_12 = torch.ops.aten.detach.default(detach_11);  detach_11 = None
    mm_1 = torch.ops.aten.mm.default(detach_13, detach_12);  detach_13 = None
    detach_14 = torch.ops.aten.detach.default(mm_1);  mm_1 = None
    detach_16 = torch.ops.aten.detach.default(detach_12);  detach_12 = None
    all_gather_into_tensor_2 = torch.ops.c10d_functional.all_gather_into_tensor.default(clone, 'ptd:0', [0, 1], 2);  clone = None
    wait_tensor_2 = torch.ops.c10d_functional.wait_tensor.default(all_gather_into_tensor_2);
    detach_17 = torch.ops.aten.detach.default(wait_tensor_2);  wait_tensor_2 = None
    mm_2 = torch.ops.aten.mm.default(detach_16, detach_17);  detach_16 = detach_17 = None
    detach_18 = torch.ops.aten.detach.default(mm_2);  mm_2 = None
    split_1 = torch.ops.aten.split.Tensor(detach_14, 2, 1);  detach_14 = None
    getitem_2 = split_1[0]
    getitem_3 = split_1[1];  split_1 = None
    cat_1 = torch.ops.aten.cat.default([getitem_2, getitem_3]);  getitem_2 = getitem_3 = None
    reduce_scatter_tensor = torch.ops.c10d_functional.reduce_scatter_tensor.default(cat_1, 'SUM', 'ptd:0', [0, 1], 2);  cat_1 = None
    wait_tensor_3 = torch.ops.c10d_functional.wait_tensor.default(reduce_scatter_tensor);  reduce_scatter_tensor = None
    detach_19 = torch.ops.aten.detach.default(wait_tensor_3);  wait_tensor_3 = None
    detach_20 = torch.ops.aten.detach.default(detach_19);  detach_19 = None
    detach_21 = torch.ops.aten.detach.default(detach_20);  detach_20 = None
    detach_22 = torch.ops.aten.detach.default(detach_21);  detach_21 = None
    _to_copy_2 = torch.ops.aten._to_copy.default(detach_22, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  detach_22 = None
    view_2 = torch.ops.aten.view.default(_to_copy_2, [8]);  _to_copy_2 = None
    detach_23 = torch.ops.aten.detach.default(detach_18);  detach_18 = None
    detach_24 = torch.ops.aten.detach.default(detach_23);  detach_23 = None
    _to_copy_3 = torch.ops.aten._to_copy.default(detach_24, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'));  detach_24 = None
    view_3 = torch.ops.aten.view.default(_to_copy_3, [8]);  _to_copy_3 = None
    return [view_3, view_2]
```

Some of the stuff in this graph looks kinda of silly though (e.g. an unnecessary split() + cat(), and all the extra detach() calls).

Stuff that's broken:
- functionalization is pretty horribly broken. In particular, the original strategy I used in this stack was to have functionalization run **above** subclass desugaring. But that doesn't play well with with the way we want to compile DTensor. DTensor has a few API's like `.redistribute()`, `.to_local()`, and the `DTensor()` constructor, that we want to put directly into the graph so that we can compile them (e.g. redistribute() will desugar into collective ops). Doing this requires functionalization to run **underneath** the subclass though. I hacked around this for now, by forcing these functions to run functionalization first if they need to.
- the backward test that I have is... wrong. The backward graph that we trace out looks kind of reasonable, but it gives incorrect gradients on one of the two inputs. This needs further debugging (presumably we should be able to stare at the graph and identify which part of it is wrong?).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105236
Approved by: https://github.com/wanchaol
2023-10-11 21:55:27 +00:00
Wanchao Liang
459cef8649 switch dtensor and functional collective to use optree (#110670)
optree recently landed and provide quite good perf, conditionally import
new optree if optree is installed

Some numbers testing mlp layer with TP + func collective:
before this PR: 10.390ms
after this PR: 9.189ms

so around e2e 10% CPU overhead reduction

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110670
Approved by: https://github.com/fegin
2023-10-08 03:05:39 +00:00
Edward Z. Yang
f274c7b32c Add functional collective all_to_all_single and support it in Inductor (#110195)
Copy of https://github.com/pytorch/pytorch/pull/106655 from yf225
rebased on top of item() support changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110195
Approved by: https://github.com/Skylion007
2023-10-05 23:11:51 +00:00
Edward Z. Yang
ec8b58f5ba Add support for tolist on AsyncCollectiveTensor (#109377)
This has to be done by hand because tolist isn't supported on tensor subclasses.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109377
Approved by: https://github.com/wconstab, https://github.com/fduwjj
2023-09-15 21:48:13 +00:00
Brian Hirsh
5efd63b1b8 better support for fakeifying and dynamoing through torch_dispatch subclasses (with dynamic shapes) (#107415)
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:

(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests

(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.

(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415
Approved by: https://github.com/ezyang
2023-08-29 02:36:48 +00:00
Antoni Viros i Martin
2c45a579ca Add wait_tensor so print always has a correct result for AsyncCollectiveTensor (#107808)
As the title says, I was trying to test the functional collectives, and, when printing the resulting tensors, sometimes they wouldn't have finished the Async operation yet. According to the comments in the file, "AsyncTensor wrapper applied to returned tensor, which issues wait_tensor() at the time of first use". This is true in most cases, but not when print() is your first use. This PR fixes that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107808
Approved by: https://github.com/fduwjj
2023-08-24 00:00:23 +00:00
Rodrigo Kumpera
bbf03561a9 [functional collectives] Move back to registering finalizers on wrappers. (#107250)
We cannot use inner tensors for finalizers as they are uncollective until waited.

This PR adds a bunch of tests for the observable behavior we want, including the
necessary scafold for us to test code for their waitiness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107250
Approved by: https://github.com/wconstab
2023-08-17 21:08:28 +00:00
Wanchao Liang
5c48ff20b5 AsyncCollectiveTensor: dont sync on view ops (#105240)
AsyncCollectiveTensor is a tensor subclass that is meant to "delay synchronization" when you call into the functional collectives API's. It does this (if I understand correctly) by internally holding an "unsynchronized" version of the tensor, which is the result of the communication op, and internally calling `.wait()` to synchronize the data the next time it is used.

Previously, these wait() calls would happen immediately, because `AsyncCollectiveTensor` gets wrapped by `DTensor()`, which calls `.detach()` on its inner tensor, immediately causing the sync (code: 1518d5eec4/torch/distributed/_tensor/api.py (L207))

AsyncCollectiveTensor shouldn't need to do a synchronization if you try to detach() it though - in fact, it should be fine to avoid synchronizing if you perform any view ops on it (which just require viewing metadata, but not actual data). This PR tries to update `AsyncCollectiveTensor` to delay `wait()` calls whenever the subclass encounters a view op.

Added some light testing, that just runs some DTensor compute followed by view ops, and confirms that the output is still an `AsyncCollectiveTensor` when we call `.to_local()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105240
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/wconstab
2023-08-11 19:20:25 +00:00
Will Constable
d64bada876 Refactor funcol for readability and dynamo tracing (#104387)
Move eager kernel impls to separate file, which is eaiser to read
(since users may be confused about 2 versions of each kernel in the same file)
and easier to set a dynamo policy to trace only the first file currently.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104387
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/kumpera
2023-07-06 23:29:49 +00:00
Rodrigo Kumpera
17ab4f85e9 [c10d] Adopt allgather_into_tensor_coalesced for NCCL. (#103086)
This is done by adding c10d::_allgather_into_tensor_coalesced wrapper.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103086
Approved by: https://github.com/rohan-varma
2023-07-06 15:05:55 +00:00
Wanchao Liang
db1ac4e29b fix functional collective's allgather for gloo (#104681)
Summary: We should explicitly check for the gloo backend instead of relying on the shard's device, because user might pass a GPU tensor as input and a process group gloo as the pg, and expect that should work.

Differential Revision: D47249172

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104681
Approved by: https://github.com/rohan-varma, https://github.com/fduwjj
2023-07-06 09:52:48 +00:00
Will Constable
d0509fe32d Document how functional collectives work under eager/dynamo (#104386)
Move user facing apis to the top for best visibility
(strictly code-motion in this PR, besides adding comments)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104386
Approved by: https://github.com/voznesenskym, https://github.com/wanchaol
2023-06-30 01:12:55 +00:00
Rodrigo Kumpera
c17bdb3247 [C10D] Add functional collective reduce_scatter_into_tensor_coalesced. (#101023)
Implementation uses a fallback that does no coalescing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101023
Approved by: https://github.com/wanchaol
2023-06-23 19:24:11 +00:00
Rodrigo Kumpera
0beec88c93 Inductor support for all_gather_into_tensor_coalesced. (#98643)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98643
Approved by: https://github.com/wanchaol
2023-06-21 19:25:03 +00:00
Rodrigo Kumpera
63fe26809d Implement all_gather_into_tensor_coalesced. (#98642)
The implementation is suboptimal since it uses c10d's group coalescing which
is known to be inneficient.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98642
Approved by: https://github.com/wanchaol
2023-06-13 15:06:52 +00:00
PyTorch MergeBot
caecb55223 Revert "Log functional_collectives apis to distributed logger (#103288)"
This reverts commit 37359c36fd.

Reverted https://github.com/pytorch/pytorch/pull/103288 on behalf of https://github.com/malfet due to Broke test_inductor_collectives, see 37359c36fd ([comment](https://github.com/pytorch/pytorch/pull/103288#issuecomment-1587677705))
2023-06-12 16:37:57 +00:00
Will Constable
37359c36fd Log functional_collectives apis to distributed logger (#103288)
This logs functional collectives API calls with debug log level only.

(the `+` in the TORCH_LOGS cmdline enables debug level, otherwise only info level)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103288
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
2023-06-12 06:33:26 +00:00
Wanchao Liang
d31707a257 Get rid of dim_groups attribute from DeviceMesh (#103105)
This PR get rids of the dim_groups attribute from DeviceMesh, the main
motivation behind this is that we should let c10d store the process
groups during its creation instead of DeviceMesh, DeviceMesh should just
handle ranks correctly.

This could enable DTensor becomes picklable! (torch.save/load could be
possible), which I will give it a try in the next PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103105
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
2023-06-09 04:11:15 +00:00
Will Constable
77f97019b7 Dynamo remaps legacy allgather to traceable one (#102232)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102232
Approved by: https://github.com/voznesenskym
2023-05-30 16:45:25 +00:00
albanD
59dff01319 Add top level function to check if running with deploy (#101420)
Also not sure if this should be a public function or not. Leaving it private for now but let me know if you prefer for it to be public.

FYI @nikitaved this will logically conflict with your triton kernel PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101420
Approved by: https://github.com/malfet
2023-05-16 16:05:49 +00:00
Will Constable
793bd6993a Work around torchdynamo import error with functional collectives (#100901)
Summary:
Currently there are build configs where the torchdynamo import trips over a
strange SystemError related to some module's __dict__.items() returning NULL,
while torchdynamo tries to iterate all torch modules and process them for
its allowed functions list.

While this is hard to repro, we should be able to work around it and then fix
it properly.

Test Plan: Rely on others to test this, assuming CI passes.

Reviewed By: anijain2305

Differential Revision: D45663313

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100901
Approved by: https://github.com/yanboliang, https://github.com/malfet
2023-05-09 16:09:42 +00:00
Rodrigo Kumpera
7a15e82388 Fix tensor registration to work with coalescing collectives. (#99763)
We do it by making it possible to register multiple tensors for the same
worker and coordinate waiting/cleanup among them.

This ensures waiting on any number the output tensors will result in a
single stream sync. This simplifies codegen by inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99763
Approved by: https://github.com/wanchaol
2023-05-05 14:25:35 +00:00
Will Constable
2dca418112 Reland basic dynamo support for traceable collectives (#100476)
Relative to the original land, this also contains:
- Fix torchdeploy import of functional collectives
- Can't import torchdynamo utils due to torch._refs being missing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100476
Approved by: https://github.com/kumpera
2023-05-04 04:25:35 +00:00
Shabab Ayub
287f74c4fc Revert D45387167: Multisect successfully blamed D45387167 for test or build failures (#100424)
Summary:
This diff is reverting D45387167
D45387167: Basic dynamo support for traceable collectives (#94440) by wconstab has been identified to be causing the following test or build failures (internal)

If you believe this diff has been generated in error you may Commandeer and Abandon it.

Test Plan: NA

Reviewed By: s4ayub

Differential Revision: D45448312

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100424
Approved by: https://github.com/rohan-varma, https://github.com/kumpera
2023-05-03 16:10:54 +00:00
Will Constable
100a25d021 Basic dynamo support for traceable collectives (#94440)
Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.

Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.

Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.

Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
  in eager vs compiled.  In eager, there will be work-obj registration and
  a wrapper subclass will insert a 'wait' call at the appropriate time.
  In compile/trace mode, wait will be immetiately called, and work obj
  registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
  api, such as '_expand_group' which is essentially a constant transformation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94440
Approved by: https://github.com/kumpera
2023-04-27 05:38:36 +00:00
Rodrigo Kumpera
5b4a523583 Add all_reduce_coalesced to functional collectives (#98640)
This adds all_reduce_coalesced to MTPG to ease testing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98640
Approved by: https://github.com/wanchaol
2023-04-26 17:05:54 +00:00