Commit Graph

33 Commits

Author SHA1 Message Date
Will Feng
4ee514144b [c10d][Partial-Graph Overlap] Support calling .wait_tensor() on output tensor of eager async_op=True collective if under allow_inflight_collective_as_graph_input_ctx() context manager (#137763)
This PR aims to support the following use case:
```python
def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y

x = torch.ones(1280, 1280, device="cuda") + self.rank
with allow_inflight_collective_as_graph_input_ctx():
    y = all_reduce_eager(x)
    z = all_reduce_wait_compiled(y)
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.

This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.

----

**Update**: Did two items to prevent regression to existing use cases:

1. Added memory-stressed test case to test_c10d_nccl.py `test_unwaited` to cover existing user's "not calling work.wait() for non-functional collective" use case
2. Gated all new `register_work()` / `unregister_work()` calls with `c10d::allow_inflight_collective_as_graph_input()` check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).

The risk of this new version of PR causing regression should be very low.

------

Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`

------

Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
2024-10-29 03:31:19 +00:00
PyTorch MergeBot
e5595f10c8 Revert "[c10d][Partial-Graph Overlap] Support calling .wait_tensor() on output tensor of eager async_op=True collective if under allow_inflight_collective_as_graph_input_ctx() context manager (#137763)"
This reverts commit a688c57033.

Reverted https://github.com/pytorch/pytorch/pull/137763 on behalf of https://github.com/yf225 due to Seems to have bad interaction with latest commits on trunk, reverting to be safe ([comment](https://github.com/pytorch/pytorch/pull/137763#issuecomment-2442527696))
2024-10-28 20:13:46 +00:00
Will Feng
a688c57033 [c10d][Partial-Graph Overlap] Support calling .wait_tensor() on output tensor of eager async_op=True collective if under allow_inflight_collective_as_graph_input_ctx() context manager (#137763)
This PR aims to support the following use case:
```python
def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y

x = torch.ones(1280, 1280, device="cuda") + self.rank
with allow_inflight_collective_as_graph_input_ctx():
    y = all_reduce_eager(x)
    z = all_reduce_wait_compiled(y)
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.

This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.

------

Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`

------

Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
2024-10-28 18:11:23 +00:00
PyTorch MergeBot
e7f1e306df Revert "[c10d][Partial-Graph Overlap] Support calling .wait_tensor() within compiled region on output tensor of eager async_op=True collective (#137763)"
This reverts commit 362ca54f03.

Reverted https://github.com/pytorch/pytorch/pull/137763 on behalf of https://github.com/wdvr due to this change is breaking our prod training pipeline (verified with bisect) by increasing memory consumption 4x and causing OOM ([comment](https://github.com/pytorch/pytorch/pull/137763#issuecomment-2435962833))
2024-10-24 17:46:09 +00:00
Will Feng
362ca54f03 [c10d][Partial-Graph Overlap] Support calling .wait_tensor() within compiled region on output tensor of eager async_op=True collective (#137763)
This PR aims to support the following use case:
```python
def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.

This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.

------

Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_work_registry`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_work_registry`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`

------

Differential Revision: [D64511994](https://our.internmc.facebook.com/intern/diff/D64511994)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
2024-10-21 06:02:57 +00:00
cyy
94e12f97dc [Distributed] [16/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#137404)
Follows #137072

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137404
Approved by: https://github.com/Skylion007
2024-10-10 18:05:34 +00:00
Yifu Wang
a136a7d623 [Functional Collective] enable custom work registration from python (#130354)
This PR does two things:
- Allow tensor -> work registration in Python via `torch._C._distributed_c10d.register_work`. Calling `torch.ops._c10d_functional.wait_tensor` on a tensor would trigger `.wait()` on the registered work object.
- Allow user-defined work object in Python to work with functional collectives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130354
Approved by: https://github.com/wanchaol, https://github.com/fegin, https://github.com/wconstab
2024-07-22 21:45:19 +00:00
cyy
3f9b8446cf [8/N] Remove unused functions (#128499)
Follows #128407

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128499
Approved by: https://github.com/malfet
2024-06-13 01:15:11 +00:00
cyy
be7be9fa16 [Distributed] [8/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#125102)
This PR continues to clean clang-tidy warnings in torch/csrc/distributed/c10d, following https://github.com/pytorch/pytorch/pull/124987.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125102
Approved by: https://github.com/ezyang
2024-05-30 16:19:53 +00:00
Will Feng
4333e122d4 [Traceable FSDP2] Add all_gather_into_tensor out variant (#126334)
This PR adds `torch.ops._c10d_functional.all_gather_into_tensor_out`.

It's important for tracing FSDP2, because FSDP2 pre-allocates the output buffer of AllGather, and makes input buffer an alias of the output buffer, and expects both of them to be used to achieve lower memory usage. If we don't preserve this behavior and instead functionalize the AllGather op, AllGather op will then create a brand-new output buffer (instead of reusing), thus significantly increasing the memory usage.

The expectation is that we will "re-inplace" the AllGather op by switching to the out variant in Inductor post-grad stage via an FX pass, so this API is not expected to be directly used by users.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126334
Approved by: https://github.com/yifuwang, https://github.com/wanchaol
2024-05-16 10:27:06 +00:00
Wanchao Liang
00df0d3e94 [dtensor] implement shard dim change with alltoall (#124872)
as titled, we implement a dedicated communication op to allow efficient
sharding dimension change using alltoall, to replace our previous
allgather + local chunk

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124872
Approved by: https://github.com/XilunWu, https://github.com/yifuwang
ghstack dependencies: #124871
2024-04-30 18:30:34 +00:00
PyTorch MergeBot
f1d1e3246f Revert "[dtensor] implement shard dim change with alltoall (#124872)"
This reverts commit 6b79469d24.

Reverted https://github.com/pytorch/pytorch/pull/124872 on behalf of https://github.com/clee2000 due to broke distributed/tensor/parallel/test_tp_examples.py::DistTensorParallelExampleTest::test_transformer_training_is_seq_parallel_True https://github.com/pytorch/pytorch/actions/runs/8882762411/job/24389191482 f7f018a0ed.  Bad TD ([comment](https://github.com/pytorch/pytorch/pull/124872#issuecomment-2083599445))
2024-04-29 20:26:16 +00:00
Wanchao Liang
6b79469d24 [dtensor] implement shard dim change with alltoall (#124872)
as titled, we implement a dedicated communication op to allow efficient
sharding dimension change using alltoall, to replace our previous
allgather + local chunk

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124872
Approved by: https://github.com/XilunWu, https://github.com/yifuwang
ghstack dependencies: #124871
2024-04-29 17:22:30 +00:00
cyy
ea61c9cb29 [Distributed] [5/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#124043)
This PR continues to fix some clang-tidy warnings in distributed/c10d code, following https://github.com/pytorch/pytorch/pull/124032.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124043
Approved by: https://github.com/ezyang
2024-04-23 00:43:50 +00:00
Tristan Rice
ddd0ed1b43 distributed: templated ring attention (#124215)
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR.

This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way.

Misc changes:

* Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test
* Adds compile support to the ring attention implementations (required some tweaks to process groups)

Test plan:

```
pytest test/distributed/_tensor/test_attention.py
pytest test/distributed/test_functional_api.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124215
Approved by: https://github.com/wanchaol
2024-04-19 00:57:08 +00:00
Tristan Rice
1ec05c769b all_gather and reduce_scatter autograd (#123989)
This adds `all_gather_tensor_autograd` and `reduce_scatter_tensor_autograd` to the functional_collectives library.

This only supports `sum` mode for `reduce_scatter` but should be easy to extend in the future.

The backwards implementations match the behavior in https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/comm_ops.py

This follows the pattern of #123599 .

Test plan:

```sh
pytest test/distributed/test_functional_api.py -k Autograd
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123989
Approved by: https://github.com/wanchaol
2024-04-17 21:32:22 +00:00
cyy
c2596fd3e0 [Distributed] [4/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#124032)
This PR continues to fix some clang-tidy warnings in distributed/c10d code, following https://github.com/pytorch/pytorch/pull/123312.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124032
Approved by: https://github.com/Skylion007
2024-04-16 00:42:18 +00:00
cyy
77a45883ce [Reland] [Distributed] [2/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#123821)
Reland of #122892 with problematic changes reverted.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123821
Approved by: https://github.com/Skylion007
2024-04-13 00:57:03 +00:00
Tristan Rice
358ace1a1b functional_collectives: add first differentiable collective -- all_to_all_single_grad (#123599)
This adds the differentiable collective -- all_to_all_single_grad. This is the initial proof of concept PR and I will be adding the remaining collectives in follow up PRs.

This adds a new function called `all_to_all_single_autograd` which is the autograd variant of `all_to_all_single`. For backwards compatibility + initial testing we wanted to make the autograd variant separate to avoid regressions.

This uses `autograd::Function` to register an Autograd op that calls the original `_c10d_functional::all_to_all_single` via the dispatcher. This works with compile and inductor as opposed to the previous Python implementation that had issues. As this uses the existing `_c10d_functional` ops we don't need to register any meta functions or lowering.

To avoid cudaStream issues this explicitly calls `wait_tensor` in the backward method to ensure it runs under the same stream as the async operation. This hurts performance but can be alleviated potentially using `compile`.

Related work: https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/comm_ops.py

Test plan:

```
pytest test/distributed/test_functional_api.py -k test_all_to_all_single_compile
pytest test/distributed/test_functional_api.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123599
Approved by: https://github.com/yifuwang
2024-04-12 01:48:49 +00:00
PyTorch MergeBot
54801e6fd6 Revert "[Distributed] [2/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#122892)"
This reverts commit 0ba16ffd35.

Reverted https://github.com/pytorch/pytorch/pull/122892 on behalf of https://github.com/atalman due to broke cuda tests ([comment](https://github.com/pytorch/pytorch/pull/122892#issuecomment-2037207036))
2024-04-04 13:22:22 +00:00
cyy
0ba16ffd35 [Distributed] [2/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#122892)
This PR continues to fix some clang-tidy warnings in distributed code, following #122884.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122892
Approved by: https://github.com/Skylion007
2024-04-04 00:39:31 +00:00
Yifu Wang
22cd2658b4 Disable GroupRegistry's thread isolation by default (#121457)
Today `GroupRegistry` employs thread isolation by default, i.e. every thread sees its own process group registry. This is intended to work for one-device-per-process (for python use cases) and one-device-per-thread case (for custom native runtimes).

However, there's a problem - there are python use cases that initializes/registers process groups in one thread, and runs collectives in another thread. This use case should be supported. However, since `GroupRegistry` employs thread isolation by default, collectives in different threads can't find the registered process groups.

This PR fixes the issue by:
- Make `GroupRegistry` work in non-thread isolation mode by default. This would match the behavior w/o the native process group registry.
- Introduces `set_thread_isolation_mode` so one-device-per-thread runtimes can enable thread isolation mode explicitly.

Differential Revision: [D54658515](https://our.internmc.facebook.com/intern/diff/D54658515)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121457
Approved by: https://github.com/wanchaol
2024-03-08 19:31:24 +00:00
Yifu Wang
1c9fc720ae Change the .clone() in native funcol's all_reduce to use at::MemoryFormat::Contiguous (#120042)
Summary:
While I think it probably makes more sense to only require `all_reduce` input to be non-overlapping and dense, today `ProcessGroupNCCL` requires it to be contiguous. This is also what the `all_reduce` in non-native funcol does.

Also marking a test affected by this with `@run_with_both_funcol_impls`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120042
Approved by: https://github.com/wanchaol
2024-02-22 20:24:15 +00:00
Yifu Wang
40786ca509 Handle unwaited work objects on process termination (#119881)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119881
Approved by: https://github.com/wconstab
2024-02-19 02:46:02 +00:00
Yifu Wang
4ac857f94e Support broadcast in native funcol (#119229)
### Summary

@LucasLLC recently implemented `broadcast` in funcol. This is not yet available in the native funcol ops. This PR adds support for broadcast for native funcol.

- Added `_c10d_functional::broadcast` and `_c10d_functional::broadcast_`
- Integrated with python functol broadcast and `AsyncCollectiveTensor`
- Implemented Inductor lowering. Verified correctness and buffer reuse behavior
- Validated dynamo traceability
- Validated AOTInductor compile-ability

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119229
Approved by: https://github.com/wanchaol
ghstack dependencies: #119104
2024-02-16 21:01:34 +00:00
Yifu Wang
5086e1cf3f Remove distributed/c10d/Functional.hpp (#119138)
This file is useless and was accidentally checked in.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119138
Approved by: https://github.com/Skylion007
2024-02-05 21:58:08 +00:00
Yifu Wang
b778f44e97 Allow using native c10d_functional via _functional_collectives (#113057)
This diff introduces an env var `_USE_NATIVE_C10D_FUNCTIONAL` that tells `_functional_collective` to use native `c10d_functional` ops. The Python version and the native version will co-exist until we completely switch to the native version after more testing and verification.

NOTE: `DeviceMesh` support for native `c10d_functional` will be added in a subsequent PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113057
Approved by: https://github.com/LucasLLC, https://github.com/wconstab, https://github.com/wanchaol
2024-01-30 02:34:25 +00:00
Yifu Wang
7d0ad6e870 Make native c10d_functional ops work with AOTInductor (#113735)
Summary:
- Revised `c10d_functional` ops to conform to https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#func
- Modifed `get_cpp_op_schema()` to handle mutable args and aliasing returns

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113735
Approved by: https://github.com/desertfire
ghstack dependencies: #113438
2023-12-22 08:12:13 +00:00
Yifu Wang
718b576e2c Port all_to_all_single to native c10d_functional (#113438)
Summary:
- Ported `all_to_all_single` to native c10d_functional
- Added Inductor support for the native `all_to_all_single` via the new collective IR's `create_out_of_place()`
- Since the new collective IR derives from `FallbackKernel` which implements a generic `free_unbacked_symbols`, no additional unbacked symbol handling for all_to_all_single is required

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113438
Approved by: https://github.com/yf225, https://github.com/ezyang
2023-12-22 08:12:13 +00:00
rzou
a06832f911 Grandfather in c10d_functional ops to pt2_compliant (#113049)
This PR also adds the ability to specify Tags for more `m.def(`
overloads.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113049
Approved by: https://github.com/williamwen42
2023-11-07 12:55:05 +00:00
PyTorch MergeBot
1fea599d9a Revert "Grandfather in c10d_functional ops to pt2_compliant (#113049)"
This reverts commit fe8570a1fe.

Reverted https://github.com/pytorch/pytorch/pull/113049 on behalf of https://github.com/clee2000 due to something in the stack broke distributed and inductor, pretty sure its this one ([comment](https://github.com/pytorch/pytorch/pull/113049#issuecomment-1797298969))
2023-11-07 02:34:13 +00:00
rzou
fe8570a1fe Grandfather in c10d_functional ops to pt2_compliant (#113049)
This PR also adds the ability to specify Tags for more `m.def(`
overloads.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113049
Approved by: https://github.com/williamwen42
ghstack dependencies: #113036
2023-11-06 23:43:23 +00:00
Yifu Wang
ec18ef62f4 Native c10d_functional ops (#110570)
This PR introduces a native version of c10d_functional ops. The main goal is to add collective support in AOTInductor and allow collective ops to work in multi-threaded native runtimes.

The native version also incorporated API improvements we wished to implement in Python c10d_functional:

- Removed `ranks` and `group_size` from collective op signatures which were proven to be redundant.
- Use tensor storage as opposed to `void*` to resolve in-flight work.

The native process group registration/resolution mechansim is only used for native c10d_functional in the PR. It will become the single source of truth in upcoming PRs.

The upcoming PRs will implement Inductor/AOTInductor support for c10d_functional, after which native c10d_functional will replace Python c10d_functional.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110570
Approved by: https://github.com/wanchaol
2023-10-25 22:56:06 +00:00