mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] [Normalization] Update block size (#165941)
* Seeing upto 6x improvement Pull Request resolved: https://github.com/pytorch/pytorch/pull/165941 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
5b35fc8777
commit
9f82535c5a
|
|
@ -23,7 +23,7 @@ namespace at::native {
|
|||
|
||||
// The maximum number of threads in a block
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int MAX_BLOCK_SIZE = 256;
|
||||
constexpr int MAX_BLOCK_SIZE = 1024;
|
||||
#else
|
||||
constexpr int MAX_BLOCK_SIZE = 512;
|
||||
#endif
|
||||
|
|
@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
|
|||
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
|
||||
static int getNumThreads(int nElem) {
|
||||
#if defined(USE_ROCM)
|
||||
int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
|
||||
int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
|
||||
#else
|
||||
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user