mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Implement scatter primitive for ProcessGroupNCCL (#70029)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70029
This PR implements NCCL scatter and add scatter to ProcessGroupNCCL.
NCCL doesn’t directly provide primitives for scatter, so we need to be implemented on top of NCCL’s send/recv API.
1. In ProcessGroupNCCL.cpp, the inputTensors are first flattened, then outputTensors and inputFlattened are passed by the collective class to scatter() function in nccl.cpp.
2. In nccl.cpp, scatter is implemented using ncclSend/ncclRecv: the root rank uses a for loop to send(distribute) the inputTensors to each rank, then all the ranks receive the inputTensor from the root rank.
ghstack-source-id: 147754837
Test Plan:
test_scatter_ops
test_scatter_stress
test_scatter_checks
Reviewed By: pritamdamania87
Differential Revision: D33154823
fbshipit-source-id: 4513e7eaf7d47a60eb67da99dc6c2e9a2882f3fd
(cherry picked from commit 93201f9d4a)
This commit is contained in:
parent
9b53d3194c
commit
6feba4bc7e
|
|
@ -628,6 +628,126 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
|
|||
opts.rootRank = 0
|
||||
pg.gather(output_ts, tensors2, opts)
|
||||
|
||||
@requires_nccl()
|
||||
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
def test_scatter_ops(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
pg = self._create_process_group_nccl(store, self.opts())
|
||||
local_device_ids = self.rank_to_GPU[self.rank]
|
||||
num_gpus = len(local_device_ids)
|
||||
|
||||
def scatter(output_t, input_t, rootRank):
|
||||
opts = c10d.ScatterOptions()
|
||||
opts.rootRank = rootRank
|
||||
if rootRank == self.rank:
|
||||
work = pg.scatter(output_t, input_t, opts)
|
||||
else:
|
||||
work = pg.scatter(output_t, [], opts)
|
||||
work.wait()
|
||||
|
||||
# init output
|
||||
tensors = []
|
||||
for device_id in local_device_ids:
|
||||
tensors.append(torch.tensor([-1]).cuda(device_id))
|
||||
|
||||
# init input
|
||||
scatter_list = []
|
||||
for idx in range(num_gpus):
|
||||
gpu_idx = local_device_ids[idx]
|
||||
scatter_list.append([])
|
||||
for rank in range(self.world_size):
|
||||
scatter_list[idx].append(torch.tensor([rank]).cuda(gpu_idx))
|
||||
|
||||
# test each rank to scatter
|
||||
expected = [torch.tensor([self.rank])]
|
||||
for rank in range(self.world_size):
|
||||
scatter(tensors, scatter_list, rank)
|
||||
self.assertEqual(expected, tensors)
|
||||
|
||||
@requires_nccl()
|
||||
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
def test_scatter_stress(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
pg = self._create_process_group_nccl(store, self.opts())
|
||||
local_device_ids = self.rank_to_GPU[self.rank]
|
||||
num_gpus = len(local_device_ids)
|
||||
|
||||
def scatter(output_t, input_t, rootRank):
|
||||
opts = c10d.ScatterOptions()
|
||||
opts.rootRank = rootRank
|
||||
if rootRank == self.rank:
|
||||
work = pg.scatter(output_t, input_t, opts)
|
||||
else:
|
||||
work = pg.scatter(output_t, [], opts)
|
||||
work.wait()
|
||||
|
||||
stress_length = 1000
|
||||
|
||||
# init output
|
||||
tensors = []
|
||||
for i in range(stress_length):
|
||||
tensors.append([])
|
||||
for device_id in local_device_ids:
|
||||
tensors[i].append(torch.tensor([-1]).cuda(device_id))
|
||||
|
||||
# init input
|
||||
scatter_list = []
|
||||
for i in range(stress_length):
|
||||
scatter_list.append([[] for _ in range(num_gpus)])
|
||||
for idx, ls in enumerate(scatter_list[i]):
|
||||
gpu_idx = local_device_ids[idx]
|
||||
for rank in range(self.world_size):
|
||||
ls.append(torch.tensor([rank]).cuda(gpu_idx))
|
||||
|
||||
|
||||
# test each rank to scatter
|
||||
expected = [torch.tensor([self.rank])]
|
||||
for i in range(stress_length):
|
||||
for rank in range(self.world_size):
|
||||
scatter(tensors[i], scatter_list[i], rank)
|
||||
# Verification
|
||||
self.assertEqual(tensors[i], expected)
|
||||
|
||||
@requires_nccl()
|
||||
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
def test_scatter_checks(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
pg = self._create_process_group_nccl(store, self.opts())
|
||||
local_device_ids = self.rank_to_GPU[self.rank]
|
||||
num_gpus = len(local_device_ids)
|
||||
|
||||
# init output
|
||||
tensors = []
|
||||
for device_id in local_device_ids:
|
||||
tensors.append(torch.tensor([-1]).cuda(device_id))
|
||||
|
||||
# init input
|
||||
scatter_list = []
|
||||
for idx in range(num_gpus):
|
||||
gpu_idx = local_device_ids[idx]
|
||||
scatter_list.append([])
|
||||
for rank in range(self.world_size):
|
||||
scatter_list[idx].append(torch.tensor([rank]).cuda(gpu_idx))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
|
||||
opts = c10d.ScatterOptions()
|
||||
opts.rootRank = -1
|
||||
pg.scatter(tensors, scatter_list, opts)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
|
||||
pg.scatter(tensors, scatter_list, 0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
|
||||
opts = c10d.ScatterOptions()
|
||||
opts.rootRank = self.world_size
|
||||
pg.scatter(tensors, scatter_list, opts)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Tensor list must be nonempty"
|
||||
):
|
||||
opts = c10d.ScatterOptions()
|
||||
opts.rootRank = 0
|
||||
pg.scatter([], scatter_list, opts)
|
||||
|
||||
@requires_nccl()
|
||||
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
|
|
|
|||
|
|
@ -856,6 +856,52 @@ void gather(
|
|||
#endif
|
||||
}
|
||||
|
||||
void scatter(
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
at::Tensor& outputs,
|
||||
ncclComm_t _comm,
|
||||
at::cuda::CUDAStream& stream,
|
||||
int32_t root) {
|
||||
#ifdef USE_NCCL
|
||||
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
|
||||
using namespace torch::cuda::nccl::detail;
|
||||
|
||||
auto comm = to_nccl_comm(_comm);
|
||||
int numranks, cur_rank;
|
||||
NCCL_CHECK(ncclCommCount(comm, &numranks));
|
||||
NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
|
||||
|
||||
NCCL_CHECK(ncclGroupStart());
|
||||
if (cur_rank == root)
|
||||
{
|
||||
for (int r = 0; r < numranks; r++)
|
||||
{
|
||||
if (r != root) {
|
||||
size_t send_count = inputs[r].numel();
|
||||
auto send_type = to_nccl_data_type(inputs[r]);
|
||||
const auto* sendbuff = reinterpret_cast<char*>(inputs[r].data_ptr());
|
||||
NCCL_CHECK(ncclSend(sendbuff, send_count, send_type, r, comm, stream));
|
||||
} else {
|
||||
// on its own rank, simply copy it to the output
|
||||
outputs.copy_(inputs[r]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
size_t recv_count = outputs.numel();
|
||||
auto recv_type = to_nccl_data_type(outputs);
|
||||
auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr());
|
||||
NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream));
|
||||
}
|
||||
NCCL_CHECK(ncclGroupEnd());
|
||||
|
||||
#else
|
||||
AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0");
|
||||
#endif
|
||||
#else
|
||||
AT_ERROR("PyTorch built without NCCL support");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
} // namespace nccl
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -144,6 +144,13 @@ TORCH_CUDA_CPP_API void reduce_scatter(
|
|||
const stream_list& streams = {},
|
||||
const comm_list& user_comms = {});
|
||||
|
||||
TORCH_CUDA_CPP_API void scatter(
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
at::Tensor& outputs,
|
||||
ncclComm_t comm,
|
||||
at::cuda::CUDAStream& stream,
|
||||
int32_t root = 0);
|
||||
|
||||
TORCH_CUDA_CPP_API void all_gather(
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
std::vector<at::Tensor>& outputs,
|
||||
|
|
|
|||
|
|
@ -2286,10 +2286,78 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::gather(
|
|||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::scatter(
|
||||
std::vector<at::Tensor>& /* unused */,
|
||||
std::vector<std::vector<at::Tensor>>& /* unused */,
|
||||
const ScatterOptions& /* unused */) {
|
||||
TORCH_CHECK(false, "ProcessGroupNCCL does not support scatter");
|
||||
std::vector<at::Tensor>& outputTensors,
|
||||
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||
const ScatterOptions& opts) {
|
||||
|
||||
static auto invalidArgument = [](const std::string& msg) {
|
||||
TORCH_CHECK(false, "ProcessGroupNCCL::scatter: " + msg);
|
||||
};
|
||||
|
||||
assertRootRank(invalidArgument, opts.rootRank, size_);
|
||||
check_gpu_tensors_different_devices(outputTensors);
|
||||
assertSingleElementInput(invalidArgument, outputTensors);
|
||||
|
||||
// @lint-ignore CLANGTIDY
|
||||
auto tensor = outputTensors.back();
|
||||
RECORD_PARAM_COMMS(
|
||||
rank_, // rank
|
||||
"scatter", // colName
|
||||
tensor.numel(), // inSize
|
||||
tensor.numel() *
|
||||
this->getSize(), // outSize
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>()); // outSplitSize
|
||||
|
||||
std::vector<at::Tensor> inputs;
|
||||
|
||||
if (getRank() == opts.rootRank) {
|
||||
if (inputTensors.size() != 1) {
|
||||
std::stringstream ss;
|
||||
ss << "requires a single-element input list containing a list with "
|
||||
<< getSize() << " tensors.";
|
||||
invalidArgument(ss.str());
|
||||
} else if (inputTensors[0].size() != static_cast<size_t>(getSize())) {
|
||||
std::stringstream ss;
|
||||
ss << "Incorrect input list size " << inputTensors[0].size()
|
||||
<< ". Input list size should be " << getSize()
|
||||
<< ", same as size of the process group.";
|
||||
invalidArgument(ss.str());
|
||||
}
|
||||
|
||||
const auto& options = outputTensors[0].options();
|
||||
const auto& sizes = outputTensors[0].sizes();
|
||||
assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes);
|
||||
inputs = inputTensors[0];
|
||||
} else {
|
||||
// if not in the root rank, initialize inputTensors as empty place holder
|
||||
// with an empty list
|
||||
if (inputTensors.size() != 0) {
|
||||
invalidArgument("requires empty input on non-root");
|
||||
}
|
||||
inputs = {};
|
||||
}
|
||||
|
||||
return collective(
|
||||
outputTensors,
|
||||
inputs,
|
||||
[&](at::Tensor& /* unused */,
|
||||
at::Tensor& /* unused */,
|
||||
ncclComm_t comm,
|
||||
at::cuda::CUDAStream& stream) {
|
||||
const auto root = opts.rootRank;
|
||||
if (getRank() == root) {
|
||||
for(auto input: inputs) {
|
||||
c10::cuda::CUDACachingAllocator::recordStream(
|
||||
input.storage().data_ptr(), stream);
|
||||
}
|
||||
}
|
||||
torch::cuda::nccl::scatter(inputs, outputTensors[0], comm, stream, root);
|
||||
return ncclSuccess;
|
||||
},
|
||||
OpType::SCATTER,
|
||||
"nccl:scatter");
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::recvAnysource(
|
||||
|
|
|
|||
|
|
@ -69,7 +69,6 @@ class DistTestCases:
|
|||
# Backends that do not support a specific collective
|
||||
skip_collective = {}
|
||||
skip_collective["allgather_coalesced"] = {"nccl", "mpi"}
|
||||
skip_collective["scatter"] = {"nccl"}
|
||||
skip_collective["reduce"] = set()
|
||||
skip_collective["sendrecv anysource"] = {"nccl"}
|
||||
skip_collective["cpu barrier"] = {"nccl"}
|
||||
|
|
|
|||
|
|
@ -2639,7 +2639,9 @@ class DistributedTest:
|
|||
)
|
||||
|
||||
# SCATTER
|
||||
def _test_scatter_helper(self, group, group_id, rank, dtype=torch.float):
|
||||
def _test_scatter_helper(
|
||||
self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
|
||||
):
|
||||
for dest in group:
|
||||
tensor = _build_tensor(dest + 1, -1, dtype=dtype)
|
||||
expected_tensor = _build_tensor(dest + 1, rank, dtype=dtype)
|
||||
|
|
@ -2648,6 +2650,9 @@ class DistributedTest:
|
|||
if rank == dest
|
||||
else []
|
||||
)
|
||||
if cuda:
|
||||
tensor = tensor.cuda(rank_to_GPU[rank][0])
|
||||
tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
|
||||
if dtype == torch.complex64:
|
||||
tensor_shapes = [torch.view_as_real(t).shape for t in tensors]
|
||||
else:
|
||||
|
|
@ -2660,6 +2665,7 @@ class DistributedTest:
|
|||
src=dest,
|
||||
scatter_list=tensors,
|
||||
group=group_id,
|
||||
expect_event=False,
|
||||
tensor_shapes=tensor_shapes,
|
||||
)
|
||||
self.assertEqual(tensor, expected_tensor)
|
||||
|
|
@ -2667,7 +2673,6 @@ class DistributedTest:
|
|||
self._barrier()
|
||||
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["scatter"], f"{BACKEND} does not support scatter")
|
||||
def test_scatter_checks(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
one = torch.ones([1])
|
||||
|
|
@ -2690,23 +2695,37 @@ class DistributedTest:
|
|||
dist.scatter(output)
|
||||
self.assertEqual(output, one * rank)
|
||||
|
||||
@sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["scatter"], f"{BACKEND} does not support scatter")
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
def test_scatter(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_scatter_helper(group, group_id, rank)
|
||||
|
||||
@sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["scatter"], f"{BACKEND} does not support scatter")
|
||||
@sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA gather")
|
||||
@skip_if_no_gpu
|
||||
def test_scatter_cuda(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
|
||||
self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU)
|
||||
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
def test_scatter_complex(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
self._test_scatter_helper(group, group_id, rank, dtype=torch.cfloat)
|
||||
|
||||
@sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["scatter"], f"{BACKEND} does not support scatter")
|
||||
@sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA gather")
|
||||
@skip_if_no_gpu
|
||||
def test_scatter_cuda_complex(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
|
||||
self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat)
|
||||
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
@skip_if_small_worldsize
|
||||
def test_scatter_group(self):
|
||||
group, group_id, rank = self._init_group_test()
|
||||
self._test_scatter_helper(group, group_id, rank)
|
||||
|
||||
@sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["scatter"], f"{BACKEND} does not support scatter")
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
|
||||
def test_scatter_full_group(self):
|
||||
group, group_id, rank = self._init_full_group_test()
|
||||
self._test_scatter_helper(group, group_id, rank)
|
||||
|
|
@ -2825,7 +2844,6 @@ class DistributedTest:
|
|||
self._test_all_gather_helper(group, group_id, rank)
|
||||
|
||||
@sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all gather")
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "CUDA all gather skipped for NCCL")
|
||||
@skip_if_no_gpu
|
||||
def test_all_gather_cuda(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
|
|
@ -2838,7 +2856,6 @@ class DistributedTest:
|
|||
self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat)
|
||||
|
||||
@sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all gather")
|
||||
@sandcastle_skip_if(BACKEND == "nccl", "CUDA all gather skipped for NCCL")
|
||||
@skip_if_no_gpu
|
||||
def test_all_gather_cuda_complex(self):
|
||||
group, group_id, rank = self._init_global_test()
|
||||
|
|
@ -6925,7 +6942,6 @@ class DistributedTest:
|
|||
dist.barrier()
|
||||
|
||||
@require_backend({"gloo"})
|
||||
@sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["scatter"], f"{BACKEND} does not support scatter")
|
||||
def test_scatter_object_list(self):
|
||||
src_rank = 0
|
||||
scatter_list = (
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user