mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Fix dot library predicates
This changes the predicates for calling a library for a dot to indicate whether we will actually call the library, not just whether the library supports the dot. This fixes bugs where we incorrectly claim a library will handle the dot. PiperOrigin-RevId: 825231317
This commit is contained in:
parent
d7b371034b
commit
281fa6f4d3
29
third_party/xla/xla/service/cpu/cpu_compiler.cc
vendored
29
third_party/xla/xla/service/cpu/cpu_compiler.cc
vendored
|
|
@ -654,17 +654,30 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
|||
pipeline.AddPass<BatchedGatherScatterNormalizer>();
|
||||
pipeline.AddPass<ResultCaster>();
|
||||
|
||||
// If XNNPACK is enabled, we only need to upcast dots that XnnDotThunk does
|
||||
// not support. `upcaster_filter` returns false if the instruction shouldn't
|
||||
// be processed.
|
||||
auto library_supports_dot =
|
||||
LibrarySupportsDot(module, target_machine_features);
|
||||
|
||||
HloPredicate upcaster_filter = [&](const HloInstruction* instr) {
|
||||
if (instr->opcode() != HloOpcode::kDot) {
|
||||
return true;
|
||||
auto call_library_for_dot = [&](const HloInstruction& instr) {
|
||||
if (instr.opcode() != HloOpcode::kDot) {
|
||||
return false;
|
||||
}
|
||||
return !library_supports_dot(*instr);
|
||||
|
||||
auto dot_strategy = GetDotImplementationStrategy(
|
||||
module->config(), instr, *target_machine_features,
|
||||
/*allow_runtime_calls=*/true);
|
||||
if (dot_strategy != DotImplementationStrategy::kEigen) {
|
||||
// We aren't going to call a library for this dot.
|
||||
return false;
|
||||
}
|
||||
|
||||
return library_supports_dot(instr);
|
||||
};
|
||||
|
||||
// If YNNPACK is enabled, we only need to upcast dots that YnnDotThunk does
|
||||
// not support. `upcaster_filter` returns false if the instruction shouldn't
|
||||
// be processed.
|
||||
HloPredicate upcaster_filter = [&](const HloInstruction* instr) {
|
||||
return !call_library_for_dot(*instr);
|
||||
};
|
||||
|
||||
// xla::cpu::GetDotImplementationStrategy (used by call_library_for_dot)
|
||||
|
|
@ -728,7 +741,7 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
|||
// Convert BF16 and F8 operations to F32 and F16 respectively so that the CPU
|
||||
// backend can support BF16/F8 operations without directly implementing a
|
||||
// BF16/F8 lowering for most ops.
|
||||
CpuFloatSupport bf16_support(BF16, library_supports_dot);
|
||||
CpuFloatSupport bf16_support(BF16, call_library_for_dot);
|
||||
#ifdef XLA_ONEDNN
|
||||
OneDnnFloatSupport onednn_bf16_support(BF16);
|
||||
if (use_onednn_custom_call) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user