mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Context: we are trying to pass an empty tensor through the system now (sometimes;... its an edge case); and it seems to cause all_reduce to seg fault, which is unexpected to me Deep Shah and Pavan identified the issue, I'm just pushing for a fix :) Test Plan: idk what i'm doing here, someone help Reviewed By: shuqiangzhang Differential Revision: D65956095 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140741 Approved by: https://github.com/shuqiangzhang
261 lines
7.8 KiB
Plaintext
261 lines
7.8 KiB
Plaintext
#ifdef USE_C10D_NCCL
|
|
|
|
#include <ATen/Dispatch.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <torch/torch.h>
|
|
#include <algorithm>
|
|
#include <torch/csrc/distributed/c10d/NanCheck.hpp>
|
|
#include <stdint.h>
|
|
|
|
namespace c10d {
|
|
|
|
// CUDA kernel to check if data has NAN, device side assert
|
|
// is raised if NAN is found
|
|
|
|
// Using ulong2 as a "byte pack", with 16 bytes, for efficient data load
|
|
union BytePack16 {
|
|
ulong2 ul2;
|
|
uint64_t ul[2];
|
|
};
|
|
|
|
typedef union BytePack16 BytePack;
|
|
|
|
// AMD HIP doesn't define `__trap()`, using `assert` instead
|
|
#ifdef USE_ROCM
|
|
#define __trap() assert(0)
|
|
#endif
|
|
|
|
//// Start of templated functions for checking NaNs inside a BytePack
|
|
|
|
// (i) General implementation (aka fallback)
|
|
// We use a for loop to iterate over the elements in a BytePack.
|
|
// EltPerPack would be greater than 8 if falling in this case.
|
|
|
|
template <typename T, int EltPerPack>
|
|
struct CheckBytePack {
|
|
static __device__ __forceinline__ void check(BytePack* tmp) {
|
|
T* data = (T*)tmp;
|
|
#pragma unroll 8
|
|
for (int i = 0; i < EltPerPack; i++) {
|
|
if (isnan(data[i])) __trap();
|
|
}
|
|
}
|
|
};
|
|
|
|
// (ii) Template Specialization for 8-byte data types, e.g. double
|
|
// EltPerPack = 16 / 8 = 2
|
|
|
|
template <typename T>
|
|
struct CheckBytePack<T, /*EltPerPack*/2> {
|
|
static __device__ __forceinline__ void check(BytePack* tmp) {
|
|
T* data = (T*)tmp;
|
|
if (isnan(data[0]) || isnan(data[1])) __trap();
|
|
}
|
|
};
|
|
|
|
// (iii) Template specialization for 4-byte data types, e.g. float32
|
|
// EltPerPack = 16 / 4 = 4
|
|
|
|
template <typename T>
|
|
struct CheckBytePack<T, /*EltPerPack*/4> {
|
|
static __device__ __forceinline__ void check(BytePack* tmp) {
|
|
T* data = (T*)tmp;
|
|
if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3])) __trap();
|
|
}
|
|
};
|
|
|
|
// (iv) Template specialization for 2-byte data types, e.g. float16, bfloat16, half.
|
|
// EltPerPack = 16 / 2 = 8
|
|
|
|
template <typename T>
|
|
struct CheckBytePack<T, /*EltPerPack*/8> {
|
|
static __device__ __forceinline__ void check(BytePack* tmp) {
|
|
T* data = (T*)tmp;
|
|
if (isnan(data[0]) || isnan(data[1]) || isnan(data[2]) || isnan(data[3]) ||
|
|
isnan(data[4]) || isnan(data[5]) || isnan(data[6]) || isnan(data[7])) {
|
|
__trap();
|
|
}
|
|
}
|
|
};
|
|
|
|
// (v) Template specialization for Float8 types.
|
|
// EltPerPack = 16 / 1 = 16
|
|
|
|
// We want to check 8 x FP8 simultaneously, hence this template definition.
|
|
template<typename T>
|
|
struct HasNanFP8x8 {
|
|
static __device__ __forceinline__ bool check(uint64_t fp8x8) = delete;
|
|
/*
|
|
{
|
|
// `static_assert` in template definition requires c++23 onwards.
|
|
// But the error message still applies if you find yourself here.
|
|
static_assert(
|
|
false,
|
|
"You should never call this template definition because it is empty. You "
|
|
"can follow the example of Float8_e4m3fn below to implement the check for "
|
|
"your new datatype."
|
|
);
|
|
}
|
|
*/
|
|
};
|
|
|
|
// isnan condition for Float8_e4m3fn:
|
|
// (x & 0b01111111) == 0b01111111
|
|
// i.e.
|
|
// (x & 0x7f) == 0x7f
|
|
|
|
// The algorithm is as follows:
|
|
// (1) Mask out the most significant bit with mask 0x7f.
|
|
// (2) If the result is 0x7f (is nan), the following arithmetic would cause the
|
|
// 8th bit to be 1: x[i] = x[i] + 0x01
|
|
// (3) Only leave the 8th bit by masking with 0x80.
|
|
// (4) If any x[i] is nan, then the whole x != 0.
|
|
|
|
template<>
|
|
struct HasNanFP8x8<c10::Float8_e4m3fn> {
|
|
static __device__ __forceinline__ bool check(uint64_t fp8x8) {
|
|
auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL;
|
|
auto incremented = t + 0x0101010101010101ULL;
|
|
auto overflow = incremented & 0x8080808080808080ULL;
|
|
return overflow != 0;
|
|
}
|
|
};
|
|
|
|
// isnan condition for Float8_e5m2:
|
|
// (x & 0x7f) > 0x7c
|
|
// This case does not overflow: 0x7c + 0x03 == 0x7f but adding 0x03 to anything
|
|
// greater than 0x7c will overflow.
|
|
|
|
template<>
|
|
struct HasNanFP8x8<c10::Float8_e5m2> {
|
|
static __device__ __forceinline__ bool check(uint64_t fp8x8) {
|
|
auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL;
|
|
auto incremented = t + 0x0303030303030303ULL;
|
|
auto overflow = incremented & 0x8080808080808080ULL;
|
|
return overflow != 0;
|
|
}
|
|
};
|
|
|
|
template<typename T>
|
|
struct CheckBytePack<T, /*EltPerPack*/16> {
|
|
static __device__ __forceinline__ void check(BytePack* tmp) {
|
|
if (HasNanFP8x8<T>::check(tmp->ul[0]) || HasNanFP8x8<T>::check(tmp->ul[1]))
|
|
__trap();
|
|
}
|
|
};
|
|
|
|
//// End of templated functions for checking NaNs inside a BytePack
|
|
|
|
|
|
// Fast-path check routine:
|
|
// each thread will load and check 8 BytePacks in this routine
|
|
|
|
// Create a tmp buffer of size 8, also unroll for loop by 8
|
|
#define UNROLL 8
|
|
|
|
template <typename T>
|
|
__device__ __forceinline__ void checkChunk(BytePack* ptr) {
|
|
BytePack tmp[UNROLL];
|
|
int nWorkers = blockDim.x * gridDim.x;
|
|
// First load values from global memory into tmp buffer
|
|
#pragma unroll 8
|
|
for (int j = 0; j < UNROLL; j++) {
|
|
tmp[j] = ptr[nWorkers * j];
|
|
}
|
|
// Then check each BytePack in the tmp buffer
|
|
#pragma unroll 8
|
|
for (int j = 0; j < UNROLL; j++) {
|
|
CheckBytePack<T, sizeof(BytePack)/sizeof(T)>::check(tmp + j);
|
|
}
|
|
// Note: we separate the check from the load for efficient loading
|
|
}
|
|
|
|
// Align address of `ptr` up, to the alignment of `T`
|
|
#define ALIGN_UP(ptr, T) (((uintptr_t)ptr + sizeof(T) - 1) / sizeof(T) * sizeof(T))
|
|
|
|
// This is the host-facing kernel
|
|
|
|
template <typename T>
|
|
__global__ void checkForNaN(T* data, size_t size) {
|
|
constexpr int EltPerPack = sizeof(BytePack) / sizeof(T);
|
|
// Offset of current thread
|
|
size_t offset = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
// Align input address up to BytePack in case it is not
|
|
T* ptrAlign = (T*)ALIGN_UP(data, BytePack);
|
|
// Pre-process the data before alignment
|
|
size_t preProcElts = min(ptrAlign - data, size);
|
|
// Read memory by T (slow). One iter is enough bc the number of threads would
|
|
// be bigger than `preProcElts`
|
|
if (offset < preProcElts) {
|
|
if (isnan(data[offset])) __trap();
|
|
}
|
|
// We have processes this amount of data
|
|
size -= preProcElts;
|
|
|
|
// Start BytePack processing
|
|
BytePack* ptr = (BytePack*)ptrAlign;
|
|
// Size of input data in unit of BytePack
|
|
size_t sizeInBP = size * sizeof(T) / sizeof(BytePack);
|
|
// Number of BytePacks processed in one fast-path iteration
|
|
size_t loopSize = blockDim.x * gridDim.x * UNROLL;
|
|
|
|
// Fast path
|
|
// The condition below makes sure there is enough data to process (`loopSize`)
|
|
for (; offset + loopSize <= sizeInBP; offset += loopSize) {
|
|
checkChunk<T>(ptr + offset);
|
|
}
|
|
|
|
// The rest data goes on slow path
|
|
// We just do regular load and check
|
|
for (; offset < sizeInBP; offset += blockDim.x * gridDim.x) {
|
|
BytePack tmp = ptr[offset];
|
|
CheckBytePack<T, EltPerPack>::check(&tmp);
|
|
}
|
|
|
|
// We can still have a tail smaller than 1 BytePack
|
|
// TODO: merge this tail check with head check to make them concurrent
|
|
if (threadIdx.x < size % EltPerPack) {
|
|
T* tailPtr = (T*)(ptr + sizeInBP);
|
|
if (isnan(tailPtr[threadIdx.x])) __trap();
|
|
}
|
|
}
|
|
|
|
// CHECK if a Tensor contains NAN in any of its element
|
|
void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream) {
|
|
// skip check for non float types
|
|
if (!torch::is_floating_point(tensor)) {
|
|
return;
|
|
}
|
|
const size_t maxNumThreadsPerBlock = 512;
|
|
const size_t maxNumBlocks = 24;
|
|
const size_t numThreadsPerBlock =
|
|
std::min<size_t>(maxNumThreadsPerBlock, tensor.numel());
|
|
|
|
if (!(numThreadsPerBlock > 0)) {
|
|
return;
|
|
}
|
|
|
|
const size_t numBlocks = std::min<size_t>(
|
|
maxNumBlocks,
|
|
(tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock);
|
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND4(
|
|
at::ScalarType::Half,
|
|
at::ScalarType::BFloat16,
|
|
at::ScalarType::Float8_e4m3fn,
|
|
at::ScalarType::Float8_e5m2,
|
|
tensor.scalar_type(),
|
|
"checkForNaN",
|
|
[&] {
|
|
checkForNaN<scalar_t><<<numBlocks, numThreadsPerBlock, 0, stream>>>(
|
|
tensor.data_ptr<scalar_t>(), tensor.numel());
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
});
|
|
}
|
|
|
|
} // namespace c10d
|
|
|
|
#endif // USE_C10D_NCCL
|