[ROCm] Tune 3d tensor sums when not using fastest dimension (#146170)

Tune 3d tensor sums when not using fastest dimension.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146170
Approved by: https://github.com/jeffdaily
This commit is contained in:
Doru Bercea 2025-02-04 04:02:14 +00:00 committed by PyTorch MergeBot
parent 7997ecf809
commit a79d8f8ba4

View File

@ -1159,6 +1159,8 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
config.ctas_per_output = div_up(num_mp, 2);
else if (config.ctas_per_output < 16)
config.ctas_per_output = 1;
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension)
config.ctas_per_output = 4;
#endif
if (config.ctas_per_output > 1) {
config.input_mult[2] = config.split_input(config.ctas_per_output);