mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add one_shot_all_reduce_copy to allow non-symm-mem allocated tensors to be reduced (#150129)
Per title, we want to be able to use it even if inputs are not registered. Separate copy would add latency, and one-shot is all about the lowest possible latency. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150129 Approved by: https://github.com/xw285cornell
This commit is contained in:
parent
1e55b9c0b5
commit
8a872261dc
|
|
@ -1,5 +1,6 @@
|
|||
# Owner(s): ["module: c10d"]
|
||||
|
||||
import itertools
|
||||
import os
|
||||
from unittest import skipIf
|
||||
|
||||
|
|
@ -860,22 +861,32 @@ class SymmMemCollectiveTest(MultiProcessTestCase):
|
|||
|
||||
@skipIfRocm
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@parametrize("dtype", [torch.float, torch.bfloat16])
|
||||
@parametrize("align_bytes", [4, 8, 16])
|
||||
@parametrize("size_bytes", [4, 8192, 8196])
|
||||
def test_one_shot_all_reduce(
|
||||
self, dtype: torch.dtype, size_bytes: int, align_bytes: int
|
||||
) -> None:
|
||||
def test_one_shot_all_reduce(self) -> None:
|
||||
self._init_process()
|
||||
group_name = dist.group.WORLD.group_name
|
||||
|
||||
inp = symm_mem.empty(
|
||||
size_bytes // dtype.itemsize, dtype=dtype, device=self.device
|
||||
).normal_()
|
||||
symm_mem.rendezvous(inp, group=group_name)
|
||||
|
||||
res = torch.ops.symm_mem.one_shot_all_reduce(inp, "sum", group_name)
|
||||
self._verify_all_reduce_result(inp, res)
|
||||
for dtype, size_bytes, align_bytes, copy, offset in itertools.product(
|
||||
[torch.float, torch.bfloat16],
|
||||
[4, 8192, 8196],
|
||||
[4, 8, 16],
|
||||
[True, False],
|
||||
[0, 16],
|
||||
):
|
||||
inp = symm_mem.empty(
|
||||
size_bytes // dtype.itemsize + offset, dtype=dtype, device=self.device
|
||||
)
|
||||
symm_mem.rendezvous(inp, group=group_name)
|
||||
if not copy:
|
||||
inp.normal_()
|
||||
res = torch.ops.symm_mem.one_shot_all_reduce(
|
||||
inp[offset:], "sum", group_name
|
||||
)
|
||||
if copy:
|
||||
local_inp = torch.randn_like(inp[offset:])
|
||||
res = torch.ops.symm_mem.one_shot_all_reduce_copy(
|
||||
inp[offset:], local_inp, "sum", group_name
|
||||
)
|
||||
self._verify_all_reduce_result(local_inp if copy else inp[offset:], res)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
|
|
|||
|
|
@ -387,7 +387,7 @@ at::Tensor multimem_all_gather_out(
|
|||
// One-shot all-reduce is register-intensive because it stages values loaded
|
||||
// from peers in registers before performing reduction. Setting the thread
|
||||
// count to 512 to prevent/alleviate register spill.
|
||||
constexpr size_t one_shot_all_reduce_max_num_blocks = 8;
|
||||
constexpr size_t one_shot_all_reduce_max_num_blocks = 24;
|
||||
constexpr size_t one_shot_all_reduce_max_num_threads = 512;
|
||||
|
||||
template <typename T, int alignment, int k_world_size>
|
||||
|
|
@ -395,6 +395,7 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__
|
|||
void one_shot_all_reduce_kernel(
|
||||
T** input_ptrs,
|
||||
T* output_ptr,
|
||||
T* input_ptr,
|
||||
size_t input_offset,
|
||||
size_t numel,
|
||||
uint32_t** signal_pads,
|
||||
|
|
@ -402,12 +403,18 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__
|
|||
size_t world_size) {
|
||||
static_assert(alignment % sizeof(T) == 0);
|
||||
constexpr size_t numel_per_thread = alignment / sizeof(T);
|
||||
|
||||
sync_remote_blocks<std::memory_order_relaxed>(signal_pads, rank, world_size);
|
||||
__syncthreads();
|
||||
|
||||
// copy input to shared ptr
|
||||
auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread;
|
||||
auto stride = blockDim.x * gridDim.x * numel_per_thread;
|
||||
if (input_ptr) {
|
||||
for (size_t i = offset; i < numel; i += stride) {
|
||||
Vec<alignment> vec_st = ld_vec<alignment>(input_ptr + i);
|
||||
st_vec<alignment>(input_ptrs[rank] + input_offset + i, vec_st);
|
||||
}
|
||||
}
|
||||
// TODO make it sync with one block for no-copy case
|
||||
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
|
||||
__syncthreads();
|
||||
|
||||
for (size_t i = offset; i < numel; i += stride) {
|
||||
auto vec = load_and_reduce<T, alignment, k_world_size>(
|
||||
|
|
@ -416,11 +423,12 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__
|
|||
}
|
||||
|
||||
__syncthreads();
|
||||
sync_remote_blocks<std::memory_order_relaxed>(signal_pads, rank, world_size);
|
||||
sync_remote_blocks<std::memory_order_acq_rel>(signal_pads, rank, world_size);
|
||||
}
|
||||
|
||||
at::Tensor one_shot_all_reduce_out(
|
||||
at::Tensor one_shot_all_reduce_out_impl(
|
||||
const at::Tensor& input,
|
||||
const c10::optional<at::Tensor>& local_input,
|
||||
std::string reduce_op,
|
||||
std::string group_name,
|
||||
at::Tensor out) {
|
||||
|
|
@ -430,11 +438,21 @@ at::Tensor one_shot_all_reduce_out(
|
|||
out.is_contiguous(), "one_shot_all_reduce: output must be contiguous.");
|
||||
TORCH_CHECK(
|
||||
out.sizes() == input.sizes(),
|
||||
"one_shot_all_reduce: input/output size mismatch.");
|
||||
"one_shot_all_reduce: input/output size mismatch, input.sizes(): ",
|
||||
input.sizes(),
|
||||
", output.sizes(): ",
|
||||
out.sizes());
|
||||
TORCH_CHECK(
|
||||
reduce_op == "sum",
|
||||
"one_shot_all_reduce: only sum is supported for now.");
|
||||
|
||||
if (local_input.has_value()) {
|
||||
TORCH_CHECK(
|
||||
local_input->is_contiguous(),
|
||||
"one_shot_all_reduce: local input must be contiguous.");
|
||||
TORCH_CHECK(
|
||||
local_input->numel() <= input.numel(),
|
||||
"one_shot_all_reduce: local input size must be smaller than symm buffer size.");
|
||||
}
|
||||
auto symm_mem = c10d::symmetric_memory::rendezvous(input, group_name);
|
||||
TORCH_CHECK(
|
||||
symm_mem != nullptr,
|
||||
|
|
@ -442,6 +460,13 @@ at::Tensor one_shot_all_reduce_out(
|
|||
|
||||
const size_t alignment =
|
||||
get_and_verify_alignment(input, "one_shot_all_reduce");
|
||||
if (local_input.has_value()) {
|
||||
const size_t local_alignment =
|
||||
get_and_verify_alignment(*local_input, "one_shot_all_reduce");
|
||||
TORCH_CHECK(
|
||||
alignment == local_alignment,
|
||||
"one_shot_all_reduce: local input and symm buffer must have the same alignment.");
|
||||
}
|
||||
|
||||
int num_blocks = 0, num_threads = 0;
|
||||
init_elementwise_launch_config(
|
||||
|
|
@ -466,6 +491,8 @@ at::Tensor one_shot_all_reduce_out(
|
|||
reinterpret_cast<scalar_t**>(
|
||||
symm_mem->get_buffer_ptrs_dev()),
|
||||
out.data_ptr<scalar_t>(),
|
||||
local_input.has_value() ? local_input->data_ptr<scalar_t>()
|
||||
: nullptr,
|
||||
input.storage_offset(),
|
||||
input.numel(),
|
||||
reinterpret_cast<uint32_t**>(
|
||||
|
|
@ -479,12 +506,42 @@ at::Tensor one_shot_all_reduce_out(
|
|||
return out;
|
||||
}
|
||||
|
||||
at::Tensor one_shot_all_reduce_out(
|
||||
const at::Tensor& input,
|
||||
std::string reduce_op,
|
||||
std::string group_name,
|
||||
at::Tensor out) {
|
||||
return one_shot_all_reduce_out_impl(
|
||||
input, c10::nullopt, reduce_op, group_name, out);
|
||||
}
|
||||
|
||||
at::Tensor one_shot_all_reduce_copy_out(
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& local_input,
|
||||
std::string reduce_op,
|
||||
std::string group_name,
|
||||
at::Tensor out) {
|
||||
return one_shot_all_reduce_out_impl(
|
||||
input, local_input, reduce_op, group_name, out);
|
||||
}
|
||||
|
||||
at::Tensor one_shot_all_reduce(
|
||||
const at::Tensor& input,
|
||||
std::string reduce_op,
|
||||
std::string group_name) {
|
||||
auto out = at::empty_like(input);
|
||||
return one_shot_all_reduce_out(input, reduce_op, group_name, out);
|
||||
return one_shot_all_reduce_out_impl(
|
||||
input, c10::nullopt, reduce_op, group_name, out);
|
||||
}
|
||||
|
||||
at::Tensor one_shot_all_reduce_copy(
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& local_input,
|
||||
std::string reduce_op,
|
||||
std::string group_name) {
|
||||
auto out = at::empty_like(local_input);
|
||||
return one_shot_all_reduce_out_impl(
|
||||
input, local_input, reduce_op, group_name, out);
|
||||
}
|
||||
|
||||
constexpr size_t two_shot_all_reduce_max_num_blocks = 24;
|
||||
|
|
@ -712,6 +769,8 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
|
|||
m.impl("multimem_all_gather_out", ::multimem_all_gather_out);
|
||||
m.impl("one_shot_all_reduce", ::one_shot_all_reduce);
|
||||
m.impl("one_shot_all_reduce_out", ::one_shot_all_reduce_out);
|
||||
m.impl("one_shot_all_reduce_copy", ::one_shot_all_reduce_copy);
|
||||
m.impl("one_shot_all_reduce_copy_out", ::one_shot_all_reduce_copy_out);
|
||||
m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_);
|
||||
m.impl("_async_input_mm", c10d::cuda::detail::async_input_mm);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -217,6 +217,14 @@ at::Tensor one_shot_all_reduce_meta(
|
|||
return at::empty_like(input);
|
||||
}
|
||||
|
||||
at::Tensor one_shot_all_reduce_copy_meta(
|
||||
const at::Tensor& symm_buffer,
|
||||
const at::Tensor& local_input,
|
||||
std::string reduce_op,
|
||||
std::string group_name) {
|
||||
return at::empty_like(local_input);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
||||
m.def(
|
||||
"multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)");
|
||||
|
|
@ -230,6 +238,11 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
|||
"one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor");
|
||||
m.def(
|
||||
"one_shot_all_reduce_out(Tensor input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)");
|
||||
m.def(
|
||||
"one_shot_all_reduce_copy(Tensor symm_buffer, Tensor local_input, str reduce_op, str group_name) -> Tensor");
|
||||
m.def(
|
||||
"one_shot_all_reduce_copy_out(Tensor symm_buffer, Tensor local_input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)");
|
||||
|
||||
m.def(
|
||||
"two_shot_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)");
|
||||
|
||||
|
|
@ -252,6 +265,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
|
|||
|
||||
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
|
||||
m.impl("one_shot_all_reduce", one_shot_all_reduce_meta);
|
||||
m.impl("one_shot_all_reduce_copy", one_shot_all_reduce_copy_meta);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user