Add one_shot_all_reduce_copy to allow non-symm-mem allocated tensors to be reduced (#150129)

Per title, we want to be able to use it even if inputs are not registered. Separate copy would add latency, and one-shot is all about the lowest possible latency.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150129
Approved by: https://github.com/xw285cornell
This commit is contained in:
Natalia Gimelshein 2025-03-28 02:14:23 +00:00 committed by PyTorch MergeBot
parent 1e55b9c0b5
commit 8a872261dc
3 changed files with 107 additions and 23 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: c10d"]
import itertools
import os
from unittest import skipIf
@ -860,22 +861,32 @@ class SymmMemCollectiveTest(MultiProcessTestCase):
@skipIfRocm
@skip_if_lt_x_gpu(4)
@parametrize("dtype", [torch.float, torch.bfloat16])
@parametrize("align_bytes", [4, 8, 16])
@parametrize("size_bytes", [4, 8192, 8196])
def test_one_shot_all_reduce(
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
) -> None:
def test_one_shot_all_reduce(self) -> None:
self._init_process()
group_name = dist.group.WORLD.group_name
inp = symm_mem.empty(
size_bytes // dtype.itemsize, dtype=dtype, device=self.device
).normal_()
symm_mem.rendezvous(inp, group=group_name)
res = torch.ops.symm_mem.one_shot_all_reduce(inp, "sum", group_name)
self._verify_all_reduce_result(inp, res)
for dtype, size_bytes, align_bytes, copy, offset in itertools.product(
[torch.float, torch.bfloat16],
[4, 8192, 8196],
[4, 8, 16],
[True, False],
[0, 16],
):
inp = symm_mem.empty(
size_bytes // dtype.itemsize + offset, dtype=dtype, device=self.device
)
symm_mem.rendezvous(inp, group=group_name)
if not copy:
inp.normal_()
res = torch.ops.symm_mem.one_shot_all_reduce(
inp[offset:], "sum", group_name
)
if copy:
local_inp = torch.randn_like(inp[offset:])
res = torch.ops.symm_mem.one_shot_all_reduce_copy(
inp[offset:], local_inp, "sum", group_name
)
self._verify_all_reduce_result(local_inp if copy else inp[offset:], res)
dist.destroy_process_group()

View File

@ -387,7 +387,7 @@ at::Tensor multimem_all_gather_out(
// One-shot all-reduce is register-intensive because it stages values loaded
// from peers in registers before performing reduction. Setting the thread
// count to 512 to prevent/alleviate register spill.
constexpr size_t one_shot_all_reduce_max_num_blocks = 8;
constexpr size_t one_shot_all_reduce_max_num_blocks = 24;
constexpr size_t one_shot_all_reduce_max_num_threads = 512;
template <typename T, int alignment, int k_world_size>
@ -395,6 +395,7 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__
void one_shot_all_reduce_kernel(
T** input_ptrs,
T* output_ptr,
T* input_ptr,
size_t input_offset,
size_t numel,
uint32_t** signal_pads,
@ -402,12 +403,18 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__
size_t world_size) {
static_assert(alignment % sizeof(T) == 0);
constexpr size_t numel_per_thread = alignment / sizeof(T);
sync_remote_blocks<std::memory_order_relaxed>(signal_pads, rank, world_size);
__syncthreads();
// copy input to shared ptr
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
auto stride = blockDim.x * gridDim.x * numel_per_thread;
if (input_ptr) {
for (size_t i = offset; i < numel; i += stride) {
Vec<alignment> vec_st = ld_vec<alignment>(input_ptr + i);
st_vec<alignment>(input_ptrs[rank] + input_offset + i, vec_st);
}
}
// TODO make it sync with one block for no-copy case
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
__syncthreads();
for (size_t i = offset; i < numel; i += stride) {
auto vec = load_and_reduce<T, alignment, k_world_size>(
@ -416,11 +423,12 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__
}
__syncthreads();
sync_remote_blocks<std::memory_order_relaxed>(signal_pads, rank, world_size);
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
}
at::Tensor one_shot_all_reduce_out(
at::Tensor one_shot_all_reduce_out_impl(
const at::Tensor& input,
const c10::optional<at::Tensor>& local_input,
std::string reduce_op,
std::string group_name,
at::Tensor out) {
@ -430,11 +438,21 @@ at::Tensor one_shot_all_reduce_out(
out.is_contiguous(), "one_shot_all_reduce: output must be contiguous.");
TORCH_CHECK(
out.sizes() == input.sizes(),
"one_shot_all_reduce: input/output size mismatch.");
"one_shot_all_reduce: input/output size mismatch, input.sizes(): ",
input.sizes(),
", output.sizes(): ",
out.sizes());
TORCH_CHECK(
reduce_op == "sum",
"one_shot_all_reduce: only sum is supported for now.");
if (local_input.has_value()) {
TORCH_CHECK(
local_input->is_contiguous(),
"one_shot_all_reduce: local input must be contiguous.");
TORCH_CHECK(
local_input->numel() <= input.numel(),
"one_shot_all_reduce: local input size must be smaller than symm buffer size.");
}
auto symm_mem = c10d::symmetric_memory::rendezvous(input, group_name);
TORCH_CHECK(
symm_mem != nullptr,
@ -442,6 +460,13 @@ at::Tensor one_shot_all_reduce_out(
const size_t alignment =
get_and_verify_alignment(input, "one_shot_all_reduce");
if (local_input.has_value()) {
const size_t local_alignment =
get_and_verify_alignment(*local_input, "one_shot_all_reduce");
TORCH_CHECK(
alignment == local_alignment,
"one_shot_all_reduce: local input and symm buffer must have the same alignment.");
}
int num_blocks = 0, num_threads = 0;
init_elementwise_launch_config(
@ -466,6 +491,8 @@ at::Tensor one_shot_all_reduce_out(
reinterpret_cast<scalar_t**>(
symm_mem->get_buffer_ptrs_dev()),
out.data_ptr<scalar_t>(),
local_input.has_value() ? local_input->data_ptr<scalar_t>()
: nullptr,
input.storage_offset(),
input.numel(),
reinterpret_cast<uint32_t**>(
@ -479,12 +506,42 @@ at::Tensor one_shot_all_reduce_out(
return out;
}
at::Tensor one_shot_all_reduce_out(
const at::Tensor& input,
std::string reduce_op,
std::string group_name,
at::Tensor out) {
return one_shot_all_reduce_out_impl(
input, c10::nullopt, reduce_op, group_name, out);
}
at::Tensor one_shot_all_reduce_copy_out(
const at::Tensor& input,
const at::Tensor& local_input,
std::string reduce_op,
std::string group_name,
at::Tensor out) {
return one_shot_all_reduce_out_impl(
input, local_input, reduce_op, group_name, out);
}
at::Tensor one_shot_all_reduce(
const at::Tensor& input,
std::string reduce_op,
std::string group_name) {
auto out = at::empty_like(input);
return one_shot_all_reduce_out(input, reduce_op, group_name, out);
return one_shot_all_reduce_out_impl(
input, c10::nullopt, reduce_op, group_name, out);
}
at::Tensor one_shot_all_reduce_copy(
const at::Tensor& input,
const at::Tensor& local_input,
std::string reduce_op,
std::string group_name) {
auto out = at::empty_like(local_input);
return one_shot_all_reduce_out_impl(
input, local_input, reduce_op, group_name, out);
}
constexpr size_t two_shot_all_reduce_max_num_blocks = 24;
@ -712,6 +769,8 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
m.impl("multimem_all_gather_out", ::multimem_all_gather_out);
m.impl("one_shot_all_reduce", ::one_shot_all_reduce);
m.impl("one_shot_all_reduce_out", ::one_shot_all_reduce_out);
m.impl("one_shot_all_reduce_copy", ::one_shot_all_reduce_copy);
m.impl("one_shot_all_reduce_copy_out", ::one_shot_all_reduce_copy_out);
m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_);
m.impl("_async_input_mm", c10d::cuda::detail::async_input_mm);
#endif

View File

@ -217,6 +217,14 @@ at::Tensor one_shot_all_reduce_meta(
return at::empty_like(input);
}
at::Tensor one_shot_all_reduce_copy_meta(
const at::Tensor& symm_buffer,
const at::Tensor& local_input,
std::string reduce_op,
std::string group_name) {
return at::empty_like(local_input);
}
TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
m.def(
"multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)");
@ -230,6 +238,11 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
"one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor");
m.def(
"one_shot_all_reduce_out(Tensor input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)");
m.def(
"one_shot_all_reduce_copy(Tensor symm_buffer, Tensor local_input, str reduce_op, str group_name) -> Tensor");
m.def(
"one_shot_all_reduce_copy_out(Tensor symm_buffer, Tensor local_input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)");
m.def(
"two_shot_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)");
@ -252,6 +265,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
m.impl("one_shot_all_reduce", one_shot_all_reduce_meta);
m.impl("one_shot_all_reduce_copy", one_shot_all_reduce_copy_meta);
}
} // namespace