[ROCm] Remove use of warpsize on host-side compilation (#156979)

Changes needed for ROCm7.0:
* `warpSize` is _not_ a compile-time constant on device-side compilation for ROCm anymore
* `warpSize` is _not_ defined on host-side compilation, hence `at::cuda::warp_size()` must be used to query warpsize at runtime
* Redefining `C10_WARP_SIZE` to be a compile-time constant, with a reasonable value for device-side compilation, but an unreasonable value of 1 for host-side compilation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156979
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Ethan Wee 2025-07-01 04:55:31 +00:00 committed by PyTorch MergeBot
parent c811f41cf5
commit 04bd7e6850
9 changed files with 47 additions and 18 deletions

View File

@ -369,7 +369,7 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
int warp_size = at::cuda::warp_size();
TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 &&
num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads,
num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads(),
"BlockReduceSum requires all warps be active");
const int64_t *num_unique_indices_ptr = num_unique_indices.const_data_ptr<int64_t>();
dim3 grid = unique_indices.numel();

View File

@ -86,7 +86,7 @@ void renormRows(Tensor& t) {
TORCH_CHECK(props != nullptr);
int numSM = props->multiProcessorCount;
const int64_t maxThreads = std::min(
props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads);
props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads());
int warp_size = at::cuda::warp_size();
dim3 grid(rows < numSM * 4 ? rows : numSM * 4);

View File

@ -183,15 +183,16 @@ inline dim3 SoftMaxForward_getBlockSize(uint64_t dim_size) {
uint64_t block_size = 1;
uint64_t max_block_size = std::min(dim_size, static_cast<uint64_t>(max_threads));
// We need a block size that is a multiple of C10_WARP_SIZE in order
// We need a block size that is a multiple of at::cuda::warp_size() in order
// to perform block size reductions using warp shuffle instructions.
// Since max_threads is also a multiple of C10_WARPS_SIZE we do not
// Since max_threads is also a multiple of at::cuda::warp_size() we do not
// risk creating a block size larger than the limit.
if (max_block_size % C10_WARP_SIZE == 0) {
int warp_size = at::cuda::warp_size();
if (max_block_size % warp_size == 0) {
block_size = max_block_size;
} else {
block_size = (max_block_size / C10_WARP_SIZE + 1) * C10_WARP_SIZE;
block_size = (max_block_size / warp_size + 1) * warp_size;
}
return dim3(block_size);
@ -1107,7 +1108,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
if constexpr (use_fast_softmax) {
dim3 block(512);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
if (dim_size % ILP == 0) {
cunn_SoftMaxForwardGmem<ILP, scalar_t, accscalar_t, scalar_t, EpilogueWithMul>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
@ -1117,7 +1118,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
}
} else {
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(scalar_t);
@ -1198,7 +1199,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
if constexpr (use_fast_softmax) {
dim3 block(512);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
if (dim_size % ILP == 0) {
cunn_SoftMaxForwardGmem<ILP, scalar_t, accscalar_t, accscalar_t, EpilogueWithMul>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
@ -1208,7 +1209,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
}
} else {
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(scalar_t);
@ -1274,7 +1275,7 @@ void dispatch_host_softmax_backward(int64_t dim_size, dim3 grid, Tensor &grad, T
constexpr int ILP = sizeof(float4) / sizeof(output_t);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
size_t smem_reduction_sz = block.x / at::cuda::warp_size() * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(output_t);
bool can_use_smem = static_cast<size_t>(dim_size) < max_elements_per_smem;

View File

@ -207,7 +207,7 @@ void handle_fused_mode(
constexpr int num_threads = size / 2;
int warp_size = at::cuda::warp_size();
TORCH_INTERNAL_ASSERT(num_threads % warp_size == 0 &&
num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, "");
num_threads <= cuda_utils::kCUDABlockReduceMaxThreads(), "");
const auto memsize =
(sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int));
compute_mode<scalar_t, size>

View File

@ -439,8 +439,12 @@ __global__ void computeBlockwiseWithinKCounts(
warp_counts[warp] = count;
}
__syncthreads();
#ifdef USE_ROCM
CUDA_KERNEL_ASSERT(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE);
#else
static_assert(RADIX_DIGITS < C10_WARP_SIZE * C10_WARP_SIZE,
"Assuming only 1 warp is needed for final reduction");
#endif
if (warp != 0) {
return;
}

View File

@ -12,7 +12,17 @@ constexpr int kCUDABlockReduceNumThreads = 512;
// of which reduces C10_WARP_SIZE elements. So, at most
// C10_WARP_SIZE**2 elements can be reduced at a time.
// NOTE: This is >= the max block size on current hardware anyway (1024).
constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
// ROCm NOTE: C10_WARP_SIZE should only be used inside device functions,
// and kCUDABlockReduceMaxThreads is a host-side variable.
#ifdef USE_ROCM
static int kCUDABlockReduceMaxThreads() {
return at::cuda::warp_size() * at::cuda::warp_size();
}
#else
constexpr int kCUDABlockReduceMaxThreads() {
return C10_WARP_SIZE * C10_WARP_SIZE;
}
#endif
// Sums `val` across all threads in a warp.
//

View File

@ -312,7 +312,21 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
#endif
#if defined(USE_ROCM)
#define C10_WARP_SIZE warpSize // = 64 or 32 (Defined in hip_runtime.h)
// C10_WARP_SIZE is only allowed for device code.
// Host code _must_ use at::cuda::warp_size()
// HIP header used to define warpSize as a constexpr that was either 32 or 64
// depending on the target device, and then always set it to 64 for host code.
// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we
// set it to something unreasonable to trigger obvious host code errors.
#if defined(__HIP_DEVICE_COMPILE__)
#if defined(__GFX9__)
static constexpr int C10_WARP_SIZE = 64;
#else // __GFX9__
static constexpr int C10_WARP_SIZE = 32;
#endif // __GFX9__
#else
static constexpr int C10_WARP_SIZE = 1;
#endif // __HIP_DEVICE_COMPILE__
#else
#define C10_WARP_SIZE 32
#endif

View File

@ -255,7 +255,7 @@ static __global__ void barrier_kernel(
void CUDASymmetricMemory::barrier(int channel, size_t timeout_ms) {
check_channel(channel, world_size_);
c10::cuda::CUDAGuard guard(local_device_idx_);
barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
barrier_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<uint32_t**>(signal_pads_dev_),
channel,
rank_,
@ -293,7 +293,7 @@ void CUDASymmetricMemory::put_signal(
size_t timeout_ms) {
check_channel(channel, world_size_);
c10::cuda::CUDAGuard guard(local_device_idx_);
put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
put_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<uint32_t**>(signal_pads_dev_),
dst_rank,
channel,
@ -337,7 +337,7 @@ void CUDASymmetricMemory::wait_signal(
size_t timeout_ms) {
check_channel(channel, world_size_);
c10::cuda::CUDAGuard guard(local_device_idx_);
wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>(
wait_signal_kernel<<<1, at::cuda::warp_size(), 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<uint32_t**>(signal_pads_dev_),
src_rank,
channel,

View File

@ -114,7 +114,7 @@ void init_elementwise_launch_config(
num_blocks = 1;
num_threads = at::round_up(
at::ceil_div(numel_per_split, numel_per_thread),
static_cast<size_t>(C10_WARP_SIZE));
static_cast<size_t>(at::cuda::warp_size()));
} else {
num_blocks = std::min(
at::ceil_div(numel_per_split, max_num_threads * numel_per_thread),