Revert "[PGNCCL] Launch kernel on current stream & remove record_stream entirely (#148590)"

This reverts commit ef6296e7f2.

Reverted https://github.com/pytorch/pytorch/pull/148590 on behalf of https://github.com/izaitsevfb due to reverted internally, see D71292427 ([comment](https://github.com/pytorch/pytorch/pull/148590#issuecomment-2731114626))
This commit is contained in:
PyTorch MergeBot 2025-03-17 22:43:15 +00:00
parent a16ada41b9
commit afa1eda901
11 changed files with 362 additions and 411 deletions

View File

@ -363,9 +363,6 @@ class TestDebugInfoWriter : public c10d::DebugInfoWriter {
}; };
TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
// Note (kwen2501) 03/07/2025
// TODO: re-enable
GTEST_SKIP() << "Skipping test as the trace write seems unstable.";
int heartBeatIntervalInSec = 2; int heartBeatIntervalInSec = 2;
std::string timeInterval = std::to_string(heartBeatIntervalInSec); std::string timeInterval = std::to_string(heartBeatIntervalInSec);
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0); ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0);

View File

@ -126,9 +126,6 @@ ALLOW_LIST = [
("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)), ("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)),
("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)), ("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)),
("aten::all_reduce", datetime.date(9999, 1, 30)), ("aten::all_reduce", datetime.date(9999, 1, 30)),
# These ops are defined in torch/csrc/distributed/c10d/Ops.cpp
# TODO: add back restriction when c10d ops can be exported
("c10d::.*", datetime.date(9999, 1, 1)),
] ]
ALLOW_LIST_COMPILED = [ ALLOW_LIST_COMPILED = [

View File

@ -2,7 +2,7 @@
# mypy: disable-error-code="type-arg" # mypy: disable-error-code="type-arg"
from datetime import timedelta from datetime import timedelta
from enum import Enum from enum import Enum
from typing import Any, Optional, overload from typing import Any, overload
import torch import torch
from torch import Tensor from torch import Tensor
@ -139,8 +139,6 @@ class BroadcastOptions:
class AllreduceOptions: class AllreduceOptions:
reduceOp: ReduceOp reduceOp: ReduceOp
timeout: timedelta timeout: timedelta
asyncOp: bool
sparseIndices: Optional[Tensor]
class AllreduceCoalescedOptions(AllreduceOptions): ... class AllreduceCoalescedOptions(AllreduceOptions): ...
@ -149,7 +147,6 @@ class ReduceOptions:
rootRank: int rootRank: int
rootTensor: int rootTensor: int
timeout: timedelta timeout: timedelta
asyncOp: bool
class AllgatherOptions: class AllgatherOptions:
timeout: timedelta timeout: timedelta
@ -158,7 +155,6 @@ class AllgatherOptions:
class GatherOptions: class GatherOptions:
rootRank: int rootRank: int
timeout: timedelta timeout: timedelta
asyncOp: bool
class ScatterOptions: class ScatterOptions:
rootRank: int rootRank: int
@ -174,11 +170,9 @@ class BarrierOptions:
device_ids: list[int] device_ids: list[int]
device: torch.device device: torch.device
timeout: timedelta timeout: timedelta
asyncOp: bool
class AllToAllOptions: class AllToAllOptions:
timeout: timedelta timeout: timedelta
asyncOp: bool
class Store: class Store:
def set(self, key: str, value: str): ... def set(self, key: str, value: str): ...

View File

@ -17,37 +17,37 @@ TORCH_LIBRARY(c10d, m) {
.def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); }); .def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
m.class_<ReduceOp>("ReduceOp").def(torch::init<>()); m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
m.def( m.def(
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def( m.def(
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def( m.def(
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def( m.def(
"allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)"); "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
m.def( m.def(
"_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)"); "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
m.def( m.def(
"allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work"); "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
m.def( m.def(
"allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work"); "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
m.def( m.def(
"reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def( m.def(
"_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)"); "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
m.def( m.def(
"reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def( m.def(
"reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def( m.def(
"gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def( m.def(
"scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def( m.def(
"alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def( m.def(
"alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def( m.def(
"barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout) -> __torch__.torch.classes.c10d.Work");
m.def( m.def(
"monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()"); "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()");
m.def( m.def(
@ -118,7 +118,6 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1)
const c10::intrusive_ptr<ReduceOp>& reduce_op, \ const c10::intrusive_ptr<ReduceOp>& reduce_op, \
int64_t root_rank, \ int64_t root_rank, \
int64_t root_tensor, \ int64_t root_tensor, \
bool asyncOp, \
int64_t timeout) { \ int64_t timeout) { \
auto tensor_vec = tensors.vec(); \ auto tensor_vec = tensors.vec(); \
return process_group->getBackend(c10::DeviceType::DEV) \ return process_group->getBackend(c10::DeviceType::DEV) \
@ -128,8 +127,7 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1)
*reduce_op.get(), \ *reduce_op.get(), \
root_rank, \ root_rank, \
root_tensor, \ root_tensor, \
std::chrono::milliseconds(timeout), \ std::chrono::milliseconds(timeout)}); \
asyncOp}); \
} }
IMPL_REDUCE(CPU) IMPL_REDUCE(CPU)
@ -171,13 +169,12 @@ IMPL_BROADCAST(PrivateUse1)
const c10::intrusive_ptr<ProcessGroup>& process_group, \ const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \ const c10::intrusive_ptr<ReduceOp>& reduce_op, \
const std::optional<at::Tensor>& sparse_indices, \ const std::optional<at::Tensor>& sparse_indices, \
bool asyncOp, \
int64_t timeout) { \ int64_t timeout) { \
auto tensor_vec = tensors.vec(); \ auto tensor_vec = tensors.vec(); \
auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \ auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \
tensor_vec, \ tensor_vec, \
AllreduceOptions{ \ AllreduceOptions{ \
*reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp}); \ *reduce_op.get(), std::chrono::milliseconds(timeout)}); \
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \ return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
std::move(tensor_vec), work); \ std::move(tensor_vec), work); \
} }
@ -191,13 +188,11 @@ IMPL_ALLREDUCE(PrivateUse1)
at::TensorList tensors, \ at::TensorList tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \ const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \ const c10::intrusive_ptr<ReduceOp>& reduce_op, \
bool asyncOp, \
int64_t timeout) { \ int64_t timeout) { \
auto tensor_vec = tensors.vec(); \ auto tensor_vec = tensors.vec(); \
AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \ AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \
opts.reduceOp = *reduce_op.get(); \ opts.reduceOp = *reduce_op.get(); \
opts.timeout = std::chrono::milliseconds(timeout); \ opts.timeout = std::chrono::milliseconds(timeout); \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV) \ return process_group->getBackend(c10::DeviceType::DEV) \
->allreduce_coalesced(tensor_vec, opts); \ ->allreduce_coalesced(tensor_vec, opts); \
} }
@ -214,13 +209,12 @@ IMPL_ALLREDUCE_COALESCED(PrivateUse1)
const std::vector<std::vector<at::Tensor>>& output_tensors, \ const std::vector<std::vector<at::Tensor>>& output_tensors, \
at::TensorList input_tensors, \ at::TensorList input_tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \ const c10::intrusive_ptr<ProcessGroup>& process_group, \
bool asyncOp, \
int64_t timeout) { \ int64_t timeout) { \
auto input_tensors_vec = input_tensors.vec(); \ auto input_tensors_vec = input_tensors.vec(); \
auto work = process_group->getBackend(c10::DeviceType::DEV) -> allgather( \ auto work = process_group->getBackend(c10::DeviceType::DEV) -> allgather( \
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \ const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
input_tensors_vec, \ input_tensors_vec, \
AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); \ AllgatherOptions{std::chrono::milliseconds(timeout)}); \
return std:: \ return std:: \
tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>( \ tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>( \
output_tensors, work); \ output_tensors, work); \
@ -255,16 +249,12 @@ IMPL__ALLGATHER_BASE(PrivateUse1)
c10::intrusive_ptr<Work> allgather_coalesced_##DEV( \ c10::intrusive_ptr<Work> allgather_coalesced_##DEV( \
const std::vector<std::vector<at::Tensor>>& output_lists, \ const std::vector<std::vector<at::Tensor>>& output_lists, \
const at::TensorList& input_list, \ const at::TensorList& input_list, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \ const c10::intrusive_ptr<ProcessGroup>& process_group) { \
bool asyncOp) { \
auto input_list_vec = input_list.vec(); \ auto input_list_vec = input_list.vec(); \
auto opts = AllgatherOptions{}; \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV) \ return process_group->getBackend(c10::DeviceType::DEV) \
->allgather_coalesced( \ ->allgather_coalesced( \
const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists), \ const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists), \
input_list_vec, \ input_list_vec); \
opts); \
} }
IMPL_ALLGATHER_COALESCED(CPU) IMPL_ALLGATHER_COALESCED(CPU)
@ -275,14 +265,11 @@ IMPL_ALLGATHER_COALESCED(PrivateUse1)
c10::intrusive_ptr<c10d::Work> allgather_into_tensor_coalesced_##DEV( \ c10::intrusive_ptr<c10d::Work> allgather_into_tensor_coalesced_##DEV( \
at::TensorList outputs, \ at::TensorList outputs, \
at::TensorList inputs, \ at::TensorList inputs, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \ const c10::intrusive_ptr<ProcessGroup>& process_group) { \
bool asyncOp) { \
auto output_vec = outputs.vec(); \ auto output_vec = outputs.vec(); \
auto input_vec = inputs.vec(); \ auto input_vec = inputs.vec(); \
auto opts = AllgatherOptions{}; \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV) \ return process_group->getBackend(c10::DeviceType::DEV) \
->allgather_into_tensor_coalesced(output_vec, input_vec, opts); \ ->allgather_into_tensor_coalesced(output_vec, input_vec); \
} }
IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU) IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU)
@ -296,7 +283,6 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1)
const std::vector<std::vector<at::Tensor>>& input_tensors, \ const std::vector<std::vector<at::Tensor>>& input_tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \ const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \ const c10::intrusive_ptr<ReduceOp>& reduce_op, \
bool asyncOp, \
int64_t timeout) { \ int64_t timeout) { \
auto output_tensors_vec = output_tensors.vec(); \ auto output_tensors_vec = output_tensors.vec(); \
auto work = \ auto work = \
@ -304,9 +290,7 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1)
output_tensors_vec, \ output_tensors_vec, \
const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), \ const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), \
ReduceScatterOptions{ \ ReduceScatterOptions{ \
*reduce_op.get(), \ *reduce_op.get(), std::chrono::milliseconds(timeout)}); \
std::chrono::milliseconds(timeout), \
asyncOp}); \
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \ return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
output_tensors_vec, work); \ output_tensors_vec, work); \
} }
@ -345,7 +329,6 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1)
at::TensorList inputs, \ at::TensorList inputs, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \ const c10::intrusive_ptr<ProcessGroup>& process_group, \
const c10::intrusive_ptr<ReduceOp>& reduce_op, \ const c10::intrusive_ptr<ReduceOp>& reduce_op, \
bool asyncOp, \
int64_t timeout) { \ int64_t timeout) { \
auto output_vec = outputs.vec(); \ auto output_vec = outputs.vec(); \
auto input_vec = inputs.vec(); \ auto input_vec = inputs.vec(); \
@ -354,9 +337,7 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1)
output_vec, \ output_vec, \
input_vec, \ input_vec, \
ReduceScatterOptions{ \ ReduceScatterOptions{ \
*reduce_op.get(), \ *reduce_op.get(), std::chrono::milliseconds(timeout)}); \
std::chrono::milliseconds(timeout), \
asyncOp}); \
} }
IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU) IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU)
@ -369,15 +350,13 @@ IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1)
const at::TensorList& input_tensors, \ const at::TensorList& input_tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \ const c10::intrusive_ptr<ProcessGroup>& process_group, \
int64_t root_rank, \ int64_t root_rank, \
bool asyncOp, \
int64_t timeout) { \ int64_t timeout) { \
auto input_tensors_vec = input_tensors.vec(); \ auto input_tensors_vec = input_tensors.vec(); \
return process_group->getBackend(c10::DeviceType::DEV) \ return process_group->getBackend(c10::DeviceType::DEV) \
->gather( \ ->gather( \
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \ const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), \
input_tensors_vec, \ input_tensors_vec, \
GatherOptions{ \ GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); \
root_rank, std::chrono::milliseconds(timeout), asyncOp}); \
} }
IMPL_GATHER(CPU) IMPL_GATHER(CPU)
@ -412,14 +391,13 @@ IMPL_SCATTER(PrivateUse1)
const at::TensorList& output_tensors, \ const at::TensorList& output_tensors, \
const at::TensorList& input_tensors, \ const at::TensorList& input_tensors, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \ const c10::intrusive_ptr<ProcessGroup>& process_group, \
bool asyncOp, \
int64_t timeout) { \ int64_t timeout) { \
auto output_tensors_vec = output_tensors.vec(); \ auto output_tensors_vec = output_tensors.vec(); \
auto input_tensors_vec = input_tensors.vec(); \ auto input_tensors_vec = input_tensors.vec(); \
auto work = process_group->getBackend(c10::DeviceType::DEV) -> alltoall( \ auto work = process_group->getBackend(c10::DeviceType::DEV) -> alltoall( \
output_tensors_vec, \ output_tensors_vec, \
input_tensors_vec, \ input_tensors_vec, \
AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \ AllToAllOptions{std::chrono::milliseconds(timeout)}); \
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \ return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( \
std::move(output_tensors_vec), work); \ std::move(output_tensors_vec), work); \
} }
@ -428,22 +406,21 @@ IMPL_ALLTOALL(CPU)
IMPL_ALLTOALL(CUDA) IMPL_ALLTOALL(CUDA)
IMPL_ALLTOALL(PrivateUse1) IMPL_ALLTOALL(PrivateUse1)
#define IMPL_ALLTOALL_BASE(DEV) \ #define IMPL_ALLTOALL_BASE(DEV) \
c10::intrusive_ptr<Work> alltoall_base_##DEV( \ c10::intrusive_ptr<Work> alltoall_base_##DEV( \
at::Tensor& output, \ at::Tensor& output, \
at::Tensor& input, \ at::Tensor& input, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \ const c10::intrusive_ptr<ProcessGroup>& process_group, \
std::vector<int64_t> output_split_sizes, \ std::vector<int64_t> output_split_sizes, \
std::vector<int64_t> input_split_sizes, \ std::vector<int64_t> input_split_sizes, \
bool asyncOp, \ int64_t timeout) { \
int64_t timeout) { \ return process_group->getBackend(c10::DeviceType::DEV) \
return process_group->getBackend(c10::DeviceType::DEV) \ ->alltoall_base( \
->alltoall_base( \ output, \
output, \ input, \
input, \ output_split_sizes, \
output_split_sizes, \ input_split_sizes, \
input_split_sizes, \ AllToAllOptions{std::chrono::milliseconds(timeout)}); \
AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \
} }
IMPL_ALLTOALL_BASE(CPU) IMPL_ALLTOALL_BASE(CPU)
@ -451,18 +428,15 @@ IMPL_ALLTOALL_BASE(CUDA)
IMPL_ALLTOALL_BASE(PrivateUse1) IMPL_ALLTOALL_BASE(PrivateUse1)
// NOLINTBEGIN(performance-unnecessary-value-param) // NOLINTBEGIN(performance-unnecessary-value-param)
#define IMPL_BARRIER(DEV) \ #define IMPL_BARRIER(DEV) \
c10::intrusive_ptr<Work> barrier##DEV( \ c10::intrusive_ptr<Work> barrier##DEV( \
at::Tensor /* unused */, \ at::Tensor /* unused */, \
const c10::intrusive_ptr<ProcessGroup>& process_group, \ const c10::intrusive_ptr<ProcessGroup>& process_group, \
const std::vector<int64_t>& device_ids, \ const std::vector<int64_t>& device_ids, \
bool asyncOp, \ int64_t timeout) { \
int64_t timeout) { \ return process_group->getBackend(c10::DeviceType::DEV) \
auto opts = BarrierOptions{}; \ ->barrier( \
opts.device_ids = device_ids; \ BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); \
opts.timeout = std::chrono::milliseconds(timeout); \
opts.asyncOp = asyncOp; \
return process_group->getBackend(c10::DeviceType::DEV)->barrier(opts); \
} }
IMPL_BARRIER(CPU) IMPL_BARRIER(CPU)
@ -490,7 +464,6 @@ allreduce_sparse_cuda_(
const c10::intrusive_ptr<ProcessGroup>& process_group, const c10::intrusive_ptr<ProcessGroup>& process_group,
const c10::intrusive_ptr<ReduceOp>& reduce_op, const c10::intrusive_ptr<ReduceOp>& reduce_op,
const std::optional<at::Tensor>& sparse_indices, const std::optional<at::Tensor>& sparse_indices,
bool asyncOp,
int64_t timeout) { int64_t timeout) {
auto tensor_vec = tensors.vec(); auto tensor_vec = tensors.vec();
auto work = process_group->getBackend(c10::DeviceType::CUDA) auto work = process_group->getBackend(c10::DeviceType::CUDA)
@ -499,7 +472,6 @@ allreduce_sparse_cuda_(
AllreduceOptions{ AllreduceOptions{
*reduce_op, *reduce_op,
std::chrono::milliseconds(timeout), std::chrono::milliseconds(timeout),
asyncOp,
sparse_indices}); sparse_indices});
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(

View File

@ -224,7 +224,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&, const c10::intrusive_ptr<::c10d::ReduceOp>&,
const std::optional<at::Tensor>& sparse_indices, const std::optional<at::Tensor>& sparse_indices,
bool,
int64_t)>(); int64_t)>();
auto work = std::get<1>(op.call( auto work = std::get<1>(op.call(
@ -232,7 +231,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp), c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.sparseIndices, opts.sparseIndices,
opts.asyncOp,
opts.timeout.count())); opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) { if (c10d::allow_inflight_collective_as_graph_input()) {
@ -252,14 +250,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::TensorList, at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&, const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>(); int64_t)>();
auto work = op.call( auto work = op.call(
tensors, tensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<ReduceOp>(opts.reduceOp), c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count()); opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) { if (c10d::allow_inflight_collective_as_graph_input()) {
@ -281,7 +277,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ReduceOp>&, const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t, int64_t,
int64_t, int64_t,
bool,
int64_t)>(); int64_t)>();
auto work = op.call( auto work = op.call(
tensors, tensors,
@ -289,7 +284,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::make_intrusive<ReduceOp>(opts.reduceOp), c10::make_intrusive<ReduceOp>(opts.reduceOp),
opts.rootRank, opts.rootRank,
opts.rootTensor, opts.rootTensor,
opts.asyncOp,
opts.timeout.count()); opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) { if (c10d::allow_inflight_collective_as_graph_input()) {
@ -312,14 +306,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<std::vector<at::Tensor>>&, const std::vector<std::vector<at::Tensor>>&,
at::TensorList, at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool,
int64_t)>(); int64_t)>();
auto work = std::get<1>(op.call( auto work = std::get<1>(op.call(
outputTensors, outputTensors,
inputTensors, inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count())); opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) { if (c10d::allow_inflight_collective_as_graph_input()) {
@ -371,19 +363,18 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::vector<std::vector<at::Tensor>>& outputTensorLists, std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors, std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) { const AllgatherOptions& opts = AllgatherOptions()) {
static auto op = c10::Dispatcher::singleton() static auto op =
.findSchemaOrThrow("c10d::allgather_coalesced_", "") c10::Dispatcher::singleton()
.typed<c10::intrusive_ptr<Work>( .findSchemaOrThrow("c10d::allgather_coalesced_", "")
const std::vector<std::vector<at::Tensor>>&, .typed<c10::intrusive_ptr<Work>(
const at::TensorList&, const std::vector<std::vector<at::Tensor>>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&, const at::TensorList&,
bool)>(); const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
auto work = op.call( auto work = op.call(
outputTensorLists, outputTensorLists,
inputTensors, inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
opts.asyncOp);
if (c10d::allow_inflight_collective_as_graph_input()) { if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor_list : outputTensorLists) { for (const auto& tensor_list : outputTensorLists) {
@ -408,14 +399,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
.typed<c10::intrusive_ptr<Work>( .typed<c10::intrusive_ptr<Work>(
const at::TensorList, const at::TensorList,
const at::TensorList, const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
bool)>();
auto work = op.call( auto work = op.call(
outputTensors, outputTensors,
inputTensors, inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
opts.asyncOp);
if (c10d::allow_inflight_collective_as_graph_input()) { if (c10d::allow_inflight_collective_as_graph_input()) {
for (const auto& tensor : outputTensors) { for (const auto& tensor : outputTensors) {
@ -436,14 +425,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&, const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t, int64_t,
bool,
int64_t)>(); int64_t)>();
auto work = op.call( auto work = op.call(
outputTensors, outputTensors,
inputTensors, inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.rootRank, opts.rootRank,
opts.asyncOp,
opts.timeout.count()); opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) { if (c10d::allow_inflight_collective_as_graph_input()) {
@ -500,14 +487,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::vector<std::vector<at::Tensor>>&, const std::vector<std::vector<at::Tensor>>&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&, const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>(); int64_t)>();
auto work = std::get<1>(op.call( auto work = std::get<1>(op.call(
outputTensors, outputTensors,
inputTensors, inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count())); opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) { if (c10d::allow_inflight_collective_as_graph_input()) {
@ -561,7 +546,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList, const at::TensorList,
const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const c10::intrusive_ptr<::c10d::ReduceOp>&, const c10::intrusive_ptr<::c10d::ReduceOp>&,
bool,
int64_t)>(); int64_t)>();
auto work = op.call( auto work = op.call(
@ -569,7 +553,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
inputTensors, inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
opts.asyncOp,
opts.timeout.count()); opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) { if (c10d::allow_inflight_collective_as_graph_input()) {
@ -594,7 +577,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&,
std::vector<int64_t>, std::vector<int64_t>,
std::vector<int64_t>, std::vector<int64_t>,
bool,
int64_t)>(); int64_t)>();
auto work = op.call( auto work = op.call(
outputBuffer, outputBuffer,
@ -602,7 +584,6 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
outputSplitSizes, outputSplitSizes,
inputSplitSizes, inputSplitSizes,
opts.asyncOp,
opts.timeout.count()); opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) { if (c10d::allow_inflight_collective_as_graph_input()) {
@ -623,13 +604,11 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const at::TensorList&, const at::TensorList&,
const at::TensorList&, const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&,
bool,
int64_t)>(); int64_t)>();
auto work = std::get<1>(op.call( auto work = std::get<1>(op.call(
outputTensors, outputTensors,
inputTensors, inputTensors,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.asyncOp,
opts.timeout.count())); opts.timeout.count()));
if (c10d::allow_inflight_collective_as_graph_input()) { if (c10d::allow_inflight_collective_as_graph_input()) {
@ -799,14 +778,12 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::Tensor, at::Tensor,
const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&,
const std::vector<int64_t>&, const std::vector<int64_t>&,
bool,
int64_t)>(); int64_t)>();
auto work = op.call( auto work = op.call(
tensor, tensor,
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this), c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
opts.device_ids, opts.device_ids,
opts.asyncOp,
opts.timeout.count()); opts.timeout.count());
if (c10d::allow_inflight_collective_as_graph_input()) { if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::register_work(tensor, work); c10d::register_work(tensor, work);

View File

@ -496,8 +496,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(
} }
futureWorkResult_ = futureWorkResult_ =
c10::make_intrusive<at::ivalue::Future>(c10::AnyEnumType::get()); c10::make_intrusive<at::ivalue::Future>(c10::AnyEnumType::get());
// other functions expect an initialized ptr
stashed_for_allocator_safety_ = std::make_shared<std::vector<at::Tensor>>();
} }
ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
@ -519,11 +517,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
numelIn_(w.numelIn_), numelIn_(w.numelIn_),
numelOut_(w.numelOut_), numelOut_(w.numelOut_),
store_(w.store_), store_(w.store_),
// Note: the `work` returned to user and the `work` enqueued to watchdog
// share the pointer to the tensor stash. At least one of them should
// clean the tensor stash, the earlier the better, i.e. user calling
// `work.wait` than watchdog detecting work completion.
stashed_for_allocator_safety_(w.stashed_for_allocator_safety_),
futureWorkResult_(w.futureWorkResult_), futureWorkResult_(w.futureWorkResult_),
timingEnabled_(w.timingEnabled_), timingEnabled_(w.timingEnabled_),
trace_id_(w.trace_id_), trace_id_(w.trace_id_),
@ -717,25 +710,14 @@ void ProcessGroupNCCL::WorkNCCL::synchronize() {
} }
} }
void ProcessGroupNCCL::WorkNCCL::stashTensors(
std::vector<at::Tensor>& tensors) {
std::lock_guard<std::mutex> lock(stashMutex_);
stashed_for_allocator_safety_->insert(
stashed_for_allocator_safety_->end(), tensors.begin(), tensors.end());
}
void ProcessGroupNCCL::WorkNCCL::unstashTensors() {
std::lock_guard<std::mutex> lock(stashMutex_);
stashed_for_allocator_safety_->clear();
}
void ProcessGroupNCCL::WorkNCCL::synchronizeStream() { void ProcessGroupNCCL::WorkNCCL::synchronizeStream() {
auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); auto currentStream = at::cuda::getCurrentCUDAStream(device_.index());
// Block the current stream on the NCCL stream // Block the current stream on the NCCL stream
ncclEndEvent_->block(currentStream); ncclEndEvent_->block(currentStream);
// Unstage the stashed tensors so that CachingAllocator can recycle them
// THIS MUST HAPPEN AFTER THE BLOCKING CALL ABOVE if (avoidRecordStreams_) {
unstashTensors(); stashed_for_allocator_safety_->clear();
}
} }
// Same as calling synchronize() when blockingWait_ is false // Same as calling synchronize() when blockingWait_ is false
@ -951,10 +933,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
enableTiming_.store( enableTiming_.store(
getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_); getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
#endif // ENABLE_NCCL_ERROR_CHECKING #endif // ENABLE_NCCL_ERROR_CHECKING
if (getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false)) { avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false);
TORCH_WARN_ONCE(
"TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated.");
}
#ifdef NCCL_HAS_COMM_REGISTER #ifdef NCCL_HAS_COMM_REGISTER
useTensorRegisterAllocatorHook_ = useTensorRegisterAllocatorHook_ =
getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false);
@ -2344,12 +2323,6 @@ void ProcessGroupNCCL::watchdogHandler() {
// Clean up completed work // Clean up completed work
if (work.isCompleted()) { if (work.isCompleted()) {
// In case user didn't call `work.wait()` with async collectives,
// watchdog would unstage the stashed tensors when detecting completion
// of the collective, to prevent ProcessGroupNCCL from holding reference
// to those tensors forever.
work.unstashTensors();
// Work status logging for desync debug // Work status logging for desync debug
desyncDebugger_.logWorkEnd(work); desyncDebugger_.logWorkEnd(work);
@ -3084,7 +3057,6 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
enableTiming_.load(), enableTiming_.load(),
cudaEventCacheEnabled_.load(), cudaEventCacheEnabled_.load(),
dist_debug_level_); dist_debug_level_);
if (record) { if (record) {
bool isP2P = isP2POp(opType); bool isP2P = isP2POp(opType);
// Ideally record every work that we enqueue, rather than every work we // Ideally record every work that we enqueue, rather than every work we
@ -3242,6 +3214,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
enqueue); enqueue);
work->ncclComm_ = comm; work->ncclComm_ = comm;
work->blockingWait_ = blockingWait_; work->blockingWait_ = blockingWait_;
work->avoidRecordStreams_ = avoidRecordStreams_;
work->store_ = store_; work->store_ = store_;
assignTimeoutToWork(work, options_); assignTimeoutToWork(work, options_);
@ -3260,16 +3233,19 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing(OpType optype) {
// TODO(eqy): is this still necessary if avoidRecordStreams_ is set? // TODO(eqy): is this still necessary if avoidRecordStreams_ is set?
work->ncclEndEvent_->record(ncclStream); work->ncclEndEvent_->record(ncclStream);
if (avoidRecordStreams_) {
// other functions expect an initialized ptr if avoidRecordStreams_ is set
work->stashed_for_allocator_safety_ =
std::make_shared<std::vector<at::Tensor>>();
}
if (enqueue) { if (enqueue) {
workEnqueue(work); workEnqueue(work);
} }
// Reset coalescing state
coalescing_state_ = 0; coalescing_state_ = 0;
coalescedComm_ = nullptr; coalescedComm_ = nullptr;
// If in async mode, return work; otherwise, kernel is enqueued on current return work;
// stream, no need to return work
return coalescedAsync_ ? work : nullptr;
} }
c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing() { c10::intrusive_ptr<Work> ProcessGroupNCCL::endCoalescing() {
@ -3285,10 +3261,11 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
PreProcess pre, PreProcess pre,
PostProcess post, PostProcess post,
OpType opType, OpType opType,
bool asyncOp,
const char* profilingTitle, const char* profilingTitle,
bool avoidRecordStreams,
bool nanCheck) { bool nanCheck) {
// Environment setting by the user may add onto collective call's option // Environment setting by the user may add onto collective call's option
avoidRecordStreams |= avoidRecordStreams_;
nanCheck &= enableNanCheck_; nanCheck &= enableNanCheck_;
auto device = getDevice(inputs[0]); auto device = getDevice(inputs[0]);
@ -3329,17 +3306,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
} else { } else {
TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
} }
coalescedAsync_ = asyncOp;
} }
// in asyncOp=false [default] mode, we use currentStream as ncclStream // Used many times below, so we stash the unordered_map lookup
// otherwise, we use separate ncclStream and let it sync on currentStream auto ncclStream = ncclStreams_.at(key);
auto ncclStream = asyncOp ? ncclStreams_.at(key)
: at::cuda::getCurrentCUDAStream(device.index()); // First let NCCL streams wait for input tensors allocation streams
if (asyncOp) { syncStream(device, ncclEvents_[key], ncclStream);
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
}
bool enqueue = bool enqueue =
!coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None; !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None;
@ -3349,12 +3322,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// Store references to outputs to be used by WorkNCCL::result and operator<<. // Store references to outputs to be used by WorkNCCL::result and operator<<.
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs); work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
// If we are performing sync operations, i.e. equeuing kernel onto "current" if (avoidRecordStreams) {
// stream, we don't need to do anything for tensor lifetime management. work->stashed_for_allocator_safety_ =
// Otherwise, we need to stage the tensors will `work.wait()`. std::make_shared<std::vector<at::Tensor>>(inputs);
if (asyncOp) {
work->stashTensors(inputs);
work->stashTensors(outputs);
} }
if (nanCheck) { if (nanCheck) {
@ -3380,6 +3350,21 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// operations where `inputs' and `outputs' are not the same. // operations where `inputs' and `outputs' are not the same.
// //
// See [Sync Streams]. // See [Sync Streams].
if (!avoidRecordStreams) {
for (const auto& input : inputs) {
if (!input.is_sparse()) {
c10::cuda::CUDACachingAllocator::recordStream(
input.storage().data_ptr(), ncclStream);
} else {
// for sparse input case record streams on both index and value
// tensors
c10::cuda::CUDACachingAllocator::recordStream(
input.values().storage().data_ptr(), ncclStream);
c10::cuda::CUDACachingAllocator::recordStream(
input.indices().storage().data_ptr(), ncclStream);
}
}
}
// Not all collectives have the same signature, e.g, all-reduce take in a Tensor // Not all collectives have the same signature, e.g, all-reduce take in a Tensor
// as the input and output while all-to-all take in a vector of Tensors as input // as the input and output while all-to-all take in a vector of Tensors as input
@ -3431,6 +3416,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
// Set appropriate work parameters. // Set appropriate work parameters.
work->blockingWait_ = blockingWait_; work->blockingWait_ = blockingWait_;
work->avoidRecordStreams_ = avoidRecordStreams;
work->store_ = store_; work->store_ = store_;
assignTimeoutToWork(work, options_); assignTimeoutToWork(work, options_);
// Record size info for debug. We only record the size on the first device as // Record size info for debug. We only record the size on the first device as
@ -3448,7 +3434,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
workEnqueue(work); workEnqueue(work);
} }
return asyncOp ? work : nullptr; return work;
} }
template <typename Fn> template <typename Fn>
@ -3457,8 +3443,11 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
std::vector<at::Tensor>& outputs, std::vector<at::Tensor>& outputs,
Fn fn, Fn fn,
OpType opType, OpType opType,
bool asyncOp, const char* profilingTitle,
const char* profilingTitle) { bool avoidRecordStreams) {
// Environment setting by the user may add onto collective call's option
avoidRecordStreams |= avoidRecordStreams_;
// Currently, the API permits one scenario where inputs.size() and // Currently, the API permits one scenario where inputs.size() and
// outputs.size() are > 0. // outputs.size() are > 0.
// 1. If the call was a _coalesced call, all inputs must be on the same // 1. If the call was a _coalesced call, all inputs must be on the same
@ -3504,17 +3493,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
} else { } else {
TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
} }
coalescedAsync_ = asyncOp;
} }
// in asyncOp=false [default] mode, we use currentStream as ncclStream // Used many times below, so we stash the unordered_map lookup
// otherwise, we use separate ncclStream and let it sync on currentStream auto ncclStream = ncclStreams_.at(key);
auto ncclStream = asyncOp ? ncclStreams_.at(key)
: at::cuda::getCurrentCUDAStream(device.index()); // First let NCCL streams wait for input tensors allocation streams
if (asyncOp) { syncStream(device, ncclEvents_[key], ncclStream);
// First let NCCL streams wait for input tensors allocation streams
syncStream(device, ncclEvents_[key], ncclStream);
}
auto work = initWork( auto work = initWork(
device, device,
@ -3529,12 +3514,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
// Store references to outputs to be used by WorkNCCL::result and operator<<. // Store references to outputs to be used by WorkNCCL::result and operator<<.
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs); work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
// If we are performing sync operations, i.e. equeuing kernel onto "current" if (avoidRecordStreams) {
// stream, we don't need to do anything for tensor lifetime management. work->stashed_for_allocator_safety_ =
// Otherwise, we need to stage the tensors will `work.wait()`. std::make_shared<std::vector<at::Tensor>>(inputs);
if (asyncOp) {
work->stashTensors(inputs);
work->stashTensors(outputs);
} }
// Start event should only be recorded before the ncclGroupStart() (which // Start event should only be recorded before the ncclGroupStart() (which
@ -3560,6 +3542,27 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
{ {
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking()); torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking());
for (const auto i : c10::irange(inputs.size())) { for (const auto i : c10::irange(inputs.size())) {
// Both `inputs' and `outputs' are created on a worker stream and used in
// different ncclStreams. Hence, both must record the ncclStream to
// prevent being freed before the collective finishes.
//
// We only record `inputs' here, and leave recording `outputs' to `fn' for
// operations where `inputs' and `outputs' are not the same.
//
// See [Sync Streams].
if (!avoidRecordStreams) {
if (!inputs[i].is_sparse()) {
c10::cuda::CUDACachingAllocator::recordStream(
inputs[i].storage().data_ptr(), ncclStream);
} else {
// for sparse input case record streams on both index and value
// tensors
c10::cuda::CUDACachingAllocator::recordStream(
inputs[i].values().storage().data_ptr(), ncclStream);
c10::cuda::CUDACachingAllocator::recordStream(
inputs[i].indices().storage().data_ptr(), ncclStream);
}
}
#ifndef NCCL_HAS_COMM_NONBLOCKING #ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK( C10D_NCCL_CHECK(
fn(inputs[i], outputs[i], comm, ncclStream), fn(inputs[i], outputs[i], comm, ncclStream),
@ -3600,6 +3603,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
// Set appropriate work parameters. // Set appropriate work parameters.
work->blockingWait_ = blockingWait_; work->blockingWait_ = blockingWait_;
work->avoidRecordStreams_ = avoidRecordStreams;
work->store_ = store_; work->store_ = store_;
assignTimeoutToWork(work, options_); assignTimeoutToWork(work, options_);
// Record size info for debug. We only record the size on the first device as // Record size info for debug. We only record the size on the first device as
@ -3630,7 +3634,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collectiveCoalesced(
// it, since interactions with it by usercode won't behave normally - they // it, since interactions with it by usercode won't behave normally - they
// won't observe work completion, for instance. Will this lead to silent // won't observe work completion, for instance. Will this lead to silent
// problems during capture? // problems during capture?
return asyncOp ? work : nullptr; return work;
} }
template <typename Fn, typename PreProcess, typename PostProcess> template <typename Fn, typename PreProcess, typename PostProcess>
@ -3648,8 +3652,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
// to wait() on the returned handle, so ProcessGroupNCCL can't know // to wait() on the returned handle, so ProcessGroupNCCL can't know
// when it's safe to release the input back to the allocator, // when it's safe to release the input back to the allocator,
// and the present call has no way to know it's not an isend. // and the present call has no way to know it's not an isend.
// Therefore, we warn and fall back to the typical recordStream logic. // Therefore, we warn and fall back to the typical recordStream logic:
// TODO( kwen2501 ): revisit this when we have a better solution. if (avoidRecordStreams_) {
TORCH_WARN_ONCE(
"TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point "
"collectives.");
}
auto device = getDevice(tensor); auto device = getDevice(tensor);
at::cuda::OptionalCUDAGuard gpuGuard(device); at::cuda::OptionalCUDAGuard gpuGuard(device);
@ -3704,8 +3713,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint(
} else { } else {
TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG);
} }
// For now, P2P ops are always put on internal stream
coalescedAsync_ = true;
} }
// Used many times below, so we stash the unordered_map lookup // Used many times below, so we stash the unordered_map lookup
@ -3877,8 +3884,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
PreProcess pre, PreProcess pre,
PostProcess post, PostProcess post,
OpType opType, OpType opType,
bool asyncOp,
const char* profilingTitle, const char* profilingTitle,
bool avoidRecordStreams,
bool nanCheck) { bool nanCheck) {
auto inputs = std::vector<at::Tensor>{input}; auto inputs = std::vector<at::Tensor>{input};
auto outputs = std::vector<at::Tensor>{output}; auto outputs = std::vector<at::Tensor>{output};
@ -3889,8 +3896,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
pre, pre,
post, post,
opType, opType,
asyncOp,
profilingTitle, profilingTitle,
avoidRecordStreams,
nanCheck); nanCheck);
} }
@ -3900,8 +3907,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
at::Tensor& output, at::Tensor& output,
Fn fn, Fn fn,
OpType opType, OpType opType,
bool asyncOp,
const char* profilingTitle, const char* profilingTitle,
bool avoidRecordStreams,
bool nanCheck) { bool nanCheck) {
auto inputs = std::vector<at::Tensor>{input}; auto inputs = std::vector<at::Tensor>{input};
auto outputs = std::vector<at::Tensor>{output}; auto outputs = std::vector<at::Tensor>{output};
@ -3914,8 +3921,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
[](at::cuda::CUDAStream&, [](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {}, c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
opType, opType,
asyncOp,
profilingTitle, profilingTitle,
avoidRecordStreams,
nanCheck); nanCheck);
} }
@ -3967,8 +3974,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
auto recvIndices = indices[0] * colSize; auto recvIndices = indices[0] * colSize;
// prevent output and recvIndices from being freed // prevent output and recvIndices from being freed
// TODO: not changing the lifetime management of outputs this time,
// revisit later
c10::cuda::CUDACachingAllocator::recordStream( c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream); output.storage().data_ptr(), stream);
c10::cuda::CUDACachingAllocator::recordStream( c10::cuda::CUDACachingAllocator::recordStream(
@ -4000,7 +4005,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
} }
}, },
OpType::_ALLREDUCE_SPARSE, OpType::_ALLREDUCE_SPARSE,
opts.asyncOp,
"nccl:all_reduce_sparse"); "nccl:all_reduce_sparse");
return work; return work;
#else #else
@ -4035,7 +4039,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_impl(
stream.stream()); stream.stream());
}, },
OpType::ALLREDUCE, OpType::ALLREDUCE,
opts.asyncOp,
profilingTitle); profilingTitle);
} }
@ -4136,7 +4139,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
stream.stream()); stream.stream());
}, },
OpType::COALESCED, OpType::COALESCED,
opts.asyncOp,
"nccl:allreduce_coalesced"); "nccl:allreduce_coalesced");
} }
@ -4168,10 +4170,12 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
globalRankStride_, // globalRankStride_ globalRankStride_, // globalRankStride_
this->getSize()); // worldSize this->getSize()); // worldSize
// avoidRecordStreams_ note: collective() will stash tensors.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
const auto root = opts.rootRank + opts.rootTensor; const auto root = opts.rootRank + opts.rootTensor;
bool nanCheck = (root == rank_); bool nanCheck = (root == rank_);
// avoidRecordStreams_ note: collective() will stash tensors.
return collective( return collective(
tensor, tensor,
tensor, tensor,
@ -4188,8 +4192,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
stream.stream()); stream.stream());
}, },
OpType::BROADCAST, OpType::BROADCAST,
opts.asyncOp,
"nccl:broadcast", "nccl:broadcast",
avoidRecordStreams,
nanCheck); nanCheck);
} }
@ -4228,8 +4232,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_broadcast_oop(
stream.stream()); stream.stream());
}, },
OpType::BROADCAST, OpType::BROADCAST,
opts.asyncOp,
"nccl:_broadcast_oop", "nccl:_broadcast_oop",
/*avoidRecordStreams=*/false,
nanCheck); nanCheck);
} }
@ -4288,7 +4292,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
stream.stream()); stream.stream());
}, },
OpType::REDUCE, OpType::REDUCE,
opts.asyncOp,
"nccl:reduce"); "nccl:reduce");
} }
@ -4330,7 +4333,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_oop(
stream.stream()); stream.stream());
}, },
OpType::REDUCE, OpType::REDUCE,
opts.asyncOp,
"nccl:_reduce_oop"); "nccl:_reduce_oop");
} }
@ -4374,7 +4376,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
at::Tensor& output, at::Tensor& output,
ncclComm_t comm, ncclComm_t comm,
at::cuda::CUDAStream& stream) { at::cuda::CUDAStream& stream) {
// See [We actually don't need to stash anything here]. if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
return ncclAllGather( return ncclAllGather(
input.data_ptr(), input.data_ptr(),
output.data_ptr(), output.data_ptr(),
@ -4390,27 +4395,27 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
// - inputTensors is stashed onto work->stashed_for_allocator_safety_ // - inputTensors is stashed onto work->stashed_for_allocator_safety_
// in collective(). // in collective().
// - outputFlattened is stashed onto work->outputs_ in collective(). // - outputFlattened is stashed onto work->outputs_ in collective().
// - User-facing outputTensors should be held by the user until after
// waiting on work_, or the call makes no sense.
// So all participating tensors are accounted for, and won't be
// released back to their allocation streams until after work_ is
// waited on.
}, },
[&](at::cuda::CUDAStream& ncclStream, [&](at::cuda::CUDAStream& ncclStream,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) { c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
// User-facing outputTensors should be held by the user until after
// waiting on work_, or the call makes no sense. We do a stashing here
// in case user doesn't hold the outputTensors in downstream code,
// which can cause an early recyle by the CachingAllocator, which can
// lead to segfault or data corruption.
if (opts.asyncOp) {
work->stashTensors(outputTensors_);
}
// Copy the flattened output tensors to the outputs. // Copy the flattened output tensors to the outputs.
at::cuda::CUDAStreamGuard guard(ncclStream); at::cuda::CUDAStreamGuard guard(ncclStream);
for (const auto j : c10::irange(outputTensors_.size())) { for (const auto j : c10::irange(outputTensors_.size())) {
// See [We actually don't need to stash anything here]. // See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
outputTensors_[j].storage().data_ptr(), ncclStream);
}
outputTensors_[j].copy_( outputTensors_[j].copy_(
outputFlattened[static_cast<int64_t>(j)], true); outputFlattened[static_cast<int64_t>(j)], true);
} }
}, },
OpType::ALLGATHER, OpType::ALLGATHER,
opts.asyncOp,
"nccl:all_gather"); "nccl:all_gather");
} else { } else {
const auto num_reduces = outputTensors_.size(); const auto num_reduces = outputTensors_.size();
@ -4418,8 +4423,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
for (const int64_t i : c10::irange(static_cast<int64_t>(num_reduces))) { for (const int64_t i : c10::irange(static_cast<int64_t>(num_reduces))) {
auto& output = outputTensors_[i]; auto& output = outputTensors_[i];
auto& input = (i == rank_) ? inputTensor : output; auto& input = (i == rank_) ? inputTensor : output;
auto broadcastOpts = auto broadcastOpts = BroadcastOptions{i, int64_t(0), opts.timeout};
BroadcastOptions{i, int64_t(0), opts.timeout, opts.asyncOp};
_broadcast_oop(output, input, broadcastOpts); _broadcast_oop(output, input, broadcastOpts);
} }
auto work = endCoalescing(OpType::ALLGATHER); auto work = endCoalescing(OpType::ALLGATHER);
@ -4475,7 +4479,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather_into_tensor_coalesced(
stream.stream()); stream.stream());
}, },
OpType::COALESCED, OpType::COALESCED,
opts.asyncOp,
"nccl:all_gather_into_tensor_coalesced"); "nccl:all_gather_into_tensor_coalesced");
} }
@ -4521,6 +4524,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
at::Tensor& output, at::Tensor& output,
ncclComm_t comm, ncclComm_t comm,
at::cuda::CUDAStream& stream) { at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
const auto ncclDataType = getNcclDataType(input.scalar_type()); const auto ncclDataType = getNcclDataType(input.scalar_type());
const auto ncclReduceOp = const auto ncclReduceOp =
getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
@ -4535,18 +4542,27 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
}, },
[&](at::cuda::CUDAStream& ncclStream, [&](at::cuda::CUDAStream& ncclStream,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) { c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
// We only need to stash inputTensors. if (avoidRecordStreams_) {
// - inputFlattened is stashed onto // We only need to stash inputTensors.
// work->stashed_for_allocator_safety_ in collective(). // - inputFlattened is stashed onto
// - User-facing outputTensors is stashed onto work->outputs_ in // work->stashed_for_allocator_safety_
// collective(), and should also be held by the user until after // in collective().
// waiting on work_. // - User-facing outputTensors is stashed onto work->outputs_ in
if (opts.asyncOp) { // collective(),
work->stashTensors(inputTensors_); // and should also be held by the user until after waiting on
// work_.
auto& v = work->stashed_for_allocator_safety_;
v->insert(v->end(), inputTensors_.begin(), inputTensors_.end());
} }
// Copy the input tensors to the flattened inputs. // Copy the input tensors to the flattened inputs.
at::cuda::CUDAStreamGuard guard(ncclStream); at::cuda::CUDAStreamGuard guard(ncclStream);
for (const auto j : c10::irange(inputTensors_.size())) { for (const auto j : c10::irange(inputTensors_.size())) {
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
inputTensors_[j].storage().data_ptr(), ncclStream);
}
inputFlattened[static_cast<int64_t>(j)].copy_( inputFlattened[static_cast<int64_t>(j)].copy_(
inputTensors_[j], true); inputTensors_[j], true);
} }
@ -4554,7 +4570,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
[&](at::cuda::CUDAStream&, [&](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {}, c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::REDUCE_SCATTER, OpType::REDUCE_SCATTER,
opts.asyncOp,
"nccl:reduce_scatter"); "nccl:reduce_scatter");
} else { } else {
const auto num_reduces = inputTensors_.size(); const auto num_reduces = inputTensors_.size();
@ -4566,8 +4581,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
opts.reduceOp, opts.reduceOp,
static_cast<int64_t>(i), static_cast<int64_t>(i),
static_cast<int64_t>(0), static_cast<int64_t>(0),
opts.timeout, opts.timeout};
opts.asyncOp};
_reduce_oop(output, input, reduceOpts); _reduce_oop(output, input, reduceOpts);
} }
auto work = endCoalescing(OpType::REDUCE_SCATTER); auto work = endCoalescing(OpType::REDUCE_SCATTER);
@ -4621,6 +4635,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
// stream so that the caching allocator can reuse memory pool for this stream // stream so that the caching allocator can reuse memory pool for this stream
// in a clever way. This setting is added for libraries like FSDP which uses // in a clever way. This setting is added for libraries like FSDP which uses
// `reduce_scatter_tensor`. // `reduce_scatter_tensor`.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
return collective( return collective(
inputTensor, inputTensor,
@ -4629,6 +4644,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
at::Tensor& output, at::Tensor& output,
ncclComm_t comm, ncclComm_t comm,
at::cuda::CUDAStream& stream) { at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
auto ncclDataType = getNcclDataType(input.scalar_type()); auto ncclDataType = getNcclDataType(input.scalar_type());
auto ncclReduceOp = auto ncclReduceOp =
getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
@ -4642,8 +4661,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
stream.stream()); stream.stream());
}, },
OpType::_REDUCE_SCATTER_BASE, OpType::_REDUCE_SCATTER_BASE,
opts.asyncOp, "nccl:_reduce_scatter_base",
"nccl:_reduce_scatter_base"); avoidRecordStreams);
} }
c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced( c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
@ -4680,6 +4699,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
at::Tensor& output, at::Tensor& output,
ncclComm_t comm, ncclComm_t comm,
at::cuda::CUDAStream& stream) { at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
auto ncclDataType = getNcclDataType(input.scalar_type()); auto ncclDataType = getNcclDataType(input.scalar_type());
auto ncclReduceOp = auto ncclReduceOp =
getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm);
@ -4693,7 +4716,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
stream.stream()); stream.stream());
}, },
OpType::COALESCED, OpType::COALESCED,
opts.asyncOp,
"nccl:reduce_scatter_tensor_coalesced"); "nccl:reduce_scatter_tensor_coalesced");
} }
@ -4772,28 +4794,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat));
// All reduce to achieve the barrier // All reduce to achieve the barrier
AllreduceOptions arOpts = AllreduceOptions(); auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier");
arOpts.asyncOp = opts.asyncOp;
auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier", arOpts);
if (opts.asyncOp) { // Work will take over barrierTensors
// Work will take over barrierTensors auto ncclWork = dynamic_cast<ProcessGroupNCCL::WorkNCCL*>(work.get());
auto ncclWork = dynamic_cast<ProcessGroupNCCL::WorkNCCL*>(work.get()); TORCH_CHECK(ncclWork);
// If user specified async, the work should not be nullptr ncclWork->isBarrierOp_ = true;
TORCH_CHECK(ncclWork); return work;
// Put a marker here so that `work.wait()` issue by users does
// barrier-specific thing: CPU sync
ncclWork->isBarrierOp_ = true;
return work;
}
// Otherwise, we are in sync mode, we directly wait here.
// (It is a CPU wait for barrier)
auto currentStream = at::cuda::getCurrentCUDAStream(barDevIdx);
// CUDAStream wrapper will correctly use a DeviceGuard here
currentStream.synchronize();
// No work to return
return nullptr;
} }
c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base( c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
@ -4801,7 +4808,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
at::Tensor& inputTensor, at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes, std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes, std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts) { const AllToAllOptions& /* unused */) {
check_gpu_single_tensor(outputTensor); check_gpu_single_tensor(outputTensor);
check_gpu_single_tensor(inputTensor); check_gpu_single_tensor(inputTensor);
if (outputSplitSizes.empty() && inputSplitSizes.empty()) { if (outputSplitSizes.empty() && inputSplitSizes.empty()) {
@ -4832,12 +4839,16 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
at::Tensor& output, at::Tensor& output,
ncclComm_t comm, ncclComm_t comm,
at::cuda::CUDAStream& stream) { at::cuda::CUDAStream& stream) {
// See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
torch::cuda::nccl::all2all_single_equal_split( torch::cuda::nccl::all2all_single_equal_split(
input, output, this->getSize(), comm, stream); input, output, this->getSize(), comm, stream);
return ncclSuccess; return ncclSuccess;
}, },
OpType::ALLTOALL_BASE, OpType::ALLTOALL_BASE,
opts.asyncOp,
"nccl:all_to_all"); "nccl:all_to_all");
} else { } else {
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
@ -4879,6 +4890,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
c10d::computeLengthsAndOffsets( c10d::computeLengthsAndOffsets(
outputSplitSizes, output, &recv_lengths, &recv_offsets); outputSplitSizes, output, &recv_lengths, &recv_offsets);
// See [Sync Streams]. // See [Sync Streams].
if (!avoidRecordStreams_) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
torch::cuda::nccl::all2all_single_unequal_split( torch::cuda::nccl::all2all_single_unequal_split(
input.data_ptr(), input.data_ptr(),
send_lengths.data(), send_lengths.data(),
@ -4893,7 +4908,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
return ncclSuccess; return ncclSuccess;
}, },
OpType::ALLTOALL_BASE, OpType::ALLTOALL_BASE,
opts.asyncOp,
"nccl:all_to_all"); "nccl:all_to_all");
} }
} }
@ -4901,7 +4915,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall( c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
std::vector<at::Tensor>& outputTensors, std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors, std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts) { const AllToAllOptions& /* unused */) {
std::vector<int64_t> inSplitSizes; std::vector<int64_t> inSplitSizes;
std::vector<int64_t> outSplitSizes; std::vector<int64_t> outSplitSizes;
int64_t total_numel = 0; int64_t total_numel = 0;
@ -4948,11 +4962,18 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
return ncclSuccess; return ncclSuccess;
}, },
[&](at::cuda::CUDAStream&, [&](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {}, c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
if (avoidRecordStreams_) {
// inputTensor0 and outputTensor0 are stashed redundantly by
// collective(), but that's ok.
auto& v = work->stashed_for_allocator_safety_;
v->insert(v->end(), inputTensors.begin(), inputTensors.end());
v->insert(v->end(), outputTensors.begin(), outputTensors.end());
}
},
[](at::cuda::CUDAStream&, [](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {}, c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::ALLTOALL, OpType::ALLTOALL,
opts.asyncOp,
"nccl:all_to_all"); "nccl:all_to_all");
} }
@ -5150,6 +5171,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
ncclComm_t comm, ncclComm_t comm,
at::cuda::CUDAStream& stream) { at::cuda::CUDAStream& stream) {
const auto root = opts.rootRank; const auto root = opts.rootRank;
if (getRank() == root) {
if (!avoidRecordStreams_) {
for (auto const& output : outputs) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
}
}
torch::cuda::nccl::gather( torch::cuda::nccl::gather(
inputTensor, outputs, comm, stream, static_cast<int32_t>(root)); inputTensor, outputs, comm, stream, static_cast<int32_t>(root));
return ncclSuccess; return ncclSuccess;
@ -5159,7 +5188,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
[](at::cuda::CUDAStream&, [](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {}, c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::GATHER, OpType::GATHER,
opts.asyncOp,
"nccl:gather"); "nccl:gather");
} }
@ -5228,6 +5256,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
// avoidRecordStreams_ note: collective() will stash outputTensors and // avoidRecordStreams_ note: collective() will stash outputTensors and
// inputs, which == inputTensors[0] on the root rank where it matters. // inputs, which == inputTensors[0] on the root rank where it matters.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
const auto root = opts.rootRank; const auto root = opts.rootRank;
bool nanCheck = (rank_ == root); bool nanCheck = (rank_ == root);
@ -5239,6 +5269,14 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
at::Tensor& /* unused */, at::Tensor& /* unused */,
ncclComm_t comm, ncclComm_t comm,
at::cuda::CUDAStream& stream) { at::cuda::CUDAStream& stream) {
if (getRank() == root) {
if (!avoidRecordStreams) {
for (auto const& input : inputs) {
c10::cuda::CUDACachingAllocator::recordStream(
input.storage().data_ptr(), stream);
}
}
}
torch::cuda::nccl::scatter( torch::cuda::nccl::scatter(
inputs, outputTensor, comm, stream, static_cast<int32_t>(root)); inputs, outputTensor, comm, stream, static_cast<int32_t>(root));
return ncclSuccess; return ncclSuccess;
@ -5248,8 +5286,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
[](at::cuda::CUDAStream&, [](at::cuda::CUDAStream&,
c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {}, c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
OpType::SCATTER, OpType::SCATTER,
opts.asyncOp,
"nccl:scatter", "nccl:scatter",
avoidRecordStreams,
nanCheck); nanCheck);
} }
@ -5305,6 +5343,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
// stream so that the caching allocator can reuse memory pool for this stream // stream so that the caching allocator can reuse memory pool for this stream
// in a clever way. This setting is added for libraries like FSDP which uses // in a clever way. This setting is added for libraries like FSDP which uses
// `all_gather_into_tensor`. // `all_gather_into_tensor`.
bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp);
return collective( return collective(
input_tensor, input_tensor,
@ -5313,6 +5352,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
at::Tensor& output, at::Tensor& output,
ncclComm_t comm, ncclComm_t comm,
at::cuda::CUDAStream& stream) { at::cuda::CUDAStream& stream) {
if (!avoidRecordStreams) {
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}
return ncclAllGather( return ncclAllGather(
input.data_ptr(), input.data_ptr(),
output.data_ptr(), output.data_ptr(),
@ -5322,8 +5365,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
stream.stream()); stream.stream());
}, },
OpType::_ALLGATHER_BASE, OpType::_ALLGATHER_BASE,
opts.asyncOp, "nccl:_all_gather_base",
"nccl:_all_gather_base"); avoidRecordStreams);
} }
// Create a memory allocator for NCCL. This allocator is used to allocate memory // Create a memory allocator for NCCL. This allocator is used to allocate memory

View File

@ -382,6 +382,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Clone of blockingWait_ from ProcessGroupNCCL. // Clone of blockingWait_ from ProcessGroupNCCL.
bool blockingWait_{false}; bool blockingWait_{false};
// Clone of avoidRecordStreams_ from ProcessGroupNCCL.
bool avoidRecordStreams_{false};
// Clone of opTimeout_ from ProcessGroupNCCL. // Clone of opTimeout_ from ProcessGroupNCCL.
std::chrono::milliseconds opTimeout_{}; std::chrono::milliseconds opTimeout_{};
@ -428,13 +431,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// exception_ptr. // exception_ptr.
bool finishedGPUExecutionInternal() const; bool finishedGPUExecutionInternal() const;
// Stash tensors so that CachingAllocator cannot recycle them prematurely.
// Used in case of async ops.
void stashTensors(std::vector<at::Tensor>& tensors);
// Unstage the stashed tensors so that CachingAllocator can recycle them
void unstashTensors();
// Reference to the store so that we can write aborted communicators // Reference to the store so that we can write aborted communicators
// to the store. // to the store.
c10::intrusive_ptr<Store> store_; c10::intrusive_ptr<Store> store_;
@ -454,9 +450,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// For in-place collectives, some refs stashed here may alias outputs_, // For in-place collectives, some refs stashed here may alias outputs_,
// but that doesn't do any harm. // but that doesn't do any harm.
std::shared_ptr<std::vector<at::Tensor>> stashed_for_allocator_safety_; std::shared_ptr<std::vector<at::Tensor>> stashed_for_allocator_safety_;
// Need a mutex to protect stashed_for_allocator_safety_ because it can be
// accessed from both main thread and watchdog thread.
std::mutex stashMutex_;
// The future returned by getFuture. // The future returned by getFuture.
c10::intrusive_ptr<at::ivalue::Future> future_; c10::intrusive_ptr<at::ivalue::Future> future_;
@ -885,8 +878,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
at::Tensor& output, at::Tensor& output,
Fn fn, Fn fn,
OpType opType, OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr, const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true); bool nanCheck = true);
template <typename Fn, typename PreProcess, typename PostProcess> template <typename Fn, typename PreProcess, typename PostProcess>
@ -897,8 +890,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
PreProcess pre, PreProcess pre,
PostProcess post, PostProcess post,
OpType opType, OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr, const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true); bool nanCheck = true);
template <typename Fn, typename PreProcess, typename PostProcess> template <typename Fn, typename PreProcess, typename PostProcess>
@ -909,8 +902,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
PreProcess pre, PreProcess pre,
PostProcess post, PostProcess post,
OpType opType, OpType opType,
bool asyncOp,
const char* profilingTitle = nullptr, const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true); bool nanCheck = true);
template <typename Fn> template <typename Fn>
@ -919,8 +912,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::vector<at::Tensor>& output, std::vector<at::Tensor>& output,
Fn fn, Fn fn,
OpType opType, OpType opType,
bool asyncOp, const char* profilingTitle = nullptr,
const char* profilingTitle = nullptr); bool avoidRecordStreams = false);
// Helper that encapsulates work shared across point-to-point communication // Helper that encapsulates work shared across point-to-point communication
// primitives. It is the same structure as the helper used for collective // primitives. It is the same structure as the helper used for collective
@ -1229,9 +1222,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Stores communicators for all collectives run inside a coalescing block // Stores communicators for all collectives run inside a coalescing block
std::shared_ptr<NCCLComm> coalescedComm_ = nullptr; std::shared_ptr<NCCLComm> coalescedComm_ = nullptr;
// Whether the coalesced calls are sync or async.
bool coalescedAsync_;
// Whether or not wait() and synchronize() are blocking operations that wait // Whether or not wait() and synchronize() are blocking operations that wait
// for the operation to complete. // for the operation to complete.
bool blockingWait_ = false; bool blockingWait_ = false;

View File

@ -122,7 +122,6 @@ struct BroadcastOptions {
struct AllreduceOptions { struct AllreduceOptions {
ReduceOp reduceOp = ReduceOp::SUM; ReduceOp reduceOp = ReduceOp::SUM;
std::chrono::milliseconds timeout = kUnsetTimeout; std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
std::optional<at::Tensor> sparseIndices = std::nullopt; std::optional<at::Tensor> sparseIndices = std::nullopt;
}; };
@ -133,7 +132,6 @@ struct ReduceOptions {
int64_t rootRank = 0; int64_t rootRank = 0;
int64_t rootTensor = 0; int64_t rootTensor = 0;
std::chrono::milliseconds timeout = kUnsetTimeout; std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
}; };
struct AllgatherOptions { struct AllgatherOptions {
@ -144,7 +142,6 @@ struct AllgatherOptions {
struct GatherOptions { struct GatherOptions {
int64_t rootRank = 0; int64_t rootRank = 0;
std::chrono::milliseconds timeout = kUnsetTimeout; std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
}; };
struct ScatterOptions { struct ScatterOptions {
@ -161,14 +158,12 @@ struct ReduceScatterOptions {
struct AllToAllOptions { struct AllToAllOptions {
std::chrono::milliseconds timeout = kUnsetTimeout; std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
}; };
struct BarrierOptions { struct BarrierOptions {
std::vector<int64_t> device_ids; std::vector<int64_t> device_ids;
std::chrono::milliseconds timeout = kUnsetTimeout; std::chrono::milliseconds timeout = kUnsetTimeout;
std::optional<at::Device> device; std::optional<at::Device> device;
bool asyncOp = true;
}; };
struct DistributedBackendOptions { struct DistributedBackendOptions {

View File

@ -999,23 +999,20 @@ This class does not support ``__members__`` property.)");
py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions") py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions")
.def(py::init<>()) .def(py::init<>())
.def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp) .def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp)
.def_readwrite("timeout", &::c10d::AllreduceOptions::timeout) .def_readwrite("timeout", &::c10d::AllreduceOptions::timeout);
.def_readwrite("asyncOp", &::c10d::AllreduceOptions::asyncOp);
py::class_<::c10d::AllreduceCoalescedOptions>( py::class_<::c10d::AllreduceCoalescedOptions>(
module, "AllreduceCoalescedOptions") module, "AllreduceCoalescedOptions")
.def(py::init<>()) .def(py::init<>())
.def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp) .def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp)
.def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout) .def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout);
.def_readwrite("asyncOp", &::c10d::AllreduceCoalescedOptions::asyncOp);
py::class_<::c10d::ReduceOptions>(module, "ReduceOptions") py::class_<::c10d::ReduceOptions>(module, "ReduceOptions")
.def(py::init<>()) .def(py::init<>())
.def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp) .def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp)
.def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank) .def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank)
.def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor) .def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor)
.def_readwrite("timeout", &::c10d::ReduceOptions::timeout) .def_readwrite("timeout", &::c10d::ReduceOptions::timeout);
.def_readwrite("asyncOp", &::c10d::ReduceOptions::asyncOp);
py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions") py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions")
.def(py::init<>()) .def(py::init<>())
@ -1025,8 +1022,7 @@ This class does not support ``__members__`` property.)");
py::class_<::c10d::GatherOptions>(module, "GatherOptions") py::class_<::c10d::GatherOptions>(module, "GatherOptions")
.def(py::init<>()) .def(py::init<>())
.def_readwrite("rootRank", &::c10d::GatherOptions::rootRank) .def_readwrite("rootRank", &::c10d::GatherOptions::rootRank)
.def_readwrite("timeout", &::c10d::GatherOptions::timeout) .def_readwrite("timeout", &::c10d::GatherOptions::timeout);
.def_readwrite("asyncOp", &::c10d::GatherOptions::asyncOp);
py::class_<::c10d::ScatterOptions>(module, "ScatterOptions") py::class_<::c10d::ScatterOptions>(module, "ScatterOptions")
.def(py::init<>()) .def(py::init<>())
@ -1044,13 +1040,11 @@ This class does not support ``__members__`` property.)");
.def(py::init<>()) .def(py::init<>())
.def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids) .def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids)
.def_readwrite("timeout", &::c10d::BarrierOptions::timeout) .def_readwrite("timeout", &::c10d::BarrierOptions::timeout)
.def_readwrite("device", &::c10d::BarrierOptions::device) .def_readwrite("device", &::c10d::BarrierOptions::device);
.def_readwrite("asyncOp", &::c10d::BarrierOptions::asyncOp);
py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions") py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions")
.def(py::init<>()) .def(py::init<>())
.def_readwrite("timeout", &::c10d::AllToAllOptions::timeout) .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout);
.def_readwrite("asyncOp", &::c10d::AllToAllOptions::asyncOp);
py::class_<::c10d::DistributedBackendOptions>( py::class_<::c10d::DistributedBackendOptions>(
module, "_DistributedBackendOptions") module, "_DistributedBackendOptions")

View File

@ -2500,7 +2500,7 @@ class _CoalescingManager:
def __init__(self) -> None: def __init__(self) -> None:
self.works: list[Work] = [] self.works: list[Work] = []
def append(self, work: Optional[Work] = None): def append(self, work: Work):
if work: if work:
self.works.append(work) self.works.append(work)
@ -2513,7 +2513,7 @@ class _CoalescingManager:
def _coalescing_manager( def _coalescing_manager(
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
async_ops: bool = False, async_ops: Optional[bool] = False,
): ):
""" """
Context manager used to coalesce collectives or P2P operations when possible. Context manager used to coalesce collectives or P2P operations when possible.
@ -2552,7 +2552,6 @@ def _coalescing_manager(
group._start_coalescing(device) group._start_coalescing(device)
cm = _CoalescingManager() cm = _CoalescingManager()
yield cm yield cm
work = None
op_list = _world.pg_coalesce_state.pop(group) op_list = _world.pg_coalesce_state.pop(group)
if op_list: if op_list:
# Collectives supporting "Fast Path" coalescing are captured. # Collectives supporting "Fast Path" coalescing are captured.
@ -2566,7 +2565,6 @@ def _coalescing_manager(
tensors = [op.tensor for op in op_list] tensors = [op.tensor for op in op_list]
all_reduce_opts = AllreduceCoalescedOptions() all_reduce_opts = AllreduceCoalescedOptions()
all_reduce_opts.reduceOp = not_none(op_list[0].redop) all_reduce_opts.reduceOp = not_none(op_list[0].redop)
all_reduce_opts.asyncOp = async_ops
work = group.allreduce_coalesced(tensors, all_reduce_opts) work = group.allreduce_coalesced(tensors, all_reduce_opts)
elif op0 == all_gather_into_tensor: elif op0 == all_gather_into_tensor:
inputs = [] inputs = []
@ -2574,8 +2572,6 @@ def _coalescing_manager(
for op in op_list: for op in op_list:
inputs.append(op.tensor) inputs.append(op.tensor)
outputs.append(not_none(op.dst_tensor)) outputs.append(not_none(op.dst_tensor))
all_gather_opts = AllgatherOptions()
all_gather_opts.asyncOp = async_ops
work = group.allgather_into_tensor_coalesced(outputs, inputs) work = group.allgather_into_tensor_coalesced(outputs, inputs)
elif op0 == reduce_scatter_tensor: elif op0 == reduce_scatter_tensor:
inputs = [] inputs = []
@ -2585,7 +2581,6 @@ def _coalescing_manager(
outputs.append(not_none(op.dst_tensor)) outputs.append(not_none(op.dst_tensor))
reduce_opts = ReduceScatterOptions() reduce_opts = ReduceScatterOptions()
reduce_opts.reduceOp = not_none(op_list[0].redop) reduce_opts.reduceOp = not_none(op_list[0].redop)
reduce_opts.asyncOp = async_ops
work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts) work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts)
else: else:
raise AssertionError( raise AssertionError(
@ -2598,12 +2593,9 @@ def _coalescing_manager(
work = group._end_coalescing(device) work = group._end_coalescing(device)
if async_ops: if async_ops:
cm.append(work) cm.append(work) # type: ignore[possibly-undefined]
elif ( else:
work is not None work.wait() # type: ignore[possibly-undefined]
): # Backward compatible with backends that don't sync at CPP level
work.wait()
# Otherwise, the backend has sync'ed at CPP level
def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]: def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
@ -2728,11 +2720,8 @@ def broadcast(
work = group.broadcast([tensor], opts) work = group.broadcast([tensor], opts)
if async_op: if async_op:
return work return work
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger @_exception_logger
@ -2812,7 +2801,6 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
opts = AllreduceOptions() opts = AllreduceOptions()
opts.reduceOp = op opts.reduceOp = op
opts.asyncOp = async_op
if group is None: if group is None:
group = _get_default_group() group = _get_default_group()
@ -2829,11 +2817,8 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
if async_op: if async_op:
return work return work
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger @_exception_logger
@ -2892,17 +2877,13 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
opts = AllreduceCoalescedOptions() opts = AllreduceCoalescedOptions()
opts.reduceOp = op opts.reduceOp = op
opts.asyncOp = async_op
group = group or _get_default_group() group = group or _get_default_group()
work = group.allreduce_coalesced(tensors, opts) work = group.allreduce_coalesced(tensors, opts)
if async_op: if async_op:
return work.get_future() return work.get_future()
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger @_exception_logger
@ -2947,15 +2928,11 @@ def reduce(
opts = ReduceOptions() opts = ReduceOptions()
opts.reduceOp = op opts.reduceOp = op
opts.rootRank = group_dst opts.rootRank = group_dst
opts.asyncOp = async_op
work = group.reduce([tensor], opts) work = group.reduce([tensor], opts)
if async_op: if async_op:
return work return work
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
def _object_to_tensor(obj, device, group): def _object_to_tensor(obj, device, group):
@ -3754,17 +3731,12 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
group = group or _get_default_group() group = group or _get_default_group()
opts = AllgatherOptions() work = group.allgather([tensor_list], [tensor])
opts.asyncOp = async_op
work = group.allgather([tensor_list], [tensor], opts)
if async_op: if async_op:
return work return work
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger @_exception_logger
@ -3871,11 +3843,8 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal
if async_op: if async_op:
return work return work
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger @_exception_logger
@ -3985,17 +3954,12 @@ def all_gather_coalesced(
] ]
group = group or _get_default_group() group = group or _get_default_group()
opts = AllgatherOptions() work = group.allgather_coalesced(output_tensor_lists, input_tensor_list)
opts.asyncOp = async_op
work = group.allgather_coalesced(output_tensor_lists, input_tensor_list, opts)
if async_op: if async_op:
return work.get_future() return work.get_future()
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
def _validate_output_list_for_rank(my_rank, dst, gather_list): def _validate_output_list_for_rank(my_rank, dst, gather_list):
@ -4082,16 +4046,12 @@ def gather(
opts = GatherOptions() opts = GatherOptions()
opts.rootRank = group_dst opts.rootRank = group_dst
opts.asyncOp = async_op
work = group.gather(output_tensors, input_tensors, opts) work = group.gather(output_tensors, input_tensors, opts)
if async_op: if async_op:
return work return work
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger @_exception_logger
@ -4193,11 +4153,8 @@ def scatter(
if async_op: if async_op:
return work return work
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger @_exception_logger
@ -4229,18 +4186,14 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
opts = ReduceScatterOptions() opts = ReduceScatterOptions()
opts.reduceOp = op opts.reduceOp = op
opts.asyncOp = async_op
group = group or _get_default_group() group = group or _get_default_group()
work = group.reduce_scatter([output], [input_list], opts) work = group.reduce_scatter([output], [input_list], opts)
if async_op: if async_op:
return work return work
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger @_exception_logger
@ -4340,11 +4293,8 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F
if async_op: if async_op:
return work return work
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
@deprecated( @deprecated(
@ -4497,7 +4447,6 @@ def all_to_all_single(
return return
opts = AllToAllOptions() opts = AllToAllOptions()
opts.asyncOp = async_op
_check_single_tensor(output, "output") _check_single_tensor(output, "output")
_check_single_tensor(input, "input") _check_single_tensor(input, "input")
_ensure_all_tensors_same_dtype(output, input) _ensure_all_tensors_same_dtype(output, input)
@ -4517,11 +4466,8 @@ def all_to_all_single(
if async_op: if async_op:
return work return work
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger @_exception_logger
@ -4622,7 +4568,6 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
return return
opts = AllToAllOptions() opts = AllToAllOptions()
opts.asyncOp = async_op
_check_tensor_list(output_tensor_list, "output_tensor_list") _check_tensor_list(output_tensor_list, "output_tensor_list")
_check_tensor_list(input_tensor_list, "input_tensor_list") _check_tensor_list(input_tensor_list, "input_tensor_list")
_ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list) _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list)
@ -4639,11 +4584,8 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
if async_op: if async_op:
return work return work
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
@_exception_logger @_exception_logger
@ -4674,7 +4616,6 @@ def barrier(
opts = BarrierOptions() opts = BarrierOptions()
opts.device = torch.device(_get_object_coll_device(group)) opts.device = torch.device(_get_object_coll_device(group))
opts.asyncOp = async_op
if device_ids is not None: if device_ids is not None:
if isinstance(device_ids, list): if isinstance(device_ids, list):
opts.device_ids = device_ids opts.device_ids = device_ids
@ -4688,11 +4629,8 @@ def barrier(
if async_op: if async_op:
return work return work
elif ( else:
work is not None
): # Backward compatible with backends that don't sync at CPP level
work.wait() work.wait()
# Otherwise, the backend has sync'ed at CPP level
def monitored_barrier( def monitored_barrier(

View File

@ -96,7 +96,7 @@ try:
import torchvision import torchvision
HAS_TORCHVISION = True HAS_TORCHVISION = True
except Exception: # Covering both ImportError and RuntimeError except ImportError:
HAS_TORCHVISION = False HAS_TORCHVISION = False
if sys.platform == "win32": if sys.platform == "win32":
@ -8310,14 +8310,50 @@ class DistributedTest:
def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self): def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self):
self._test_compute_bucket_assignment_by_size(use_logger=True) self._test_compute_bucket_assignment_by_size(use_logger=True)
def _determine_expected_error_verify_model_across_rank(
self, group_to_use, diff_num_params=False
):
# When running with NCCL backend, we don't expect an error on rank 0,
# rather, it will be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING. When
# running with Gloo or with debug mode wrapper, we expect the error
# to be caught inline.
# All ranks report same error when there is a # of parameter
# mismatch since we use allgather in the impl.
if diff_num_params:
expected_err = "DDP expects same model across all ranks"
ctx = self.assertRaisesRegex(RuntimeError, expected_err)
return ctx, expected_err
is_detail_dbg_mode = dist.get_debug_level() == dist.DebugLevel.DETAIL
if self.rank == 0:
if (
dist.get_backend(group_to_use) == dist.Backend.NCCL
and not is_detail_dbg_mode
):
expected_err = "caught collective operation timeout"
ctx = self.assertRaisesRegex(RuntimeError, expected_err)
else:
expected_err = None
ctx = self.assertRaises(RuntimeError)
else:
expected_err = "appears not to match"
ctx = self.assertRaisesRegex(RuntimeError, expected_err)
return ctx, expected_err
def _test_verify_model_across_rank(self, use_logger): def _test_verify_model_across_rank(self, use_logger):
group_gloo = dist.new_group( group_gloo = dist.new_group(
timeout=timedelta(seconds=60), backend=dist.Backend.GLOO timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
) )
# Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
# determinism.
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
group_to_use = dist.new_group( group_to_use = dist.new_group(
backend=dist.get_backend(), timeout=timedelta(seconds=5) backend=dist.get_backend(), timeout=timedelta(seconds=5)
) )
torch.cuda.set_device(self.rank) torch.cuda.set_device(self.rank)
ctx, expected_err = self._determine_expected_error_verify_model_across_rank(
group_to_use
)
# Create a valid model. The constructor initializes the logger that we use later. # Create a valid model. The constructor initializes the logger that we use later.
net = EmbeddingNetDifferentParams(0) net = EmbeddingNetDifferentParams(0)
@ -8335,8 +8371,7 @@ class DistributedTest:
net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1) net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1)
# if we pass a logger we can verify that it was logged # if we pass a logger we can verify that it was logged
caught = 0 with ctx:
try:
if use_logger: if use_logger:
_verify_param_shape_across_processes( _verify_param_shape_across_processes(
net.process_group, list(net.parameters()), net.logger net.process_group, list(net.parameters()), net.logger
@ -8345,13 +8380,18 @@ class DistributedTest:
_verify_param_shape_across_processes( _verify_param_shape_across_processes(
net.process_group, list(net.parameters()) net.process_group, list(net.parameters())
) )
except Exception: # Should only be run by rank 0, and blocking_wait catches and
caught = 1 # reports exception.
dist.barrier(group_to_use)
# As long as there is one rank catching the exception # We don't check when self.rank != 0 because the logger doesn't log
t = torch.Tensor([caught]) # the error "Caught collective operation" as that is not thrown in the reducer.
dist.all_reduce(t, group=group_gloo) if use_logger and self.rank != 0:
self.assertGreater(t, 0) verify_ddp_error_logged(net, expected_err)
# Perform gloo-based barrier to ensure one rank doesn't exit test
# early which causes failure with Barrier.sync.
dist.barrier(group_gloo)
@require_backend_is_available(DistTestCases.backend_feature["gpu"]) @require_backend_is_available(DistTestCases.backend_feature["gpu"])
@skip_but_pass_in_sandcastle_if( @skip_but_pass_in_sandcastle_if(
@ -8369,19 +8409,20 @@ class DistributedTest:
def test_verify_model_across_rank_without_logger(self): def test_verify_model_across_rank_without_logger(self):
self._test_verify_model_across_rank(use_logger=False) self._test_verify_model_across_rank(use_logger=False)
def _run_test_ddp_model_with_diff_params(self, net, ddp_group, group_gloo): def _run_test_ddp_model_with_diff_params(self, ctx, net, ddp_group, group_gloo):
caught = 0 with ctx:
try:
net = torch.nn.parallel.DistributedDataParallel( net = torch.nn.parallel.DistributedDataParallel(
net.to(self.rank), device_ids=[self.rank], process_group=ddp_group net.to(self.rank), device_ids=[self.rank], process_group=ddp_group
) )
except Exception: # Should only be run by rank 0, and blocking_wait catches and
caught = 1 # reports exception.
dist.barrier(ddp_group)
# As long as there is one rank catching the exception # can't use verify_ddp_error_logged here because net was never properly constructed
t = torch.Tensor([caught])
dist.all_reduce(t, group=group_gloo) # Perform gloo-based barrier to ensure one rank doesn't exit test
self.assertGreater(t, 0) # early which causes failure with Barrier.sync.
dist.barrier(group_gloo)
@require_backend_is_available(DistTestCases.backend_feature["gpu"]) @require_backend_is_available(DistTestCases.backend_feature["gpu"])
@skip_but_pass_in_sandcastle_if( @skip_but_pass_in_sandcastle_if(
@ -8392,15 +8433,21 @@ class DistributedTest:
group_gloo = dist.new_group( group_gloo = dist.new_group(
timeout=timedelta(seconds=60), backend=dist.Backend.GLOO timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
) )
# Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
# determinism.
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
group_to_use = dist.new_group( group_to_use = dist.new_group(
backend=dist.get_backend(), timeout=timedelta(seconds=10) backend=dist.get_backend(), timeout=timedelta(seconds=10)
) )
torch.cuda.set_device(self.rank) torch.cuda.set_device(self.rank)
ctx, _expected_err = self._determine_expected_error_verify_model_across_rank(
group_to_use
)
# Creates network with different sized embedding table on different # Creates network with different sized embedding table on different
# ranks. This should throw an error during DDP init. # ranks. This should throw an error during DDP init.
net = EmbeddingNetDifferentParams(self.rank) net = EmbeddingNetDifferentParams(self.rank)
self._run_test_ddp_model_with_diff_params( self._run_test_ddp_model_with_diff_params(
net, group_to_use, group_gloo ctx, net, group_to_use, group_gloo
) )
@require_backend_is_available(DistTestCases.backend_feature["gpu"]) @require_backend_is_available(DistTestCases.backend_feature["gpu"])
@ -8412,10 +8459,16 @@ class DistributedTest:
group_gloo = dist.new_group( group_gloo = dist.new_group(
timeout=timedelta(seconds=60), backend=dist.Backend.GLOO timeout=timedelta(seconds=60), backend=dist.Backend.GLOO
) )
# Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test
# determinism.
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
group_to_use = dist.new_group( group_to_use = dist.new_group(
backend=dist.get_backend(), timeout=timedelta(seconds=10) backend=dist.get_backend(), timeout=timedelta(seconds=10)
) )
torch.cuda.set_device(self.rank) torch.cuda.set_device(self.rank)
ctx, _expected_err = self._determine_expected_error_verify_model_across_rank(
group_to_use, diff_num_params=True
)
# Creates network with diff # of param across ranks, reducer should # Creates network with diff # of param across ranks, reducer should
# recognize this and throw appropriate error. # recognize this and throw appropriate error.
@ -8424,6 +8477,7 @@ class DistributedTest:
) )
self._run_test_ddp_model_with_diff_params( self._run_test_ddp_model_with_diff_params(
ctx,
net, net,
group_to_use, group_to_use,
group_gloo, group_gloo,