pytorch/torch/csrc/distributed/c10d/Work.cpp
Will Feng 4ee514144b [c10d][Partial-Graph Overlap] Support calling .wait_tensor() on output tensor of eager async_op=True collective if under allow_inflight_collective_as_graph_input_ctx() context manager (#137763)
This PR aims to support the following use case:
```python
def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y

x = torch.ones(1280, 1280, device="cuda") + self.rank
with allow_inflight_collective_as_graph_input_ctx():
    y = all_reduce_eager(x)
    z = all_reduce_wait_compiled(y)
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.

This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.

----

**Update**: Did two items to prevent regression to existing use cases:

1. Added memory-stressed test case to test_c10d_nccl.py `test_unwaited` to cover existing user's "not calling work.wait() for non-functional collective" use case
2. Gated all new `register_work()` / `unregister_work()` calls with `c10d::allow_inflight_collective_as_graph_input()` check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).

The risk of this new version of PR causing regression should be very low.

------

Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`

------

Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
2024-10-29 03:31:19 +00:00

201 lines
5.3 KiB
C++

#include <ATen/ThreadLocalState.h>
#include <distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <utility>
namespace c10d {
Work::Work(
int rank,
OpType opType,
const char* profilingTitle,
const std::optional<std::vector<at::Tensor>>& inputTensors)
: rank_(rank), opType_(opType) {
if (profilingTitle != nullptr) {
auto recordingFunction =
std::make_shared<at::RecordFunction>(at::RecordScope::USER_SCOPE);
if (recordingFunction->isActive()) {
// Work events follow a future like pattern and can potentially be marked
// as complete by different threads, so explicitly set as async event.
recordingFunction->_setAsync();
// Passing input tensor to recordFunction allows for shape information in
// profiling output.
std::vector<c10::IValue> inputs;
if (inputTensors) {
inputs.reserve(inputTensors->size());
for (const auto& tensor : *inputTensors) {
inputs.emplace_back(tensor);
}
}
recordingFunction->before(
profilingTitle,
c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size()));
std::function<void()> end_handler = [recordingFunction]() {
recordingFunction->end();
};
recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler);
}
}
}
OpType Work::retrieveOpType() const {
return opType_;
}
Work::~Work() = default;
bool Work::isCompleted() {
std::lock_guard<std::mutex> lock(mutex_);
return completed_;
}
bool Work::isSuccess() const {
std::lock_guard<std::mutex> lock(mutex_);
return !exception_;
}
std::exception_ptr Work::exception() const {
std::lock_guard<std::mutex> lock(mutex_);
return exception_;
}
int Work::sourceRank() const {
TORCH_CHECK(
false,
"sourceRank() may only be called on work objects "
"that correspond to a recv or recv-from-any call.");
}
std::vector<at::Tensor> Work::result() {
TORCH_CHECK(false, "result() not implemented.");
}
void Work::synchronize() {
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::unregister_work(
c10::intrusive_ptr<Work>::unsafe_reclaim_from_nonowning(this));
}
}
bool Work::wait(std::chrono::milliseconds timeout) {
std::unique_lock<std::mutex> lock(mutex_);
if (timeout == kNoTimeout) {
// This waits without a timeout.
cv_.wait(lock, [&] { return completed_; });
} else {
// Waits for the user-provided timeout.
cv_.wait_for(lock, timeout, [&] { return completed_; });
if (!completed_) {
// Throw exception if the wait operation timed out and the work was not
// completed.
TORCH_CHECK(false, "Operation timed out!");
}
}
if (exception_) {
std::rethrow_exception(exception_);
}
synchronize();
// Always return true, because abort API is not implemented.
return true;
}
void Work::abort() {
TORCH_CHECK(false, "Work::abort not implemented.");
}
c10::intrusive_ptr<c10::ivalue::Future> Work::getFuture(){
TORCH_CHECK(false, "Work::getFuture not implemented.")}
c10::intrusive_ptr<c10::ivalue::Future> Work::getFutureResult() {
TORCH_CHECK(false, "Work::getFutureResult not implemented.")
}
void Work::finish(std::exception_ptr exception) {
std::unique_lock<std::mutex> lock(mutex_);
completed_ = true;
exception_ = std::move(exception);
if (recordFunctionEndCallback_) {
recordFunctionEndCallback_();
recordFunctionEndCallback_ = nullptr;
}
lock.unlock();
cv_.notify_all();
}
void Work::finishAndThrow(std::exception_ptr exception) {
std::unique_lock<std::mutex> lock(mutex_);
completed_ = true;
exception_ = std::move(exception);
if (recordFunctionEndCallback_) {
recordFunctionEndCallback_();
recordFunctionEndCallback_ = nullptr;
}
if (exception_) {
std::rethrow_exception(exception_);
}
}
float Work::getDuration() const {
TORCH_CHECK(false, "This Backend doesn't support getDuration.");
}
uint64_t Work::getSequencenumber() const {
TORCH_CHECK(false, "This Backend doesn't support getSequencenumber.");
}
class FutureWrappingWork : public Work {
public:
FutureWrappingWork(c10::intrusive_ptr<c10::ivalue::Future> fut)
: Work(), _fut(std::move(fut)) {}
~FutureWrappingWork() override = default;
bool isCompleted() override {
return _fut->completed();
}
bool isSuccess() const override {
return _fut->hasValue();
}
std::exception_ptr exception() const override {
return _fut->exception_ptr();
}
int sourceRank() const override {
TORCH_CHECK(false, "FutureWrappingWork::sourceRank() not implemented");
}
std::vector<at::Tensor> result() override {
return _fut->value().toPyObjectHolder()->extractTensors();
}
bool wait(std::chrono::milliseconds timeout) override {
// FIXME
TORCH_CHECK(
timeout == kNoTimeout,
"FutureWrappingWork::wait() with finite timeout not implemented");
_fut->wait();
return true;
}
void abort() override {
TORCH_CHECK(false, "FutureWrappingWork::abort() not implemented");
}
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
return _fut;
}
private:
c10::intrusive_ptr<c10::ivalue::Future> _fut;
};
c10::intrusive_ptr<Work> Work::create_from_future(
const c10::intrusive_ptr<c10::ivalue::Future>& future) {
return c10::make_intrusive<FutureWrappingWork>(future);
}
} // namespace c10d