mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Maxpool backward NHWC Perf Improvement targeting Resnet scenarios (#152267)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/152267 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
4c5cf18ee0
commit
4015166e5d
|
|
@ -297,6 +297,51 @@ __global__ void max_pool_backward_nhwc(const scalar_t* top_diff,
|
|||
int pwend = p_end(iw, pad_w, pooled_width, stride_w);
|
||||
int index_shift = ih * width + iw;
|
||||
if ((phstart + 1 != phend) || (pwstart + 1 != pwend)) {
|
||||
|
||||
#if defined (USE_ROCM)
|
||||
#define _MAXh 2
|
||||
#define _MAXw 2
|
||||
if (phend-phstart<=_MAXh && pwend-pwstart<=_MAXw) {
|
||||
int msk[_MAXh][_MAXw];
|
||||
scalar_t tpd[_MAXh][_MAXw];
|
||||
int cached_index = threadIdx.x;
|
||||
#pragma unroll
|
||||
for (int c = channel_offset; c < channels; c += blockDim.x*kernel_stride_C) {
|
||||
#pragma unroll
|
||||
for(int oh = 0; oh < _MAXh; ++oh) {
|
||||
#pragma unroll
|
||||
for(int ow = 0; ow < _MAXw; ++ow) {
|
||||
int oh_ = oh+phstart;
|
||||
int ow_ = ow+pwstart;
|
||||
const int64_t* ptr_top_mask = top_mask + oh_ * out_stride_h + ow_ * out_stride_w;
|
||||
if (oh_ >= phend || ow_ >= pwend) {
|
||||
msk[oh][ow] = ~index_shift;
|
||||
} else {
|
||||
msk[oh][ow] = ptr_top_mask[c*out_stride_c];
|
||||
tpd[oh][ow] = top_diff[oh_ * out_stride_h + ow_ * out_stride_w + c*out_stride_c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accscalar_t acm = 0;
|
||||
#pragma unroll
|
||||
for(int oh = 0; oh < _MAXh; ++oh) {
|
||||
#pragma unroll
|
||||
for(int ow = 0; ow < _MAXw; ++ow) {
|
||||
if (msk[oh][ow] == index_shift) {
|
||||
acm += static_cast<accscalar_t>(tpd[oh][ow]);
|
||||
}
|
||||
}
|
||||
}
|
||||
out_cached[cached_index] += acm;
|
||||
cached_index += blockDim.x;
|
||||
}
|
||||
}
|
||||
else
|
||||
#undef _MAXh
|
||||
#undef _MAXw
|
||||
#endif
|
||||
|
||||
for(int oh = phstart; oh < phend; ++oh) {
|
||||
for(int ow = pwstart; ow < pwend; ++ow) {
|
||||
int cached_index = threadIdx.x;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user