diff --git a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh index 9136f63acab..c4c3af83ccd 100644 --- a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh +++ b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh @@ -49,8 +49,8 @@ struct JittedVecKernelCache { at::cuda::jit::NvrtcFunction vec1; at::cuda::jit::NvrtcFunction vec2; at::cuda::jit::NvrtcFunction vec4; -#ifdef USE_ROCM at::cuda::jit::NvrtcFunction vec8; +#ifdef USE_ROCM at::cuda::jit::NvrtcFunction vec16; #endif @@ -131,6 +131,18 @@ void launch_jitted_vectorized_kernel( int vec_size = at::cuda::jit::can_vectorize_up_to( desc, c10::ArrayRef(data.data(), data.size())); +#ifndef USE_ROCM + const auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize(); + const int optimal_vec_size = 16 / static_cast(input_size); + vec_size = std::min(optimal_vec_size, vec_size); + // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC + // that causes some numerical mismatches with uint8 on sm80 and sm90. + // TODO: Revisit this after CUDA 12.8 update. + if (input_size < 2) { + vec_size = std::min(vec_size, 4); + } +#endif + // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) // fn_ptr is set to the appropriate function based on the vec size and GPU used at::cuda::jit::NvrtcFunction* fn_ptr = nullptr; @@ -138,11 +150,11 @@ void launch_jitted_vectorized_kernel( #ifdef USE_ROCM if (vec_size == 16) { fn_ptr = &fn_cache.vec16; - } else if (vec_size == 8) { - fn_ptr = &fn_cache.vec8; } else #endif - if (vec_size == 4) { + if (vec_size == 8) { + fn_ptr = &fn_cache.vec8; + } else if (vec_size == 4) { fn_ptr = &fn_cache.vec4; } else if (vec_size == 2) { fn_ptr = &fn_cache.vec2; diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index 44d2aa6ecc2..6c4ad7a48a0 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -61,6 +61,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence) { } } +#ifdef USE_ROCM template constexpr auto elems_per_thread(){ if constexpr (io_sizes == 1) { @@ -71,6 +72,16 @@ constexpr auto elems_per_thread(){ return 4; } } +#else +template +constexpr auto elems_per_thread(){ + if constexpr (io_sizes == 1) { + return 16; + } else { + return 8; + } +} +#endif template constexpr auto io_block_work_size() { @@ -191,8 +202,20 @@ static inline void launch_vectorized_kernel( constexpr auto io_size = calc_io_size(); int64_t grid = (N + io_block_work_size() - 1) / io_block_work_size(); auto stream = at::cuda::getCurrentCUDAStream(); +#ifdef USE_ROCM int vec_size = memory::can_vectorize_up_to(data); - +#else + using cpp_type = typename function_traits::result_type; + const uint16_t max_vec_size = memory::can_vectorize_up_to(data); + uint16_t vec_size = 16 / static_cast(sizeof(cpp_type)); + vec_size = std::min(vec_size, max_vec_size); + // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC + // that causes some numerical mismatches with uint8 on sm80 and sm90. + // TODO: Revisit this after CUDA 12.8 update. + if constexpr (sizeof(cpp_type) < 2) { + vec_size = std::min(vec_size, 4); + } +#endif switch (vec_size) { #ifdef USE_ROCM case 16: @@ -200,12 +223,12 @@ static inline void launch_vectorized_kernel( <<>>(N, f, data); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; +#endif case 8: vectorized_elementwise_kernel<8, func_t, array_t> <<>>(N, f, data); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; -#endif case 4: vectorized_elementwise_kernel<4, func_t, array_t> <<>>(N, f, data); diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu index e37fa1b1835..9c1a6e046de 100644 --- a/aten/src/ATen/native/cuda/Dropout.cu +++ b/aten/src/ATen/native/cuda/Dropout.cu @@ -217,8 +217,11 @@ int get_vector_size(at::Tensor self, at::Tensor ret, at::Tensor mask) { // make sure we don't break assumption that we can't have > 16 elements / thread TORCH_INTERNAL_ASSERT(vec_size <= 16, "Value of VEC must be in [2, 4, 8, 16]"); #else + const int optimal_vec_size = 16 / static_cast(sizeof(scalar_t)); + vec_size = std::min(optimal_vec_size, vec_size); + // make sure we don't break assumption that we can't have > 4 elements / thread - TORCH_INTERNAL_ASSERT(vec_size <= 4, "Value of VEC must be in [2, 4]"); + TORCH_INTERNAL_ASSERT(vec_size <= 8, "Value of VEC must be in [2, 4, 8]"); #endif } diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index cda0e10a1fa..4a00d714a0c 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -351,8 +351,8 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) { uint64_t address = reinterpret_cast(pointer); constexpr int vec2_alignment = std::alignment_of_v>; constexpr int vec4_alignment = std::alignment_of_v>; -#ifdef USE_ROCM constexpr int vec8_alignment = std::alignment_of_v>; +#ifdef USE_ROCM constexpr int vec16_alignment = std::alignment_of_v>; constexpr int type_size = sizeof(scalar_t); if (type_size == 1 && (address % vec16_alignment == 0)) { @@ -360,6 +360,10 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) { } else if (type_size <= 2 && (address % vec8_alignment == 0)) { return 8; } else +#else + if (address % vec8_alignment == 0) { + return 8; + } else #endif if (address % vec4_alignment == 0) { return 4; diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index 7a6f0d0be47..83007cc7159 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -932,7 +932,6 @@ void initializeCudaContext() { } } -#ifdef USE_ROCM int calc_io_size( const int nInputs, const int nOutputs, @@ -952,7 +951,6 @@ int calc_io_size( return 0; } -#endif int calc_thread_work_size( const int nInputs, @@ -971,7 +969,14 @@ int calc_thread_work_size( } return io_size; #else - return JIT_THREAD_WORK_SIZE; + auto io_size = at::cuda::jit::calc_io_size(nInputs, nOutputs, inputs_type, result_type); + TORCH_INTERNAL_ASSERT(io_size > 0); + if (io_size == 1) { + return 16; + } else { + return 8; + } + return io_size; #endif } diff --git a/aten/src/ATen/native/cuda/jit_utils.h b/aten/src/ATen/native/cuda/jit_utils.h index c7dcbd8cf72..d971df01396 100644 --- a/aten/src/ATen/native/cuda/jit_utils.h +++ b/aten/src/ATen/native/cuda/jit_utils.h @@ -60,6 +60,10 @@ inline int can_vectorize_up_to(size_t default_alignment, void *pointer) { if ((default_alignment <= 2) && (ip % (8 * default_alignment) == 0)) { return 8; } +#else + if (ip % (8 * default_alignment) == 0) { + return 8; + } #endif if (ip % (4 * default_alignment) == 0) { return 4; @@ -88,15 +92,17 @@ inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef CUDALoops.cuh -> jit_utils.h -> Loops.cuh -#define JIT_THREAD_WORK_SIZE 4 - #ifdef USE_ROCM +#define JIT_THREAD_WORK_SIZE 4 +#else +#define JIT_THREAD_WORK_SIZE 8 +#endif + int calc_io_size( const int nInputs, const int nOutputs, const c10::ScalarType& inputs_type, const c10::ScalarType& result_type); -#endif int calc_thread_work_size( const int nInputs, diff --git a/aten/src/ATen/native/cuda/thread_constants.h b/aten/src/ATen/native/cuda/thread_constants.h index 651053d663e..bcc797a26e1 100644 --- a/aten/src/ATen/native/cuda/thread_constants.h +++ b/aten/src/ATen/native/cuda/thread_constants.h @@ -12,11 +12,14 @@ constexpr int num_threads() { return 256; } + +constexpr int thread_work_size() { return 4; } #else constexpr uint32_t num_threads() { return C10_WARP_SIZE * 4; } + +constexpr int thread_work_size() { return 8; } #endif -constexpr int thread_work_size() { return 4; } constexpr int block_work_size() { return thread_work_size() * num_threads(); } diff --git a/aten/src/ATen/test/cuda_vectorized_test.cu b/aten/src/ATen/test/cuda_vectorized_test.cu index c1143fabb27..6b120f7eb30 100644 --- a/aten/src/ATen/test/cuda_vectorized_test.cu +++ b/aten/src/ATen/test/cuda_vectorized_test.cu @@ -47,11 +47,11 @@ TEST(TestLoops, HasSameArgTypes) { TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) { char *ptr = reinterpret_cast(buffer1); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 4); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 4); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 4); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 4); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 4); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 1), 1); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 1), 1); @@ -65,8 +65,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) { ASSERT_EQ(memory::can_vectorize_up_to(ptr + 4), 2); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 4), 1); - ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 4); - ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 4); + ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 8); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 4); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 2); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 1);