[ROCm] Limit number of values per thread for reductions on three dimensions (#159652)

In the current implementation of reductions in three dimensions for AMD GPUs the number of values per thread is unbounded and can end up being in the hundreds of thousands for certain tensors. This of course is bad for performance. This patch fixes this issue by increasing the parallelism and thus lowering the number of value per thread to reasonable limits i.e. less than 2048 values per thread. The performance gains can be between 10x-17x for certain examples where the number of values per thread was originally very high.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159652
Approved by: https://github.com/jeffdaily
This commit is contained in:
Gheorghe-Teodor Bercea 2025-08-12 21:15:52 +00:00 committed by PyTorch MergeBot
parent c24ca7f4bf
commit f27232a213

View File

@ -209,6 +209,10 @@ struct ReduceConfig {
int values_per_thread() const { int values_per_thread() const {
return div_up(num_inputs, step_input); return div_up(num_inputs, step_input);
} }
int mock_values_per_thread(int parallelism) {
return div_up(num_inputs, step_input * parallelism);
}
}; };
std::ostream& operator<<(std::ostream& out, const ReduceConfig& config); std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
@ -1166,8 +1170,17 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
else if (config.ctas_per_output < 16) else if (config.ctas_per_output < 16)
config.ctas_per_output = 1; config.ctas_per_output = 1;
bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast); bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast);
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) {
config.ctas_per_output = 4; config.ctas_per_output = 4;
int vpt = config.values_per_thread();
// Capping the number of values per thread to 2048 for now
// based on known use cases.
while (vpt >= 2048) {
config.ctas_per_output *= 2;
// Computes the new values per thread without side effects
vpt = config.mock_values_per_thread(config.ctas_per_output);
}
}
#endif #endif
if (config.ctas_per_output > 1) { if (config.ctas_per_output > 1) {
config.input_mult[2] = config.split_input(config.ctas_per_output); config.input_mult[2] = config.split_input(config.ctas_per_output);