[XLA:GPU] Make ReductionEmitter deterministic.

So far, the output could be non-deterministic if multiple reductions are
grouped together. This change makes it deterministic.

PiperOrigin-RevId: 824965037
This commit is contained in:
Adrian Kuegel 2025-10-28 04:23:25 -07:00 committed by TensorFlower Gardener
parent 29bc205be3
commit a12d2cfb31
2 changed files with 41 additions and 2 deletions

View File

@ -195,9 +195,9 @@ PerThreadOutputs ReductionFusion::EmitterState::EmitPerThreadElements(
const auto& reductions = owner.reduction_heroes_[group_id];
absl::flat_hash_map<const HloInstruction*, int> iter_arg_starts;
for (const auto& [reduction, init] : inits) {
for (const HloInstruction* reduction : reductions) {
iter_arg_starts[reduction] = iter_arg_inits.size();
iter_arg_inits.append(init);
iter_arg_inits.append(inits.find(reduction)->second);
}
auto body_builder = [&](ImplicitLocOpBuilder& nested_b,

View File

@ -0,0 +1,39 @@
// RUN: fusion_to_mlir %s | FileCheck %s
// RUN: gpu_test_correctness %s
%add_f32 {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
%add_f32.2 {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
fusion {
param_1.23864 = f32[1,1,144,256,32]{4,3,2,1,0} parameter(1)
constant_8203_2_clone_1 = f32[] constant(0.844827533)
broadcast.6321.5.clone.1 = f32[1,1,144,256,32]{4,3,2,1,0} broadcast(constant_8203_2_clone_1), dimensions={}
mul.770.5.clone.1 = f32[1,1,144,256,32]{4,3,2,1,0} multiply(param_1.23864, broadcast.6321.5.clone.1)
bitcast.125.3.clone.1 = f32[1,144,256,32]{3,2,1,0} bitcast(mul.770.5.clone.1)
param_0.21084 = f32[1,144,256,32]{3,2,1,0} parameter(0)
add_any.68.3.clone.1 = f32[1,144,256,32]{3,2,1,0} add(bitcast.125.3.clone.1, param_0.21084)
constant_8190_1_clone_1 = f32[] constant(0.393919319)
broadcast.6364.3.clone.1 = f32[1,144,256,32]{3,2,1,0} broadcast(constant_8190_1_clone_1), dimensions={}
mul.675.1.clone.1 = f32[1,144,256,32]{3,2,1,0} multiply(add_any.68.3.clone.1, broadcast.6364.3.clone.1)
bitcast.15178.1 = f32[128,288,32]{2,1,0} bitcast(mul.675.1.clone.1)
constant_8186_50 = f32[] constant(0)
reduce.812.1 = f32[128,32]{1,0} reduce(bitcast.15178.1, constant_8186_50), dimensions={1}, to_apply=add_f32
constant_8204_2_clone_1 = f32[] constant(0.362068981)
broadcast.6327.3.clone.1 = f32[1,1,144,256,32]{4,3,2,1,0} broadcast(constant_8204_2_clone_1), dimensions={}
mul.771.3.clone.1 = f32[1,1,144,256,32]{4,3,2,1,0} multiply(param_1.23864, broadcast.6327.3.clone.1)
bitcast.15180.1.clone.1 = f32[128,288,32]{2,1,0} bitcast(mul.771.3.clone.1)
reduce.816.1.clone.1 = f32[128,32]{1,0} reduce(bitcast.15180.1.clone.1, constant_8186_50), dimensions={1}, to_apply=add_f32.2
ROOT tuple.1351 = (f32[128,32]{1,0}, f32[1,144,256,32]{3,2,1,0}, f32[128,32]{1,0}) tuple(reduce.812.1, mul.675.1.clone.1, reduce.816.1.clone.1)
}
// CHECK: xla.pure_call @add_f32_add
// CHECK: xla.pure_call @add_f32_2_add_1