mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR aims to support the following use case:
```python
def all_reduce_eager(x):
y = x * x
req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
assert isinstance(req, torch.distributed.Work)
return y
@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
torch.ops.c10d_functional.wait_tensor(y)
return y * y
x = torch.ones(1280, 1280, device="cuda") + self.rank
with allow_inflight_collective_as_graph_input_ctx():
y = all_reduce_eager(x)
z = all_reduce_wait_compiled(y)
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.
This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.
----
**Update**: Did two items to prevent regression to existing use cases:
1. Added memory-stressed test case to test_c10d_nccl.py `test_unwaited` to cover existing user's "not calling work.wait() for non-functional collective" use case
2. Gated all new `register_work()` / `unregister_work()` calls with `c10d::allow_inflight_collective_as_graph_input()` check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).
The risk of this new version of PR causing regression should be very low.
------
Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`
------
Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
331 lines
9.6 KiB
C++
331 lines
9.6 KiB
C++
#include <ATen/ThreadLocalState.h>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
|
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
|
|
|
|
#include <c10/util/Logging.h>
|
|
#include <fmt/format.h>
|
|
#include <string_view>
|
|
|
|
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
|
|
#include <torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp>
|
|
|
|
namespace c10d {
|
|
|
|
std::string opTypeToString(OpType opType) {
|
|
switch (opType) {
|
|
case OpType::BROADCAST:
|
|
return "BROADCAST";
|
|
case OpType::ALLREDUCE:
|
|
return "ALLREDUCE";
|
|
case OpType::ALLREDUCE_COALESCED:
|
|
return "ALLREDUCE_COALESCED";
|
|
case OpType::REDUCE:
|
|
return "REDUCE";
|
|
case OpType::ALLGATHER:
|
|
return "ALLGATHER";
|
|
case OpType::_ALLGATHER_BASE:
|
|
return "_ALLGATHER_BASE";
|
|
case OpType::ALLGATHER_COALESCED:
|
|
return "ALLGATHER_COALESCED";
|
|
case OpType::GATHER:
|
|
return "GATHER";
|
|
case OpType::SCATTER:
|
|
return "SCATTER";
|
|
case OpType::REDUCE_SCATTER:
|
|
return "REDUCE_SCATTER";
|
|
case OpType::ALLTOALL_BASE:
|
|
return "ALLTOALL_BASE";
|
|
case OpType::ALLTOALL:
|
|
return "ALLTOALL";
|
|
case OpType::SEND:
|
|
return "SEND";
|
|
case OpType::RECV:
|
|
return "RECV";
|
|
case OpType::RECVANYSOURCE:
|
|
return "RECVANYSOURCE";
|
|
case OpType::BARRIER:
|
|
return "BARRIER";
|
|
case OpType::UNKNOWN:
|
|
return "UNKNOWN";
|
|
case OpType::_REDUCE_SCATTER_BASE:
|
|
return "_REDUCE_SCATTER_BASE";
|
|
case OpType::COALESCED:
|
|
return "COALESCED";
|
|
case OpType::_ALLREDUCE_SPARSE:
|
|
return "_ALLREDUCE_SPARSE";
|
|
default:
|
|
TORCH_INTERNAL_ASSERT(false, "Unknown op type!");
|
|
}
|
|
return "UNKNOWN";
|
|
}
|
|
|
|
bool isP2POp(OpType opType, bool batchP2P /*= false*/) {
|
|
if (batchP2P)
|
|
return false;
|
|
return opType == OpType::SEND || opType == OpType::RECV ||
|
|
opType == OpType::RECVANYSOURCE;
|
|
}
|
|
|
|
c10::intrusive_ptr<Backend> ProcessGroup::getBackend(
|
|
c10::DeviceType deviceType) {
|
|
// If there is a backend associated with this device type then return it
|
|
if (deviceTypeToBackend_.find(deviceType) != deviceTypeToBackend_.end()) {
|
|
return deviceTypeToBackend_.at(deviceType);
|
|
}
|
|
|
|
// Get the backend type associated with the device
|
|
ProcessGroup::BackendType backendType{ProcessGroup::BackendType::UNDEFINED};
|
|
try {
|
|
backendType = deviceTypeToBackendType_.at(deviceType);
|
|
} catch (const std::out_of_range& e) {
|
|
TORCH_CHECK(
|
|
false, "No backend type associated with device type ", deviceType);
|
|
}
|
|
|
|
// Check if the backend has already been initialized
|
|
if (backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end()) {
|
|
auto backend = backendTypeToBackend_.at(backendType);
|
|
deviceTypeToBackend_[deviceType] = backend;
|
|
return backend;
|
|
}
|
|
|
|
TORCH_CHECK(
|
|
false,
|
|
"Could not retrieve or create the backend ",
|
|
backendType,
|
|
" for device type ",
|
|
deviceType);
|
|
}
|
|
|
|
ProcessGroup::ProcessGroup(
|
|
c10::intrusive_ptr<::c10d::Store> store,
|
|
int rank,
|
|
int size)
|
|
: store_(std::move(store)),
|
|
rank_(rank),
|
|
size_(size),
|
|
backendType_(BackendType::UNDEFINED),
|
|
dist_debug_level_(debug_level()) {
|
|
C10_LOG_API_USAGE_ONCE("c10d.process_group");
|
|
}
|
|
|
|
ProcessGroup::ProcessGroup(int rank, int size)
|
|
: rank_(rank), size_(size), backendType_(BackendType::UNDEFINED) {}
|
|
|
|
ProcessGroup::~ProcessGroup() = default;
|
|
|
|
void ProcessGroup::init() {
|
|
C10_LOG_API_USAGE_ONCE(
|
|
fmt::format("c10d.process_group_{}", getBackendName()));
|
|
}
|
|
|
|
const std::string& ProcessGroup::getGroupName() const {
|
|
TORCH_CHECK(!deviceTypeToBackend_.empty(), "ProcessGroup name not set");
|
|
return deviceTypeToBackend_.begin()->second->getGroupUid();
|
|
}
|
|
|
|
void ProcessGroup::setGroupName(const std::string& name) {
|
|
for (auto& kv : deviceTypeToBackend_) {
|
|
kv.second->setGroupUid(name);
|
|
}
|
|
}
|
|
|
|
const std::string& ProcessGroup::getGroupDesc() const {
|
|
return pg_desc_;
|
|
}
|
|
|
|
void ProcessGroup::setGroupDesc(const std::string& name) {
|
|
pg_desc_ = name;
|
|
// Also set the group desc for all backends
|
|
for (auto& kv : deviceTypeToBackend_) {
|
|
kv.second->setGroupDesc(name);
|
|
}
|
|
}
|
|
|
|
void ProcessGroup::enableCollectivesTiming() {
|
|
for (auto& kv : deviceTypeToBackend_) {
|
|
kv.second->enableCollectivesTiming();
|
|
}
|
|
}
|
|
|
|
void ProcessGroup::release_resources() {
|
|
store_.reset();
|
|
deviceTypeToBackend_.clear();
|
|
backendTypeToBackend_.clear();
|
|
}
|
|
|
|
} // namespace c10d
|
|
|
|
namespace {
|
|
|
|
class WorkRegistry {
|
|
public:
|
|
void register_work(
|
|
const at::Tensor& tensor,
|
|
const c10::intrusive_ptr<c10d::Work>& work) {
|
|
if (!tensor.has_storage()) {
|
|
TORCH_WARN_ONCE(
|
|
"Registering collective work for tensor without storage is not supported. "
|
|
"Calling c10d_functional.wait_tensor() on this tensor will not wait for the collective to complete. "
|
|
"Unsupported tensor type: " +
|
|
tensor.toString());
|
|
return;
|
|
}
|
|
auto storage = tensor.storage().getWeakStorageImpl();
|
|
std::unique_lock lock(lock_);
|
|
|
|
auto it = registry_.find(storage);
|
|
if (it == registry_.end()) {
|
|
registry_.emplace(
|
|
std::move(storage),
|
|
std::vector<c10::intrusive_ptr<c10d::Work>>{work});
|
|
} else {
|
|
// There is no guarantee that the previous work object for this
|
|
// tensor storage is completed before the new work object is registered.
|
|
// Therefore we need to maintain a list of work objects for each tensor
|
|
// storage.
|
|
|
|
// Check if work is already in the list
|
|
bool work_exists = false;
|
|
for (const auto& existing_work : it->second) {
|
|
if (existing_work == work) {
|
|
work_exists = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Only append if work is not already in the list
|
|
if (!work_exists) {
|
|
it->second.push_back(work);
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<c10::intrusive_ptr<c10d::Work>> pop_works(
|
|
const at::Tensor& tensor) {
|
|
const auto storage = tensor.storage().getWeakStorageImpl();
|
|
std::unique_lock lock(lock_);
|
|
auto it = registry_.find(storage);
|
|
if (it == registry_.end()) {
|
|
return {};
|
|
}
|
|
auto works = it->second;
|
|
registry_.erase(it);
|
|
return works;
|
|
}
|
|
|
|
void unregister_work(const c10::intrusive_ptr<c10d::Work>& work) {
|
|
std::unique_lock lock(lock_);
|
|
for (auto it = registry_.begin(); it != registry_.end();) {
|
|
std::vector<c10::intrusive_ptr<c10d::Work>> nonmatching_works;
|
|
for (const auto& _work : it->second) {
|
|
if (_work != work) {
|
|
nonmatching_works.push_back(_work);
|
|
}
|
|
}
|
|
if (nonmatching_works.empty()) {
|
|
it = registry_.erase(it);
|
|
} else {
|
|
it->second = std::move(nonmatching_works);
|
|
++it;
|
|
}
|
|
}
|
|
}
|
|
|
|
size_t get_work_registry_size() {
|
|
std::unique_lock lock(lock_);
|
|
size_t total_size = 0;
|
|
for (const auto& [storage, works] : registry_) {
|
|
total_size += works.size();
|
|
}
|
|
return total_size;
|
|
}
|
|
|
|
void set_allow_inflight_collective_as_graph_input(bool value) {
|
|
std::unique_lock lock(lock_);
|
|
allow_inflight_collective_as_graph_input_ = value;
|
|
}
|
|
|
|
bool allow_inflight_collective_as_graph_input() {
|
|
std::unique_lock lock(lock_);
|
|
return allow_inflight_collective_as_graph_input_;
|
|
}
|
|
|
|
~WorkRegistry() {
|
|
// If there are still unwaited work objects, their corresponding process
|
|
// groups should have already been destroyed at this stage. Any attempts to
|
|
// wait for these work objects or to destroy them will only result in
|
|
// confusing errors. Therefore, we simply issue a warning and intentionally
|
|
// allow the unwaited work objects to leak.
|
|
size_t registry_size = get_work_registry_size();
|
|
if (registry_size > 0) {
|
|
TORCH_WARN(
|
|
"At the time of process termination, there are still ",
|
|
registry_size,
|
|
" unwaited collective calls. "
|
|
"Please review your program to ensure that:\n"
|
|
"1. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,\n"
|
|
"2. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective "
|
|
"called under `with allow_inflight_collective_as_graph_input_ctx():`,\n"
|
|
"before the output tensors of the collective are used.");
|
|
}
|
|
for (auto& it : registry_) {
|
|
for (auto& work : it.second) {
|
|
work.release();
|
|
}
|
|
}
|
|
}
|
|
|
|
private:
|
|
std::unordered_map<
|
|
c10::weak_intrusive_ptr<c10::StorageImpl>,
|
|
std::vector<c10::intrusive_ptr<c10d::Work>>>
|
|
registry_;
|
|
bool allow_inflight_collective_as_graph_input_ = false;
|
|
std::mutex lock_;
|
|
};
|
|
|
|
static WorkRegistry process_registry;
|
|
|
|
} // namespace
|
|
|
|
namespace c10d {
|
|
|
|
void register_work(
|
|
const at::Tensor& tensor,
|
|
const c10::intrusive_ptr<c10d::Work>& work) {
|
|
RankLocal<WorkRegistry>::get().register_work(tensor, work);
|
|
}
|
|
|
|
at::Tensor wait_tensor(const at::Tensor& tensor) {
|
|
auto works = RankLocal<WorkRegistry>::get().pop_works(tensor);
|
|
for (const auto& work : works) {
|
|
work->wait();
|
|
}
|
|
return tensor;
|
|
}
|
|
|
|
void unregister_work(const c10::intrusive_ptr<c10d::Work>& work) {
|
|
RankLocal<WorkRegistry>::get().unregister_work(work);
|
|
}
|
|
|
|
size_t get_work_registry_size() {
|
|
return RankLocal<WorkRegistry>::get().get_work_registry_size();
|
|
}
|
|
|
|
void set_allow_inflight_collective_as_graph_input(bool value) {
|
|
return RankLocal<WorkRegistry>::get()
|
|
.set_allow_inflight_collective_as_graph_input(value);
|
|
}
|
|
|
|
bool allow_inflight_collective_as_graph_input() {
|
|
return RankLocal<WorkRegistry>::get()
|
|
.allow_inflight_collective_as_graph_input();
|
|
}
|
|
|
|
} // namespace c10d
|