From ffca28bcf875269d1d89c677388bb6633491b447 Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Tue, 28 Oct 2025 08:08:16 -0700 Subject: [PATCH] [XLA:GPU] Ignore reductions over dimensions of size 1 in UnstableReductionDetect The UnstableReductionDetector now considers reductions where all reduced dimensions have a size of 1 to be stable, as these operations are effectively no-ops and do not introduce numerical instability. A test case is added to verify this behavior. PiperOrigin-RevId: 825045042 --- .../debug/unstable_reduction_detector_test.cc | 33 +++++++++++++++++++ .../debug/unstable_reduction_finder.cc | 9 ++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/debug/unstable_reduction_detector_test.cc b/third_party/xla/xla/service/debug/unstable_reduction_detector_test.cc index b2c93426281..5e131c612a3 100644 --- a/third_party/xla/xla/service/debug/unstable_reduction_detector_test.cc +++ b/third_party/xla/xla/service/debug/unstable_reduction_detector_test.cc @@ -62,6 +62,22 @@ static constexpr absl::string_view kUnstableReductionNoMetadataHloModule = R"( } )"; +static constexpr absl::string_view kNoOpUnstableReductionHloModule = R"( + red { + p0 = bf16[] parameter(0) + p1 = bf16[] parameter(1) + ROOT red = bf16[] add(p0, p1) + } + + ENTRY main { + p0 = bf16[1] parameter(0) + init = bf16[] constant(1.0) + ROOT red = bf16[] reduce(p0, init), + to_apply=red, + dimensions={0} + } +)"; + using ::absl::LogSeverity; using ::absl_testing::IsOkAndHolds; using ::absl_testing::StatusIs; @@ -153,5 +169,22 @@ TEST(UnstableReductionDetectorTest, DoNothingOnUnstableReduction) { IsOkAndHolds(false)); } +TEST(UnstableReductionDetectorTest, NoOpUnstableReduction) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule( + kNoOpUnstableReductionHloModule)); + module->mutable_config() + .mutable_debug_options() + .set_xla_detect_unstable_reductions( + DebugOptions::UNSTABLE_REDUCTION_DETECTION_MODE_WARNING); + UnstableReductionDetector detector; + ::absl::ScopedMockLog log; + EXPECT_CALL(log, Log(LogSeverity::kWarning, _, _)).Times(0); + EXPECT_CALL(log, Log(LogSeverity::kError, _, _)).Times(0); + log.StartCapturingLogs(); + EXPECT_THAT(detector.Run(module.get(), /*execution_threads=*/{}), + IsOkAndHolds(false)); + log.StopCapturingLogs(); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/debug/unstable_reduction_finder.cc b/third_party/xla/xla/service/debug/unstable_reduction_finder.cc index 764b115d326..07f866c4a00 100644 --- a/third_party/xla/xla/service/debug/unstable_reduction_finder.cc +++ b/third_party/xla/xla/service/debug/unstable_reduction_finder.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" #include "xla/service/pattern_matcher.h" +#include "xla/shape.h" #include "xla/xla_data.pb.h" namespace xla { @@ -62,7 +63,13 @@ bool IsKnownStableReduction(const HloReduceInstruction* reduction) { return true; } - return false; + Shape operand_shape = reduction->operand(0)->shape(); + for (auto dim : reduction->dimensions()) { + if (operand_shape.dimensions(dim) != 1) { + return false; + } + } + return true; } } // namespace