pytorch/torch/_C
Will Feng 4ee514144b [c10d][Partial-Graph Overlap] Support calling .wait_tensor() on output tensor of eager async_op=True collective if under allow_inflight_collective_as_graph_input_ctx() context manager (#137763)
This PR aims to support the following use case:
```python
def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y

x = torch.ones(1280, 1280, device="cuda") + self.rank
with allow_inflight_collective_as_graph_input_ctx():
    y = all_reduce_eager(x)
    z = all_reduce_wait_compiled(y)
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.

This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.

----

**Update**: Did two items to prevent regression to existing use cases:

1. Added memory-stressed test case to test_c10d_nccl.py `test_unwaited` to cover existing user's "not calling work.wait() for non-functional collective" use case
2. Gated all new `register_work()` / `unregister_work()` calls with `c10d::allow_inflight_collective_as_graph_input()` check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).

The risk of this new version of PR causing regression should be very low.

------

Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`

------

Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
2024-10-29 03:31:19 +00:00
..
_dynamo Back out "[compiled autograd] tls access helpers (#138061)" and Back out "[compiled autograd] Compiled autograd configs in TLS (#137821)" (#139086) 2024-10-28 23:37:05 +00:00
__init__.pyi.in reimport pr137735 due to merging check issues (#138959) 2024-10-27 16:31:34 +00:00
_aoti.pyi [aoti] Add cpp loader (#135374) 2024-09-11 03:00:01 +00:00
_autograd.pyi [Profiler] Add API for Dynamic Activity Toggling [2/n] (#133035) 2024-08-09 21:54:54 +00:00
_cpu.pyi Extend vectorization with SVE(ARM) with Torch Compile (Inductor) (#134672) 2024-10-10 13:20:40 +00:00
_cudnn.pyi
_cusparselt.pyi [sparse] Add cuSPARSELt as a backend (#128534) 2024-08-21 22:06:07 +00:00
_distributed_autograd.pyi Use Generic TypeAlias (PEP 585) and Union Type (PEP 604) in .pyi stub files (#129419) 2024-06-29 09:23:39 +00:00
_distributed_c10d.pyi [c10d][Partial-Graph Overlap] Support calling .wait_tensor() on output tensor of eager async_op=True collective if under allow_inflight_collective_as_graph_input_ctx() context manager (#137763) 2024-10-29 03:31:19 +00:00
_distributed_rpc_testing.pyi Use Generic TypeAlias (PEP 585) and Union Type (PEP 604) in .pyi stub files (#129419) 2024-06-29 09:23:39 +00:00
_distributed_rpc.pyi Use Generic TypeAlias (PEP 585) and Union Type (PEP 604) in .pyi stub files (#129419) 2024-06-29 09:23:39 +00:00
_functions.pyi Use Generic TypeAlias (PEP 585) and Union Type (PEP 604) in .pyi stub files (#129419) 2024-06-29 09:23:39 +00:00
_functorch.pyi Use Generic TypeAlias (PEP 585) and Union Type (PEP 604) in .pyi stub files (#129419) 2024-06-29 09:23:39 +00:00
_instruction_counter.pyi Add compile time instruction count metric (#133834) 2024-08-27 23:29:02 +00:00
_itt.pyi
_lazy_ts_backend.pyi Use Generic TypeAlias (PEP 585) and Union Type (PEP 604) in .pyi stub files (#129419) 2024-06-29 09:23:39 +00:00
_lazy.pyi Use Generic TypeAlias (PEP 585) and Union Type (PEP 604) in .pyi stub files (#129419) 2024-06-29 09:23:39 +00:00
_monitor.pyi Use Generic TypeAlias (PEP 585) and Union Type (PEP 604) in .pyi stub files (#129419) 2024-06-29 09:23:39 +00:00
_nn.pyi.in Add padding_side to pad_sequence with "left" and "right" options ("right" as default) (#131884) 2024-08-07 15:53:07 +00:00
_nvtx.pyi Flip default value for mypy disallow_untyped_defs [1/11] (#127838) 2024-06-08 18:16:33 +00:00
_onnx.pyi [1/N] [Caffe2] Remove caffe2_aten_fallback code (#128675) 2024-06-17 21:25:59 +00:00
_profiler.pyi [BE] Format uncategorized Python files with ruff format (#132576) 2024-08-04 17:13:31 +00:00
_VariableFunctions.pyi.in Flip default value for mypy disallow_untyped_defs [1/11] (#127838) 2024-06-08 18:16:33 +00:00
_verbose.pyi
build.bzl
return_types.pyi.in [torchgen] reference generated comment to actual location of the generator and template (#130020) 2024-07-05 21:47:14 +00:00