Implement scatter primitive for ProcessGroupNCCL (#70029)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70029

This PR implements NCCL scatter and add scatter to ProcessGroupNCCL.

NCCL doesn’t directly provide primitives for scatter, so we need to be implemented on top of NCCL’s send/recv API.

1. In ProcessGroupNCCL.cpp, the inputTensors are first flattened, then outputTensors and inputFlattened are passed by the collective class to scatter() function in nccl.cpp.
2. In nccl.cpp, scatter is implemented using ncclSend/ncclRecv: the root rank uses a for loop to send(distribute) the inputTensors to each rank, then all the ranks receive the inputTensor from the root rank.
ghstack-source-id: 147754837

Test Plan:
test_scatter_ops
test_scatter_stress
test_scatter_checks

Reviewed By: pritamdamania87

Differential Revision: D33154823

fbshipit-source-id: 4513e7eaf7d47a60eb67da99dc6c2e9a2882f3fd
(cherry picked from commit 93201f9d4a)
This commit is contained in:
Wanchao Liang 2022-01-27 11:32:48 -08:00 committed by PyTorch MergeBot
parent 9b53d3194c
commit 6feba4bc7e
6 changed files with 270 additions and 14 deletions

View File

@ -628,6 +628,126 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
opts.rootRank = 0
pg.gather(output_ts, tensors2, opts)
@requires_nccl()
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
def test_scatter_ops(self):
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())
local_device_ids = self.rank_to_GPU[self.rank]
num_gpus = len(local_device_ids)
def scatter(output_t, input_t, rootRank):
opts = c10d.ScatterOptions()
opts.rootRank = rootRank
if rootRank == self.rank:
work = pg.scatter(output_t, input_t, opts)
else:
work = pg.scatter(output_t, [], opts)
work.wait()
# init output
tensors = []
for device_id in local_device_ids:
tensors.append(torch.tensor([-1]).cuda(device_id))
# init input
scatter_list = []
for idx in range(num_gpus):
gpu_idx = local_device_ids[idx]
scatter_list.append([])
for rank in range(self.world_size):
scatter_list[idx].append(torch.tensor([rank]).cuda(gpu_idx))
# test each rank to scatter
expected = [torch.tensor([self.rank])]
for rank in range(self.world_size):
scatter(tensors, scatter_list, rank)
self.assertEqual(expected, tensors)
@requires_nccl()
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
def test_scatter_stress(self):
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())
local_device_ids = self.rank_to_GPU[self.rank]
num_gpus = len(local_device_ids)
def scatter(output_t, input_t, rootRank):
opts = c10d.ScatterOptions()
opts.rootRank = rootRank
if rootRank == self.rank:
work = pg.scatter(output_t, input_t, opts)
else:
work = pg.scatter(output_t, [], opts)
work.wait()
stress_length = 1000
# init output
tensors = []
for i in range(stress_length):
tensors.append([])
for device_id in local_device_ids:
tensors[i].append(torch.tensor([-1]).cuda(device_id))
# init input
scatter_list = []
for i in range(stress_length):
scatter_list.append([[] for _ in range(num_gpus)])
for idx, ls in enumerate(scatter_list[i]):
gpu_idx = local_device_ids[idx]
for rank in range(self.world_size):
ls.append(torch.tensor([rank]).cuda(gpu_idx))
# test each rank to scatter
expected = [torch.tensor([self.rank])]
for i in range(stress_length):
for rank in range(self.world_size):
scatter(tensors[i], scatter_list[i], rank)
# Verification
self.assertEqual(tensors[i], expected)
@requires_nccl()
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
def test_scatter_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
pg = self._create_process_group_nccl(store, self.opts())
local_device_ids = self.rank_to_GPU[self.rank]
num_gpus = len(local_device_ids)
# init output
tensors = []
for device_id in local_device_ids:
tensors.append(torch.tensor([-1]).cuda(device_id))
# init input
scatter_list = []
for idx in range(num_gpus):
gpu_idx = local_device_ids[idx]
scatter_list.append([])
for rank in range(self.world_size):
scatter_list[idx].append(torch.tensor([rank]).cuda(gpu_idx))
with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
opts = c10d.ScatterOptions()
opts.rootRank = -1
pg.scatter(tensors, scatter_list, opts)
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
pg.scatter(tensors, scatter_list, 0)
with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
opts = c10d.ScatterOptions()
opts.rootRank = self.world_size
pg.scatter(tensors, scatter_list, opts)
with self.assertRaisesRegex(
RuntimeError, "Tensor list must be nonempty"
):
opts = c10d.ScatterOptions()
opts.rootRank = 0
pg.scatter([], scatter_list, opts)
@requires_nccl()
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")

View File

@ -856,6 +856,52 @@ void gather(
#endif
}
void scatter(
const std::vector<at::Tensor>& inputs,
at::Tensor& outputs,
ncclComm_t _comm,
at::cuda::CUDAStream& stream,
int32_t root) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;
auto comm = to_nccl_comm(_comm);
int numranks, cur_rank;
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
NCCL_CHECK(ncclGroupStart());
if (cur_rank == root)
{
for (int r = 0; r < numranks; r++)
{
if (r != root) {
size_t send_count = inputs[r].numel();
auto send_type = to_nccl_data_type(inputs[r]);
const auto* sendbuff = reinterpret_cast<char*>(inputs[r].data_ptr());
NCCL_CHECK(ncclSend(sendbuff, send_count, send_type, r, comm, stream));
} else {
// on its own rank, simply copy it to the output
outputs.copy_(inputs[r]);
}
}
} else {
size_t recv_count = outputs.numel();
auto recv_type = to_nccl_data_type(outputs);
auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr());
NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream));
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
} // namespace nccl
} // namespace cuda

View File

@ -144,6 +144,13 @@ TORCH_CUDA_CPP_API void reduce_scatter(
const stream_list& streams = {},
const comm_list& user_comms = {});
TORCH_CUDA_CPP_API void scatter(
const std::vector<at::Tensor>& inputs,
at::Tensor& outputs,
ncclComm_t comm,
at::cuda::CUDAStream& stream,
int32_t root = 0);
TORCH_CUDA_CPP_API void all_gather(
const std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,

View File

@ -2286,10 +2286,78 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::gather(
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::scatter(
std::vector<at::Tensor>& /* unused */,
std::vector<std::vector<at::Tensor>>& /* unused */,
const ScatterOptions& /* unused */) {
TORCH_CHECK(false, "ProcessGroupNCCL does not support scatter");
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts) {
static auto invalidArgument = [](const std::string& msg) {
TORCH_CHECK(false, "ProcessGroupNCCL::scatter: " + msg);
};
assertRootRank(invalidArgument, opts.rootRank, size_);
check_gpu_tensors_different_devices(outputTensors);
assertSingleElementInput(invalidArgument, outputTensors);
// @lint-ignore CLANGTIDY
auto tensor = outputTensors.back();
RECORD_PARAM_COMMS(
rank_, // rank
"scatter", // colName
tensor.numel(), // inSize
tensor.numel() *
this->getSize(), // outSize
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>()); // outSplitSize
std::vector<at::Tensor> inputs;
if (getRank() == opts.rootRank) {
if (inputTensors.size() != 1) {
std::stringstream ss;
ss << "requires a single-element input list containing a list with "
<< getSize() << " tensors.";
invalidArgument(ss.str());
} else if (inputTensors[0].size() != static_cast<size_t>(getSize())) {
std::stringstream ss;
ss << "Incorrect input list size " << inputTensors[0].size()
<< ". Input list size should be " << getSize()
<< ", same as size of the process group.";
invalidArgument(ss.str());
}
const auto& options = outputTensors[0].options();
const auto& sizes = outputTensors[0].sizes();
assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes);
inputs = inputTensors[0];
} else {
// if not in the root rank, initialize inputTensors as empty place holder
// with an empty list
if (inputTensors.size() != 0) {
invalidArgument("requires empty input on non-root");
}
inputs = {};
}
return collective(
outputTensors,
inputs,
[&](at::Tensor& /* unused */,
at::Tensor& /* unused */,
ncclComm_t comm,
at::cuda::CUDAStream& stream) {
const auto root = opts.rootRank;
if (getRank() == root) {
for(auto input: inputs) {
c10::cuda::CUDACachingAllocator::recordStream(
input.storage().data_ptr(), stream);
}
}
torch::cuda::nccl::scatter(inputs, outputTensors[0], comm, stream, root);
return ncclSuccess;
},
OpType::SCATTER,
"nccl:scatter");
}
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::recvAnysource(

View File

@ -69,7 +69,6 @@ class DistTestCases:
# Backends that do not support a specific collective
skip_collective = {}
skip_collective["allgather_coalesced"] = {"nccl", "mpi"}
skip_collective["scatter"] = {"nccl"}
skip_collective["reduce"] = set()
skip_collective["sendrecv anysource"] = {"nccl"}
skip_collective["cpu barrier"] = {"nccl"}

View File

@ -2639,7 +2639,9 @@ class DistributedTest:
)
# SCATTER
def _test_scatter_helper(self, group, group_id, rank, dtype=torch.float):
def _test_scatter_helper(
self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
):
for dest in group:
tensor = _build_tensor(dest + 1, -1, dtype=dtype)
expected_tensor = _build_tensor(dest + 1, rank, dtype=dtype)
@ -2648,6 +2650,9 @@ class DistributedTest:
if rank == dest
else []
)
if cuda:
tensor = tensor.cuda(rank_to_GPU[rank][0])
tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
if dtype == torch.complex64:
tensor_shapes = [torch.view_as_real(t).shape for t in tensors]
else:
@ -2660,6 +2665,7 @@ class DistributedTest:
src=dest,
scatter_list=tensors,
group=group_id,
expect_event=False,
tensor_shapes=tensor_shapes,
)
self.assertEqual(tensor, expected_tensor)
@ -2667,7 +2673,6 @@ class DistributedTest:
self._barrier()
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
@sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["scatter"], f"{BACKEND} does not support scatter")
def test_scatter_checks(self):
group, group_id, rank = self._init_global_test()
one = torch.ones([1])
@ -2690,23 +2695,37 @@ class DistributedTest:
dist.scatter(output)
self.assertEqual(output, one * rank)
@sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["scatter"], f"{BACKEND} does not support scatter")
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
def test_scatter(self):
group, group_id, rank = self._init_global_test()
self._test_scatter_helper(group, group_id, rank)
@sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["scatter"], f"{BACKEND} does not support scatter")
@sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA gather")
@skip_if_no_gpu
def test_scatter_cuda(self):
group, group_id, rank = self._init_global_test()
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU)
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
def test_scatter_complex(self):
group, group_id, rank = self._init_global_test()
self._test_scatter_helper(group, group_id, rank, dtype=torch.cfloat)
@sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["scatter"], f"{BACKEND} does not support scatter")
@sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA gather")
@skip_if_no_gpu
def test_scatter_cuda_complex(self):
group, group_id, rank = self._init_global_test()
rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat)
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
@skip_if_small_worldsize
def test_scatter_group(self):
group, group_id, rank = self._init_group_test()
self._test_scatter_helper(group, group_id, rank)
@sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["scatter"], f"{BACKEND} does not support scatter")
@sandcastle_skip_if(BACKEND == "nccl", "Nccl does not support CPU tensors")
def test_scatter_full_group(self):
group, group_id, rank = self._init_full_group_test()
self._test_scatter_helper(group, group_id, rank)
@ -2825,7 +2844,6 @@ class DistributedTest:
self._test_all_gather_helper(group, group_id, rank)
@sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all gather")
@sandcastle_skip_if(BACKEND == "nccl", "CUDA all gather skipped for NCCL")
@skip_if_no_gpu
def test_all_gather_cuda(self):
group, group_id, rank = self._init_global_test()
@ -2838,7 +2856,6 @@ class DistributedTest:
self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat)
@sandcastle_skip_if(BACKEND != "nccl", "Only Nccl supports CUDA all gather")
@sandcastle_skip_if(BACKEND == "nccl", "CUDA all gather skipped for NCCL")
@skip_if_no_gpu
def test_all_gather_cuda_complex(self):
group, group_id, rank = self._init_global_test()
@ -6925,7 +6942,6 @@ class DistributedTest:
dist.barrier()
@require_backend({"gloo"})
@sandcastle_skip_if(BACKEND in DistTestCases.skip_collective["scatter"], f"{BACKEND} does not support scatter")
def test_scatter_object_list(self):
src_rank = 0
scatter_list = (