[SymmMem] Increase minimum nthreads to cover sync needs in NVL72 (#161983)

`sync_remote_blocks` maps threads to peers. Previously min nthreads is warp size, which is too small to cover NVL72. Bumping it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161983
Approved by: https://github.com/ngimel
This commit is contained in:
Ke Wen 2025-09-02 13:32:39 -07:00 committed by PyTorch MergeBot
parent 5a2da090ed
commit ab643e4dbb

View File

@ -104,7 +104,8 @@ void init_elementwise_launch_config(
size_t max_num_blocks,
size_t max_num_threads,
int& num_blocks,
int& num_threads) {
int& num_threads,
int world_size) {
// Align to preserve alignment in each split
const size_t aligned_numel = at::round_up(numel, alignment * splits);
const size_t numel_per_split = aligned_numel / splits;
@ -112,9 +113,11 @@ void init_elementwise_launch_config(
if (numel_per_split <= max_num_threads * numel_per_thread) {
num_blocks = 1;
num_threads = at::round_up(
at::ceil_div(numel_per_split, numel_per_thread),
static_cast<size_t>(at::cuda::warp_size()));
num_threads = at::ceil_div(numel_per_split, numel_per_thread);
// `sync_remote_blocks` maps threads to peers, so we need to make sure there
// are enough threads
num_threads = max(num_threads, world_size);
num_threads = at::round_up(num_threads, at::cuda::warp_size());
} else {
num_blocks = std::min(
at::ceil_div(numel_per_split, max_num_threads * numel_per_thread),
@ -185,7 +188,8 @@ at::Tensor multimem_all_reduce_(
8,
1024,
num_blocks,
num_threads);
num_threads,
symm_mem->get_world_size());
AT_DISPATCH_FLOAT_AND_BFLOAT16(
input.scalar_type(), "multimem_all_reduce_", [&]() {
@ -271,7 +275,8 @@ at::Tensor multimem_one_shot_all_reduce_out(
8,
1024,
num_blocks,
num_threads);
num_threads,
symm_mem->get_world_size());
AT_DISPATCH_FLOAT_AND_BFLOAT16(
input.scalar_type(), "multimem_one_shot_all_reduce", [&]() {
@ -378,7 +383,8 @@ at::Tensor multimem_all_gather_out(
8,
1024,
num_blocks,
num_threads);
num_threads,
symm_mem->get_world_size());
DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() {
multimem_all_gather_kernel<k_alignment>
@ -493,7 +499,8 @@ at::Tensor one_shot_all_reduce_out_impl(
one_shot_all_reduce_max_num_blocks,
one_shot_all_reduce_max_num_threads,
num_blocks,
num_threads);
num_threads,
symm_mem->get_world_size());
AT_DISPATCH_FLOAT_AND_BFLOAT16(
input.scalar_type(), "one_shot_all_reduce", [&]() {
@ -748,7 +755,8 @@ at::Tensor two_shot_all_reduce_impl(
two_shot_all_reduce_max_num_blocks,
two_shot_all_reduce_max_num_threads,
num_blocks,
num_threads);
num_threads,
symm_mem->get_world_size());
if (!output.has_value()) {
AT_DISPATCH_FLOAT_AND_BFLOAT16(
@ -895,7 +903,8 @@ at::Tensor reduce_scatter_out(
two_shot_all_reduce_max_num_blocks,
two_shot_all_reduce_max_num_threads,
num_blocks,
num_threads);
num_threads,
symm_mem->get_world_size());
if (split_last_dim) {
AT_DISPATCH_FLOAT_AND_BFLOAT16(
input.scalar_type(), "two_shot_all_reduce", [&]() {