mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[SymmMem] Find NVSHMEM from system installation (#157513)
Previously we only search for NVSHMEM from pip install location. This PR adds search in system locations deemed default by CMake. Related: #157453 untars NVSHMEM into `/usr/local` on our CI machines. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157513 Approved by: https://github.com/atalman, https://github.com/Skylion007
This commit is contained in:
parent
4ed1b03f72
commit
99c1a6bdd9
|
|
@ -989,23 +989,30 @@ elseif(USE_CUDA)
|
|||
|
||||
# Compile with NVSHMEM
|
||||
# Default value of `USE_NVSHMEM` is set in CMakeLists.txt under root, to ON.
|
||||
# If user has specified NVSHMEM_HOME, we use it;
|
||||
# Otherwise, NVSHMEM_HOME is auto detected in tools/setup_helpers/cmake.py
|
||||
if($ENV{NVSHMEM_HOME})
|
||||
set(NVSHMEM_HOME $ENV{NVSHMEM_HOME} CACHE PATH "Path to NVSHMEM build dir")
|
||||
endif()
|
||||
|
||||
if(NOT DEFINED NVSHMEM_HOME)
|
||||
message(WARNING "NVSHMEM_HOME not found. Please run `pip install nvidia-nvshmem-<version>`, or set NVSHMEM_HOME to its location.")
|
||||
# Disable nvshmem if NVSHMEM_HOME is not found
|
||||
set(USE_NVSHMEM FALSE)
|
||||
endif()
|
||||
|
||||
if(USE_NVSHMEM)
|
||||
message(STATUS "Building with NVSHMEM support: '${NVSHMEM_HOME}'")
|
||||
set(NVSHMEM_INCLUDE_DIR "${NVSHMEM_HOME}/include")
|
||||
set(NVSHMEM_LIB_DIR "${NVSHMEM_HOME}/lib")
|
||||
message(STATUS "NVSHMEM_HOME set to: '$ENV{NVSHMEM_HOME}'")
|
||||
message(STATUS "NVSHMEM wheel installed at: '${NVSHMEM_PY_DIR}'")
|
||||
# Search order:
|
||||
# 1. If user has specified `NVSHMEM_HOME`, we use it;
|
||||
# 2. If NVSHMEM wheel has been installed, we use it, see
|
||||
# tools/setup_helpers/cmake.py, where we set `NVSHMEM_PY_DIR` to the wheel
|
||||
# location, e.g.
|
||||
# `/path/to/conda/lib/python3.10/site-packages/nvidia/nvshmem`,
|
||||
# 3. Let CMake find it in the default system paths, e.g. /usr/local.
|
||||
find_path(NVSHMEM_LIB_DIR
|
||||
# In pip install case, the lib suffix is `.so.3` instead of `.so`
|
||||
NAMES libnvshmem_host.so libnvshmem_host.so.3
|
||||
PATHS $ENV{NVSHMEM_HOME}/lib ${NVSHMEM_PY_DIR}/lib
|
||||
DOC "The location of NVSHMEM library.")
|
||||
find_path(NVSHMEM_INCLUDE_DIR
|
||||
NAMES nvshmem.h
|
||||
PATHS $ENV{NVSHMEM_HOME}/include ${NVSHMEM_PY_DIR}/include
|
||||
DOC "The location of NVSHMEM headers.")
|
||||
endif()
|
||||
|
||||
# If NVSHMEM_LIBRARY is found, we build torch_cuda with NVSHMEM support.
|
||||
if(NVSHMEM_LIB_DIR AND NVSHMEM_INCLUDE_DIR)
|
||||
message(STATUS "Building with NVSHMEM support: '${NVSHMEM_LIB_DIR}'")
|
||||
include_directories(${NVSHMEM_INCLUDE_DIR})
|
||||
|
||||
# Linking with nvshmem requires the source binary to be built with -rdc
|
||||
|
|
@ -1031,6 +1038,8 @@ elseif(USE_CUDA)
|
|||
target_compile_definitions(nvshmem_extension PUBLIC USE_NVSHMEM)
|
||||
target_link_libraries(torch_cuda PRIVATE nvshmem_extension)
|
||||
install(TARGETS nvshmem_extension EXPORT Caffe2Targets DESTINATION lib)
|
||||
else()
|
||||
message(STATUS "NVSHMEM not found, not building with NVSHMEM support.")
|
||||
endif()
|
||||
|
||||
if(USE_UCC)
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ function(caffe2_print_configuration_summary)
|
|||
if(${USE_NCCL})
|
||||
message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}")
|
||||
endif()
|
||||
message(STATUS " USE_NVSHMEM : ${USE_NVSHMEM}")
|
||||
message(STATUS " NVSHMEM_LIB_DIR : ${NVSHMEM_LIB_DIR}")
|
||||
message(STATUS " USE_NNPACK : ${USE_NNPACK}")
|
||||
message(STATUS " USE_NUMPY : ${USE_NUMPY}")
|
||||
message(STATUS " USE_OBSERVERS : ${USE_OBSERVERS}")
|
||||
|
|
|
|||
|
|
@ -346,9 +346,9 @@ class CMake:
|
|||
|
||||
# Detect build dependencies from python lib path (in order to set *_HOME variables)
|
||||
# NVSHMEM
|
||||
nvshmem_home = py_lib_path + "/nvidia/nvshmem"
|
||||
if os.path.exists(nvshmem_home):
|
||||
build_options["NVSHMEM_HOME"] = nvshmem_home
|
||||
nvshmem_py_dir = py_lib_path + "/nvidia/nvshmem"
|
||||
if os.path.exists(nvshmem_py_dir):
|
||||
build_options["NVSHMEM_PY_DIR"] = nvshmem_py_dir
|
||||
|
||||
# Options starting with CMAKE_
|
||||
cmake__options = {
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ bool deviceSupportsMulticast(int device_idx) {
|
|||
// - Device support: Determined by querying
|
||||
// CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED at runtime.
|
||||
auto driver_api = c10::cuda::DriverAPI::get();
|
||||
int multicast_supported;
|
||||
int multicast_supported = 0;
|
||||
C10_CUDA_DRIVER_CHECK(driver_api->cuDeviceGetAttribute_(
|
||||
&multicast_supported,
|
||||
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/util/error.h>
|
||||
|
||||
#include <nvshmem.h>
|
||||
#include <nvshmem_host.h>
|
||||
|
||||
namespace c10d {
|
||||
namespace symmetric_memory {
|
||||
|
|
|
|||
|
|
@ -6,9 +6,18 @@
|
|||
#include <torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
|
||||
|
||||
#include <cuda_awbarrier_primitives.h>
|
||||
// Use torch's cub wrapper instead of CUDA's <cub/cub.cuh>, see #55292
|
||||
#include <ATen/cuda/cub.cuh>
|
||||
|
||||
// NVSHMEM minimum SM arch
|
||||
#define _NVSHMEM_MIN_SM_ARCH 700
|
||||
|
||||
// Some NVSHMEM device APIs do not compile on older SM archs
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH)
|
||||
// Only include host APIs. See nvshmem.h for details.
|
||||
#define NVSHMEM_HOSTLIB_ONLY
|
||||
#endif // Must be done before nvshmem.h is included
|
||||
|
||||
#include <nvshmem.h>
|
||||
|
||||
namespace c10d::nvshmem_extension {
|
||||
|
|
@ -244,6 +253,9 @@ __device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) {
|
|||
// - output splits (OUT) and
|
||||
// - source offsets (OUT).
|
||||
__global__ void exchangeSplitAndOffset(int64_t* in_out_splits, int mype, int npes) {
|
||||
#if __CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch too old for NVSHMEM");
|
||||
#else
|
||||
auto input_splits = in_out_splits;
|
||||
auto output_splits = in_out_splits + npes;
|
||||
auto source_offsets = in_out_splits + npes * 2;
|
||||
|
|
@ -263,12 +275,16 @@ __global__ void exchangeSplitAndOffset(int64_t* in_out_splits, int mype, int npe
|
|||
}
|
||||
// This barrier ensures that all remote PEs see the updated values
|
||||
nvshmemx_barrier_all_block();
|
||||
#endif
|
||||
}
|
||||
|
||||
// This kernel is used to do the actual data exchange.
|
||||
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
|
||||
// `stride` is the stride at dim 0, unit in byte.
|
||||
__global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes) {
|
||||
#if __CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch too old for NVSHMEM");
|
||||
#else
|
||||
auto output_splits = in_out_splits + npes;
|
||||
auto source_offsets = in_out_splits + npes * 2;
|
||||
int bid = blockIdx.x;
|
||||
|
|
@ -303,6 +319,7 @@ __global__ void allToAllV(void *send_data, void *recv_data, int64_t* in_out_spli
|
|||
if (bid == 0 && tid < npes) {
|
||||
source_offsets[tid] = peer_offsets[tid];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
at::Tensor all_to_all_vdev(
|
||||
|
|
@ -398,6 +415,9 @@ at::Tensor all_to_all_vdev(
|
|||
// - output splits (OUT) and
|
||||
// - source offsets (OUT).
|
||||
__global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int npes, int ne, size_t input_dim0) {
|
||||
#if __CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch too old for NVSHMEM");
|
||||
#else
|
||||
int nsplits = npes * ne;
|
||||
auto input_splits = in_out_splits;
|
||||
auto output_splits = in_out_splits + nsplits;
|
||||
|
|
@ -424,6 +444,7 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int
|
|||
}
|
||||
// This barrier ensures that all remote PEs see the updated values
|
||||
nvshmemx_barrier_all_block();
|
||||
#endif
|
||||
}
|
||||
|
||||
// This is an warp-scope, exclusive prefix sum. When called by a block of
|
||||
|
|
@ -467,6 +488,9 @@ __device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n) {
|
|||
// `stride` is the stride at dim 0, unit in byte.
|
||||
// For meaning of `mype` and `npes`, see the docstring of `all_to_all_vdev_2d`.
|
||||
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne, int64_t major_align) {
|
||||
#if __CUDA_ARCH__ < _NVSHMEM_MIN_SM_ARCH
|
||||
CUDA_KERNEL_ASSERT_MSG(false, "SM arch too old for NVSHMEM");
|
||||
#else
|
||||
int nsplits = npes * ne;
|
||||
auto output_splits = in_out_splits + nsplits;
|
||||
auto source_offsets = in_out_splits + nsplits * 2;
|
||||
|
|
@ -540,6 +564,7 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
|
|||
if (bid == 0 && tid < nsplits) {
|
||||
source_offsets[tid] = tile_prefix_sums[tid / npes][tid % npes];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
at::Tensor all_to_all_vdev_2d(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user