mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Tune 3d tensor sums when not using fastest dimension (#146170)
Tune 3d tensor sums when not using fastest dimension. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146170 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
7997ecf809
commit
a79d8f8ba4
|
|
@ -1159,6 +1159,8 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
|
|||
config.ctas_per_output = div_up(num_mp, 2);
|
||||
else if (config.ctas_per_output < 16)
|
||||
config.ctas_per_output = 1;
|
||||
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension)
|
||||
config.ctas_per_output = 4;
|
||||
#endif
|
||||
if (config.ctas_per_output > 1) {
|
||||
config.input_mult[2] = config.split_input(config.ctas_per_output);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user