mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[XLA:GPU] Make ReductionEmitter deterministic.
So far, the output could be non-deterministic if multiple reductions are grouped together. This change makes it deterministic. PiperOrigin-RevId: 824965037
This commit is contained in:
parent
29bc205be3
commit
a12d2cfb31
|
|
@ -195,9 +195,9 @@ PerThreadOutputs ReductionFusion::EmitterState::EmitPerThreadElements(
|
|||
const auto& reductions = owner.reduction_heroes_[group_id];
|
||||
absl::flat_hash_map<const HloInstruction*, int> iter_arg_starts;
|
||||
|
||||
for (const auto& [reduction, init] : inits) {
|
||||
for (const HloInstruction* reduction : reductions) {
|
||||
iter_arg_starts[reduction] = iter_arg_inits.size();
|
||||
iter_arg_inits.append(init);
|
||||
iter_arg_inits.append(inits.find(reduction)->second);
|
||||
}
|
||||
|
||||
auto body_builder = [&](ImplicitLocOpBuilder& nested_b,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
// RUN: fusion_to_mlir %s | FileCheck %s
|
||||
// RUN: gpu_test_correctness %s
|
||||
|
||||
%add_f32 {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
%add_f32.2 {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
fusion {
|
||||
param_1.23864 = f32[1,1,144,256,32]{4,3,2,1,0} parameter(1)
|
||||
constant_8203_2_clone_1 = f32[] constant(0.844827533)
|
||||
broadcast.6321.5.clone.1 = f32[1,1,144,256,32]{4,3,2,1,0} broadcast(constant_8203_2_clone_1), dimensions={}
|
||||
mul.770.5.clone.1 = f32[1,1,144,256,32]{4,3,2,1,0} multiply(param_1.23864, broadcast.6321.5.clone.1)
|
||||
bitcast.125.3.clone.1 = f32[1,144,256,32]{3,2,1,0} bitcast(mul.770.5.clone.1)
|
||||
param_0.21084 = f32[1,144,256,32]{3,2,1,0} parameter(0)
|
||||
add_any.68.3.clone.1 = f32[1,144,256,32]{3,2,1,0} add(bitcast.125.3.clone.1, param_0.21084)
|
||||
constant_8190_1_clone_1 = f32[] constant(0.393919319)
|
||||
broadcast.6364.3.clone.1 = f32[1,144,256,32]{3,2,1,0} broadcast(constant_8190_1_clone_1), dimensions={}
|
||||
mul.675.1.clone.1 = f32[1,144,256,32]{3,2,1,0} multiply(add_any.68.3.clone.1, broadcast.6364.3.clone.1)
|
||||
bitcast.15178.1 = f32[128,288,32]{2,1,0} bitcast(mul.675.1.clone.1)
|
||||
constant_8186_50 = f32[] constant(0)
|
||||
reduce.812.1 = f32[128,32]{1,0} reduce(bitcast.15178.1, constant_8186_50), dimensions={1}, to_apply=add_f32
|
||||
constant_8204_2_clone_1 = f32[] constant(0.362068981)
|
||||
broadcast.6327.3.clone.1 = f32[1,1,144,256,32]{4,3,2,1,0} broadcast(constant_8204_2_clone_1), dimensions={}
|
||||
mul.771.3.clone.1 = f32[1,1,144,256,32]{4,3,2,1,0} multiply(param_1.23864, broadcast.6327.3.clone.1)
|
||||
bitcast.15180.1.clone.1 = f32[128,288,32]{2,1,0} bitcast(mul.771.3.clone.1)
|
||||
reduce.816.1.clone.1 = f32[128,32]{1,0} reduce(bitcast.15180.1.clone.1, constant_8186_50), dimensions={1}, to_apply=add_f32.2
|
||||
ROOT tuple.1351 = (f32[128,32]{1,0}, f32[1,144,256,32]{3,2,1,0}, f32[128,32]{1,0}) tuple(reduce.812.1, mul.675.1.clone.1, reduce.816.1.clone.1)
|
||||
}
|
||||
|
||||
// CHECK: xla.pure_call @add_f32_add
|
||||
// CHECK: xla.pure_call @add_f32_2_add_1
|
||||
Loading…
Reference in New Issue
Block a user