[XLA:GPU] Ignore reductions over dimensions of size 1 in UnstableReductionDetect

The UnstableReductionDetector now considers reductions where all reduced dimensions have a size of 1 to be stable, as these operations are effectively no-ops and do not introduce numerical instability. A test case is added to verify this behavior.

PiperOrigin-RevId: 825045042
This commit is contained in:
Ilya Tikhonovskiy 2025-10-28 08:08:16 -07:00 committed by TensorFlower Gardener
parent 7334d07917
commit ffca28bcf8
2 changed files with 41 additions and 1 deletions

View File

@ -62,6 +62,22 @@ static constexpr absl::string_view kUnstableReductionNoMetadataHloModule = R"(
}
)";
static constexpr absl::string_view kNoOpUnstableReductionHloModule = R"(
red {
p0 = bf16[] parameter(0)
p1 = bf16[] parameter(1)
ROOT red = bf16[] add(p0, p1)
}
ENTRY main {
p0 = bf16[1] parameter(0)
init = bf16[] constant(1.0)
ROOT red = bf16[] reduce(p0, init),
to_apply=red,
dimensions={0}
}
)";
using ::absl::LogSeverity;
using ::absl_testing::IsOkAndHolds;
using ::absl_testing::StatusIs;
@ -153,5 +169,22 @@ TEST(UnstableReductionDetectorTest, DoNothingOnUnstableReduction) {
IsOkAndHolds(false));
}
TEST(UnstableReductionDetectorTest, NoOpUnstableReduction) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
kNoOpUnstableReductionHloModule));
module->mutable_config()
.mutable_debug_options()
.set_xla_detect_unstable_reductions(
DebugOptions::UNSTABLE_REDUCTION_DETECTION_MODE_WARNING);
UnstableReductionDetector detector;
::absl::ScopedMockLog log;
EXPECT_CALL(log, Log(LogSeverity::kWarning, _, _)).Times(0);
EXPECT_CALL(log, Log(LogSeverity::kError, _, _)).Times(0);
log.StartCapturingLogs();
EXPECT_THAT(detector.Run(module.get(), /*execution_threads=*/{}),
IsOkAndHolds(false));
log.StopCapturingLogs();
}
} // namespace
} // namespace xla

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/primitive_util.h"
#include "xla/service/pattern_matcher.h"
#include "xla/shape.h"
#include "xla/xla_data.pb.h"
namespace xla {
@ -62,7 +63,13 @@ bool IsKnownStableReduction(const HloReduceInstruction* reduction) {
return true;
}
return false;
Shape operand_shape = reduction->operand(0)->shape();
for (auto dim : reduction->dimensions()) {
if (operand_shape.dimensions(dim) != 1) {
return false;
}
}
return true;
}
} // namespace