Add initial support for offloading dots to YNNPACK.

PiperOrigin-RevId: 823318539
This commit is contained in:
Alexander Shaposhnikov 2025-10-23 20:57:12 -07:00 committed by TensorFlower Gardener
parent c3c6653a98
commit 3be9a21d7e
7 changed files with 211 additions and 16 deletions

View File

@ -157,9 +157,10 @@ cc_library(
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/runtime/xnnpack:xnn_interop",
"//xla/backends/cpu/runtime:dot_lib",
"//xla/backends/cpu/runtime/ynnpack:ynn_interop",
"//xla/hlo/ir:hlo",
"//xla/stream_executor:device_memory",
"//xla/tsl/platform:logging",
"//xla/tsl/platform:statusor",
"@XNNPACK//ynnpack",

View File

@ -15,9 +15,12 @@ limitations under the License.
#include "xla/backends/cpu/ynn_emitter.h"
#include <array>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <numeric>
#include <utility>
#include <vector>
#include "ynnpack/include/ynnpack.h"
@ -25,13 +28,18 @@ limitations under the License.
#include "absl/functional/any_invocable.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/runtime/dot_lib.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/backends/cpu/ynn_support.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/literal.h"
#include "xla/primitive_util.h"
#include "xla/shape.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
@ -280,7 +288,113 @@ static absl::StatusOr<YnnSubgraph> EmitYnnSubgraph(
return subgraph;
}
absl::StatusOr<absl::AnyInvocable<absl::StatusOr<YnnSubgraph>()>>
//===----------------------------------------------------------------------===//
// Emit YNNPACK subgraph for the given HLO dot instruction.
//===----------------------------------------------------------------------===//
// TODO(ashaposhnikov): Use DefineBatchMatrixMultiply in EmitYnnSubgraph.
static ynn_status DefineBatchMatrixMultiply(ynn_subgraph_t subgraph,
uint32_t input1_id,
uint32_t input2_id,
uint32_t output_id, size_t b_rank,
bool transpose_b) {
if (transpose_b) {
uint32_t input2_id_transposed = YNN_INVALID_VALUE_ID;
std::array<int32_t, YNN_MAX_TENSOR_RANK> perm;
std::iota(perm.begin(), perm.end(), 0);
CHECK_LT(b_rank, YNN_MAX_TENSOR_RANK);
std::swap(perm[b_rank - 1], perm[b_rank - 2]);
ynn_status status = ynn_define_static_transpose(
subgraph,
/*num_dims=*/b_rank, perm.data(), input2_id, &input2_id_transposed,
/*flags=*/0);
if (status != ynn_status_success) {
return status;
}
input2_id = input2_id_transposed;
}
return ynn_define_dot(subgraph, /*num_k_dims=*/1, input1_id, input2_id,
YNN_INVALID_VALUE_ID, &output_id, /*flags=*/0);
}
static absl::StatusOr<YnnSubgraph> EmitYnnDotSubgraph(
const HloDotInstruction* dot,
std::vector<std::unique_ptr<Literal>>& literals,
absl::Span<const se::DeviceMemoryBase> arguments_buffers,
bool capture_rhs) {
TF_ASSIGN_OR_RETURN(YnnSubgraph subgraph,
CreateYnnSubgraph([&](ynn_subgraph_t* subgraph) {
return ynn_create_subgraph(
/*external_value_ids=*/3,
/*flags=*/0, subgraph);
}));
uint32_t lhs_id = 0;
uint32_t rhs_id = 1;
uint32_t out_id = 2;
const HloInstruction* lhs = dot->operand(0);
const HloInstruction* rhs = dot->operand(1);
const Shape& lhs_shape = lhs->shape();
const Shape& rhs_shape = rhs->shape();
const Shape& out_shape = dot->shape();
auto dims = [](absl::Span<const int64_t> dims) -> std::vector<size_t> {
return {dims.begin(), dims.end()};
};
std::vector<size_t> lhs_dims = dims(lhs_shape.dimensions());
std::vector<size_t> rhs_dims = dims(rhs_shape.dimensions());
std::vector<size_t> out_dims = dims(out_shape.dimensions());
PrimitiveType dtype = lhs->shape().element_type();
if (dtype != F32 && dtype != BF16) {
return InvalidArgument("Unsupported input data type for YnnDotThunk: %s",
primitive_util::LowercasePrimitiveTypeName(dtype));
}
ynn_type input_type = (dtype == F32) ? ynn_type_fp32 : ynn_type_bf16;
ynn_type output_type = ynn_type_fp32;
const uint32_t input_tensor_flags = YNN_VALUE_FLAG_EXTERNAL_INPUT;
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
subgraph.get(), input_type, lhs_dims.size(), lhs_dims.data(),
/*data=*/nullptr,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID, input_tensor_flags, &lhs_id));
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
subgraph.get(), input_type, rhs_dims.size(), rhs_dims.data(),
capture_rhs ? arguments_buffers[1].opaque() : nullptr,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID, input_tensor_flags, &rhs_id));
const uint32_t output_tensor_flags = YNN_VALUE_FLAG_EXTERNAL_OUTPUT;
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
subgraph.get(), output_type, out_dims.size(), out_dims.data(),
/*data=*/nullptr,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID, output_tensor_flags, &out_id));
DotDimensionNumbers dot_dimensions = dot->dot_dimension_numbers();
TF_ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape,
rhs_shape, out_shape));
TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims,
GetDotCanonicalDims(dot_dimensions, dot_shape));
const size_t b_rank = rhs_shape.dimensions_size();
const bool transpose_b = !dot_canonical_dims.rhs_canonical;
YNN_RETURN_IF_ERROR(DefineBatchMatrixMultiply(subgraph.get(), lhs_id, rhs_id,
out_id, b_rank, transpose_b));
return subgraph;
}
absl::StatusOr<absl::AnyInvocable<absl::StatusOr<YnnSubgraph>(
absl::Span<const se::DeviceMemoryBase> arguments_buffers)>>
EmitYnnFusionBuilder(const HloComputation* computation) {
// We do not support non-array parameters for YNNPACK operations.
for (auto& param : computation->parameter_instructions()) {
@ -297,10 +411,19 @@ EmitYnnFusionBuilder(const HloComputation* computation) {
computation->root_instruction()->shape().ToString());
}
return [computation,
literals = std::vector<std::unique_ptr<Literal>>()]() mutable {
return [computation, literals = std::vector<std::unique_ptr<Literal>>()](
absl::Span<const se::DeviceMemoryBase> arguments_buffers) mutable {
return EmitYnnSubgraph(computation, literals);
};
}
absl::StatusOr<absl::AnyInvocable<absl::StatusOr<YnnSubgraph>(
absl::Span<const se::DeviceMemoryBase> arguments_buffers)>>
EmitYnnDotBuilder(const HloDotInstruction* dot, bool capture_rhs) {
return [dot, capture_rhs, literals = std::vector<std::unique_ptr<Literal>>()](
absl::Span<const se::DeviceMemoryBase> arguments_buffers) mutable {
return EmitYnnDotSubgraph(dot, literals, arguments_buffers, capture_rhs);
};
}
} // namespace xla::cpu

View File

@ -20,12 +20,19 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/stream_executor/device_memory.h"
namespace xla::cpu {
absl::StatusOr<absl::AnyInvocable<absl::StatusOr<YnnSubgraph>()>>
absl::StatusOr<absl::AnyInvocable<absl::StatusOr<YnnSubgraph>(
absl::Span<const se::DeviceMemoryBase> arguments_buffers)>>
EmitYnnFusionBuilder(const HloComputation* computation);
absl::StatusOr<absl::AnyInvocable<absl::StatusOr<YnnSubgraph>(
absl::Span<const se::DeviceMemoryBase> arguments_buffers)>>
EmitYnnDotBuilder(const HloDotInstruction* dot, bool capture_rhs);
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_YNN_EMITTER_H_

View File

@ -1096,7 +1096,7 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
&DebugOptions::LibraryFusionType_Parse,
debug_options->mutable_xla_cpu_experimental_xnn_fusion_type()),
"",
"Comma-separated list of XNN fusion types to be enabled.; "
"Comma-separated list of XNN fusion types to be enabled; "
"no whitespace around commas. Two ways to pass values:\n"
" 1. Exact type names. This overwrites the default setting.\n"
" 2. '+' or '-' prefix: This adds or removes a fusion type "
@ -1104,6 +1104,21 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"mode. Every item must have the sign prefix.\n"
"Available fusion types: dot, eltwise, and reduce.\n"
"The default list is currently empty."));
flag_list->push_back(tsl::Flag(
"xla_cpu_experimental_ynn_fusion_type",
SetterForRepeatedEnum<DebugOptions::LibraryFusionType>(
"xla_cpu_experimental_ynn_fusion_type",
/*enum_prefix=*/"LIBRARY_FUSION_TYPE_",
&DebugOptions::LibraryFusionType_Parse,
debug_options->mutable_xla_cpu_experimental_ynn_fusion_type()),
"",
"Comma-separated list of YNN fusion types to be enabled; "
"no whitespace around commas. Two ways to pass values:\n"
" 1. Exact type names. This overwrites the default setting.\n"
" 2. '+' or '-' prefix: This adds or removes a fusion type "
"from the default list. Cannot be mixed with the overwrite "
"mode. Every item must have the sign prefix.\n"
"The default list is currently empty."));
flag_list->push_back(tsl::Flag(
"xla_cpu_experimental_xnn_graph_fusion_mode",
setter_for_xla_cpu_experimental_xnn_graph_fusion_mode,

View File

@ -988,12 +988,14 @@ cc_library(
"//xla/service:pattern_matcher",
"//xla/service/gpu/model/experimental:symbolic_expr",
"//xla/service/llvm_ir:llvm_util",
"//xla/stream_executor:device_memory",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
@ -1010,6 +1012,7 @@ cc_library(
]) + if_ynnpack([
"//xla/backends/cpu:ynn_emitter",
"//xla/backends/cpu:ynn_support",
"//xla/backends/cpu/runtime/ynnpack:ynn_interop",
"//xla/backends/cpu/runtime/ynnpack:ynn_fusion_thunk",
]),
)

View File

@ -24,6 +24,7 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
@ -107,6 +108,7 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/statusor.h"
@ -127,6 +129,7 @@ limitations under the License.
#ifdef XLA_YNNPACK
#include "xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/backends/cpu/ynn_emitter.h"
#include "xla/backends/cpu/ynn_support.h"
#endif // XLA_YNNPACK
@ -1082,6 +1085,25 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitDotThunk(
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice out_slice,
GetAllocationSlice(instruction));
#ifdef XLA_YNNPACK
const bool use_ynn = absl::c_linear_search(
hlo_module_config_.debug_options()
.xla_cpu_experimental_ynn_fusion_type(),
DebugOptions::LIBRARY_FUSION_TYPE_INDIVIDUAL_DOT);
if (use_ynn) {
// TODO(ashaposhnikov): Replace IsDotSupportedByXnn with
// IsDotSupportedByYnn.
TF_ASSIGN_OR_RETURN(
auto is_dot_supported,
IsDotSupportedByXnn(dnums, lhs->shape(), rhs->shape(),
instruction->shape(), &target_machine_features_,
/*use_cost_model=*/false));
if (is_dot_supported) {
return EmitYnnFusionThunk(instruction);
}
}
#endif // XLA_YNNPACK
// Decide whether to use XNNPACK or Eigen.
bool use_xnn = hlo_module_config_.debug_options().xla_cpu_use_xnnpack();
if (use_xnn) {
@ -1508,8 +1530,6 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitXnnFusionThunk(
absl::StatusOr<ThunkSequence> ThunkEmitter::EmitYnnFusionThunk(
const HloInstruction* instruction) {
#ifdef XLA_YNNPACK
auto* fusion = Cast<HloFusionInstruction>(instruction);
// Collect YNNPACK fusion arguments.
std::vector<YnnFusionThunk::Argument> arguments;
for (HloInstruction* operand : instruction->operands()) {
@ -1530,15 +1550,36 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitYnnFusionThunk(
results.push_back(YnnFusionThunk::Result{slice, indexed.shape});
}
const HloComputation* computation = fusion->fused_instructions_computation();
// Construct YNNPACK subgraph builder from the fusion computation.
TF_ASSIGN_OR_RETURN(auto builder, EmitYnnFusionBuilder(computation));
absl::AnyInvocable<absl::StatusOr<YnnSubgraph>(
absl::Span<const se::DeviceMemoryBase> arguments_buffers)>
builder;
absl::Span<const int64_t> captured_arguments_ids;
if (instruction->opcode() == HloOpcode::kDot) {
const HloDotInstruction* dot = Cast<HloDotInstruction>(instruction);
// TODO(ashaposhnikov): Revisit this if we ever get a reliable way
// to determine that RHS is constant.
bool capture_rhs = false;
// Construct YNNPACK subgraph builder from the dot instruction.
TF_ASSIGN_OR_RETURN(builder, EmitYnnDotBuilder(dot, capture_rhs));
static constexpr int64_t kCapturedIds[1] = {1};
if (capture_rhs) {
captured_arguments_ids = kCapturedIds;
}
} else {
auto* fusion = Cast<HloFusionInstruction>(instruction);
const HloComputation* computation =
fusion->fused_instructions_computation();
// Construct YNNPACK subgraph builder from the fusion computation.
TF_ASSIGN_OR_RETURN(builder, EmitYnnFusionBuilder(computation));
}
return ThunkSequence::Of<YnnFusionThunk>(
YnnFusionThunk::Options{}, ThunkInfo(instruction), std::move(arguments),
std::move(results),
[b = std::move(builder)](auto, auto) mutable { return b(); });
[b = std::move(builder)](auto, auto, auto arg_buffers) mutable {
return b(arg_buffers);
},
captured_arguments_ids);
#else
return Unimplemented("XLA is not built with YNNPACK.");
#endif // XLA_YNNPACK

View File

@ -155,6 +155,7 @@ message DebugOptions {
LIBRARY_FUSION_TYPE_DOT = 1; // Dot and any eltwise ops around it.
LIBRARY_FUSION_TYPE_ELTWISE = 2;
LIBRARY_FUSION_TYPE_REDUCE = 3;
LIBRARY_FUSION_TYPE_INDIVIDUAL_DOT = 4;
}
enum XnnGraphFusionMode {
@ -205,15 +206,19 @@ message DebugOptions {
// Call oneDNN custom call thunks in the CPU backend
optional bool xla_cpu_experimental_onednn_custom_call = 412;
// Stores the fusion types enabled for oneDNN in DotLibraryRewriter pass.
// Stores the fusion types enabled for oneDNN in LibraryRewriter pass.
repeated LibraryFusionType xla_cpu_experimental_onednn_fusion_type = 399;
// Stores the fusion types enabled for XNNPACK in DotLibraryRewriter pass.
// Stores the fusion types enabled for XNNPACK in LibraryRewriter pass.
repeated LibraryFusionType xla_cpu_experimental_xnn_fusion_type = 400;
// Controls XnnGraphFusion HLO pass.
optional XnnGraphFusionMode xla_cpu_experimental_xnn_graph_fusion_mode = 365;
// Stores the fusion types enabled for YNNPACK in LibraryRewriter pass or
// for individual operations.
repeated LibraryFusionType xla_cpu_experimental_ynn_fusion_type = 422;
// When xla_cpu_enable_fast_math is true then this controls whether we forbid
// to use the reciprocal of an argument instead of division. Ignored when
// xla_cpu_enable_fast_math is false.
@ -1361,7 +1366,7 @@ message DebugOptions {
// Note: when adding a new flag, please add it to one of the hardware-specific
// or hardware-agnostic sections at the top of this proto message.
// Next id: 422
// Next id: 423
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.