Fix dot library predicates

This changes the predicates for calling a library for a dot to indicate whether we will actually call the library, not just whether the library supports the dot. This fixes bugs where we incorrectly claim a library will handle the dot.

PiperOrigin-RevId: 825231317
This commit is contained in:
A. Unique TensorFlower 2025-10-28 15:30:44 -07:00 committed by TensorFlower Gardener
parent d7b371034b
commit 281fa6f4d3

View File

@ -654,17 +654,30 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<BatchedGatherScatterNormalizer>();
pipeline.AddPass<ResultCaster>();
// If XNNPACK is enabled, we only need to upcast dots that XnnDotThunk does
// not support. `upcaster_filter` returns false if the instruction shouldn't
// be processed.
auto library_supports_dot =
LibrarySupportsDot(module, target_machine_features);
HloPredicate upcaster_filter = [&](const HloInstruction* instr) {
if (instr->opcode() != HloOpcode::kDot) {
return true;
auto call_library_for_dot = [&](const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kDot) {
return false;
}
return !library_supports_dot(*instr);
auto dot_strategy = GetDotImplementationStrategy(
module->config(), instr, *target_machine_features,
/*allow_runtime_calls=*/true);
if (dot_strategy != DotImplementationStrategy::kEigen) {
// We aren't going to call a library for this dot.
return false;
}
return library_supports_dot(instr);
};
// If YNNPACK is enabled, we only need to upcast dots that YnnDotThunk does
// not support. `upcaster_filter` returns false if the instruction shouldn't
// be processed.
HloPredicate upcaster_filter = [&](const HloInstruction* instr) {
return !call_library_for_dot(*instr);
};
// xla::cpu::GetDotImplementationStrategy (used by call_library_for_dot)
@ -728,7 +741,7 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
// Convert BF16 and F8 operations to F32 and F16 respectively so that the CPU
// backend can support BF16/F8 operations without directly implementing a
// BF16/F8 lowering for most ops.
CpuFloatSupport bf16_support(BF16, library_supports_dot);
CpuFloatSupport bf16_support(BF16, call_library_for_dot);
#ifdef XLA_ONEDNN
OneDnnFloatSupport onednn_bf16_support(BF16);
if (use_onednn_custom_call) {