mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[SymmMem] Increase minimum nthreads to cover sync needs in NVL72 (#161983)
`sync_remote_blocks` maps threads to peers. Previously min nthreads is warp size, which is too small to cover NVL72. Bumping it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161983 Approved by: https://github.com/ngimel
This commit is contained in:
parent
5a2da090ed
commit
ab643e4dbb
|
|
@ -104,7 +104,8 @@ void init_elementwise_launch_config(
|
|||
size_t max_num_blocks,
|
||||
size_t max_num_threads,
|
||||
int& num_blocks,
|
||||
int& num_threads) {
|
||||
int& num_threads,
|
||||
int world_size) {
|
||||
// Align to preserve alignment in each split
|
||||
const size_t aligned_numel = at::round_up(numel, alignment * splits);
|
||||
const size_t numel_per_split = aligned_numel / splits;
|
||||
|
|
@ -112,9 +113,11 @@ void init_elementwise_launch_config(
|
|||
|
||||
if (numel_per_split <= max_num_threads * numel_per_thread) {
|
||||
num_blocks = 1;
|
||||
num_threads = at::round_up(
|
||||
at::ceil_div(numel_per_split, numel_per_thread),
|
||||
static_cast<size_t>(at::cuda::warp_size()));
|
||||
num_threads = at::ceil_div(numel_per_split, numel_per_thread);
|
||||
// `sync_remote_blocks` maps threads to peers, so we need to make sure there
|
||||
// are enough threads
|
||||
num_threads = max(num_threads, world_size);
|
||||
num_threads = at::round_up(num_threads, at::cuda::warp_size());
|
||||
} else {
|
||||
num_blocks = std::min(
|
||||
at::ceil_div(numel_per_split, max_num_threads * numel_per_thread),
|
||||
|
|
@ -185,7 +188,8 @@ at::Tensor multimem_all_reduce_(
|
|||
8,
|
||||
1024,
|
||||
num_blocks,
|
||||
num_threads);
|
||||
num_threads,
|
||||
symm_mem->get_world_size());
|
||||
|
||||
AT_DISPATCH_FLOAT_AND_BFLOAT16(
|
||||
input.scalar_type(), "multimem_all_reduce_", [&]() {
|
||||
|
|
@ -271,7 +275,8 @@ at::Tensor multimem_one_shot_all_reduce_out(
|
|||
8,
|
||||
1024,
|
||||
num_blocks,
|
||||
num_threads);
|
||||
num_threads,
|
||||
symm_mem->get_world_size());
|
||||
|
||||
AT_DISPATCH_FLOAT_AND_BFLOAT16(
|
||||
input.scalar_type(), "multimem_one_shot_all_reduce", [&]() {
|
||||
|
|
@ -378,7 +383,8 @@ at::Tensor multimem_all_gather_out(
|
|||
8,
|
||||
1024,
|
||||
num_blocks,
|
||||
num_threads);
|
||||
num_threads,
|
||||
symm_mem->get_world_size());
|
||||
|
||||
DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() {
|
||||
multimem_all_gather_kernel<k_alignment>
|
||||
|
|
@ -493,7 +499,8 @@ at::Tensor one_shot_all_reduce_out_impl(
|
|||
one_shot_all_reduce_max_num_blocks,
|
||||
one_shot_all_reduce_max_num_threads,
|
||||
num_blocks,
|
||||
num_threads);
|
||||
num_threads,
|
||||
symm_mem->get_world_size());
|
||||
|
||||
AT_DISPATCH_FLOAT_AND_BFLOAT16(
|
||||
input.scalar_type(), "one_shot_all_reduce", [&]() {
|
||||
|
|
@ -748,7 +755,8 @@ at::Tensor two_shot_all_reduce_impl(
|
|||
two_shot_all_reduce_max_num_blocks,
|
||||
two_shot_all_reduce_max_num_threads,
|
||||
num_blocks,
|
||||
num_threads);
|
||||
num_threads,
|
||||
symm_mem->get_world_size());
|
||||
|
||||
if (!output.has_value()) {
|
||||
AT_DISPATCH_FLOAT_AND_BFLOAT16(
|
||||
|
|
@ -895,7 +903,8 @@ at::Tensor reduce_scatter_out(
|
|||
two_shot_all_reduce_max_num_blocks,
|
||||
two_shot_all_reduce_max_num_threads,
|
||||
num_blocks,
|
||||
num_threads);
|
||||
num_threads,
|
||||
symm_mem->get_world_size());
|
||||
if (split_last_dim) {
|
||||
AT_DISPATCH_FLOAT_AND_BFLOAT16(
|
||||
input.scalar_type(), "two_shot_all_reduce", [&]() {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user