mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[XLA:GPU] Ignore reductions over dimensions of size 1 in UnstableReductionDetect
The UnstableReductionDetector now considers reductions where all reduced dimensions have a size of 1 to be stable, as these operations are effectively no-ops and do not introduce numerical instability. A test case is added to verify this behavior. PiperOrigin-RevId: 825045042
This commit is contained in:
parent
7334d07917
commit
ffca28bcf8
|
|
@ -62,6 +62,22 @@ static constexpr absl::string_view kUnstableReductionNoMetadataHloModule = R"(
|
|||
}
|
||||
)";
|
||||
|
||||
static constexpr absl::string_view kNoOpUnstableReductionHloModule = R"(
|
||||
red {
|
||||
p0 = bf16[] parameter(0)
|
||||
p1 = bf16[] parameter(1)
|
||||
ROOT red = bf16[] add(p0, p1)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
p0 = bf16[1] parameter(0)
|
||||
init = bf16[] constant(1.0)
|
||||
ROOT red = bf16[] reduce(p0, init),
|
||||
to_apply=red,
|
||||
dimensions={0}
|
||||
}
|
||||
)";
|
||||
|
||||
using ::absl::LogSeverity;
|
||||
using ::absl_testing::IsOkAndHolds;
|
||||
using ::absl_testing::StatusIs;
|
||||
|
|
@ -153,5 +169,22 @@ TEST(UnstableReductionDetectorTest, DoNothingOnUnstableReduction) {
|
|||
IsOkAndHolds(false));
|
||||
}
|
||||
|
||||
TEST(UnstableReductionDetectorTest, NoOpUnstableReduction) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
|
||||
kNoOpUnstableReductionHloModule));
|
||||
module->mutable_config()
|
||||
.mutable_debug_options()
|
||||
.set_xla_detect_unstable_reductions(
|
||||
DebugOptions::UNSTABLE_REDUCTION_DETECTION_MODE_WARNING);
|
||||
UnstableReductionDetector detector;
|
||||
::absl::ScopedMockLog log;
|
||||
EXPECT_CALL(log, Log(LogSeverity::kWarning, _, _)).Times(0);
|
||||
EXPECT_CALL(log, Log(LogSeverity::kError, _, _)).Times(0);
|
||||
log.StartCapturingLogs();
|
||||
EXPECT_THAT(detector.Run(module.get(), /*execution_threads=*/{}),
|
||||
IsOkAndHolds(false));
|
||||
log.StopCapturingLogs();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/primitive_util.h"
|
||||
#include "xla/service/pattern_matcher.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
|
@ -62,7 +63,13 @@ bool IsKnownStableReduction(const HloReduceInstruction* reduction) {
|
|||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
Shape operand_shape = reduction->operand(0)->shape();
|
||||
for (auto dim : reduction->dimensions()) {
|
||||
if (operand_shape.dimensions(dim) != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user