Revert "[CUDA] Only use vec128 if CUDA version is newer than 12.8 (#150705)"

This reverts commit 5228986c39.

Reverted https://github.com/pytorch/pytorch/pull/150705 on behalf of https://github.com/atalman due to break periodic tests ([comment](https://github.com/pytorch/pytorch/pull/150705#issuecomment-2787017751))
This commit is contained in:
PyTorch MergeBot 2025-04-08 16:29:05 +00:00
parent 97f34f0125
commit 4447352e64
4 changed files with 11 additions and 23 deletions

View File

@ -78,7 +78,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
}
}
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080)
#ifdef USE_ROCM
template <int io_sizes>
constexpr auto elems_per_thread(){
if constexpr (io_sizes == 1) {
@ -219,7 +219,7 @@ static inline void launch_vectorized_kernel(
constexpr auto io_size = calc_io_size<func_t>();
int64_t grid = (N + io_block_work_size<io_size>() - 1) / io_block_work_size<io_size>();
auto stream = at::cuda::getCurrentCUDAStream();
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080)
#ifdef USE_ROCM
int vec_size = memory::can_vectorize_up_to<func_t>(data);
#else
using cpp_type = typename function_traits<func_t>::result_type;
@ -241,13 +241,11 @@ static inline void launch_vectorized_kernel(
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
#endif
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
case 8:
vectorized_elementwise_kernel<8, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
#endif
case 4:
vectorized_elementwise_kernel<4, func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data);

View File

@ -486,9 +486,7 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec2_alignment = std::alignment_of_v<aligned_vector<scalar_t, 2>>;
constexpr int vec4_alignment = std::alignment_of_v<aligned_vector<scalar_t, 4>>;
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080)
constexpr int vec8_alignment = std::alignment_of_v<aligned_vector<scalar_t, 8>>;
#endif
#ifdef USE_ROCM
constexpr int vec16_alignment = std::alignment_of_v<aligned_vector<scalar_t, 16>>;
constexpr int type_size = sizeof(scalar_t);
@ -497,7 +495,7 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) {
} else if (type_size <= 2 && (address % vec8_alignment == 0)) {
return 8;
} else
#elif defined(CUDA_VERSION) && CUDA_VERSION >= 12080
#else
if (address % vec8_alignment == 0) {
return 8;
} else

View File

@ -18,11 +18,8 @@ constexpr int thread_work_size() { return 4; }
constexpr uint32_t num_threads() {
return C10_WARP_SIZE * 4;
}
#if defined(CUDA_VERSION) && CUDA_VERSION < 12080
constexpr int thread_work_size() { return 4; }
#else
constexpr int thread_work_size() { return 8; }
#endif
#endif
constexpr int block_work_size() { return thread_work_size() * num_threads(); }

View File

@ -46,17 +46,12 @@ TEST(TestLoops, HasSameArgTypes) {
TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
char *ptr = reinterpret_cast<char *>(buffer1);
#if defined(CUDA_VERSION) && CUDA_VERSION < 12080
constexpr auto vectorize_limit = 4;
#else
constexpr auto vectorize_limit= 8;
#endif
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr), 8);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 1), 1);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 1), 1);
@ -70,8 +65,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) {
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 4), 2);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 4), 1);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), vectorize_limit);
ASSERT_EQ(memory::can_vectorize_up_to<bool>(ptr + 8), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int8_t>(ptr + 8), 8);
ASSERT_EQ(memory::can_vectorize_up_to<int16_t>(ptr + 8), 4);
ASSERT_EQ(memory::can_vectorize_up_to<int>(ptr + 8), 2);
ASSERT_EQ(memory::can_vectorize_up_to<int64_t>(ptr + 8), 1);