[SymmMem] Find NVSHMEM from system installation (#157513)

Previously we only search for NVSHMEM from pip install location.
This PR adds search in system locations deemed default by CMake.
Related: #157453 untars NVSHMEM into `/usr/local` on our CI machines.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157513
Approved by: https://github.com/atalman, https://github.com/Skylion007
This commit is contained in:
Ke Wen 2025-07-03 20:32:13 -07:00 committed by PyTorch MergeBot
parent 4ed1b03f72
commit 99c1a6bdd9
6 changed files with 56 additions and 22 deletions

View File

@ -989,23 +989,30 @@ elseif(USE_CUDA)
# Compile with NVSHMEM
# Default value of `USE_NVSHMEM` is set in CMakeLists.txt under root, to ON.
# If user has specified NVSHMEM_HOME, we use it;
# Otherwise, NVSHMEM_HOME is auto detected in tools/setup_helpers/cmake.py
if($ENV{NVSHMEM_HOME})
set(NVSHMEM_HOME $ENV{NVSHMEM_HOME} CACHE PATH "Path to NVSHMEM build dir")
endif()
if(NOT DEFINED NVSHMEM_HOME)
message(WARNING "NVSHMEM_HOME not found. Please run `pip install nvidia-nvshmem-<version>`, or set NVSHMEM_HOME to its location.")
# Disable nvshmem if NVSHMEM_HOME is not found
set(USE_NVSHMEM FALSE)
endif()
if(USE_NVSHMEM)
message(STATUS "Building with NVSHMEM support: '${NVSHMEM_HOME}'")
set(NVSHMEM_INCLUDE_DIR "${NVSHMEM_HOME}/include")
set(NVSHMEM_LIB_DIR "${NVSHMEM_HOME}/lib")
message(STATUS "NVSHMEM_HOME set to: '$ENV{NVSHMEM_HOME}'")
message(STATUS "NVSHMEM wheel installed at: '${NVSHMEM_PY_DIR}'")
# Search order:
# 1. If user has specified `NVSHMEM_HOME`, we use it;
# 2. If NVSHMEM wheel has been installed, we use it, see
# tools/setup_helpers/cmake.py, where we set `NVSHMEM_PY_DIR` to the wheel
# location, e.g.
# `/path/to/conda/lib/python3.10/site-packages/nvidia/nvshmem`,
# 3. Let CMake find it in the default system paths, e.g. /usr/local.
find_path(NVSHMEM_LIB_DIR
# In pip install case, the lib suffix is `.so.3` instead of `.so`
NAMES libnvshmem_host.so libnvshmem_host.so.3
PATHS $ENV{NVSHMEM_HOME}/lib ${NVSHMEM_PY_DIR}/lib
DOC "The location of NVSHMEM library.")
find_path(NVSHMEM_INCLUDE_DIR
NAMES nvshmem.h
PATHS $ENV{NVSHMEM_HOME}/include ${NVSHMEM_PY_DIR}/include
DOC "The location of NVSHMEM headers.")
endif()
# If NVSHMEM_LIBRARY is found, we build torch_cuda with NVSHMEM support.
if(NVSHMEM_LIB_DIR AND NVSHMEM_INCLUDE_DIR)
message(STATUS "Building with NVSHMEM support: '${NVSHMEM_LIB_DIR}'")
include_directories(${NVSHMEM_INCLUDE_DIR})
# Linking with nvshmem requires the source binary to be built with -rdc
@ -1031,6 +1038,8 @@ elseif(USE_CUDA)
target_compile_definitions(nvshmem_extension PUBLIC USE_NVSHMEM)
target_link_libraries(torch_cuda PRIVATE nvshmem_extension)
install(TARGETS nvshmem_extension EXPORT Caffe2Targets DESTINATION lib)
else()
message(STATUS "NVSHMEM not found, not building with NVSHMEM support.")
endif()
if(USE_UCC)

View File

@ -172,7 +172,7 @@ function(caffe2_print_configuration_summary)
if(${USE_NCCL})
message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}")
endif()
message(STATUS " USE_NVSHMEM : ${USE_NVSHMEM}")
message(STATUS " NVSHMEM_LIB_DIR : ${NVSHMEM_LIB_DIR}")
message(STATUS " USE_NNPACK : ${USE_NNPACK}")
message(STATUS " USE_NUMPY : ${USE_NUMPY}")
message(STATUS " USE_OBSERVERS : ${USE_OBSERVERS}")

View File

@ -346,9 +346,9 @@ class CMake:
# Detect build dependencies from python lib path (in order to set *_HOME variables)
# NVSHMEM
nvshmem_home = py_lib_path + "/nvidia/nvshmem"
if os.path.exists(nvshmem_home):
build_options["NVSHMEM_HOME"] = nvshmem_home
nvshmem_py_dir = py_lib_path + "/nvidia/nvshmem"
if os.path.exists(nvshmem_py_dir):
build_options["NVSHMEM_PY_DIR"] = nvshmem_py_dir
# Options starting with CMAKE_
cmake__options = {

View File

@ -22,7 +22,7 @@ bool deviceSupportsMulticast(int device_idx) {
// - Device support: Determined by querying
// CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED at runtime.
auto driver_api = c10::cuda::DriverAPI::get();
int multicast_supported;
int multicast_supported = 0;
C10_CUDA_DRIVER_CHECK(driver_api->cuDeviceGetAttribute_(
&multicast_supported,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED,

View File

@ -10,7 +10,7 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/error.h>
#include <nvshmem.h>
#include <nvshmem_host.h>
namespace c10d {
namespace symmetric_memory {

View File

@ -6,9 +6,18 @@
#include <torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp>
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
#include <cuda_awbarrier_primitives.h>
// Use torch's cub wrapper instead of CUDA's <cub/cub.cuh>, see #55292
#include <ATen/cuda/cub.cuh>
// NVSHMEM minimum SM arch
#define _NVSHMEM_MIN_SM_ARCH 700
// Some NVSHMEM device APIs do not compile on older SM archs
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH)
// Only include host APIs. See nvshmem.h for details.
#define NVSHMEM_HOSTLIB_ONLY
#endif // Must be done before nvshmem.h is included
#include <nvshmem.h>
namespace c10d::nvshmem_extension {
@ -244,6 +253,9 @@ __device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) {
// - output splits (OUT) and
// - source offsets (OUT).
__global__ void exchangeSplitAndOffset(int64_t* in_out_splits, int mype, int npes) {
#if __CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH
CUDA_KERNEL_ASSERT_MSG(false, "SM arch too old for NVSHMEM");
#else
auto input_splits = in_out_splits;
auto output_splits = in_out_splits + npes;
auto source_offsets = in_out_splits + npes * 2;
@ -263,12 +275,16 @@ __global__ void exchangeSplitAndOffset(int64_t* in_out_splits, int mype, int npe
}
// This barrier ensures that all remote PEs see the updated values
nvshmemx_barrier_all_block();
#endif
}
// This kernel is used to do the actual data exchange.
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
// `stride` is the stride at dim 0, unit in byte.
__global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes) {
#if __CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH
CUDA_KERNEL_ASSERT_MSG(false, "SM arch too old for NVSHMEM");
#else
auto output_splits = in_out_splits + npes;
auto source_offsets = in_out_splits + npes * 2;
int bid = blockIdx.x;
@ -303,6 +319,7 @@ __global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_spli
if (bid == 0 && tid < npes) {
source_offsets[tid] = peer_offsets[tid];
}
#endif
}
at::Tensor all_to_all_vdev(
@ -398,6 +415,9 @@ at::Tensor all_to_all_vdev(
// - output splits (OUT) and
// - source offsets (OUT).
__global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int npes, int ne, size_t input_dim0) {
#if __CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH
CUDA_KERNEL_ASSERT_MSG(false, "SM arch too old for NVSHMEM");
#else
int nsplits = npes * ne;
auto input_splits = in_out_splits;
auto output_splits = in_out_splits + nsplits;
@ -424,6 +444,7 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int
}
// This barrier ensures that all remote PEs see the updated values
nvshmemx_barrier_all_block();
#endif
}
// This is an warp-scope, exclusive prefix sum. When called by a block of
@ -467,6 +488,9 @@ __device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n) {
// `stride` is the stride at dim 0, unit in byte.
// For meaning of `mype` and `npes`, see the docstring of `all_to_all_vdev_2d`.
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne, int64_t major_align) {
#if __CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH
CUDA_KERNEL_ASSERT_MSG(false, "SM arch too old for NVSHMEM");
#else
int nsplits = npes * ne;
auto output_splits = in_out_splits + nsplits;
auto source_offsets = in_out_splits + nsplits * 2;
@ -540,6 +564,7 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
if (bid == 0 && tid < nsplits) {
source_offsets[tid] = tile_prefix_sums[tid / npes][tid % npes];
}
#endif
}
at::Tensor all_to_all_vdev_2d(