mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:GPU] Ignore reductions over dimensions of size 1 in UnstableReductionDetect
The UnstableReductionDetector now considers reductions where all reduced dimensions have a size of 1 to be stable, as these operations are effectively no-ops and do not introduce numerical instability. A test case is added to verify this behavior. PiperOrigin-RevId: 825045042
This commit is contained in:
parent
7334d07917
commit
ffca28bcf8
|
|
@ -62,6 +62,22 @@ static constexpr absl::string_view kUnstableReductionNoMetadataHloModule = R"(
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
static constexpr absl::string_view kNoOpUnstableReductionHloModule = R"(
|
||||||
|
red {
|
||||||
|
p0 = bf16[] parameter(0)
|
||||||
|
p1 = bf16[] parameter(1)
|
||||||
|
ROOT red = bf16[] add(p0, p1)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
p0 = bf16[1] parameter(0)
|
||||||
|
init = bf16[] constant(1.0)
|
||||||
|
ROOT red = bf16[] reduce(p0, init),
|
||||||
|
to_apply=red,
|
||||||
|
dimensions={0}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
using ::absl::LogSeverity;
|
using ::absl::LogSeverity;
|
||||||
using ::absl_testing::IsOkAndHolds;
|
using ::absl_testing::IsOkAndHolds;
|
||||||
using ::absl_testing::StatusIs;
|
using ::absl_testing::StatusIs;
|
||||||
|
|
@ -153,5 +169,22 @@ TEST(UnstableReductionDetectorTest, DoNothingOnUnstableReduction) {
|
||||||
IsOkAndHolds(false));
|
IsOkAndHolds(false));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(UnstableReductionDetectorTest, NoOpUnstableReduction) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
|
||||||
|
kNoOpUnstableReductionHloModule));
|
||||||
|
module->mutable_config()
|
||||||
|
.mutable_debug_options()
|
||||||
|
.set_xla_detect_unstable_reductions(
|
||||||
|
DebugOptions::UNSTABLE_REDUCTION_DETECTION_MODE_WARNING);
|
||||||
|
UnstableReductionDetector detector;
|
||||||
|
::absl::ScopedMockLog log;
|
||||||
|
EXPECT_CALL(log, Log(LogSeverity::kWarning, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(log, Log(LogSeverity::kError, _, _)).Times(0);
|
||||||
|
log.StartCapturingLogs();
|
||||||
|
EXPECT_THAT(detector.Run(module.get(), /*execution_threads=*/{}),
|
||||||
|
IsOkAndHolds(false));
|
||||||
|
log.StopCapturingLogs();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||||
#include "xla/hlo/ir/hlo_opcode.h"
|
#include "xla/hlo/ir/hlo_opcode.h"
|
||||||
#include "xla/primitive_util.h"
|
#include "xla/primitive_util.h"
|
||||||
#include "xla/service/pattern_matcher.h"
|
#include "xla/service/pattern_matcher.h"
|
||||||
|
#include "xla/shape.h"
|
||||||
#include "xla/xla_data.pb.h"
|
#include "xla/xla_data.pb.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
@ -62,7 +63,13 @@ bool IsKnownStableReduction(const HloReduceInstruction* reduction) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
Shape operand_shape = reduction->operand(0)->shape();
|
||||||
|
for (auto dim : reduction->dimensions()) {
|
||||||
|
if (operand_shape.dimensions(dim) != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user