[ROCm] [Normalization] Update block size (#165941)

* Seeing upto 6x improvement

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165941
Approved by: https://github.com/jeffdaily
This commit is contained in:
Jerry Mannil 2025-10-21 20:53:01 +00:00 committed by PyTorch MergeBot
parent 5b35fc8777
commit 9f82535c5a

View File

@ -23,7 +23,7 @@ namespace at::native {
// The maximum number of threads in a block
#if defined(USE_ROCM)
constexpr int MAX_BLOCK_SIZE = 256;
constexpr int MAX_BLOCK_SIZE = 1024;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(USE_ROCM)
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
#else
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif