[ROCm] Maxpool backward NHWC Perf Improvement targeting Resnet scenarios (#152267)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152267
Approved by: https://github.com/jeffdaily
This commit is contained in:
Hashem Hashemi 2025-05-14 06:59:29 +00:00 committed by PyTorch MergeBot
parent 4c5cf18ee0
commit 4015166e5d

View File

@ -297,6 +297,51 @@ __global__ void max_pool_backward_nhwc(const scalar_t* top_diff,
int pwend = p_end(iw, pad_w, pooled_width, stride_w);
int index_shift = ih * width + iw;
if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) {
#if defined (USE_ROCM)
#define _MAXh 2
#define _MAXw 2
if (phend-phstart<=_MAXh && pwend-pwstart<=_MAXw) {
int msk[_MAXh][_MAXw];
scalar_t tpd[_MAXh][_MAXw];
int cached_index = threadIdx.x;
#pragma unroll
for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
#pragma unroll
for(int oh = 0; oh < _MAXh; ++oh) {
#pragma unroll
for(int ow = 0; ow < _MAXw; ++ow) {
int oh_ = oh+phstart;
int ow_ = ow+pwstart;
const int64_t* ptr_top_mask = top_mask + oh_ * out_stride_h + ow_ * out_stride_w;
if (oh_ >= phend || ow_ >= pwend) {
msk[oh][ow] = ~index_shift;
} else {
msk[oh][ow] = ptr_top_mask[c*out_stride_c];
tpd[oh][ow] = top_diff[oh_ * out_stride_h + ow_ * out_stride_w + c*out_stride_c];
}
}
}
accscalar_t acm = 0;
#pragma unroll
for(int oh = 0; oh < _MAXh; ++oh) {
#pragma unroll
for(int ow = 0; ow < _MAXw; ++ow) {
if (msk[oh][ow] == index_shift) {
acm += static_cast<accscalar_t>(tpd[oh][ow]);
}
}
}
out_cached[cached_index] += acm;
cached_index += blockDim.x;
}
}
else
#undef _MAXh
#undef _MAXw
#endif
for(int oh = phstart; oh < phend; ++oh) {
for(int ow = pwstart; ow < pwend; ++ow) {
int cached_index = threadIdx.x;