[XLA:CPU] Add initial bits for YNNPACK support.

+ Do not build XLA with YNNPACK on Windows.

Co-authored-by: Penporn Koanantakool <penporn@google.com>
PiperOrigin-RevId: 820896434
This commit is contained in:
Alexander Shaposhnikov 2025-10-17 18:19:12 -07:00 committed by TensorFlower Gardener
parent f0057ee4b7
commit ce65a0ad5c
23 changed files with 1849 additions and 5 deletions

View File

@ -153,6 +153,30 @@ cc_library(
],
)
cc_library(
name = "ynn_emitter",
srcs = ["ynn_emitter.cc"],
hdrs = ["ynn_emitter.h"],
deps = [
":ynn_support",
"//xla:literal",
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/runtime/xnnpack:xnn_interop",
"//xla/backends/cpu/runtime/ynnpack:ynn_interop",
"//xla/hlo/ir:hlo",
"//xla/tsl/platform:logging",
"//xla/tsl/platform:statusor",
"@XNNPACK//ynnpack",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "xnn_gemm_config",
srcs = ["xnn_gemm_config.cc"],
@ -234,6 +258,33 @@ cc_library(
],
)
cc_library(
name = "ynn_support",
srcs = ["ynn_support.cc"],
hdrs = ["ynn_support.h"],
deps = [
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/codegen:target_machine_features",
"//xla/backends/cpu/runtime:dot_lib",
"//xla/backends/cpu/runtime/ynnpack:ynn_interop",
"//xla/hlo/ir:hlo",
"//xla/service:pattern_matcher",
"//xla/tsl/platform:statusor",
"@XNNPACK//ynnpack",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "constant_allocation",
srcs = ["constant_allocation.cc"],

View File

@ -4,6 +4,7 @@ load("//xla/tsl:tsl.bzl", "if_windows", "internal_visibility")
load("//xla/tsl:tsl.default.bzl", "filegroup")
load("//xla/tsl/platform:build_config.bzl", "tf_proto_library")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
load("//xla/tsl/xnnpack:build_defs.bzl", "if_ynnpack")
package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
@ -160,6 +161,7 @@ cc_library(
name = "thunk",
srcs = ["thunk.cc"],
hdrs = ["thunk.h"],
defines = if_ynnpack(["XLA_YNNPACK"]),
deps = [
":buffer_allocations",
":function_library",
@ -188,7 +190,10 @@ cc_library(
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/profiler/lib:traceme",
"@local_tsl//tsl/profiler/lib:traceme_encode",
],
] + if_ynnpack([
"//xla/backends/cpu/runtime/ynnpack:ynn_interop",
"//xla/backends/cpu/runtime/ynnpack:ynn_threadpool",
]),
)
cc_library(
@ -1246,6 +1251,7 @@ xla_cc_test(
xla_cc_test(
name = "thunk_sequence_serdes_test",
srcs = ["thunk_sequence_serdes_test.cc"],
local_defines = if_ynnpack(["XLA_YNNPACK"]),
deps = [
":all_gather_thunk",
":all_reduce_thunk",
@ -1261,7 +1267,6 @@ xla_cc_test(
":dot_thunk",
":fft_thunk",
":infeed_thunk",
":kernel",
":kernel_thunk",
":logical_id_thunk",
":outfeed_thunk",
@ -1293,7 +1298,6 @@ xla_cc_test(
"//xla/service:hlo_proto_cc",
"//xla/stream_executor:launch_dim",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:status",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
@ -1305,7 +1309,9 @@ xla_cc_test(
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:casts",
],
] + if_ynnpack([
"//xla/backends/cpu/runtime/ynnpack:ynn_fusion_thunk",
]),
)
cc_library(

View File

@ -43,6 +43,11 @@ limitations under the License.
#include "tsl/profiler/lib/traceme.h"
#include "tsl/profiler/lib/traceme_encode.h"
#ifdef XLA_YNNPACK
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
#endif // XLA_YNNPACK
namespace xla::cpu {
// Ok execute event allocated with the static storage duration.
@ -88,6 +93,8 @@ absl::string_view Thunk::KindToString(Kind kind) {
return "while";
case Kind::kXnnFusion:
return "xnn-fusion";
case Kind::kYnnFusion:
return "ynn-fusion";
case Kind::kOneDnnFusion:
return "onednn-fusion";
}
@ -168,6 +175,18 @@ absl::StatusOr<Thunk::XnnParams> Thunk::XnnParams::Create(
Thunk::XnnParams::XnnParams(XnnThreadpool threadpool)
: threadpool(std::move(threadpool)) {}
#ifdef XLA_YNNPACK
absl::StatusOr<Thunk::YnnParams> Thunk::YnnParams::Create(
const ExecutableRunOptions* run_options) {
TF_ASSIGN_OR_RETURN(YnnThreadpool threadpool,
CreateYnnThreadpool(run_options->intra_op_thread_pool()));
return YnnParams(std::move(threadpool));
}
Thunk::YnnParams::YnnParams(YnnThreadpool threadpool)
: threadpool(std::move(threadpool)) {}
#endif // XLA_YNNPACK
Thunk::ExecuteSession::ExecuteSession(int64_t max_workers,
int64_t split_threshold)
: lock_(std::make_shared<std::nullopt_t>(std::nullopt)),

View File

@ -47,6 +47,11 @@ limitations under the License.
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/statusor.h"
#ifdef XLA_YNNPACK
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
#endif // XLA_YNNPACK
namespace Eigen {
struct ThreadPoolDevice;
} // namespace Eigen
@ -87,6 +92,7 @@ class Thunk {
kTopK,
kWhile,
kXnnFusion,
kYnnFusion,
kOneDnnFusion,
};
@ -262,6 +268,25 @@ class Thunk {
explicit XnnParams(XnnThreadpool threadpool);
};
//===--------------------------------------------------------------------===//
// YnnParams
//===--------------------------------------------------------------------===//
#ifdef XLA_YNNPACK
// Parameters capturing all the details required for running XNNPACK fusions.
struct YnnParams {
static absl::StatusOr<YnnParams> Create(
const ExecutableRunOptions* run_options);
YnnThreadpool threadpool = nullptr;
explicit YnnParams(YnnThreadpool threadpool);
};
#else
// Use XnnParams for placeholder. The parameter won't be used anyway.
using YnnParams = XnnParams;
#endif // XLA_YNNPACK
//===--------------------------------------------------------------------===//
// ExecuteParams
//===--------------------------------------------------------------------===//
@ -277,6 +302,7 @@ class Thunk {
CollectiveExecuteParams* collective_params = nullptr;
CustomCallExecuteParams* custom_call_params = nullptr;
XnnParams* xnn_params = nullptr;
YnnParams* ynn_params = nullptr;
int64_t run_id = -1; // -1 means no run id is set.
int64_t device_ordinal = -1; // -1 means no device ordinal is set.
ExecuteSession session = ExecuteSession(ExecuteSession::kMaxWorkers,

View File

@ -80,6 +80,10 @@ limitations under the License.
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"
#ifdef XLA_YNNPACK
#include "xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.h"
#endif // XLA_YNNPACK
namespace xla::cpu {
namespace {
@ -1103,6 +1107,15 @@ class ThunkSequenceSerdesTest : public ::testing::Test {
return false;
}
#ifdef XLA_YNNPACK
bool VerifyYnnFusionThunkEquality(const YnnFusionThunk& thunk_1,
const YnnFusionThunk& thunk_2) {
// TODO(ashaposhnikov) assume this is always false until we implement
// serialization of YnnFusionThunk.
return false;
}
#endif // XLA_YNNPACK
bool VerifyXnnDotThunkEquality(const XnnDotThunk& thunk_1,
const XnnDotThunk& thunk_2) {
const bool are_dot_dimensions_equal =
@ -1412,6 +1425,24 @@ class ThunkSequenceSerdesTest : public ::testing::Test {
tsl::down_cast<const XnnConvolutionThunk&>(thunk_2));
}
}
case Thunk::Kind::kYnnFusion: {
#ifdef XLA_YNNPACK
const YnnFusionThunk& ynn_fusion_thunk_1 =
tsl::down_cast<const YnnFusionThunk&>(thunk_1);
const YnnFusionThunk& ynn_fusion_thunk_2 =
tsl::down_cast<const YnnFusionThunk&>(thunk_2);
if (ynn_fusion_thunk_1.ynn_fusion_kind() !=
ynn_fusion_thunk_2.ynn_fusion_kind()) {
return false;
}
return VerifyYnnFusionThunkEquality(
tsl::down_cast<const YnnFusionThunk&>(thunk_1),
tsl::down_cast<const YnnFusionThunk&>(thunk_2));
#else
CHECK(false) << "Unsupported YNN fusion thunk type";
return false;
#endif // XLA_YNNPACK
}
case Thunk::Kind::kKernel:
return VerifyKernelThunkEquality(
tsl::down_cast<const KernelThunkBase&>(thunk_1),

View File

@ -0,0 +1,103 @@
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
load("//xla/tsl/xnnpack:build_defs.bzl", "ynn_cc_test")
package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
default_visibility = [":friends"],
licenses = ["notice"],
)
package_group(
name = "friends",
includes = [
"//xla:friends",
],
)
cc_library(
name = "ynn_interop",
srcs = ["ynn_interop.cc"],
hdrs = ["ynn_interop.h"],
deps = [
"//xla:shape_util",
"//xla:util",
"//xla/tsl/platform:logging",
"@XNNPACK//ynnpack",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "ynn_threadpool",
srcs = ["ynn_threadpool.cc"],
hdrs = ["ynn_threadpool.h"],
deps = [
":ynn_interop",
"@XNNPACK//ynnpack:ynnpack_h",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status:statusor",
"@eigen_archive//:eigen3",
],
)
cc_library(
name = "ynn_fusion_thunk",
srcs = ["ynn_fusion_thunk.cc"],
hdrs = ["ynn_fusion_thunk.h"],
deps = [
":ynn_interop",
"//xla:shape_util",
"//xla/backends/cpu/runtime:thunk",
"//xla/runtime:buffer_use",
"//xla/runtime:object_pool",
"//xla/service:buffer_assignment",
"//xla/stream_executor:device_memory",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
"//xla/tsl/platform:statusor",
"@XNNPACK//ynnpack:ynnpack_h",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
],
)
ynn_cc_test(
name = "ynn_fusion_thunk_test",
srcs = ["ynn_fusion_thunk_test.cc"],
deps = [
":ynn_fusion_thunk",
":ynn_interop",
":ynn_threadpool",
"//xla:literal_util",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/runtime:buffer_allocations",
"//xla/backends/cpu/runtime:thunk",
"//xla/backends/cpu/runtime:thunk_testlib",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/platform:env",
"//xla/tsl/platform:statusor",
"//xla/tsl/platform:test",
"@XNNPACK//ynnpack:ynnpack_h",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@eigen_archive//:eigen3",
],
)

View File

@ -0,0 +1,371 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.h"
#include <cstddef>
#include <cstdint>
#include <memory>
#include <ostream>
#include <utility>
#include <vector>
#include "ynnpack/include/ynnpack.h"
#include "absl/algorithm/container.h"
#include "absl/base/no_destructor.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/bind_front.h"
#include "absl/functional/function_ref.h"
#include "absl/log/check.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/runtime/buffer_use.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/statusor.h"
namespace xla::cpu {
absl::string_view YnnFusionThunk::YnnFusionKindToString(YnnFusionKind kind) {
switch (kind) {
case YnnFusionKind::kFusion:
return "ynn-fusion";
}
}
std::ostream& operator<<(std::ostream& os, YnnFusionThunk::YnnFusionKind kind) {
return os << YnnFusionThunk::YnnFusionKindToString(kind);
}
// YNNPACK executable instantiated for the fusion operation.
struct YnnFusionThunk::YnnExecutable {
tsl::AsyncValueRef<YnnFusionThunk::ExecuteEvent> Invoke(
const YnnThreadpool& threadpool,
absl::Span<se::DeviceMemoryBase> arguments,
absl::Span<se::DeviceMemoryBase> results,
absl::FunctionRef<bool(size_t)> is_captured_argument);
// Resets YNNPACK runtime and subgraph.
absl::Status Reset();
YnnSubgraph subgraph = nullptr;
YnnRuntime runtime = nullptr;
// TODO(ezhulenev): Today we rely on device memory as an identity of the
// captured argument, and this is not correct as we can have multiple
// arguments allocated to the heap address. This is work in progress, and will
// be migrated to a buffer identity passed to XLA by the client (PjRt).
std::vector<se::DeviceMemoryBase> captured_arguments;
};
namespace {
struct YnnExternalValue {
uint32_t id;
void* data;
};
} // namespace
static enum ynn_status set_external_values(
ynn_runtime_t runtime, absl::Span<const YnnExternalValue> external_values) {
for (const auto& [id, data] : external_values) {
enum ynn_status status = ynn_set_external_value_data(runtime, id, data);
if (status != ynn_status_success) {
return status;
}
}
return ynn_status_success;
}
tsl::AsyncValueRef<YnnFusionThunk::ExecuteEvent>
YnnFusionThunk::YnnExecutable::Invoke(
const YnnThreadpool& threadpool, absl::Span<se::DeviceMemoryBase> arguments,
absl::Span<se::DeviceMemoryBase> results,
absl::FunctionRef<bool(size_t)> is_captured_argument) {
// Create external values for all arguments and results.
absl::InlinedVector<YnnExternalValue, 8> external_values;
external_values.reserve(arguments.size() + results.size());
// External tensor id for arguments and results.
uint32_t id = 0;
for (const se::DeviceMemoryBase& argument : arguments) {
YnnExternalValue value{id++, argument.opaque()};
if (!is_captured_argument(value.id)) {
external_values.push_back(value);
}
}
for (const se::DeviceMemoryBase& result : results) {
YnnExternalValue value{id++, result.opaque()};
external_values.push_back(value);
}
DCHECK_NE(runtime.get(), nullptr) << "YNNPACK runtime is not initialized";
YNN_RETURN_IF_ERROR(set_external_values(runtime.get(), external_values));
// Update threadpool used by the YNNPACK runtime.
YNN_RETURN_IF_ERROR(ynn_update_runtime_with_threadpool(
runtime.get(), reinterpret_cast<ynn_threadpool_t>(threadpool.get())));
// Execute YNNPACK runtime in the caller thread.
YNN_RETURN_IF_ERROR(ynn_invoke_runtime(runtime.get()));
return OkExecuteEvent();
}
absl::Status YnnFusionThunk::YnnExecutable::Reset() {
runtime.reset();
subgraph.reset();
return absl::OkStatus();
}
absl::StatusOr<YnnFusionThunk::YnnExecutable>
YnnFusionThunk::CreateYnnExecutable(
const YnnThreadpool& threadpool,
absl::Span<const se::DeviceMemoryBase> arguments_buffers) {
bool capturing = !captured_arguments_ids_.empty();
VLOG(3) << absl::StreamFormat(
"Create %s YNN executable for `%s` operation: num_created=%d",
capturing ? "capturing" : "pooled", info().op_name,
capturing ? num_capturing_created_.fetch_add(1)
: ynn_executable_pool_.num_created());
YnnExecutable executable;
// Keep track of the arguments captured by value.
executable.captured_arguments = CaptureArguments(arguments_buffers);
if (builder_) {
TF_ASSIGN_OR_RETURN(executable.subgraph, builder_(arguments_, results_));
} else {
TF_ASSIGN_OR_RETURN(
executable.subgraph,
capturing_builder_(arguments_, results_, arguments_buffers));
}
TF_ASSIGN_OR_RETURN(
executable.runtime, CreateYnnRuntime([&](ynn_runtime_t* runtime) {
uint32_t ynn_flags = 0;
return ynn_create_runtime(
executable.subgraph.get(),
reinterpret_cast<ynn_threadpool_t>(threadpool.get()), ynn_flags,
runtime);
}));
YNN_RETURN_IF_ERROR(ynn_reshape_runtime(executable.runtime.get()));
return {std::move(executable)};
}
absl::Status YnnFusionThunk::UpdateYnnExecutable(
const YnnThreadpool& threadpool, YnnExecutable& executable,
absl::Span<const se::DeviceMemoryBase> arguments_buffers) {
DCHECK(capturing_builder_) << "YNN executable is not capturing arguments";
DCHECK_EQ(executable.captured_arguments.size(),
captured_arguments_ids_.size())
<< "Unexpected number of captured arguments";
// If all arguments captured by value are the same as the last execution,
// we can reuse the YNN executable.
auto capture_arguments = CaptureArguments(arguments_buffers);
if (executable.captured_arguments == capture_arguments) {
VLOG(3) << absl::StreamFormat("Reuse YNN executable for `%s` operation",
info().op_name);
return absl::OkStatus();
}
VLOG(3) << absl::StreamFormat("Update YNN executable for `%s` operation",
info().op_name);
TF_RETURN_IF_ERROR(executable.Reset());
// Keep track of the updated arguments captured by value.
executable.captured_arguments = std::move(capture_arguments);
TF_ASSIGN_OR_RETURN(
executable.subgraph,
capturing_builder_(arguments_, results_, arguments_buffers));
TF_ASSIGN_OR_RETURN(
executable.runtime, CreateYnnRuntime([&](ynn_runtime_t* runtime) {
uint32_t ynn_flags = 0;
return ynn_create_runtime(
executable.subgraph.get(),
reinterpret_cast<ynn_threadpool_t>(threadpool.get()), ynn_flags,
runtime);
}));
YNN_RETURN_IF_ERROR(ynn_reshape_runtime(executable.runtime.get()));
return absl::OkStatus();
}
std::vector<se::DeviceMemoryBase> YnnFusionThunk::CaptureArguments(
absl::Span<const se::DeviceMemoryBase> arguments_buffers) {
std::vector<se::DeviceMemoryBase> captured_arguments_ids;
captured_arguments_ids.reserve(captured_arguments_ids_.size());
for (int64_t i = 0; i < captured_arguments_ids_.size(); ++i) {
int32_t arg_index = captured_arguments_ids_[i];
captured_arguments_ids.push_back(arguments_buffers[arg_index]);
}
return captured_arguments_ids;
}
absl::StatusOr<std::unique_ptr<YnnFusionThunk>> YnnFusionThunk::Create(
Options options, Info info, std::vector<Argument> arguments,
std::vector<Result> results, Builder builder) {
return absl::WrapUnique(new YnnFusionThunk(
YnnFusionKind::kFusion, std::move(options), std::move(info),
std::move(arguments), std::move(results), std::move(builder)));
}
absl::StatusOr<std::unique_ptr<YnnFusionThunk>> YnnFusionThunk::Create(
Options options, Info info, std::vector<Argument> arguments,
std::vector<Result> results, CapturingBuilder capturing_builder,
absl::Span<const int64_t> captured_arguments_ids) {
return absl::WrapUnique(new YnnFusionThunk(
YnnFusionKind::kFusion, std::move(options), std::move(info),
std::move(arguments), std::move(results), std::move(capturing_builder),
captured_arguments_ids));
}
YnnFusionThunk::YnnFusionThunk(YnnFusionKind kind, Options options, Info info,
std::vector<Argument> arguments,
std::vector<Result> results, Builder builder)
: Thunk(Kind::kYnnFusion, std::move(info)),
ynn_fusion_kind_(kind),
options_(std::move(options)),
arguments_(std::move(arguments)),
results_(std::move(results)),
builder_(std::move(builder)),
ynn_executable_pool_(
absl::bind_front(&YnnFusionThunk::CreateYnnExecutable, this)) {}
YnnFusionThunk::YnnFusionThunk(YnnFusionKind kind, Options options, Info info,
std::vector<Argument> arguments,
std::vector<Result> results,
CapturingBuilder capturing_builder,
absl::Span<const int64_t> captured_arguments_ids)
: Thunk(Kind::kYnnFusion, std::move(info)),
ynn_fusion_kind_(kind),
options_(std::move(options)),
arguments_(std::move(arguments)),
results_(std::move(results)),
capturing_builder_(std::move(capturing_builder)),
captured_arguments_ids_(captured_arguments_ids.begin(),
captured_arguments_ids.end()),
ynn_executable_pool_(
absl::bind_front(&YnnFusionThunk::CreateYnnExecutable, this)) {}
YnnFusionThunk::~YnnFusionThunk() = default;
YnnFusionThunk::BufferUses YnnFusionThunk::buffer_uses() const {
BufferUses buffer_uses;
for (const Argument& argument : arguments_) {
buffer_uses.push_back(BufferUse::Read(argument.slice));
}
for (const Result& result : results_) {
buffer_uses.push_back(BufferUse::Write(result.slice));
}
return buffer_uses;
}
const YnnThreadpool& GetYnnThreadpool(const Thunk::ExecuteParams& params) {
static absl::NoDestructor<YnnThreadpool> no_threadpool(nullptr);
return params.ynn_params ? params.ynn_params->threadpool : *no_threadpool;
}
tsl::AsyncValueRef<YnnFusionThunk::ExecuteEvent> YnnFusionThunk::Execute(
const ExecuteParams& params) {
VLOG(3) << absl::StreamFormat("YNN %s `%s`: %s", fusion_kind(),
info().op_name, fusion_description());
if (VLOG_IS_ON(3) && has_fusion_details()) {
for (auto& detail : fusion_details()) {
VLOG(3) << detail;
}
}
// Resolve device memory for arguments.
absl::InlinedVector<se::DeviceMemoryBase, 8> arguments_buffers;
arguments_buffers.resize(arguments_.size());
for (size_t i = 0; i < arguments_.size(); ++i) {
Argument& argument = arguments_[i];
TF_ASSIGN_OR_RETURN(
arguments_buffers[i],
params.buffer_allocations->GetDeviceAddress(argument.slice));
VLOG(3) << absl::StreamFormat(" %s: %s in slice %s (%p)", argument_name(i),
argument.shape.ToString(true),
argument.slice.ToString(),
arguments_buffers[i].opaque());
}
// Resolve device memory for results.
absl::InlinedVector<se::DeviceMemoryBase, 4> results_buffers;
results_buffers.resize(results_.size());
for (size_t i = 0; i < results_.size(); ++i) {
Result& result = results_[i];
TF_ASSIGN_OR_RETURN(
results_buffers[i],
params.buffer_allocations->GetDeviceAddress(results_[i].slice));
VLOG(3) << absl::StreamFormat(" %s: %s in slice %s (%p)", result_name(i),
result.shape.ToString(true),
result.slice.ToString(),
results_buffers[i].opaque());
}
DCHECK(builder_ || capturing_builder_) << "One of the builders must be set.";
auto invoke = [&](typename YnnExecutablePool::BorrowedObject executable) {
auto executed = executable->Invoke(
GetYnnThreadpool(params), absl::MakeSpan(arguments_buffers),
absl::MakeSpan(results_buffers), [&](size_t id) {
return absl::c_linear_search(captured_arguments_ids_, id);
});
// Do not return executable to the pool until the execution is done.
executed.AndThen([executable = std::move(executable)] {});
return executed;
};
// Borrow YnnExecutable from the pool.
TF_ASSIGN_OR_RETURN(auto executable,
ynn_executable_pool_.GetOrCreate(GetYnnThreadpool(params),
arguments_buffers));
// If YNN graph doesn't capture any of the arguments by value, we can execute
// XnnExecutable immediately.
if (captured_arguments_ids_.empty()) {
return invoke(std::move(executable));
}
// Otherwise reset YnnExecutable to capture new arguments buffers.
TF_RETURN_IF_ERROR(UpdateYnnExecutable(GetYnnThreadpool(params), *executable,
arguments_buffers));
return invoke(std::move(executable));
}
} // namespace xla::cpu

View File

@ -0,0 +1,182 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_FUSION_THUNK_H_
#define XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_FUSION_THUNK_H_
#include <stdbool.h>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <ostream>
#include <string>
#include <vector>
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/runtime/object_pool.h"
#include "xla/service/buffer_assignment.h"
#include "xla/shape.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/concurrency/async_value_ref.h"
namespace xla::cpu {
// YNN fusion thunk encapsulates YNNPACK subgraph constructed from an XLA fusion
// operation.
class YnnFusionThunk : public Thunk {
public:
enum class YnnFusionKind {
kFusion,
};
static absl::string_view YnnFusionKindToString(YnnFusionKind kind);
~YnnFusionThunk() override;
struct Options {
// Pass YnnThreadpool constructed from the intra-op threadpool to the
// YNNPACK runtime to allow YNNPACK to parallelize the execution.
bool use_threadpool = true;
};
struct Argument {
BufferAllocation::Slice slice;
Shape shape;
};
struct Result {
BufferAllocation::Slice slice;
Shape shape;
};
// Builder function constructs YNNPACK subgraph for the fusion operation.
using Builder = absl::AnyInvocable<absl::StatusOr<YnnSubgraph>(
absl::Span<const Argument> arguments, absl::Span<const Result> results)>;
// Builder function that constructs YNNPACK subgraph for the fusion operation
// and captures some of the arguments buffers by value. Such YNNPACK subgraphs
// can't be reused if captured arguments are not the same, and can lead to
// crashes and undefined behavior if captured arguments are destroyed.
// Capturing arguments by value allows YNNPACK to do packing at graph compile
// time, and avoid re-packing costs at run time (at inference weights stay
// constant, i.e. convolution filters and one of the dot arguments).
using CapturingBuilder = absl::AnyInvocable<absl::StatusOr<YnnSubgraph>(
absl::Span<const Argument> arguments, absl::Span<const Result> results,
absl::Span<const se::DeviceMemoryBase> arguments_buffers)>;
static absl::StatusOr<std::unique_ptr<YnnFusionThunk>> Create(
Options options, Info info, std::vector<Argument> arguments,
std::vector<Result> results, Builder builder);
static absl::StatusOr<std::unique_ptr<YnnFusionThunk>> Create(
Options options, Info info, std::vector<Argument> arguments,
std::vector<Result> results, CapturingBuilder capturing_builder,
absl::Span<const int64_t> captured_arguments_ids);
tsl::AsyncValueRef<ExecuteEvent> Execute(const ExecuteParams& params) final;
bool ExecuteMayBlock() const final { return true; }
BufferUses buffer_uses() const final;
Options options() const { return options_; }
YnnFusionKind ynn_fusion_kind() const { return ynn_fusion_kind_; }
protected:
YnnFusionThunk(YnnFusionKind kind, Options options, Info info,
std::vector<Argument> arguments, std::vector<Result> results,
Builder builder);
YnnFusionThunk(YnnFusionKind kind, Options options, Info info,
std::vector<Argument> arguments, std::vector<Result> results,
CapturingBuilder capturing_builder,
absl::Span<const int64_t> captured_arguments_ids);
// Extension points for subclasses to customize the logging behavior.
virtual std::string fusion_kind() const { return "fusion"; }
virtual std::string fusion_description() const { return ""; }
virtual bool has_fusion_details() const { return false; }
virtual std::vector<std::string> fusion_details() const { return {}; }
virtual std::string argument_name(size_t index) const {
return absl::StrCat("arg #", index);
}
virtual std::string result_name(size_t index) const {
return absl::StrCat("res #", index);
}
private:
// YNNPACK subgraph + runtime instantiated and ready for execution.
struct YnnExecutable;
// Creates YnnExecutable for the fusion operation using one of the builders.
absl::StatusOr<YnnExecutable> CreateYnnExecutable(
const YnnThreadpool& threadpool,
absl::Span<const se::DeviceMemoryBase> arguments_buffers);
// Updates YnnExecutable to the YNN subgraph constructed with the given
// arguments buffers.
absl::Status UpdateYnnExecutable(
const YnnThreadpool& threadpool, YnnExecutable& executable,
absl::Span<const se::DeviceMemoryBase> arguments_buffers);
// Returns the list of captured arguments buffers.
std::vector<se::DeviceMemoryBase> CaptureArguments(
absl::Span<const se::DeviceMemoryBase> arguments_buffers);
YnnFusionKind ynn_fusion_kind_;
Options options_;
std::vector<Argument> arguments_;
std::vector<Result> results_;
// Builder that constructs YNNPACK subgraph for the fusion operation.
Builder builder_;
// Builder that constructs YNNPACK subgraph for the fusion operation and
// captures some of the arguments buffers by value. Such subgraphs can't be
// reused if captured arguments changed since the last execution.
CapturingBuilder capturing_builder_;
// Indices of arguments that are captured by YNNPACK subgraph by value.
std::vector<int64_t> captured_arguments_ids_;
// XLA:CPU executable can be called concurrently from multiple threads,
// and we need to keep a pool of YNNPACK executables to avoid data races.
using YnnExecutablePool = ObjectPool<YnnExecutable, const YnnThreadpool&,
absl::Span<const se::DeviceMemoryBase>>;
YnnExecutablePool ynn_executable_pool_;
// The number of YNNPACK executables created for capturing graphs.
std::atomic<int64_t> num_capturing_created_{0};
};
std::ostream& operator<<(std::ostream& os, YnnFusionThunk::YnnFusionKind kind);
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_FUSION_THUNK_H_

View File

@ -0,0 +1,161 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.h"
#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>
#include <vector>
#include "ynnpack/include/ynnpack.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "xla/backends/cpu/runtime/buffer_allocations.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/backends/cpu/runtime/thunk_testlib.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
#include "xla/literal_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
#include "xla/tsl/platform/threadpool.h"
#include "xla/xla_data.pb.h"
#define EIGEN_USE_THREADS
#include "unsupported/Eigen/CXX11/Tensor"
namespace xla::cpu {
namespace {
static absl::StatusOr<YnnSubgraph> BuildBinaryAddSubgraph(
absl::Span<const YnnFusionThunk::Argument> arguments,
absl::Span<const YnnFusionThunk::Result> results) {
TF_ASSIGN_OR_RETURN(YnnSubgraph subgraph,
CreateYnnSubgraph([&](ynn_subgraph_t* subgraph) {
return ynn_create_subgraph(
/*external_value_ids=*/3,
/*flags=*/0, subgraph);
}));
auto dims = [](absl::Span<const int64_t> dims) -> std::vector<size_t> {
return {dims.begin(), dims.end()};
};
uint32_t lhs_id = 0;
uint32_t rhs_id = 1;
uint32_t out_id = 2;
std::vector<size_t> lhs_dims = dims(arguments[0].shape.dimensions());
std::vector<size_t> rhs_dims = dims(arguments[1].shape.dimensions());
std::vector<size_t> out_dims = dims(results[0].shape.dimensions());
YNN_RETURN_IF_ERROR(
ynn_define_tensor_value(subgraph.get(), ynn_type_fp32, lhs_dims.size(),
lhs_dims.data(), /*data=*/nullptr,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID,
YNN_VALUE_FLAG_EXTERNAL_INPUT, &lhs_id));
YNN_RETURN_IF_ERROR(
ynn_define_tensor_value(subgraph.get(), ynn_type_fp32, rhs_dims.size(),
rhs_dims.data(), /*data=*/nullptr,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID,
YNN_VALUE_FLAG_EXTERNAL_INPUT, &rhs_id));
YNN_RETURN_IF_ERROR(
ynn_define_tensor_value(subgraph.get(), ynn_type_fp32, rhs_dims.size(),
rhs_dims.data(), /*data=*/nullptr,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID,
YNN_VALUE_FLAG_EXTERNAL_OUTPUT, &out_id));
YNN_RETURN_IF_ERROR(ynn_define_binary(subgraph.get(), ynn_binary_add, lhs_id,
rhs_id, &out_id, /*flags=*/0));
return subgraph;
}
class YnnFusionThunkTest : public testing::TestWithParam<bool> {
public:
static std::string Name(const ::testing::TestParamInfo<bool>& info) {
return absl::StrCat(info.param ? "threadpool" : "single_threaded");
}
protected:
bool use_threadpool() const { return GetParam(); }
};
TEST_P(YnnFusionThunkTest, ElementwiseAdd) {
if (use_threadpool()) {
GTEST_SKIP() << "Threadpool is not yet supported. Needs more clean-up.";
}
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());
auto lhs = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
auto rhs = LiteralUtil::CreateR1<float>({4.0, 3.0, 2.0, 1.0});
auto out = LiteralUtil::CreateR1<float>({0.0, 0.0, 0.0, 0.0});
BufferAllocations allocations = CreateBufferAllocations(lhs, rhs, out);
auto [lhs_alloc, rhs_alloc, out_alloc] =
CreateBufferAllocation(lhs, rhs, out);
auto [lhs_slice, rhs_slice, out_slice] =
CreateBufferAllocationSlice(lhs_alloc, rhs_alloc, out_alloc);
Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
YnnFusionThunk::Argument lhs_arg = {lhs_slice, shape};
YnnFusionThunk::Argument rhs_arg = {rhs_slice, shape};
YnnFusionThunk::Result out_res = {out_slice, shape};
TF_ASSERT_OK_AND_ASSIGN(
auto thunk, YnnFusionThunk::Create(
YnnFusionThunk::Options{use_threadpool()}, {"fusion"},
{lhs_arg, rhs_arg}, {out_res}, &BuildBinaryAddSubgraph));
YnnThreadpool threadpool;
if (use_threadpool()) {
TF_ASSERT_OK_AND_ASSIGN(threadpool, CreateYnnThreadpool(&device));
}
Thunk::YnnParams ynn_params(std::move(threadpool));
Thunk::ExecuteParams params;
params.buffer_allocations = &allocations;
params.intra_op_threadpool = use_threadpool() ? &device : nullptr;
params.ynn_params = &ynn_params;
auto execute_event = thunk->Execute(params);
tsl::BlockUntilReady(execute_event);
ASSERT_FALSE(execute_event.IsError()) << execute_event.GetError();
EXPECT_EQ(out, LiteralUtil::CreateR1<float>({5.0, 5.0, 5.0, 5.0}));
}
INSTANTIATE_TEST_SUITE_P(YnnFusion, YnnFusionThunkTest, ::testing::Bool(),
YnnFusionThunkTest::Name);
} // namespace
} // namespace xla::cpu

View File

@ -0,0 +1,61 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "ynnpack/include/ynnpack.h"
#include "absl/functional/function_ref.h"
#include "absl/status/statusor.h"
#include "xla/primitive_util.h"
#include "xla/util.h"
namespace xla::cpu {
absl::StatusOr<YnnSubgraph> CreateYnnSubgraph(
absl::FunctionRef<ynn_status(ynn_subgraph_t*)> builder) {
ynn_subgraph_t subgraph = nullptr;
YNN_RETURN_IF_ERROR(builder(&subgraph));
return YnnSubgraph(subgraph);
}
absl::StatusOr<YnnRuntime> CreateYnnRuntime(
absl::FunctionRef<ynn_status(ynn_runtime_t*)> builder) {
ynn_runtime_t runtime = nullptr;
YNN_RETURN_IF_ERROR(builder(&runtime));
return YnnRuntime(runtime);
}
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
absl::FunctionRef<ynn_status(ynn_threadpool_t*)> builder) {
ynn_threadpool_t threadpool = nullptr;
YNN_RETURN_IF_ERROR(builder(&threadpool));
return YnnThreadpool(threadpool);
}
absl::StatusOr<ynn_type> YnnType(const PrimitiveType& type) {
switch (type) {
case BF16:
return ynn_type_bf16;
case F16:
return ynn_type_fp16;
case F32:
return ynn_type_fp32;
default:
return InvalidArgument("Unsupported YNNPACK type: %s",
primitive_util::LowercasePrimitiveTypeName(type));
}
}
} // namespace xla::cpu

View File

@ -0,0 +1,111 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_INTEROP_H_
#define XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_INTEROP_H_
#include <memory>
#include "ynnpack/include/ynnpack.h"
#include "absl/base/optimization.h"
#include "absl/functional/function_ref.h"
#include "absl/status/status.h"
#include "xla/tsl/platform/logging.h"
#include "xla/util.h"
namespace xla::cpu {
//===----------------------------------------------------------------------===//
// YNNPACK status to ABSL status conversion macros.
//===----------------------------------------------------------------------===//
#define YNN_RETURN_IF_ERROR(expr) \
do { \
absl::Status s = YnnStatusToStatus(expr); \
if (!s.ok()) { \
return s; \
} \
} while (0)
#define YNN_LOG_IF_ERROR(expr) \
do { \
absl::Status s = YnnStatusToStatus(expr); \
if (!s.ok()) { \
LOG(ERROR) << "YNNPACK operation failed: " << s; \
} \
} while (0)
// Converts YNNPACK status to absl::Status.
inline absl::Status YnnStatusToStatus(ynn_status status) {
if (ABSL_PREDICT_TRUE(status == ynn_status_success)) {
return absl::OkStatus();
}
auto error_message = [](ynn_status status) {
switch (status) {
case ynn_status_success:
return "";
case ynn_status_deprecated:
return "deprecated";
case ynn_status_error:
return "error";
case ynn_status_invalid_parameter:
return "invalid parameter";
case ynn_status_unsupported_parameter:
return "unsupported parameter";
}
};
return Internal("YNNPACK operation failed: %s", error_message(status));
}
//===----------------------------------------------------------------------===//
// XLA to YNNPACK type conversions.
//===----------------------------------------------------------------------===//
absl::StatusOr<ynn_type> YnnType(const PrimitiveType& type);
//===----------------------------------------------------------------------===//
// RAII wrappers for YNNPACK types.
//===----------------------------------------------------------------------===//
namespace internal {
struct YnnDeleter {
void operator()(ynn_subgraph* subgraph) { ynn_delete_subgraph(subgraph); }
void operator()(ynn_runtime* runtime) { ynn_delete_runtime(runtime); }
void operator()(ynn_threadpool* threadpool) {
ynn_delete_threadpool(threadpool);
}
};
} // namespace internal
using YnnSubgraph = std::unique_ptr<ynn_subgraph, internal::YnnDeleter>;
using YnnRuntime = std::unique_ptr<ynn_runtime, internal::YnnDeleter>;
using YnnThreadpool = std::unique_ptr<ynn_threadpool, internal::YnnDeleter>;
absl::StatusOr<YnnSubgraph> CreateYnnSubgraph(
absl::FunctionRef<ynn_status(ynn_subgraph_t*)> builder);
absl::StatusOr<YnnRuntime> CreateYnnRuntime(
absl::FunctionRef<ynn_status(ynn_runtime_t*)> builder);
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
absl::FunctionRef<ynn_status(ynn_threadpool_t*)> builder);
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_INTEROP_H_

View File

@ -0,0 +1,62 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
#include <cstdint>
#include "ynnpack/include/ynnpack.h"
#include "absl/base/optimization.h"
#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#define EIGEN_USE_THREADS
#include "Eigen/ThreadPool"
#include "unsupported/Eigen/CXX11/Tensor"
namespace xla::cpu {
static int32_t NumThreads(void* pool) {
if (ABSL_PREDICT_FALSE(pool == nullptr)) {
return 0;
}
return reinterpret_cast<Eigen::ThreadPoolInterface*>(pool)->NumThreads();
}
static void Schedule(void* pool, void* context, void (*task)(void* context)) {
if (ABSL_PREDICT_FALSE(pool == nullptr)) {
(*task)(context);
}
reinterpret_cast<Eigen::ThreadPoolInterface*>(pool)->Schedule(
[task, context]() { (*task)(context); });
}
// An adaptor from Eigen::ThreadPoolInterface to xnn_threadpool_t.
static constexpr ynn_scheduler kYnnScheduler = {&NumThreads, &Schedule};
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
Eigen::ThreadPoolInterface* threadpool) {
return CreateYnnThreadpool([&](ynn_threadpool_t* ynn_threadpool) {
return ynn_create_threadpool(&kYnnScheduler, threadpool, /*flags=*/1,
ynn_threadpool);
});
}
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
const Eigen::ThreadPoolDevice* device) {
return CreateYnnThreadpool(device->getPool());
}
} // namespace xla::cpu

View File

@ -0,0 +1,39 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_THREADPOOL_H_
#define XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_THREADPOOL_H_
#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
namespace Eigen {
struct ThreadPoolDevice;
class ThreadPoolInterface;
} // namespace Eigen
namespace xla::cpu {
// Creates an YNNPACK threadpool from an Eigen threadpool.
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
Eigen::ThreadPoolInterface* threadpool);
// Creates an YNNPACK threadpool from an Eigen ThreadPoolDevice.
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
const Eigen::ThreadPoolDevice* device);
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_THREADPOOL_H_

View File

@ -0,0 +1,306 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/cpu/ynn_emitter.h"
#include <cstddef>
#include <cstdint>
#include <memory>
#include <vector>
#include "ynnpack/include/ynnpack.h"
#include "absl/container/flat_hash_map.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/backends/cpu/ynn_support.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/literal.h"
#include "xla/shape.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
namespace xla::cpu {
// A mapping from HloInstruction to YNNPACK subgraph tensor id.
using TensorIdMap = absl::flat_hash_map<const HloInstruction*, uint32_t>;
//===----------------------------------------------------------------------===//
// XLA <-> YNNPACK type conversion library.
//===----------------------------------------------------------------------===//
static std::vector<size_t> YnnDimensions(const Shape& shape) {
std::vector<size_t> dims;
for (auto& dim : shape.dimensions()) {
dims.push_back(dim);
}
return dims;
}
//===----------------------------------------------------------------------===//
// XLA <-> YNNPACK emitters.
//===----------------------------------------------------------------------===//
static absl::StatusOr<uint32_t> FindTensorValue(const TensorIdMap& tensor_ids,
const HloInstruction* instr) {
if (auto it = tensor_ids.find(instr); it != tensor_ids.end()) {
return it->second;
}
return Internal("Can't fine YNNPACK tensor value for instruction %s",
instr->ToString());
}
static absl::StatusOr<uint32_t> DefineTensorValue(ynn_subgraph_t subgraph,
const HloInstruction* instr) {
// We do not support instructions with multiple results (tuples).
if (!instr->shape().IsArray()) {
return Internal("Unsupported YNNPACK instruction shape: %s",
instr->ToString());
}
auto dims = YnnDimensions(instr->shape());
TF_ASSIGN_OR_RETURN(auto type, YnnType(instr->shape().element_type()));
uint32_t tensor_id = YNN_INVALID_VALUE_ID;
uint32_t tensor_flags = 0;
// If instruction is a root instruction of the parent computation we assign it
// an external tensor id corresponding to the result index.
const HloComputation* computation = instr->parent();
if (computation->root_instruction() == instr) {
tensor_id = computation->num_parameters();
tensor_flags = YNN_VALUE_FLAG_EXTERNAL_OUTPUT;
}
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
subgraph, type, dims.size(), dims.data(), /*data=*/nullptr,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID, tensor_flags, &tensor_id));
return tensor_id;
}
static absl::StatusOr<uint32_t> DefineConstant(
ynn_subgraph_t subgraph, std::vector<std::unique_ptr<Literal>>& literals,
const HloInstruction* instr) {
// We do not support instructions with multiple results (tuples).
if (!instr->shape().IsArray()) {
return Internal("Unsupported YNNPACK instruction shape: %s",
instr->ToString());
}
auto dims = YnnDimensions(instr->shape());
TF_ASSIGN_OR_RETURN(auto type, YnnType(instr->shape().element_type()));
uint32_t tensor_id = YNN_INVALID_VALUE_ID;
literals.push_back(instr->literal().CloneToUnique());
const void* value = literals.back()->untyped_data();
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
subgraph, type, dims.size(), dims.data(), /*data=*/value,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID,
/*flags=*/0, &tensor_id));
return tensor_id;
}
static absl::StatusOr<uint32_t> DefineParameter(ynn_subgraph_t subgraph,
const HloInstruction* param) {
VLOG(3) << absl::StreamFormat("Define tensor value for parameter: %s",
param->ToString());
auto dims = YnnDimensions(param->shape());
TF_ASSIGN_OR_RETURN(auto type, YnnType(param->shape().element_type()));
uint32_t tensor_id = param->parameter_number();
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
subgraph, type, dims.size(), dims.data(), /*data=*/nullptr,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID, YNN_VALUE_FLAG_EXTERNAL_INPUT,
&tensor_id));
return tensor_id;
}
static absl::StatusOr<uint32_t> DefineBitcastOp(ynn_subgraph_t subgraph,
TensorIdMap& tensor_ids,
const HloInstruction* instr) {
VLOG(3) << absl::StreamFormat("Define tensor value for bitcast op: %s",
instr->ToString());
CHECK_EQ(instr->opcode(), HloOpcode::kBitcast);
const HloInstruction* input = instr->operand(0);
CHECK_EQ(input->shape().element_type(), instr->shape().element_type());
TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input));
TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr));
auto dims = YnnDimensions(instr->shape());
YNN_RETURN_IF_ERROR(ynn_define_static_reshape(subgraph, dims.size(),
dims.data(), in, &out,
/*flags=*/0));
return out;
}
static absl::StatusOr<uint32_t> DefineUnaryOp(ynn_subgraph_t subgraph,
TensorIdMap& tensor_ids,
const HloInstruction* instr) {
VLOG(3) << absl::StreamFormat("Define tensor value for unary op: %s",
instr->ToString());
TF_ASSIGN_OR_RETURN(auto unary_op, YnnUnaryOperator(instr->opcode()));
TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, instr->operand(0)));
TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr));
VLOG(3) << absl::StreamFormat(" tensors: in=%d, out=%d", in, out);
YNN_RETURN_IF_ERROR(
ynn_define_unary(subgraph, unary_op, in, &out, /*flags=*/0));
return out;
}
static absl::StatusOr<uint32_t> DefineBinaryOp(ynn_subgraph_t subgraph,
TensorIdMap& tensor_ids,
const HloInstruction* instr) {
VLOG(3) << absl::StreamFormat("Define tensor value for binary op: %s",
instr->ToString());
TF_ASSIGN_OR_RETURN(auto binary_op, YnnBinaryOperator(instr->opcode()));
TF_ASSIGN_OR_RETURN(auto lhs, FindTensorValue(tensor_ids, instr->operand(0)));
TF_ASSIGN_OR_RETURN(auto rhs, FindTensorValue(tensor_ids, instr->operand(1)));
TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr));
VLOG(3) << absl::StreamFormat(" tensors: lhs=%d, rhs=%d, out=%d", lhs, rhs,
out);
YNN_RETURN_IF_ERROR(
ynn_define_binary(subgraph, binary_op, lhs, rhs, &out, /*flags=*/0));
return out;
}
//===----------------------------------------------------------------------===//
// Emit YNNPACK subgraph for the given HLO computation.
//===----------------------------------------------------------------------===//
static absl::StatusOr<YnnSubgraph> EmitYnnSubgraph(
const HloComputation* computation,
std::vector<std::unique_ptr<Literal>>& literals) {
VLOG(3) << "Emit YNNPACK subgraph for computation: " << computation->name();
TF_ASSIGN_OR_RETURN(
YnnSubgraph subgraph, CreateYnnSubgraph([&](ynn_subgraph_t* subgraph) {
return ynn_create_subgraph(
/*external_value_ids=*/computation->num_parameters() + 1,
/*flags=*/0, subgraph);
}));
// Traverse fused computation in post-order and define YNNPACK operations
// corresponding to each HLO instruction.
TensorIdMap tensor_ids;
auto instructions = computation->MakeInstructionPostOrder();
for (const HloInstruction* instr : instructions) {
if (!IsLayoutSupportedByYnn(instr->shape())) {
return InvalidArgument(
"Instruction with unsupported layout in YNN fusion: %s",
instr->ToString());
}
if (instr->IsConstant()) {
if (!IsConstantSupportedByYnn(instr)) {
return InvalidArgument(
"Unsupported constant instruction in YNN fusion: %s",
instr->ToString());
}
TF_ASSIGN_OR_RETURN(tensor_ids[instr],
DefineConstant(subgraph.get(), literals, instr));
continue;
}
if (instr->IsElementwise()) {
if (!IsElementwiseOpSupportedByYnn(instr)) {
return InvalidArgument(
"Unsupported elementwise instruction in YNN fusion: %s",
instr->ToString());
}
if (instr->operand_count() == 1) {
TF_ASSIGN_OR_RETURN(tensor_ids[instr],
DefineUnaryOp(subgraph.get(), tensor_ids, instr));
} else if (instr->operand_count() == 2) {
TF_ASSIGN_OR_RETURN(tensor_ids[instr],
DefineBinaryOp(subgraph.get(), tensor_ids, instr));
} else {
LOG(FATAL) << "Unexpected operand count " << instr->operand_count();
}
continue;
}
switch (instr->opcode()) {
case HloOpcode::kParameter: {
TF_ASSIGN_OR_RETURN(tensor_ids[instr],
DefineParameter(subgraph.get(), instr));
} break;
case HloOpcode::kBitcast: {
if (!IsBitcastOpSupportedByYnn(instr)) {
return InvalidArgument(
"Unsupported bitcast instruction in YNN fusion: %s",
instr->ToString());
}
TF_ASSIGN_OR_RETURN(tensor_ids[instr],
DefineBitcastOp(subgraph.get(), tensor_ids, instr));
} break;
default: {
return InvalidArgument("Unsupported fusion instruction: %s",
instr->ToString());
}
}
}
return subgraph;
}
absl::StatusOr<absl::AnyInvocable<absl::StatusOr<YnnSubgraph>()>>
EmitYnnFusionBuilder(const HloComputation* computation) {
// We do not support non-array parameters for YNNPACK operations.
for (auto& param : computation->parameter_instructions()) {
if (!param->shape().IsArray()) {
return InvalidArgument(
"YNNPACK fusion parameters must have array shapes, got %s",
param->shape().ToString());
}
}
// Result also must be a single array.
if (!computation->root_instruction()->shape().IsArray()) {
return InvalidArgument("YNNPACK fusion result must be an array, got %s",
computation->root_instruction()->shape().ToString());
}
return [computation,
literals = std::vector<std::unique_ptr<Literal>>()]() mutable {
return EmitYnnSubgraph(computation, literals);
};
}
} // namespace xla::cpu

View File

@ -0,0 +1,31 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_BACKENDS_CPU_YNN_EMITTER_H_
#define XLA_BACKENDS_CPU_YNN_EMITTER_H_
#include "absl/functional/any_invocable.h"
#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/hlo/ir/hlo_computation.h"
namespace xla::cpu {
absl::StatusOr<absl::AnyInvocable<absl::StatusOr<YnnSubgraph>()>>
EmitYnnFusionBuilder(const HloComputation* computation);
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_YNN_EMITTER_H_

View File

@ -0,0 +1,141 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/cpu/ynn_support.h"
#include <algorithm>
#include "ynnpack/include/ynnpack.h"
#include "absl/base/no_destructor.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout_util.h"
#include "xla/shape.h"
#include "xla/util.h"
namespace xla::cpu {
const absl::flat_hash_map<HloOpcode, ynn_unary_operator>& GetYnnUnaryOpMap() {
static absl::NoDestructor<absl::flat_hash_map<HloOpcode, ynn_unary_operator>>
unary_op_map({
{HloOpcode::kAbs, ynn_unary_abs},
{HloOpcode::kCeil, ynn_unary_ceil},
{HloOpcode::kConvert, ynn_unary_convert},
{HloOpcode::kCos, ynn_unary_cosine},
{HloOpcode::kExp, ynn_unary_exp},
{HloOpcode::kCbrt, ynn_unary_cube_root},
{HloOpcode::kFloor, ynn_unary_floor},
{HloOpcode::kLog, ynn_unary_log},
{HloOpcode::kLogistic, ynn_unary_sigmoid},
{HloOpcode::kNegate, ynn_unary_negate},
{HloOpcode::kRoundNearestEven, ynn_unary_round},
{HloOpcode::kRsqrt, ynn_unary_reciprocal_square_root},
{HloOpcode::kSign, ynn_unary_sign},
{HloOpcode::kSin, ynn_unary_sine},
{HloOpcode::kSqrt, ynn_unary_square_root},
{HloOpcode::kTanh, ynn_unary_tanh},
});
return *unary_op_map;
}
absl::StatusOr<ynn_unary_operator> YnnUnaryOperator(const HloOpcode& opcode) {
const auto& unary_op_map = GetYnnUnaryOpMap();
auto result = unary_op_map.find(opcode);
if (result == unary_op_map.end()) {
return InvalidArgument("Unsupported YNNPACK unary operator: %s",
HloOpcodeString(opcode));
}
return result->second;
}
const absl::flat_hash_map<HloOpcode, ynn_binary_operator>& GetYnnBinaryOpMap() {
static absl::NoDestructor<absl::flat_hash_map<HloOpcode, ynn_binary_operator>>
binary_op_map({
{HloOpcode::kAdd, ynn_binary_add},
{HloOpcode::kDivide, ynn_binary_divide},
{HloOpcode::kMaximum, ynn_binary_max},
{HloOpcode::kMinimum, ynn_binary_min},
{HloOpcode::kMultiply, ynn_binary_multiply},
{HloOpcode::kPower, ynn_binary_pow},
{HloOpcode::kSubtract, ynn_binary_subtract},
});
return *binary_op_map;
}
absl::StatusOr<ynn_binary_operator> YnnBinaryOperator(const HloOpcode& opcode) {
const auto& binary_op_map = GetYnnBinaryOpMap();
auto result = binary_op_map.find(opcode);
if (result == binary_op_map.end()) {
return InvalidArgument("Unsupported YNNPACK binary operator: %s",
HloOpcodeString(opcode));
}
return result->second;
}
bool IsLayoutSupportedByYnn(const Shape& shape) {
return !shape.has_layout() || LayoutUtil::HasDescendingLayout(shape.layout());
}
bool IsBitcastOpSupportedByYnn(const HloInstruction* hlo) {
CHECK_EQ(hlo->opcode(), HloOpcode::kBitcast);
if (!YnnType(hlo->shape().element_type()).ok()) {
return false;
}
const HloInstruction* input = hlo->operand(0);
return hlo->shape().element_type() == input->shape().element_type();
}
bool IsConstantSupportedByYnn(const HloInstruction* hlo) {
CHECK(hlo->IsConstant());
if (!YnnType(hlo->shape().element_type()).ok()) {
return false;
}
return hlo->shape().IsArray();
}
bool IsElementwiseOpSupportedByYnn(const HloInstruction* hlo) {
CHECK(hlo->IsElementwise());
// In XLA IsElementwise is true for constants.
CHECK(!hlo->IsConstant());
if (!YnnType(hlo->shape().element_type()).ok()) {
return false;
}
if (!std::all_of(hlo->operands().begin(), hlo->operands().end(),
[](const HloInstruction* op) {
return YnnType(op->shape().element_type()).ok();
})) {
return false;
}
switch (hlo->operand_count()) {
case 1:
return YnnUnaryOperator(hlo->opcode()).ok();
case 2:
return YnnBinaryOperator(hlo->opcode()).ok();
default:
return false;
}
}
} // namespace xla::cpu

View File

@ -0,0 +1,60 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_BACKENDS_CPU_YNN_SUPPORT_H_
#define XLA_BACKENDS_CPU_YNN_SUPPORT_H_
#include "ynnpack/include/ynnpack.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
namespace xla::cpu {
inline constexpr absl::string_view kYnnFusionKind = "__ynn_fusion";
// Returns the mappings from HLO opcodes to YNNPACK unary operators.
const absl::flat_hash_map<HloOpcode, ynn_unary_operator>& GetYnnUnaryOpMap();
// Returns the YNNPACK unary operator corresponding to the given HLO opcode.
// Returns `InvalidArgument` if the opcode is not supported.
absl::StatusOr<ynn_unary_operator> YnnUnaryOperator(const HloOpcode& opcode);
// Returns the mappings from HLO opcodes to YNNPACK binary operators.
const absl::flat_hash_map<HloOpcode, ynn_binary_operator>& GetYnnBinaryOpMap();
// Returns the YNNPACK binary operator corresponding to the given HLO opcode.
// Returns `InvalidArgument` if the opcode is not supported.
absl::StatusOr<ynn_binary_operator> YnnBinaryOperator(const HloOpcode& opcode);
// Returns true if the shape either doesn't have a layout or the layout is
// descending. Shapes without layout are accepted to make HLO tests less
// verbose.
bool IsLayoutSupportedByYnn(const Shape& shape);
// Returns true if the bitcast op is supported by YNNPACK.
bool IsBitcastOpSupportedByYnn(const HloInstruction* hlo);
// Returns true if the constant is supported by YNNPACK.
bool IsConstantSupportedByYnn(const HloInstruction* hlo);
// Returns true if the nonconstant elementwise op is supported by YNNPACK.
bool IsElementwiseOpSupportedByYnn(const HloInstruction* hlo);
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_YNN_SUPPORT_H_

View File

@ -1676,6 +1676,12 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
cpu::Thunk::XnnParams::Create(&run_options));
}
std::optional<cpu::Thunk::YnnParams> ynn_params;
if (cpu_executable->has_ynn_fusions()) {
TF_ASSIGN_OR_RETURN(ynn_params,
cpu::Thunk::YnnParams::Create(&run_options));
}
cpu::ThreadPoolTaskRunner task_runner(
run_options.intra_op_thread_pool()->getPool());
@ -1688,6 +1694,7 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
&collective_params,
&custom_call_execute_params,
xnn_params ? &*xnn_params : nullptr,
ynn_params ? &*ynn_params : nullptr,
run_options.run_id().ToInt(),
run_options.device_ordinal(),
};
@ -1814,6 +1821,12 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
xnn_params = cpu::Thunk::XnnParams::Create(&run_options);
}
absl::StatusOr<std::optional<cpu::Thunk::YnnParams>> ynn_params(
std::nullopt);
if (cpu_executable->has_ynn_fusions()) {
ynn_params = cpu::Thunk::YnnParams::Create(&run_options);
}
cpu::ThreadPoolTaskRunner task_runner(
run_options.intra_op_thread_pool()->getPool());
@ -1827,6 +1840,7 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
&*collective_params,
&*custom_call_params,
*xnn_params ? &**xnn_params : nullptr,
*ynn_params ? &**ynn_params : nullptr,
run_options.run_id().ToInt(),
run_options.device_ordinal(),
};

View File

@ -24,6 +24,7 @@ load(
"if_llvm_x86_available",
)
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
load("//xla/tsl/xnnpack:build_defs.bzl", "if_ynnpack")
load(":build_defs.bzl", "runtime_copts")
package(
@ -931,7 +932,7 @@ cc_library(
srcs = ["thunk_emitter.cc"],
hdrs = ["thunk_emitter.h"],
copts = tsl_copts(),
local_defines = if_graph_api(["XLA_ONEDNN_USE_GRAPH_API=1"]),
local_defines = if_graph_api(["XLA_ONEDNN_USE_GRAPH_API=1"]) + if_ynnpack(["XLA_YNNPACK"]),
deps = [
":backend_config_proto_cc",
":cpu_options",
@ -1023,6 +1024,10 @@ cc_library(
"@local_tsl//tsl/profiler/lib:traceme",
] + if_onednn([
"//xla/backends/cpu/runtime/onednn:onednn_op_thunk",
]) + if_ynnpack([
"//xla/backends/cpu:ynn_emitter",
"//xla/backends/cpu:ynn_support",
"//xla/backends/cpu/runtime/ynnpack:ynn_fusion_thunk",
]),
)

View File

@ -110,6 +110,12 @@ absl::StatusOr<std::unique_ptr<CpuExecutable>> CpuExecutable::Create(
executable->has_xnn_fusions_ |= thunk.kind() == Thunk::Kind::kXnnFusion;
});
// Find if the thunk sequence contains any YNN fusion thunks. If we do have
// any, we will prepare the YNNPACK thread pool for them at run time.
executable->thunks_->thunk_sequence().ForEach([&](const Thunk& thunk) {
executable->has_ynn_fusions_ |= thunk.kind() == Thunk::Kind::kYnnFusion;
});
// Re-index constants by their allocation index to allow efficient lookup.
for (auto& constant : constants) {
if (executable->constants_.size() <= constant.index) {

View File

@ -128,6 +128,7 @@ class CpuExecutable : public Executable {
ThunkExecutor& thunks() { return *thunks_; }
bool has_xnn_fusions() const { return has_xnn_fusions_; }
bool has_ynn_fusions() const { return has_ynn_fusions_; }
const BufferAssignment& buffer_assignment() const { return *assignment_; }
absl::Span<const ConstantAllocation> constants() const { return constants_; }
@ -230,6 +231,9 @@ class CpuExecutable : public Executable {
// Whether the thunk executor contains any XNN fusion thunks.
bool has_xnn_fusions_ = false;
// Whether the thunk executor contains any YNN fusion thunks.
bool has_ynn_fusions_ = false;
// Entry function name for the computation.
std::string entry_function_name_;

View File

@ -126,6 +126,12 @@ limitations under the License.
#include "xla/backends/cpu/runtime/onednn/onednn_fusion_thunk.h"
#endif // XLA_ONEDNN_USE_GRAPH_API
#ifdef XLA_YNNPACK
#include "xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.h"
#include "xla/backends/cpu/ynn_emitter.h"
#include "xla/backends/cpu/ynn_support.h"
#endif // XLA_YNNPACK
namespace xla::cpu {
namespace {
@ -440,6 +446,12 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitHloInstruction(
return EmitXnnFusionThunk(instruction);
}
#ifdef XLA_YNNPACK
if (backend_config.fusion_config().kind() == kYnnFusionKind) {
return EmitYnnFusionThunk(instruction);
}
#endif // XLA_YNNPACK
return Internal("Unsupported custom fusion kind: %s",
backend_config.DebugString());
}
@ -1494,6 +1506,45 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitXnnFusionThunk(
[b = std::move(builder)](auto, auto) mutable { return b(); });
}
absl::StatusOr<ThunkSequence> ThunkEmitter::EmitYnnFusionThunk(
const HloInstruction* instruction) {
#ifdef XLA_YNNPACK
auto* fusion = Cast<HloFusionInstruction>(instruction);
// Collect YNNPACK fusion arguments.
std::vector<YnnFusionThunk::Argument> arguments;
for (HloInstruction* operand : instruction->operands()) {
for (auto& indexed : ShapeUtil::GetLeafShapes(operand->shape())) {
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice slice,
buffer_assignment_.GetUniqueSlice(operand, indexed.index));
arguments.push_back(YnnFusionThunk::Argument{slice, indexed.shape});
}
}
// Collect YNNPACK fusion results.
std::vector<YnnFusionThunk::Result> results;
for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) {
TF_ASSIGN_OR_RETURN(
BufferAllocation::Slice slice,
buffer_assignment_.GetUniqueSlice(instruction, indexed.index));
results.push_back(YnnFusionThunk::Result{slice, indexed.shape});
}
const HloComputation* computation = fusion->fused_instructions_computation();
// Construct YNNPACK subgraph builder from the fusion computation.
TF_ASSIGN_OR_RETURN(auto builder, EmitYnnFusionBuilder(computation));
return ThunkSequence::Of<YnnFusionThunk>(
YnnFusionThunk::Options{}, ThunkInfo(instruction), std::move(arguments),
std::move(results),
[b = std::move(builder)](auto, auto) mutable { return b(); });
#else
return Unimplemented("XLA is not built with YNNPACK.");
#endif // XLA_YNNPACK
}
absl::StatusOr<ThunkEmitter::HostKernelAllocationSlices>
ThunkEmitter::GetHostKernelAllocationSlices(const HloInstruction* instruction) {
HostKernelAllocationSlices slices;

View File

@ -217,6 +217,9 @@ class ThunkEmitter {
absl::StatusOr<ThunkSequence> EmitXnnFusionThunk(
const HloInstruction* instruction);
absl::StatusOr<ThunkSequence> EmitYnnFusionThunk(
const HloInstruction* instruction);
absl::StatusOr<ThunkSequence> EmitOneDnnFusionThunk(
const HloInstruction* instruction);