mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:CPU] Add initial bits for YNNPACK support.
+ Do not build XLA with YNNPACK on Windows. Co-authored-by: Penporn Koanantakool <penporn@google.com> PiperOrigin-RevId: 820896434
This commit is contained in:
parent
f0057ee4b7
commit
ce65a0ad5c
51
third_party/xla/xla/backends/cpu/BUILD
vendored
51
third_party/xla/xla/backends/cpu/BUILD
vendored
|
|
@ -153,6 +153,30 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ynn_emitter",
|
||||
srcs = ["ynn_emitter.cc"],
|
||||
hdrs = ["ynn_emitter.h"],
|
||||
deps = [
|
||||
":ynn_support",
|
||||
"//xla:literal",
|
||||
"//xla:shape_util",
|
||||
"//xla:util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/backends/cpu/runtime/xnnpack:xnn_interop",
|
||||
"//xla/backends/cpu/runtime/ynnpack:ynn_interop",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/tsl/platform:logging",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@XNNPACK//ynnpack",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/functional:any_invocable",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xnn_gemm_config",
|
||||
srcs = ["xnn_gemm_config.cc"],
|
||||
|
|
@ -234,6 +258,33 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ynn_support",
|
||||
srcs = ["ynn_support.cc"],
|
||||
hdrs = ["ynn_support.h"],
|
||||
deps = [
|
||||
"//xla:shape_util",
|
||||
"//xla:util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/backends/cpu/codegen:target_machine_features",
|
||||
"//xla/backends/cpu/runtime:dot_lib",
|
||||
"//xla/backends/cpu/runtime/ynnpack:ynn_interop",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/service:pattern_matcher",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@XNNPACK//ynnpack",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:no_destructor",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant_allocation",
|
||||
srcs = ["constant_allocation.cc"],
|
||||
|
|
|
|||
14
third_party/xla/xla/backends/cpu/runtime/BUILD
vendored
14
third_party/xla/xla/backends/cpu/runtime/BUILD
vendored
|
|
@ -4,6 +4,7 @@ load("//xla/tsl:tsl.bzl", "if_windows", "internal_visibility")
|
|||
load("//xla/tsl:tsl.default.bzl", "filegroup")
|
||||
load("//xla/tsl/platform:build_config.bzl", "tf_proto_library")
|
||||
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
|
||||
load("//xla/tsl/xnnpack:build_defs.bzl", "if_ynnpack")
|
||||
|
||||
package(
|
||||
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
|
||||
|
|
@ -160,6 +161,7 @@ cc_library(
|
|||
name = "thunk",
|
||||
srcs = ["thunk.cc"],
|
||||
hdrs = ["thunk.h"],
|
||||
defines = if_ynnpack(["XLA_YNNPACK"]),
|
||||
deps = [
|
||||
":buffer_allocations",
|
||||
":function_library",
|
||||
|
|
@ -188,7 +190,10 @@ cc_library(
|
|||
"@com_google_absl//absl/strings:string_view",
|
||||
"@local_tsl//tsl/profiler/lib:traceme",
|
||||
"@local_tsl//tsl/profiler/lib:traceme_encode",
|
||||
],
|
||||
] + if_ynnpack([
|
||||
"//xla/backends/cpu/runtime/ynnpack:ynn_interop",
|
||||
"//xla/backends/cpu/runtime/ynnpack:ynn_threadpool",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
@ -1246,6 +1251,7 @@ xla_cc_test(
|
|||
xla_cc_test(
|
||||
name = "thunk_sequence_serdes_test",
|
||||
srcs = ["thunk_sequence_serdes_test.cc"],
|
||||
local_defines = if_ynnpack(["XLA_YNNPACK"]),
|
||||
deps = [
|
||||
":all_gather_thunk",
|
||||
":all_reduce_thunk",
|
||||
|
|
@ -1261,7 +1267,6 @@ xla_cc_test(
|
|||
":dot_thunk",
|
||||
":fft_thunk",
|
||||
":infeed_thunk",
|
||||
":kernel",
|
||||
":kernel_thunk",
|
||||
":logical_id_thunk",
|
||||
":outfeed_thunk",
|
||||
|
|
@ -1293,7 +1298,6 @@ xla_cc_test(
|
|||
"//xla/service:hlo_proto_cc",
|
||||
"//xla/stream_executor:launch_dim",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:status",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
|
@ -1305,7 +1309,9 @@ xla_cc_test(
|
|||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@local_tsl//tsl/platform:casts",
|
||||
],
|
||||
] + if_ynnpack([
|
||||
"//xla/backends/cpu/runtime/ynnpack:ynn_fusion_thunk",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
|||
|
|
@ -43,6 +43,11 @@ limitations under the License.
|
|||
#include "tsl/profiler/lib/traceme.h"
|
||||
#include "tsl/profiler/lib/traceme_encode.h"
|
||||
|
||||
#ifdef XLA_YNNPACK
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
|
||||
#endif // XLA_YNNPACK
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
// Ok execute event allocated with the static storage duration.
|
||||
|
|
@ -88,6 +93,8 @@ absl::string_view Thunk::KindToString(Kind kind) {
|
|||
return "while";
|
||||
case Kind::kXnnFusion:
|
||||
return "xnn-fusion";
|
||||
case Kind::kYnnFusion:
|
||||
return "ynn-fusion";
|
||||
case Kind::kOneDnnFusion:
|
||||
return "onednn-fusion";
|
||||
}
|
||||
|
|
@ -168,6 +175,18 @@ absl::StatusOr<Thunk::XnnParams> Thunk::XnnParams::Create(
|
|||
Thunk::XnnParams::XnnParams(XnnThreadpool threadpool)
|
||||
: threadpool(std::move(threadpool)) {}
|
||||
|
||||
#ifdef XLA_YNNPACK
|
||||
absl::StatusOr<Thunk::YnnParams> Thunk::YnnParams::Create(
|
||||
const ExecutableRunOptions* run_options) {
|
||||
TF_ASSIGN_OR_RETURN(YnnThreadpool threadpool,
|
||||
CreateYnnThreadpool(run_options->intra_op_thread_pool()));
|
||||
return YnnParams(std::move(threadpool));
|
||||
}
|
||||
|
||||
Thunk::YnnParams::YnnParams(YnnThreadpool threadpool)
|
||||
: threadpool(std::move(threadpool)) {}
|
||||
#endif // XLA_YNNPACK
|
||||
|
||||
Thunk::ExecuteSession::ExecuteSession(int64_t max_workers,
|
||||
int64_t split_threshold)
|
||||
: lock_(std::make_shared<std::nullopt_t>(std::nullopt)),
|
||||
|
|
|
|||
26
third_party/xla/xla/backends/cpu/runtime/thunk.h
vendored
26
third_party/xla/xla/backends/cpu/runtime/thunk.h
vendored
|
|
@ -47,6 +47,11 @@ limitations under the License.
|
|||
#include "xla/tsl/platform/logging.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
||||
#ifdef XLA_YNNPACK
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
|
||||
#endif // XLA_YNNPACK
|
||||
|
||||
namespace Eigen {
|
||||
struct ThreadPoolDevice;
|
||||
} // namespace Eigen
|
||||
|
|
@ -87,6 +92,7 @@ class Thunk {
|
|||
kTopK,
|
||||
kWhile,
|
||||
kXnnFusion,
|
||||
kYnnFusion,
|
||||
kOneDnnFusion,
|
||||
};
|
||||
|
||||
|
|
@ -262,6 +268,25 @@ class Thunk {
|
|||
explicit XnnParams(XnnThreadpool threadpool);
|
||||
};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// YnnParams
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
#ifdef XLA_YNNPACK
|
||||
// Parameters capturing all the details required for running XNNPACK fusions.
|
||||
struct YnnParams {
|
||||
static absl::StatusOr<YnnParams> Create(
|
||||
const ExecutableRunOptions* run_options);
|
||||
|
||||
YnnThreadpool threadpool = nullptr;
|
||||
|
||||
explicit YnnParams(YnnThreadpool threadpool);
|
||||
};
|
||||
#else
|
||||
// Use XnnParams for placeholder. The parameter won't be used anyway.
|
||||
using YnnParams = XnnParams;
|
||||
#endif // XLA_YNNPACK
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// ExecuteParams
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
@ -277,6 +302,7 @@ class Thunk {
|
|||
CollectiveExecuteParams* collective_params = nullptr;
|
||||
CustomCallExecuteParams* custom_call_params = nullptr;
|
||||
XnnParams* xnn_params = nullptr;
|
||||
YnnParams* ynn_params = nullptr;
|
||||
int64_t run_id = -1; // -1 means no run id is set.
|
||||
int64_t device_ordinal = -1; // -1 means no device ordinal is set.
|
||||
ExecuteSession session = ExecuteSession(ExecuteSession::kMaxWorkers,
|
||||
|
|
|
|||
|
|
@ -80,6 +80,10 @@ limitations under the License.
|
|||
#include "xla/xla_data.pb.h"
|
||||
#include "tsl/platform/casts.h"
|
||||
|
||||
#ifdef XLA_YNNPACK
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.h"
|
||||
#endif // XLA_YNNPACK
|
||||
|
||||
namespace xla::cpu {
|
||||
namespace {
|
||||
|
||||
|
|
@ -1103,6 +1107,15 @@ class ThunkSequenceSerdesTest : public ::testing::Test {
|
|||
return false;
|
||||
}
|
||||
|
||||
#ifdef XLA_YNNPACK
|
||||
bool VerifyYnnFusionThunkEquality(const YnnFusionThunk& thunk_1,
|
||||
const YnnFusionThunk& thunk_2) {
|
||||
// TODO(ashaposhnikov) assume this is always false until we implement
|
||||
// serialization of YnnFusionThunk.
|
||||
return false;
|
||||
}
|
||||
#endif // XLA_YNNPACK
|
||||
|
||||
bool VerifyXnnDotThunkEquality(const XnnDotThunk& thunk_1,
|
||||
const XnnDotThunk& thunk_2) {
|
||||
const bool are_dot_dimensions_equal =
|
||||
|
|
@ -1412,6 +1425,24 @@ class ThunkSequenceSerdesTest : public ::testing::Test {
|
|||
tsl::down_cast<const XnnConvolutionThunk&>(thunk_2));
|
||||
}
|
||||
}
|
||||
case Thunk::Kind::kYnnFusion: {
|
||||
#ifdef XLA_YNNPACK
|
||||
const YnnFusionThunk& ynn_fusion_thunk_1 =
|
||||
tsl::down_cast<const YnnFusionThunk&>(thunk_1);
|
||||
const YnnFusionThunk& ynn_fusion_thunk_2 =
|
||||
tsl::down_cast<const YnnFusionThunk&>(thunk_2);
|
||||
if (ynn_fusion_thunk_1.ynn_fusion_kind() !=
|
||||
ynn_fusion_thunk_2.ynn_fusion_kind()) {
|
||||
return false;
|
||||
}
|
||||
return VerifyYnnFusionThunkEquality(
|
||||
tsl::down_cast<const YnnFusionThunk&>(thunk_1),
|
||||
tsl::down_cast<const YnnFusionThunk&>(thunk_2));
|
||||
#else
|
||||
CHECK(false) << "Unsupported YNN fusion thunk type";
|
||||
return false;
|
||||
#endif // XLA_YNNPACK
|
||||
}
|
||||
case Thunk::Kind::kKernel:
|
||||
return VerifyKernelThunkEquality(
|
||||
tsl::down_cast<const KernelThunkBase&>(thunk_1),
|
||||
|
|
|
|||
103
third_party/xla/xla/backends/cpu/runtime/ynnpack/BUILD
vendored
Normal file
103
third_party/xla/xla/backends/cpu/runtime/ynnpack/BUILD
vendored
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
|
||||
load("//xla/tsl/xnnpack:build_defs.bzl", "ynn_cc_test")
|
||||
|
||||
package(
|
||||
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
|
||||
default_visibility = [":friends"],
|
||||
licenses = ["notice"],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = [
|
||||
"//xla:friends",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ynn_interop",
|
||||
srcs = ["ynn_interop.cc"],
|
||||
hdrs = ["ynn_interop.h"],
|
||||
deps = [
|
||||
"//xla:shape_util",
|
||||
"//xla:util",
|
||||
"//xla/tsl/platform:logging",
|
||||
"@XNNPACK//ynnpack",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/functional:function_ref",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ynn_threadpool",
|
||||
srcs = ["ynn_threadpool.cc"],
|
||||
hdrs = ["ynn_threadpool.h"],
|
||||
deps = [
|
||||
":ynn_interop",
|
||||
"@XNNPACK//ynnpack:ynnpack_h",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ynn_fusion_thunk",
|
||||
srcs = ["ynn_fusion_thunk.cc"],
|
||||
hdrs = ["ynn_fusion_thunk.h"],
|
||||
deps = [
|
||||
":ynn_interop",
|
||||
"//xla:shape_util",
|
||||
"//xla/backends/cpu/runtime:thunk",
|
||||
"//xla/runtime:buffer_use",
|
||||
"//xla/runtime:object_pool",
|
||||
"//xla/service:buffer_assignment",
|
||||
"//xla/stream_executor:device_memory",
|
||||
"//xla/tsl/concurrency:async_value",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:logging",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@XNNPACK//ynnpack:ynnpack_h",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:no_destructor",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/functional:any_invocable",
|
||||
"@com_google_absl//absl/functional:bind_front",
|
||||
"@com_google_absl//absl/functional:function_ref",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
ynn_cc_test(
|
||||
name = "ynn_fusion_thunk_test",
|
||||
srcs = ["ynn_fusion_thunk_test.cc"],
|
||||
deps = [
|
||||
":ynn_fusion_thunk",
|
||||
":ynn_interop",
|
||||
":ynn_threadpool",
|
||||
"//xla:literal_util",
|
||||
"//xla:shape_util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/backends/cpu/runtime:buffer_allocations",
|
||||
"//xla/backends/cpu/runtime:thunk",
|
||||
"//xla/backends/cpu/runtime:thunk_testlib",
|
||||
"//xla/tsl/concurrency:async_value",
|
||||
"//xla/tsl/platform:env",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"//xla/tsl/platform:test",
|
||||
"@XNNPACK//ynnpack:ynnpack_h",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@eigen_archive//:eigen3",
|
||||
],
|
||||
)
|
||||
371
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.cc
vendored
Normal file
371
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.cc
vendored
Normal file
|
|
@ -0,0 +1,371 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ynnpack/include/ynnpack.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/no_destructor.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/functional/bind_front.h"
|
||||
#include "absl/functional/function_ref.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/cpu/runtime/thunk.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
#include "xla/runtime/buffer_use.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/logging.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
absl::string_view YnnFusionThunk::YnnFusionKindToString(YnnFusionKind kind) {
|
||||
switch (kind) {
|
||||
case YnnFusionKind::kFusion:
|
||||
return "ynn-fusion";
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, YnnFusionThunk::YnnFusionKind kind) {
|
||||
return os << YnnFusionThunk::YnnFusionKindToString(kind);
|
||||
}
|
||||
|
||||
// YNNPACK executable instantiated for the fusion operation.
|
||||
struct YnnFusionThunk::YnnExecutable {
|
||||
tsl::AsyncValueRef<YnnFusionThunk::ExecuteEvent> Invoke(
|
||||
const YnnThreadpool& threadpool,
|
||||
absl::Span<se::DeviceMemoryBase> arguments,
|
||||
absl::Span<se::DeviceMemoryBase> results,
|
||||
absl::FunctionRef<bool(size_t)> is_captured_argument);
|
||||
|
||||
// Resets YNNPACK runtime and subgraph.
|
||||
absl::Status Reset();
|
||||
|
||||
YnnSubgraph subgraph = nullptr;
|
||||
YnnRuntime runtime = nullptr;
|
||||
|
||||
// TODO(ezhulenev): Today we rely on device memory as an identity of the
|
||||
// captured argument, and this is not correct as we can have multiple
|
||||
// arguments allocated to the heap address. This is work in progress, and will
|
||||
// be migrated to a buffer identity passed to XLA by the client (PjRt).
|
||||
std::vector<se::DeviceMemoryBase> captured_arguments;
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct YnnExternalValue {
|
||||
uint32_t id;
|
||||
void* data;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static enum ynn_status set_external_values(
|
||||
ynn_runtime_t runtime, absl::Span<const YnnExternalValue> external_values) {
|
||||
for (const auto& [id, data] : external_values) {
|
||||
enum ynn_status status = ynn_set_external_value_data(runtime, id, data);
|
||||
if (status != ynn_status_success) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return ynn_status_success;
|
||||
}
|
||||
|
||||
tsl::AsyncValueRef<YnnFusionThunk::ExecuteEvent>
|
||||
YnnFusionThunk::YnnExecutable::Invoke(
|
||||
const YnnThreadpool& threadpool, absl::Span<se::DeviceMemoryBase> arguments,
|
||||
absl::Span<se::DeviceMemoryBase> results,
|
||||
absl::FunctionRef<bool(size_t)> is_captured_argument) {
|
||||
// Create external values for all arguments and results.
|
||||
absl::InlinedVector<YnnExternalValue, 8> external_values;
|
||||
external_values.reserve(arguments.size() + results.size());
|
||||
|
||||
// External tensor id for arguments and results.
|
||||
uint32_t id = 0;
|
||||
|
||||
for (const se::DeviceMemoryBase& argument : arguments) {
|
||||
YnnExternalValue value{id++, argument.opaque()};
|
||||
if (!is_captured_argument(value.id)) {
|
||||
external_values.push_back(value);
|
||||
}
|
||||
}
|
||||
|
||||
for (const se::DeviceMemoryBase& result : results) {
|
||||
YnnExternalValue value{id++, result.opaque()};
|
||||
external_values.push_back(value);
|
||||
}
|
||||
|
||||
DCHECK_NE(runtime.get(), nullptr) << "YNNPACK runtime is not initialized";
|
||||
|
||||
YNN_RETURN_IF_ERROR(set_external_values(runtime.get(), external_values));
|
||||
|
||||
// Update threadpool used by the YNNPACK runtime.
|
||||
YNN_RETURN_IF_ERROR(ynn_update_runtime_with_threadpool(
|
||||
runtime.get(), reinterpret_cast<ynn_threadpool_t>(threadpool.get())));
|
||||
|
||||
// Execute YNNPACK runtime in the caller thread.
|
||||
YNN_RETURN_IF_ERROR(ynn_invoke_runtime(runtime.get()));
|
||||
return OkExecuteEvent();
|
||||
}
|
||||
|
||||
absl::Status YnnFusionThunk::YnnExecutable::Reset() {
|
||||
runtime.reset();
|
||||
subgraph.reset();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<YnnFusionThunk::YnnExecutable>
|
||||
YnnFusionThunk::CreateYnnExecutable(
|
||||
const YnnThreadpool& threadpool,
|
||||
absl::Span<const se::DeviceMemoryBase> arguments_buffers) {
|
||||
bool capturing = !captured_arguments_ids_.empty();
|
||||
VLOG(3) << absl::StreamFormat(
|
||||
"Create %s YNN executable for `%s` operation: num_created=%d",
|
||||
capturing ? "capturing" : "pooled", info().op_name,
|
||||
capturing ? num_capturing_created_.fetch_add(1)
|
||||
: ynn_executable_pool_.num_created());
|
||||
|
||||
YnnExecutable executable;
|
||||
|
||||
// Keep track of the arguments captured by value.
|
||||
executable.captured_arguments = CaptureArguments(arguments_buffers);
|
||||
|
||||
if (builder_) {
|
||||
TF_ASSIGN_OR_RETURN(executable.subgraph, builder_(arguments_, results_));
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
executable.subgraph,
|
||||
capturing_builder_(arguments_, results_, arguments_buffers));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
executable.runtime, CreateYnnRuntime([&](ynn_runtime_t* runtime) {
|
||||
uint32_t ynn_flags = 0;
|
||||
return ynn_create_runtime(
|
||||
executable.subgraph.get(),
|
||||
reinterpret_cast<ynn_threadpool_t>(threadpool.get()), ynn_flags,
|
||||
runtime);
|
||||
}));
|
||||
YNN_RETURN_IF_ERROR(ynn_reshape_runtime(executable.runtime.get()));
|
||||
|
||||
return {std::move(executable)};
|
||||
}
|
||||
|
||||
absl::Status YnnFusionThunk::UpdateYnnExecutable(
|
||||
const YnnThreadpool& threadpool, YnnExecutable& executable,
|
||||
absl::Span<const se::DeviceMemoryBase> arguments_buffers) {
|
||||
DCHECK(capturing_builder_) << "YNN executable is not capturing arguments";
|
||||
DCHECK_EQ(executable.captured_arguments.size(),
|
||||
captured_arguments_ids_.size())
|
||||
<< "Unexpected number of captured arguments";
|
||||
|
||||
// If all arguments captured by value are the same as the last execution,
|
||||
// we can reuse the YNN executable.
|
||||
auto capture_arguments = CaptureArguments(arguments_buffers);
|
||||
if (executable.captured_arguments == capture_arguments) {
|
||||
VLOG(3) << absl::StreamFormat("Reuse YNN executable for `%s` operation",
|
||||
info().op_name);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
VLOG(3) << absl::StreamFormat("Update YNN executable for `%s` operation",
|
||||
info().op_name);
|
||||
|
||||
TF_RETURN_IF_ERROR(executable.Reset());
|
||||
|
||||
// Keep track of the updated arguments captured by value.
|
||||
executable.captured_arguments = std::move(capture_arguments);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
executable.subgraph,
|
||||
capturing_builder_(arguments_, results_, arguments_buffers));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
executable.runtime, CreateYnnRuntime([&](ynn_runtime_t* runtime) {
|
||||
uint32_t ynn_flags = 0;
|
||||
return ynn_create_runtime(
|
||||
executable.subgraph.get(),
|
||||
reinterpret_cast<ynn_threadpool_t>(threadpool.get()), ynn_flags,
|
||||
runtime);
|
||||
}));
|
||||
YNN_RETURN_IF_ERROR(ynn_reshape_runtime(executable.runtime.get()));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::vector<se::DeviceMemoryBase> YnnFusionThunk::CaptureArguments(
|
||||
absl::Span<const se::DeviceMemoryBase> arguments_buffers) {
|
||||
std::vector<se::DeviceMemoryBase> captured_arguments_ids;
|
||||
captured_arguments_ids.reserve(captured_arguments_ids_.size());
|
||||
for (int64_t i = 0; i < captured_arguments_ids_.size(); ++i) {
|
||||
int32_t arg_index = captured_arguments_ids_[i];
|
||||
captured_arguments_ids.push_back(arguments_buffers[arg_index]);
|
||||
}
|
||||
return captured_arguments_ids;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<YnnFusionThunk>> YnnFusionThunk::Create(
|
||||
Options options, Info info, std::vector<Argument> arguments,
|
||||
std::vector<Result> results, Builder builder) {
|
||||
return absl::WrapUnique(new YnnFusionThunk(
|
||||
YnnFusionKind::kFusion, std::move(options), std::move(info),
|
||||
std::move(arguments), std::move(results), std::move(builder)));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<YnnFusionThunk>> YnnFusionThunk::Create(
|
||||
Options options, Info info, std::vector<Argument> arguments,
|
||||
std::vector<Result> results, CapturingBuilder capturing_builder,
|
||||
absl::Span<const int64_t> captured_arguments_ids) {
|
||||
return absl::WrapUnique(new YnnFusionThunk(
|
||||
YnnFusionKind::kFusion, std::move(options), std::move(info),
|
||||
std::move(arguments), std::move(results), std::move(capturing_builder),
|
||||
captured_arguments_ids));
|
||||
}
|
||||
|
||||
YnnFusionThunk::YnnFusionThunk(YnnFusionKind kind, Options options, Info info,
|
||||
std::vector<Argument> arguments,
|
||||
std::vector<Result> results, Builder builder)
|
||||
: Thunk(Kind::kYnnFusion, std::move(info)),
|
||||
ynn_fusion_kind_(kind),
|
||||
options_(std::move(options)),
|
||||
arguments_(std::move(arguments)),
|
||||
results_(std::move(results)),
|
||||
builder_(std::move(builder)),
|
||||
ynn_executable_pool_(
|
||||
absl::bind_front(&YnnFusionThunk::CreateYnnExecutable, this)) {}
|
||||
|
||||
YnnFusionThunk::YnnFusionThunk(YnnFusionKind kind, Options options, Info info,
|
||||
std::vector<Argument> arguments,
|
||||
std::vector<Result> results,
|
||||
CapturingBuilder capturing_builder,
|
||||
absl::Span<const int64_t> captured_arguments_ids)
|
||||
: Thunk(Kind::kYnnFusion, std::move(info)),
|
||||
ynn_fusion_kind_(kind),
|
||||
options_(std::move(options)),
|
||||
arguments_(std::move(arguments)),
|
||||
results_(std::move(results)),
|
||||
capturing_builder_(std::move(capturing_builder)),
|
||||
captured_arguments_ids_(captured_arguments_ids.begin(),
|
||||
captured_arguments_ids.end()),
|
||||
ynn_executable_pool_(
|
||||
absl::bind_front(&YnnFusionThunk::CreateYnnExecutable, this)) {}
|
||||
|
||||
YnnFusionThunk::~YnnFusionThunk() = default;
|
||||
|
||||
YnnFusionThunk::BufferUses YnnFusionThunk::buffer_uses() const {
|
||||
BufferUses buffer_uses;
|
||||
for (const Argument& argument : arguments_) {
|
||||
buffer_uses.push_back(BufferUse::Read(argument.slice));
|
||||
}
|
||||
for (const Result& result : results_) {
|
||||
buffer_uses.push_back(BufferUse::Write(result.slice));
|
||||
}
|
||||
|
||||
return buffer_uses;
|
||||
}
|
||||
|
||||
const YnnThreadpool& GetYnnThreadpool(const Thunk::ExecuteParams& params) {
|
||||
static absl::NoDestructor<YnnThreadpool> no_threadpool(nullptr);
|
||||
return params.ynn_params ? params.ynn_params->threadpool : *no_threadpool;
|
||||
}
|
||||
|
||||
tsl::AsyncValueRef<YnnFusionThunk::ExecuteEvent> YnnFusionThunk::Execute(
|
||||
const ExecuteParams& params) {
|
||||
VLOG(3) << absl::StreamFormat("YNN %s `%s`: %s", fusion_kind(),
|
||||
info().op_name, fusion_description());
|
||||
|
||||
if (VLOG_IS_ON(3) && has_fusion_details()) {
|
||||
for (auto& detail : fusion_details()) {
|
||||
VLOG(3) << detail;
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve device memory for arguments.
|
||||
absl::InlinedVector<se::DeviceMemoryBase, 8> arguments_buffers;
|
||||
arguments_buffers.resize(arguments_.size());
|
||||
for (size_t i = 0; i < arguments_.size(); ++i) {
|
||||
Argument& argument = arguments_[i];
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
arguments_buffers[i],
|
||||
params.buffer_allocations->GetDeviceAddress(argument.slice));
|
||||
|
||||
VLOG(3) << absl::StreamFormat(" %s: %s in slice %s (%p)", argument_name(i),
|
||||
argument.shape.ToString(true),
|
||||
argument.slice.ToString(),
|
||||
arguments_buffers[i].opaque());
|
||||
}
|
||||
|
||||
// Resolve device memory for results.
|
||||
absl::InlinedVector<se::DeviceMemoryBase, 4> results_buffers;
|
||||
results_buffers.resize(results_.size());
|
||||
for (size_t i = 0; i < results_.size(); ++i) {
|
||||
Result& result = results_[i];
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
results_buffers[i],
|
||||
params.buffer_allocations->GetDeviceAddress(results_[i].slice));
|
||||
|
||||
VLOG(3) << absl::StreamFormat(" %s: %s in slice %s (%p)", result_name(i),
|
||||
result.shape.ToString(true),
|
||||
result.slice.ToString(),
|
||||
results_buffers[i].opaque());
|
||||
}
|
||||
|
||||
DCHECK(builder_ || capturing_builder_) << "One of the builders must be set.";
|
||||
|
||||
auto invoke = [&](typename YnnExecutablePool::BorrowedObject executable) {
|
||||
auto executed = executable->Invoke(
|
||||
GetYnnThreadpool(params), absl::MakeSpan(arguments_buffers),
|
||||
absl::MakeSpan(results_buffers), [&](size_t id) {
|
||||
return absl::c_linear_search(captured_arguments_ids_, id);
|
||||
});
|
||||
|
||||
// Do not return executable to the pool until the execution is done.
|
||||
executed.AndThen([executable = std::move(executable)] {});
|
||||
return executed;
|
||||
};
|
||||
|
||||
// Borrow YnnExecutable from the pool.
|
||||
TF_ASSIGN_OR_RETURN(auto executable,
|
||||
ynn_executable_pool_.GetOrCreate(GetYnnThreadpool(params),
|
||||
arguments_buffers));
|
||||
|
||||
// If YNN graph doesn't capture any of the arguments by value, we can execute
|
||||
// XnnExecutable immediately.
|
||||
if (captured_arguments_ids_.empty()) {
|
||||
return invoke(std::move(executable));
|
||||
}
|
||||
|
||||
// Otherwise reset YnnExecutable to capture new arguments buffers.
|
||||
TF_RETURN_IF_ERROR(UpdateYnnExecutable(GetYnnThreadpool(params), *executable,
|
||||
arguments_buffers));
|
||||
return invoke(std::move(executable));
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
182
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.h
vendored
Normal file
182
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.h
vendored
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_FUSION_THUNK_H_
|
||||
#define XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_FUSION_THUNK_H_
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/functional/any_invocable.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/cpu/runtime/thunk.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
#include "xla/runtime/object_pool.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
// YNN fusion thunk encapsulates YNNPACK subgraph constructed from an XLA fusion
|
||||
// operation.
|
||||
class YnnFusionThunk : public Thunk {
|
||||
public:
|
||||
enum class YnnFusionKind {
|
||||
kFusion,
|
||||
};
|
||||
|
||||
static absl::string_view YnnFusionKindToString(YnnFusionKind kind);
|
||||
|
||||
~YnnFusionThunk() override;
|
||||
|
||||
struct Options {
|
||||
// Pass YnnThreadpool constructed from the intra-op threadpool to the
|
||||
// YNNPACK runtime to allow YNNPACK to parallelize the execution.
|
||||
bool use_threadpool = true;
|
||||
};
|
||||
|
||||
struct Argument {
|
||||
BufferAllocation::Slice slice;
|
||||
Shape shape;
|
||||
};
|
||||
|
||||
struct Result {
|
||||
BufferAllocation::Slice slice;
|
||||
Shape shape;
|
||||
};
|
||||
|
||||
// Builder function constructs YNNPACK subgraph for the fusion operation.
|
||||
using Builder = absl::AnyInvocable<absl::StatusOr<YnnSubgraph>(
|
||||
absl::Span<const Argument> arguments, absl::Span<const Result> results)>;
|
||||
|
||||
// Builder function that constructs YNNPACK subgraph for the fusion operation
|
||||
// and captures some of the arguments buffers by value. Such YNNPACK subgraphs
|
||||
// can't be reused if captured arguments are not the same, and can lead to
|
||||
// crashes and undefined behavior if captured arguments are destroyed.
|
||||
// Capturing arguments by value allows YNNPACK to do packing at graph compile
|
||||
// time, and avoid re-packing costs at run time (at inference weights stay
|
||||
// constant, i.e. convolution filters and one of the dot arguments).
|
||||
using CapturingBuilder = absl::AnyInvocable<absl::StatusOr<YnnSubgraph>(
|
||||
absl::Span<const Argument> arguments, absl::Span<const Result> results,
|
||||
absl::Span<const se::DeviceMemoryBase> arguments_buffers)>;
|
||||
|
||||
static absl::StatusOr<std::unique_ptr<YnnFusionThunk>> Create(
|
||||
Options options, Info info, std::vector<Argument> arguments,
|
||||
std::vector<Result> results, Builder builder);
|
||||
|
||||
static absl::StatusOr<std::unique_ptr<YnnFusionThunk>> Create(
|
||||
Options options, Info info, std::vector<Argument> arguments,
|
||||
std::vector<Result> results, CapturingBuilder capturing_builder,
|
||||
absl::Span<const int64_t> captured_arguments_ids);
|
||||
|
||||
tsl::AsyncValueRef<ExecuteEvent> Execute(const ExecuteParams& params) final;
|
||||
|
||||
bool ExecuteMayBlock() const final { return true; }
|
||||
|
||||
BufferUses buffer_uses() const final;
|
||||
|
||||
Options options() const { return options_; }
|
||||
|
||||
YnnFusionKind ynn_fusion_kind() const { return ynn_fusion_kind_; }
|
||||
|
||||
protected:
|
||||
YnnFusionThunk(YnnFusionKind kind, Options options, Info info,
|
||||
std::vector<Argument> arguments, std::vector<Result> results,
|
||||
Builder builder);
|
||||
|
||||
YnnFusionThunk(YnnFusionKind kind, Options options, Info info,
|
||||
std::vector<Argument> arguments, std::vector<Result> results,
|
||||
CapturingBuilder capturing_builder,
|
||||
absl::Span<const int64_t> captured_arguments_ids);
|
||||
|
||||
// Extension points for subclasses to customize the logging behavior.
|
||||
virtual std::string fusion_kind() const { return "fusion"; }
|
||||
virtual std::string fusion_description() const { return ""; }
|
||||
|
||||
virtual bool has_fusion_details() const { return false; }
|
||||
virtual std::vector<std::string> fusion_details() const { return {}; }
|
||||
|
||||
virtual std::string argument_name(size_t index) const {
|
||||
return absl::StrCat("arg #", index);
|
||||
}
|
||||
|
||||
virtual std::string result_name(size_t index) const {
|
||||
return absl::StrCat("res #", index);
|
||||
}
|
||||
|
||||
private:
|
||||
// YNNPACK subgraph + runtime instantiated and ready for execution.
|
||||
struct YnnExecutable;
|
||||
|
||||
// Creates YnnExecutable for the fusion operation using one of the builders.
|
||||
absl::StatusOr<YnnExecutable> CreateYnnExecutable(
|
||||
const YnnThreadpool& threadpool,
|
||||
absl::Span<const se::DeviceMemoryBase> arguments_buffers);
|
||||
|
||||
// Updates YnnExecutable to the YNN subgraph constructed with the given
|
||||
// arguments buffers.
|
||||
absl::Status UpdateYnnExecutable(
|
||||
const YnnThreadpool& threadpool, YnnExecutable& executable,
|
||||
absl::Span<const se::DeviceMemoryBase> arguments_buffers);
|
||||
|
||||
// Returns the list of captured arguments buffers.
|
||||
std::vector<se::DeviceMemoryBase> CaptureArguments(
|
||||
absl::Span<const se::DeviceMemoryBase> arguments_buffers);
|
||||
|
||||
YnnFusionKind ynn_fusion_kind_;
|
||||
Options options_;
|
||||
|
||||
std::vector<Argument> arguments_;
|
||||
std::vector<Result> results_;
|
||||
|
||||
// Builder that constructs YNNPACK subgraph for the fusion operation.
|
||||
Builder builder_;
|
||||
|
||||
// Builder that constructs YNNPACK subgraph for the fusion operation and
|
||||
// captures some of the arguments buffers by value. Such subgraphs can't be
|
||||
// reused if captured arguments changed since the last execution.
|
||||
CapturingBuilder capturing_builder_;
|
||||
|
||||
// Indices of arguments that are captured by YNNPACK subgraph by value.
|
||||
std::vector<int64_t> captured_arguments_ids_;
|
||||
|
||||
// XLA:CPU executable can be called concurrently from multiple threads,
|
||||
// and we need to keep a pool of YNNPACK executables to avoid data races.
|
||||
using YnnExecutablePool = ObjectPool<YnnExecutable, const YnnThreadpool&,
|
||||
absl::Span<const se::DeviceMemoryBase>>;
|
||||
YnnExecutablePool ynn_executable_pool_;
|
||||
|
||||
// The number of YNNPACK executables created for capturing graphs.
|
||||
std::atomic<int64_t> num_capturing_created_{0};
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, YnnFusionThunk::YnnFusionKind kind);
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
#endif // XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_FUSION_THUNK_H_
|
||||
161
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk_test.cc
vendored
Normal file
161
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk_test.cc
vendored
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ynnpack/include/ynnpack.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/cpu/runtime/buffer_allocations.h"
|
||||
#include "xla/backends/cpu/runtime/thunk.h"
|
||||
#include "xla/backends/cpu/runtime/thunk_testlib.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
|
||||
#include "xla/literal_util.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/tsl/concurrency/async_value_ref.h"
|
||||
#include "xla/tsl/platform/env.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
#include "xla/tsl/platform/threadpool.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace xla::cpu {
|
||||
namespace {
|
||||
|
||||
static absl::StatusOr<YnnSubgraph> BuildBinaryAddSubgraph(
|
||||
absl::Span<const YnnFusionThunk::Argument> arguments,
|
||||
absl::Span<const YnnFusionThunk::Result> results) {
|
||||
TF_ASSIGN_OR_RETURN(YnnSubgraph subgraph,
|
||||
CreateYnnSubgraph([&](ynn_subgraph_t* subgraph) {
|
||||
return ynn_create_subgraph(
|
||||
/*external_value_ids=*/3,
|
||||
/*flags=*/0, subgraph);
|
||||
}));
|
||||
|
||||
auto dims = [](absl::Span<const int64_t> dims) -> std::vector<size_t> {
|
||||
return {dims.begin(), dims.end()};
|
||||
};
|
||||
|
||||
uint32_t lhs_id = 0;
|
||||
uint32_t rhs_id = 1;
|
||||
uint32_t out_id = 2;
|
||||
|
||||
std::vector<size_t> lhs_dims = dims(arguments[0].shape.dimensions());
|
||||
std::vector<size_t> rhs_dims = dims(arguments[1].shape.dimensions());
|
||||
std::vector<size_t> out_dims = dims(results[0].shape.dimensions());
|
||||
|
||||
YNN_RETURN_IF_ERROR(
|
||||
ynn_define_tensor_value(subgraph.get(), ynn_type_fp32, lhs_dims.size(),
|
||||
lhs_dims.data(), /*data=*/nullptr,
|
||||
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
|
||||
/*scale_id=*/YNN_INVALID_VALUE_ID,
|
||||
YNN_VALUE_FLAG_EXTERNAL_INPUT, &lhs_id));
|
||||
|
||||
YNN_RETURN_IF_ERROR(
|
||||
ynn_define_tensor_value(subgraph.get(), ynn_type_fp32, rhs_dims.size(),
|
||||
rhs_dims.data(), /*data=*/nullptr,
|
||||
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
|
||||
/*scale_id=*/YNN_INVALID_VALUE_ID,
|
||||
YNN_VALUE_FLAG_EXTERNAL_INPUT, &rhs_id));
|
||||
|
||||
YNN_RETURN_IF_ERROR(
|
||||
ynn_define_tensor_value(subgraph.get(), ynn_type_fp32, rhs_dims.size(),
|
||||
rhs_dims.data(), /*data=*/nullptr,
|
||||
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
|
||||
/*scale_id=*/YNN_INVALID_VALUE_ID,
|
||||
YNN_VALUE_FLAG_EXTERNAL_OUTPUT, &out_id));
|
||||
|
||||
YNN_RETURN_IF_ERROR(ynn_define_binary(subgraph.get(), ynn_binary_add, lhs_id,
|
||||
rhs_id, &out_id, /*flags=*/0));
|
||||
|
||||
return subgraph;
|
||||
}
|
||||
|
||||
class YnnFusionThunkTest : public testing::TestWithParam<bool> {
|
||||
public:
|
||||
static std::string Name(const ::testing::TestParamInfo<bool>& info) {
|
||||
return absl::StrCat(info.param ? "threadpool" : "single_threaded");
|
||||
}
|
||||
|
||||
protected:
|
||||
bool use_threadpool() const { return GetParam(); }
|
||||
};
|
||||
|
||||
TEST_P(YnnFusionThunkTest, ElementwiseAdd) {
|
||||
if (use_threadpool()) {
|
||||
GTEST_SKIP() << "Threadpool is not yet supported. Needs more clean-up.";
|
||||
}
|
||||
|
||||
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
|
||||
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
|
||||
threads.NumThreads());
|
||||
|
||||
auto lhs = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
|
||||
auto rhs = LiteralUtil::CreateR1<float>({4.0, 3.0, 2.0, 1.0});
|
||||
auto out = LiteralUtil::CreateR1<float>({0.0, 0.0, 0.0, 0.0});
|
||||
|
||||
BufferAllocations allocations = CreateBufferAllocations(lhs, rhs, out);
|
||||
|
||||
auto [lhs_alloc, rhs_alloc, out_alloc] =
|
||||
CreateBufferAllocation(lhs, rhs, out);
|
||||
auto [lhs_slice, rhs_slice, out_slice] =
|
||||
CreateBufferAllocationSlice(lhs_alloc, rhs_alloc, out_alloc);
|
||||
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
|
||||
YnnFusionThunk::Argument lhs_arg = {lhs_slice, shape};
|
||||
YnnFusionThunk::Argument rhs_arg = {rhs_slice, shape};
|
||||
YnnFusionThunk::Result out_res = {out_slice, shape};
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto thunk, YnnFusionThunk::Create(
|
||||
YnnFusionThunk::Options{use_threadpool()}, {"fusion"},
|
||||
{lhs_arg, rhs_arg}, {out_res}, &BuildBinaryAddSubgraph));
|
||||
|
||||
YnnThreadpool threadpool;
|
||||
if (use_threadpool()) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(threadpool, CreateYnnThreadpool(&device));
|
||||
}
|
||||
Thunk::YnnParams ynn_params(std::move(threadpool));
|
||||
|
||||
Thunk::ExecuteParams params;
|
||||
params.buffer_allocations = &allocations;
|
||||
params.intra_op_threadpool = use_threadpool() ? &device : nullptr;
|
||||
params.ynn_params = &ynn_params;
|
||||
|
||||
auto execute_event = thunk->Execute(params);
|
||||
tsl::BlockUntilReady(execute_event);
|
||||
ASSERT_FALSE(execute_event.IsError()) << execute_event.GetError();
|
||||
|
||||
EXPECT_EQ(out, LiteralUtil::CreateR1<float>({5.0, 5.0, 5.0, 5.0}));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(YnnFusion, YnnFusionThunkTest, ::testing::Bool(),
|
||||
YnnFusionThunkTest::Name);
|
||||
|
||||
} // namespace
|
||||
} // namespace xla::cpu
|
||||
61
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_interop.cc
vendored
Normal file
61
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_interop.cc
vendored
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
|
||||
#include "ynnpack/include/ynnpack.h"
|
||||
#include "absl/functional/function_ref.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/primitive_util.h"
|
||||
#include "xla/util.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
absl::StatusOr<YnnSubgraph> CreateYnnSubgraph(
|
||||
absl::FunctionRef<ynn_status(ynn_subgraph_t*)> builder) {
|
||||
ynn_subgraph_t subgraph = nullptr;
|
||||
YNN_RETURN_IF_ERROR(builder(&subgraph));
|
||||
return YnnSubgraph(subgraph);
|
||||
}
|
||||
|
||||
absl::StatusOr<YnnRuntime> CreateYnnRuntime(
|
||||
absl::FunctionRef<ynn_status(ynn_runtime_t*)> builder) {
|
||||
ynn_runtime_t runtime = nullptr;
|
||||
YNN_RETURN_IF_ERROR(builder(&runtime));
|
||||
return YnnRuntime(runtime);
|
||||
}
|
||||
|
||||
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
|
||||
absl::FunctionRef<ynn_status(ynn_threadpool_t*)> builder) {
|
||||
ynn_threadpool_t threadpool = nullptr;
|
||||
YNN_RETURN_IF_ERROR(builder(&threadpool));
|
||||
return YnnThreadpool(threadpool);
|
||||
}
|
||||
|
||||
absl::StatusOr<ynn_type> YnnType(const PrimitiveType& type) {
|
||||
switch (type) {
|
||||
case BF16:
|
||||
return ynn_type_bf16;
|
||||
case F16:
|
||||
return ynn_type_fp16;
|
||||
case F32:
|
||||
return ynn_type_fp32;
|
||||
default:
|
||||
return InvalidArgument("Unsupported YNNPACK type: %s",
|
||||
primitive_util::LowercasePrimitiveTypeName(type));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
111
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_interop.h
vendored
Normal file
111
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_interop.h
vendored
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_INTEROP_H_
|
||||
#define XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_INTEROP_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ynnpack/include/ynnpack.h"
|
||||
#include "absl/base/optimization.h"
|
||||
#include "absl/functional/function_ref.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "xla/tsl/platform/logging.h"
|
||||
#include "xla/util.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// YNNPACK status to ABSL status conversion macros.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define YNN_RETURN_IF_ERROR(expr) \
|
||||
do { \
|
||||
absl::Status s = YnnStatusToStatus(expr); \
|
||||
if (!s.ok()) { \
|
||||
return s; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define YNN_LOG_IF_ERROR(expr) \
|
||||
do { \
|
||||
absl::Status s = YnnStatusToStatus(expr); \
|
||||
if (!s.ok()) { \
|
||||
LOG(ERROR) << "YNNPACK operation failed: " << s; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Converts YNNPACK status to absl::Status.
|
||||
inline absl::Status YnnStatusToStatus(ynn_status status) {
|
||||
if (ABSL_PREDICT_TRUE(status == ynn_status_success)) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
auto error_message = [](ynn_status status) {
|
||||
switch (status) {
|
||||
case ynn_status_success:
|
||||
return "";
|
||||
case ynn_status_deprecated:
|
||||
return "deprecated";
|
||||
case ynn_status_error:
|
||||
return "error";
|
||||
case ynn_status_invalid_parameter:
|
||||
return "invalid parameter";
|
||||
case ynn_status_unsupported_parameter:
|
||||
return "unsupported parameter";
|
||||
}
|
||||
};
|
||||
|
||||
return Internal("YNNPACK operation failed: %s", error_message(status));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA to YNNPACK type conversions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
absl::StatusOr<ynn_type> YnnType(const PrimitiveType& type);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RAII wrappers for YNNPACK types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace internal {
|
||||
|
||||
struct YnnDeleter {
|
||||
void operator()(ynn_subgraph* subgraph) { ynn_delete_subgraph(subgraph); }
|
||||
void operator()(ynn_runtime* runtime) { ynn_delete_runtime(runtime); }
|
||||
void operator()(ynn_threadpool* threadpool) {
|
||||
ynn_delete_threadpool(threadpool);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using YnnSubgraph = std::unique_ptr<ynn_subgraph, internal::YnnDeleter>;
|
||||
using YnnRuntime = std::unique_ptr<ynn_runtime, internal::YnnDeleter>;
|
||||
using YnnThreadpool = std::unique_ptr<ynn_threadpool, internal::YnnDeleter>;
|
||||
|
||||
absl::StatusOr<YnnSubgraph> CreateYnnSubgraph(
|
||||
absl::FunctionRef<ynn_status(ynn_subgraph_t*)> builder);
|
||||
|
||||
absl::StatusOr<YnnRuntime> CreateYnnRuntime(
|
||||
absl::FunctionRef<ynn_status(ynn_runtime_t*)> builder);
|
||||
|
||||
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
|
||||
absl::FunctionRef<ynn_status(ynn_threadpool_t*)> builder);
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
#endif // XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_INTEROP_H_
|
||||
62
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_threadpool.cc
vendored
Normal file
62
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_threadpool.cc
vendored
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "ynnpack/include/ynnpack.h"
|
||||
#include "absl/base/optimization.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#include "Eigen/ThreadPool"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
static int32_t NumThreads(void* pool) {
|
||||
if (ABSL_PREDICT_FALSE(pool == nullptr)) {
|
||||
return 0;
|
||||
}
|
||||
return reinterpret_cast<Eigen::ThreadPoolInterface*>(pool)->NumThreads();
|
||||
}
|
||||
|
||||
static void Schedule(void* pool, void* context, void (*task)(void* context)) {
|
||||
if (ABSL_PREDICT_FALSE(pool == nullptr)) {
|
||||
(*task)(context);
|
||||
}
|
||||
reinterpret_cast<Eigen::ThreadPoolInterface*>(pool)->Schedule(
|
||||
[task, context]() { (*task)(context); });
|
||||
}
|
||||
|
||||
// An adaptor from Eigen::ThreadPoolInterface to xnn_threadpool_t.
|
||||
static constexpr ynn_scheduler kYnnScheduler = {&NumThreads, &Schedule};
|
||||
|
||||
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
|
||||
Eigen::ThreadPoolInterface* threadpool) {
|
||||
return CreateYnnThreadpool([&](ynn_threadpool_t* ynn_threadpool) {
|
||||
return ynn_create_threadpool(&kYnnScheduler, threadpool, /*flags=*/1,
|
||||
ynn_threadpool);
|
||||
});
|
||||
}
|
||||
|
||||
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
|
||||
const Eigen::ThreadPoolDevice* device) {
|
||||
return CreateYnnThreadpool(device->getPool());
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
39
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h
vendored
Normal file
39
third_party/xla/xla/backends/cpu/runtime/ynnpack/ynn_threadpool.h
vendored
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_THREADPOOL_H_
|
||||
#define XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_THREADPOOL_H_
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
|
||||
namespace Eigen {
|
||||
struct ThreadPoolDevice;
|
||||
class ThreadPoolInterface;
|
||||
} // namespace Eigen
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
// Creates an YNNPACK threadpool from an Eigen threadpool.
|
||||
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
|
||||
Eigen::ThreadPoolInterface* threadpool);
|
||||
|
||||
// Creates an YNNPACK threadpool from an Eigen ThreadPoolDevice.
|
||||
absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
|
||||
const Eigen::ThreadPoolDevice* device);
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
#endif // XLA_BACKENDS_CPU_RUNTIME_YNNPACK_YNN_THREADPOOL_H_
|
||||
306
third_party/xla/xla/backends/cpu/ynn_emitter.cc
vendored
Normal file
306
third_party/xla/xla/backends/cpu/ynn_emitter.cc
vendored
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/backends/cpu/ynn_emitter.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ynnpack/include/ynnpack.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/functional/any_invocable.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
#include "xla/backends/cpu/ynn_support.h"
|
||||
#include "xla/hlo/ir/hlo_computation.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/literal.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/tsl/platform/logging.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/util.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
// A mapping from HloInstruction to YNNPACK subgraph tensor id.
|
||||
using TensorIdMap = absl::flat_hash_map<const HloInstruction*, uint32_t>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA <-> YNNPACK type conversion library.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static std::vector<size_t> YnnDimensions(const Shape& shape) {
|
||||
std::vector<size_t> dims;
|
||||
for (auto& dim : shape.dimensions()) {
|
||||
dims.push_back(dim);
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA <-> YNNPACK emitters.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static absl::StatusOr<uint32_t> FindTensorValue(const TensorIdMap& tensor_ids,
|
||||
const HloInstruction* instr) {
|
||||
if (auto it = tensor_ids.find(instr); it != tensor_ids.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return Internal("Can't fine YNNPACK tensor value for instruction %s",
|
||||
instr->ToString());
|
||||
}
|
||||
|
||||
static absl::StatusOr<uint32_t> DefineTensorValue(ynn_subgraph_t subgraph,
|
||||
const HloInstruction* instr) {
|
||||
// We do not support instructions with multiple results (tuples).
|
||||
if (!instr->shape().IsArray()) {
|
||||
return Internal("Unsupported YNNPACK instruction shape: %s",
|
||||
instr->ToString());
|
||||
}
|
||||
|
||||
auto dims = YnnDimensions(instr->shape());
|
||||
TF_ASSIGN_OR_RETURN(auto type, YnnType(instr->shape().element_type()));
|
||||
|
||||
uint32_t tensor_id = YNN_INVALID_VALUE_ID;
|
||||
uint32_t tensor_flags = 0;
|
||||
|
||||
// If instruction is a root instruction of the parent computation we assign it
|
||||
// an external tensor id corresponding to the result index.
|
||||
const HloComputation* computation = instr->parent();
|
||||
if (computation->root_instruction() == instr) {
|
||||
tensor_id = computation->num_parameters();
|
||||
tensor_flags = YNN_VALUE_FLAG_EXTERNAL_OUTPUT;
|
||||
}
|
||||
|
||||
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
|
||||
subgraph, type, dims.size(), dims.data(), /*data=*/nullptr,
|
||||
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
|
||||
/*scale_id=*/YNN_INVALID_VALUE_ID, tensor_flags, &tensor_id));
|
||||
return tensor_id;
|
||||
}
|
||||
|
||||
static absl::StatusOr<uint32_t> DefineConstant(
|
||||
ynn_subgraph_t subgraph, std::vector<std::unique_ptr<Literal>>& literals,
|
||||
const HloInstruction* instr) {
|
||||
// We do not support instructions with multiple results (tuples).
|
||||
if (!instr->shape().IsArray()) {
|
||||
return Internal("Unsupported YNNPACK instruction shape: %s",
|
||||
instr->ToString());
|
||||
}
|
||||
|
||||
auto dims = YnnDimensions(instr->shape());
|
||||
TF_ASSIGN_OR_RETURN(auto type, YnnType(instr->shape().element_type()));
|
||||
|
||||
uint32_t tensor_id = YNN_INVALID_VALUE_ID;
|
||||
|
||||
literals.push_back(instr->literal().CloneToUnique());
|
||||
const void* value = literals.back()->untyped_data();
|
||||
|
||||
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
|
||||
subgraph, type, dims.size(), dims.data(), /*data=*/value,
|
||||
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
|
||||
/*scale_id=*/YNN_INVALID_VALUE_ID,
|
||||
/*flags=*/0, &tensor_id));
|
||||
|
||||
return tensor_id;
|
||||
}
|
||||
|
||||
static absl::StatusOr<uint32_t> DefineParameter(ynn_subgraph_t subgraph,
|
||||
const HloInstruction* param) {
|
||||
VLOG(3) << absl::StreamFormat("Define tensor value for parameter: %s",
|
||||
param->ToString());
|
||||
|
||||
auto dims = YnnDimensions(param->shape());
|
||||
TF_ASSIGN_OR_RETURN(auto type, YnnType(param->shape().element_type()));
|
||||
|
||||
uint32_t tensor_id = param->parameter_number();
|
||||
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
|
||||
subgraph, type, dims.size(), dims.data(), /*data=*/nullptr,
|
||||
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
|
||||
/*scale_id=*/YNN_INVALID_VALUE_ID, YNN_VALUE_FLAG_EXTERNAL_INPUT,
|
||||
&tensor_id));
|
||||
|
||||
return tensor_id;
|
||||
}
|
||||
|
||||
static absl::StatusOr<uint32_t> DefineBitcastOp(ynn_subgraph_t subgraph,
|
||||
TensorIdMap& tensor_ids,
|
||||
const HloInstruction* instr) {
|
||||
VLOG(3) << absl::StreamFormat("Define tensor value for bitcast op: %s",
|
||||
instr->ToString());
|
||||
CHECK_EQ(instr->opcode(), HloOpcode::kBitcast);
|
||||
const HloInstruction* input = instr->operand(0);
|
||||
CHECK_EQ(input->shape().element_type(), instr->shape().element_type());
|
||||
TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input));
|
||||
TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr));
|
||||
|
||||
auto dims = YnnDimensions(instr->shape());
|
||||
YNN_RETURN_IF_ERROR(ynn_define_static_reshape(subgraph, dims.size(),
|
||||
dims.data(), in, &out,
|
||||
/*flags=*/0));
|
||||
return out;
|
||||
}
|
||||
|
||||
static absl::StatusOr<uint32_t> DefineUnaryOp(ynn_subgraph_t subgraph,
|
||||
TensorIdMap& tensor_ids,
|
||||
const HloInstruction* instr) {
|
||||
VLOG(3) << absl::StreamFormat("Define tensor value for unary op: %s",
|
||||
instr->ToString());
|
||||
TF_ASSIGN_OR_RETURN(auto unary_op, YnnUnaryOperator(instr->opcode()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, instr->operand(0)));
|
||||
TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr));
|
||||
|
||||
VLOG(3) << absl::StreamFormat(" tensors: in=%d, out=%d", in, out);
|
||||
|
||||
YNN_RETURN_IF_ERROR(
|
||||
ynn_define_unary(subgraph, unary_op, in, &out, /*flags=*/0));
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static absl::StatusOr<uint32_t> DefineBinaryOp(ynn_subgraph_t subgraph,
|
||||
TensorIdMap& tensor_ids,
|
||||
const HloInstruction* instr) {
|
||||
VLOG(3) << absl::StreamFormat("Define tensor value for binary op: %s",
|
||||
instr->ToString());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto binary_op, YnnBinaryOperator(instr->opcode()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto lhs, FindTensorValue(tensor_ids, instr->operand(0)));
|
||||
TF_ASSIGN_OR_RETURN(auto rhs, FindTensorValue(tensor_ids, instr->operand(1)));
|
||||
TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr));
|
||||
|
||||
VLOG(3) << absl::StreamFormat(" tensors: lhs=%d, rhs=%d, out=%d", lhs, rhs,
|
||||
out);
|
||||
|
||||
YNN_RETURN_IF_ERROR(
|
||||
ynn_define_binary(subgraph, binary_op, lhs, rhs, &out, /*flags=*/0));
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Emit YNNPACK subgraph for the given HLO computation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static absl::StatusOr<YnnSubgraph> EmitYnnSubgraph(
|
||||
const HloComputation* computation,
|
||||
std::vector<std::unique_ptr<Literal>>& literals) {
|
||||
VLOG(3) << "Emit YNNPACK subgraph for computation: " << computation->name();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
YnnSubgraph subgraph, CreateYnnSubgraph([&](ynn_subgraph_t* subgraph) {
|
||||
return ynn_create_subgraph(
|
||||
/*external_value_ids=*/computation->num_parameters() + 1,
|
||||
/*flags=*/0, subgraph);
|
||||
}));
|
||||
|
||||
// Traverse fused computation in post-order and define YNNPACK operations
|
||||
// corresponding to each HLO instruction.
|
||||
TensorIdMap tensor_ids;
|
||||
auto instructions = computation->MakeInstructionPostOrder();
|
||||
|
||||
for (const HloInstruction* instr : instructions) {
|
||||
if (!IsLayoutSupportedByYnn(instr->shape())) {
|
||||
return InvalidArgument(
|
||||
"Instruction with unsupported layout in YNN fusion: %s",
|
||||
instr->ToString());
|
||||
}
|
||||
|
||||
if (instr->IsConstant()) {
|
||||
if (!IsConstantSupportedByYnn(instr)) {
|
||||
return InvalidArgument(
|
||||
"Unsupported constant instruction in YNN fusion: %s",
|
||||
instr->ToString());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(tensor_ids[instr],
|
||||
DefineConstant(subgraph.get(), literals, instr));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (instr->IsElementwise()) {
|
||||
if (!IsElementwiseOpSupportedByYnn(instr)) {
|
||||
return InvalidArgument(
|
||||
"Unsupported elementwise instruction in YNN fusion: %s",
|
||||
instr->ToString());
|
||||
}
|
||||
if (instr->operand_count() == 1) {
|
||||
TF_ASSIGN_OR_RETURN(tensor_ids[instr],
|
||||
DefineUnaryOp(subgraph.get(), tensor_ids, instr));
|
||||
} else if (instr->operand_count() == 2) {
|
||||
TF_ASSIGN_OR_RETURN(tensor_ids[instr],
|
||||
DefineBinaryOp(subgraph.get(), tensor_ids, instr));
|
||||
} else {
|
||||
LOG(FATAL) << "Unexpected operand count " << instr->operand_count();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (instr->opcode()) {
|
||||
case HloOpcode::kParameter: {
|
||||
TF_ASSIGN_OR_RETURN(tensor_ids[instr],
|
||||
DefineParameter(subgraph.get(), instr));
|
||||
} break;
|
||||
|
||||
case HloOpcode::kBitcast: {
|
||||
if (!IsBitcastOpSupportedByYnn(instr)) {
|
||||
return InvalidArgument(
|
||||
"Unsupported bitcast instruction in YNN fusion: %s",
|
||||
instr->ToString());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(tensor_ids[instr],
|
||||
DefineBitcastOp(subgraph.get(), tensor_ids, instr));
|
||||
} break;
|
||||
|
||||
default: {
|
||||
return InvalidArgument("Unsupported fusion instruction: %s",
|
||||
instr->ToString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return subgraph;
|
||||
}
|
||||
|
||||
absl::StatusOr<absl::AnyInvocable<absl::StatusOr<YnnSubgraph>()>>
|
||||
EmitYnnFusionBuilder(const HloComputation* computation) {
|
||||
// We do not support non-array parameters for YNNPACK operations.
|
||||
for (auto& param : computation->parameter_instructions()) {
|
||||
if (!param->shape().IsArray()) {
|
||||
return InvalidArgument(
|
||||
"YNNPACK fusion parameters must have array shapes, got %s",
|
||||
param->shape().ToString());
|
||||
}
|
||||
}
|
||||
|
||||
// Result also must be a single array.
|
||||
if (!computation->root_instruction()->shape().IsArray()) {
|
||||
return InvalidArgument("YNNPACK fusion result must be an array, got %s",
|
||||
computation->root_instruction()->shape().ToString());
|
||||
}
|
||||
|
||||
return [computation,
|
||||
literals = std::vector<std::unique_ptr<Literal>>()]() mutable {
|
||||
return EmitYnnSubgraph(computation, literals);
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
31
third_party/xla/xla/backends/cpu/ynn_emitter.h
vendored
Normal file
31
third_party/xla/xla/backends/cpu/ynn_emitter.h
vendored
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_BACKENDS_CPU_YNN_EMITTER_H_
|
||||
#define XLA_BACKENDS_CPU_YNN_EMITTER_H_
|
||||
|
||||
#include "absl/functional/any_invocable.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
#include "xla/hlo/ir/hlo_computation.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
absl::StatusOr<absl::AnyInvocable<absl::StatusOr<YnnSubgraph>()>>
|
||||
EmitYnnFusionBuilder(const HloComputation* computation);
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
#endif // XLA_BACKENDS_CPU_YNN_EMITTER_H_
|
||||
141
third_party/xla/xla/backends/cpu/ynn_support.cc
vendored
Normal file
141
third_party/xla/xla/backends/cpu/ynn_support.cc
vendored
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/backends/cpu/ynn_support.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "ynnpack/include/ynnpack.h"
|
||||
#include "absl/base/no_destructor.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/layout_util.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/util.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
const absl::flat_hash_map<HloOpcode, ynn_unary_operator>& GetYnnUnaryOpMap() {
|
||||
static absl::NoDestructor<absl::flat_hash_map<HloOpcode, ynn_unary_operator>>
|
||||
unary_op_map({
|
||||
{HloOpcode::kAbs, ynn_unary_abs},
|
||||
{HloOpcode::kCeil, ynn_unary_ceil},
|
||||
{HloOpcode::kConvert, ynn_unary_convert},
|
||||
{HloOpcode::kCos, ynn_unary_cosine},
|
||||
{HloOpcode::kExp, ynn_unary_exp},
|
||||
{HloOpcode::kCbrt, ynn_unary_cube_root},
|
||||
{HloOpcode::kFloor, ynn_unary_floor},
|
||||
{HloOpcode::kLog, ynn_unary_log},
|
||||
{HloOpcode::kLogistic, ynn_unary_sigmoid},
|
||||
{HloOpcode::kNegate, ynn_unary_negate},
|
||||
{HloOpcode::kRoundNearestEven, ynn_unary_round},
|
||||
{HloOpcode::kRsqrt, ynn_unary_reciprocal_square_root},
|
||||
{HloOpcode::kSign, ynn_unary_sign},
|
||||
{HloOpcode::kSin, ynn_unary_sine},
|
||||
{HloOpcode::kSqrt, ynn_unary_square_root},
|
||||
{HloOpcode::kTanh, ynn_unary_tanh},
|
||||
});
|
||||
return *unary_op_map;
|
||||
}
|
||||
|
||||
absl::StatusOr<ynn_unary_operator> YnnUnaryOperator(const HloOpcode& opcode) {
|
||||
const auto& unary_op_map = GetYnnUnaryOpMap();
|
||||
auto result = unary_op_map.find(opcode);
|
||||
if (result == unary_op_map.end()) {
|
||||
return InvalidArgument("Unsupported YNNPACK unary operator: %s",
|
||||
HloOpcodeString(opcode));
|
||||
}
|
||||
return result->second;
|
||||
}
|
||||
|
||||
const absl::flat_hash_map<HloOpcode, ynn_binary_operator>& GetYnnBinaryOpMap() {
|
||||
static absl::NoDestructor<absl::flat_hash_map<HloOpcode, ynn_binary_operator>>
|
||||
binary_op_map({
|
||||
{HloOpcode::kAdd, ynn_binary_add},
|
||||
{HloOpcode::kDivide, ynn_binary_divide},
|
||||
{HloOpcode::kMaximum, ynn_binary_max},
|
||||
{HloOpcode::kMinimum, ynn_binary_min},
|
||||
{HloOpcode::kMultiply, ynn_binary_multiply},
|
||||
{HloOpcode::kPower, ynn_binary_pow},
|
||||
{HloOpcode::kSubtract, ynn_binary_subtract},
|
||||
});
|
||||
return *binary_op_map;
|
||||
}
|
||||
|
||||
absl::StatusOr<ynn_binary_operator> YnnBinaryOperator(const HloOpcode& opcode) {
|
||||
const auto& binary_op_map = GetYnnBinaryOpMap();
|
||||
auto result = binary_op_map.find(opcode);
|
||||
if (result == binary_op_map.end()) {
|
||||
return InvalidArgument("Unsupported YNNPACK binary operator: %s",
|
||||
HloOpcodeString(opcode));
|
||||
}
|
||||
return result->second;
|
||||
}
|
||||
|
||||
bool IsLayoutSupportedByYnn(const Shape& shape) {
|
||||
return !shape.has_layout() || LayoutUtil::HasDescendingLayout(shape.layout());
|
||||
}
|
||||
|
||||
bool IsBitcastOpSupportedByYnn(const HloInstruction* hlo) {
|
||||
CHECK_EQ(hlo->opcode(), HloOpcode::kBitcast);
|
||||
if (!YnnType(hlo->shape().element_type()).ok()) {
|
||||
return false;
|
||||
}
|
||||
const HloInstruction* input = hlo->operand(0);
|
||||
return hlo->shape().element_type() == input->shape().element_type();
|
||||
}
|
||||
|
||||
bool IsConstantSupportedByYnn(const HloInstruction* hlo) {
|
||||
CHECK(hlo->IsConstant());
|
||||
|
||||
if (!YnnType(hlo->shape().element_type()).ok()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return hlo->shape().IsArray();
|
||||
}
|
||||
|
||||
bool IsElementwiseOpSupportedByYnn(const HloInstruction* hlo) {
|
||||
CHECK(hlo->IsElementwise());
|
||||
// In XLA IsElementwise is true for constants.
|
||||
CHECK(!hlo->IsConstant());
|
||||
|
||||
if (!YnnType(hlo->shape().element_type()).ok()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!std::all_of(hlo->operands().begin(), hlo->operands().end(),
|
||||
[](const HloInstruction* op) {
|
||||
return YnnType(op->shape().element_type()).ok();
|
||||
})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (hlo->operand_count()) {
|
||||
case 1:
|
||||
return YnnUnaryOperator(hlo->opcode()).ok();
|
||||
case 2:
|
||||
return YnnBinaryOperator(hlo->opcode()).ok();
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
60
third_party/xla/xla/backends/cpu/ynn_support.h
vendored
Normal file
60
third_party/xla/xla/backends/cpu/ynn_support.h
vendored
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_BACKENDS_CPU_YNN_SUPPORT_H_
|
||||
#define XLA_BACKENDS_CPU_YNN_SUPPORT_H_
|
||||
|
||||
#include "ynnpack/include/ynnpack.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
inline constexpr absl::string_view kYnnFusionKind = "__ynn_fusion";
|
||||
|
||||
// Returns the mappings from HLO opcodes to YNNPACK unary operators.
|
||||
const absl::flat_hash_map<HloOpcode, ynn_unary_operator>& GetYnnUnaryOpMap();
|
||||
|
||||
// Returns the YNNPACK unary operator corresponding to the given HLO opcode.
|
||||
// Returns `InvalidArgument` if the opcode is not supported.
|
||||
absl::StatusOr<ynn_unary_operator> YnnUnaryOperator(const HloOpcode& opcode);
|
||||
|
||||
// Returns the mappings from HLO opcodes to YNNPACK binary operators.
|
||||
const absl::flat_hash_map<HloOpcode, ynn_binary_operator>& GetYnnBinaryOpMap();
|
||||
|
||||
// Returns the YNNPACK binary operator corresponding to the given HLO opcode.
|
||||
// Returns `InvalidArgument` if the opcode is not supported.
|
||||
absl::StatusOr<ynn_binary_operator> YnnBinaryOperator(const HloOpcode& opcode);
|
||||
|
||||
// Returns true if the shape either doesn't have a layout or the layout is
|
||||
// descending. Shapes without layout are accepted to make HLO tests less
|
||||
// verbose.
|
||||
bool IsLayoutSupportedByYnn(const Shape& shape);
|
||||
|
||||
// Returns true if the bitcast op is supported by YNNPACK.
|
||||
bool IsBitcastOpSupportedByYnn(const HloInstruction* hlo);
|
||||
|
||||
// Returns true if the constant is supported by YNNPACK.
|
||||
bool IsConstantSupportedByYnn(const HloInstruction* hlo);
|
||||
|
||||
// Returns true if the nonconstant elementwise op is supported by YNNPACK.
|
||||
bool IsElementwiseOpSupportedByYnn(const HloInstruction* hlo);
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
#endif // XLA_BACKENDS_CPU_YNN_SUPPORT_H_
|
||||
14
third_party/xla/xla/pjrt/cpu/cpu_client.cc
vendored
14
third_party/xla/xla/pjrt/cpu/cpu_client.cc
vendored
|
|
@ -1676,6 +1676,12 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
|
|||
cpu::Thunk::XnnParams::Create(&run_options));
|
||||
}
|
||||
|
||||
std::optional<cpu::Thunk::YnnParams> ynn_params;
|
||||
if (cpu_executable->has_ynn_fusions()) {
|
||||
TF_ASSIGN_OR_RETURN(ynn_params,
|
||||
cpu::Thunk::YnnParams::Create(&run_options));
|
||||
}
|
||||
|
||||
cpu::ThreadPoolTaskRunner task_runner(
|
||||
run_options.intra_op_thread_pool()->getPool());
|
||||
|
||||
|
|
@ -1688,6 +1694,7 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
|
|||
&collective_params,
|
||||
&custom_call_execute_params,
|
||||
xnn_params ? &*xnn_params : nullptr,
|
||||
ynn_params ? &*ynn_params : nullptr,
|
||||
run_options.run_id().ToInt(),
|
||||
run_options.device_ordinal(),
|
||||
};
|
||||
|
|
@ -1814,6 +1821,12 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
|
|||
xnn_params = cpu::Thunk::XnnParams::Create(&run_options);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::optional<cpu::Thunk::YnnParams>> ynn_params(
|
||||
std::nullopt);
|
||||
if (cpu_executable->has_ynn_fusions()) {
|
||||
ynn_params = cpu::Thunk::YnnParams::Create(&run_options);
|
||||
}
|
||||
|
||||
cpu::ThreadPoolTaskRunner task_runner(
|
||||
run_options.intra_op_thread_pool()->getPool());
|
||||
|
||||
|
|
@ -1827,6 +1840,7 @@ absl::StatusOr<PjRtLoadedExecutable::Result> PjRtCpuExecutable::ExecuteHelper(
|
|||
&*collective_params,
|
||||
&*custom_call_params,
|
||||
*xnn_params ? &**xnn_params : nullptr,
|
||||
*ynn_params ? &**ynn_params : nullptr,
|
||||
run_options.run_id().ToInt(),
|
||||
run_options.device_ordinal(),
|
||||
};
|
||||
|
|
|
|||
7
third_party/xla/xla/service/cpu/BUILD
vendored
7
third_party/xla/xla/service/cpu/BUILD
vendored
|
|
@ -24,6 +24,7 @@ load(
|
|||
"if_llvm_x86_available",
|
||||
)
|
||||
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
|
||||
load("//xla/tsl/xnnpack:build_defs.bzl", "if_ynnpack")
|
||||
load(":build_defs.bzl", "runtime_copts")
|
||||
|
||||
package(
|
||||
|
|
@ -931,7 +932,7 @@ cc_library(
|
|||
srcs = ["thunk_emitter.cc"],
|
||||
hdrs = ["thunk_emitter.h"],
|
||||
copts = tsl_copts(),
|
||||
local_defines = if_graph_api(["XLA_ONEDNN_USE_GRAPH_API=1"]),
|
||||
local_defines = if_graph_api(["XLA_ONEDNN_USE_GRAPH_API=1"]) + if_ynnpack(["XLA_YNNPACK"]),
|
||||
deps = [
|
||||
":backend_config_proto_cc",
|
||||
":cpu_options",
|
||||
|
|
@ -1023,6 +1024,10 @@ cc_library(
|
|||
"@local_tsl//tsl/profiler/lib:traceme",
|
||||
] + if_onednn([
|
||||
"//xla/backends/cpu/runtime/onednn:onednn_op_thunk",
|
||||
]) + if_ynnpack([
|
||||
"//xla/backends/cpu:ynn_emitter",
|
||||
"//xla/backends/cpu:ynn_support",
|
||||
"//xla/backends/cpu/runtime/ynnpack:ynn_fusion_thunk",
|
||||
]),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -110,6 +110,12 @@ absl::StatusOr<std::unique_ptr<CpuExecutable>> CpuExecutable::Create(
|
|||
executable->has_xnn_fusions_ |= thunk.kind() == Thunk::Kind::kXnnFusion;
|
||||
});
|
||||
|
||||
// Find if the thunk sequence contains any YNN fusion thunks. If we do have
|
||||
// any, we will prepare the YNNPACK thread pool for them at run time.
|
||||
executable->thunks_->thunk_sequence().ForEach([&](const Thunk& thunk) {
|
||||
executable->has_ynn_fusions_ |= thunk.kind() == Thunk::Kind::kYnnFusion;
|
||||
});
|
||||
|
||||
// Re-index constants by their allocation index to allow efficient lookup.
|
||||
for (auto& constant : constants) {
|
||||
if (executable->constants_.size() <= constant.index) {
|
||||
|
|
|
|||
|
|
@ -128,6 +128,7 @@ class CpuExecutable : public Executable {
|
|||
ThunkExecutor& thunks() { return *thunks_; }
|
||||
|
||||
bool has_xnn_fusions() const { return has_xnn_fusions_; }
|
||||
bool has_ynn_fusions() const { return has_ynn_fusions_; }
|
||||
|
||||
const BufferAssignment& buffer_assignment() const { return *assignment_; }
|
||||
absl::Span<const ConstantAllocation> constants() const { return constants_; }
|
||||
|
|
@ -230,6 +231,9 @@ class CpuExecutable : public Executable {
|
|||
// Whether the thunk executor contains any XNN fusion thunks.
|
||||
bool has_xnn_fusions_ = false;
|
||||
|
||||
// Whether the thunk executor contains any YNN fusion thunks.
|
||||
bool has_ynn_fusions_ = false;
|
||||
|
||||
// Entry function name for the computation.
|
||||
std::string entry_function_name_;
|
||||
|
||||
|
|
|
|||
51
third_party/xla/xla/service/cpu/thunk_emitter.cc
vendored
51
third_party/xla/xla/service/cpu/thunk_emitter.cc
vendored
|
|
@ -126,6 +126,12 @@ limitations under the License.
|
|||
#include "xla/backends/cpu/runtime/onednn/onednn_fusion_thunk.h"
|
||||
#endif // XLA_ONEDNN_USE_GRAPH_API
|
||||
|
||||
#ifdef XLA_YNNPACK
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_fusion_thunk.h"
|
||||
#include "xla/backends/cpu/ynn_emitter.h"
|
||||
#include "xla/backends/cpu/ynn_support.h"
|
||||
#endif // XLA_YNNPACK
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
namespace {
|
||||
|
|
@ -440,6 +446,12 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitHloInstruction(
|
|||
return EmitXnnFusionThunk(instruction);
|
||||
}
|
||||
|
||||
#ifdef XLA_YNNPACK
|
||||
if (backend_config.fusion_config().kind() == kYnnFusionKind) {
|
||||
return EmitYnnFusionThunk(instruction);
|
||||
}
|
||||
#endif // XLA_YNNPACK
|
||||
|
||||
return Internal("Unsupported custom fusion kind: %s",
|
||||
backend_config.DebugString());
|
||||
}
|
||||
|
|
@ -1494,6 +1506,45 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitXnnFusionThunk(
|
|||
[b = std::move(builder)](auto, auto) mutable { return b(); });
|
||||
}
|
||||
|
||||
absl::StatusOr<ThunkSequence> ThunkEmitter::EmitYnnFusionThunk(
|
||||
const HloInstruction* instruction) {
|
||||
#ifdef XLA_YNNPACK
|
||||
auto* fusion = Cast<HloFusionInstruction>(instruction);
|
||||
|
||||
// Collect YNNPACK fusion arguments.
|
||||
std::vector<YnnFusionThunk::Argument> arguments;
|
||||
for (HloInstruction* operand : instruction->operands()) {
|
||||
for (auto& indexed : ShapeUtil::GetLeafShapes(operand->shape())) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
BufferAllocation::Slice slice,
|
||||
buffer_assignment_.GetUniqueSlice(operand, indexed.index));
|
||||
arguments.push_back(YnnFusionThunk::Argument{slice, indexed.shape});
|
||||
}
|
||||
}
|
||||
|
||||
// Collect YNNPACK fusion results.
|
||||
std::vector<YnnFusionThunk::Result> results;
|
||||
for (auto& indexed : ShapeUtil::GetLeafShapes(instruction->shape())) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
BufferAllocation::Slice slice,
|
||||
buffer_assignment_.GetUniqueSlice(instruction, indexed.index));
|
||||
results.push_back(YnnFusionThunk::Result{slice, indexed.shape});
|
||||
}
|
||||
|
||||
const HloComputation* computation = fusion->fused_instructions_computation();
|
||||
|
||||
// Construct YNNPACK subgraph builder from the fusion computation.
|
||||
TF_ASSIGN_OR_RETURN(auto builder, EmitYnnFusionBuilder(computation));
|
||||
|
||||
return ThunkSequence::Of<YnnFusionThunk>(
|
||||
YnnFusionThunk::Options{}, ThunkInfo(instruction), std::move(arguments),
|
||||
std::move(results),
|
||||
[b = std::move(builder)](auto, auto) mutable { return b(); });
|
||||
#else
|
||||
return Unimplemented("XLA is not built with YNNPACK.");
|
||||
#endif // XLA_YNNPACK
|
||||
}
|
||||
|
||||
absl::StatusOr<ThunkEmitter::HostKernelAllocationSlices>
|
||||
ThunkEmitter::GetHostKernelAllocationSlices(const HloInstruction* instruction) {
|
||||
HostKernelAllocationSlices slices;
|
||||
|
|
|
|||
|
|
@ -217,6 +217,9 @@ class ThunkEmitter {
|
|||
absl::StatusOr<ThunkSequence> EmitXnnFusionThunk(
|
||||
const HloInstruction* instruction);
|
||||
|
||||
absl::StatusOr<ThunkSequence> EmitYnnFusionThunk(
|
||||
const HloInstruction* instruction);
|
||||
|
||||
absl::StatusOr<ThunkSequence> EmitOneDnnFusionThunk(
|
||||
const HloInstruction* instruction);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user