mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Limit number of values per thread for reductions on three dimensions (#159652)
In the current implementation of reductions in three dimensions for AMD GPUs the number of values per thread is unbounded and can end up being in the hundreds of thousands for certain tensors. This of course is bad for performance. This patch fixes this issue by increasing the parallelism and thus lowering the number of value per thread to reasonable limits i.e. less than 2048 values per thread. The performance gains can be between 10x-17x for certain examples where the number of values per thread was originally very high. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159652 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
c24ca7f4bf
commit
f27232a213
|
|
@ -209,6 +209,10 @@ struct ReduceConfig {
|
||||||
int values_per_thread() const {
|
int values_per_thread() const {
|
||||||
return div_up(num_inputs, step_input);
|
return div_up(num_inputs, step_input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int mock_values_per_thread(int parallelism) {
|
||||||
|
return div_up(num_inputs, step_input * parallelism);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
|
std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
|
||||||
|
|
@ -1166,8 +1170,17 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
|
||||||
else if (config.ctas_per_output < 16)
|
else if (config.ctas_per_output < 16)
|
||||||
config.ctas_per_output = 1;
|
config.ctas_per_output = 1;
|
||||||
bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast);
|
bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast);
|
||||||
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last)
|
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) {
|
||||||
config.ctas_per_output = 4;
|
config.ctas_per_output = 4;
|
||||||
|
int vpt = config.values_per_thread();
|
||||||
|
// Capping the number of values per thread to 2048 for now
|
||||||
|
// based on known use cases.
|
||||||
|
while (vpt >= 2048) {
|
||||||
|
config.ctas_per_output *= 2;
|
||||||
|
// Computes the new values per thread without side effects
|
||||||
|
vpt = config.mock_values_per_thread(config.ctas_per_output);
|
||||||
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
if (config.ctas_per_output > 1) {
|
if (config.ctas_per_output > 1) {
|
||||||
config.input_mult[2] = config.split_input(config.ctas_per_output);
|
config.input_mult[2] = config.split_input(config.ctas_per_output);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user