Add initial bits to support reductions.

PiperOrigin-RevId: 825292462
This commit is contained in:
Alexander Shaposhnikov 2025-10-28 18:21:42 -07:00 committed by TensorFlower Gardener
parent e3549cef96
commit b81ecb432f
9 changed files with 248 additions and 5 deletions

View File

@ -239,6 +239,7 @@ cc_library(
"//xla/backends/cpu/runtime:dot_lib",
"//xla/backends/cpu/runtime/ynnpack:ynn_interop",
"//xla/hlo/ir:hlo",
"//xla/service:pattern_matcher",
"//xla/tsl/platform:statusor",
"@XNNPACK//ynnpack",
"@com_google_absl//absl/base:no_destructor",

View File

@ -3,6 +3,7 @@ load("//xla/tsl:tsl.bzl", "internal_visibility")
load("//xla/tsl/mkl:build_defs.bzl", "if_graph_api")
load("//xla/tsl/mkl:graph.bzl", "onednn_graph_cc_library")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
load("//xla/tsl/xnnpack:build_defs.bzl", "if_ynnpack")
package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
@ -42,7 +43,7 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:protobuf",
],
] + if_ynnpack([":ynn_matcher"]),
)
xla_cc_test(
@ -115,6 +116,21 @@ cc_library(
],
)
cc_library(
name = "ynn_matcher",
hdrs = ["ynn_matcher.h"],
deps = [
":library_matcher",
"//xla/backends/cpu/codegen:target_machine_features",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:protobuf",
] + if_ynnpack(["//xla/backends/cpu:ynn_support"]),
)
cc_library(
name = "xnn_graph_fusion",
srcs = ["xnn_graph_fusion.cc"],

View File

@ -39,6 +39,10 @@ limitations under the License.
#include "xla/backends/cpu/transforms/onednn_matcher.h"
#endif // XLA_ONEDNN_USE_GRAPH_API
#ifdef XLA_YNNPACK
#include "xla/backends/cpu/transforms/ynn_matcher.h"
#endif
namespace xla::cpu {
enum class FusionDirection {
@ -50,8 +54,10 @@ enum class FusionDirection {
struct LibraryRewriterOptions {
bool use_onednn = false;
bool use_xnnpack = false;
bool use_ynnpack = false;
const tsl::protobuf::RepeatedField<int>* onednn_fusion_types = nullptr;
const tsl::protobuf::RepeatedField<int>* xnn_fusion_types = nullptr;
const tsl::protobuf::RepeatedField<int>* ynn_fusion_types = nullptr;
};
// Rewrites suitable Dot operations into library fusions.
@ -74,6 +80,14 @@ class LibraryRewriter : public HloModulePass {
libs_.push_back(std::make_unique<XnnMatcher>(target_machine_features_,
options_.xnn_fusion_types));
}
#ifdef XLA_YNNPACK
if (options_.use_ynnpack && options_.ynn_fusion_types != nullptr &&
!options_.ynn_fusion_types->empty()) {
libs_.push_back(std::make_unique<YnnMatcher>(target_machine_features_,
options_.ynn_fusion_types));
}
#endif // XLA_YNNPACK
for (std::unique_ptr<LibraryMatcher>& lib : libs_) {
supported_ops_.merge(lib->SupportedOps());
}

View File

@ -101,11 +101,16 @@ class CpuLibraryTest : public TargetMachineTestBase {
tsl::protobuf::RepeatedField<int> empty_fusion_types;
bool use_onednn = spec.lib == "onednn";
bool use_xnnpack = spec.lib == "xnn";
bool use_ynnpack = spec.lib == "ynn";
LibraryRewriterOptions options = {
use_onednn, use_xnnpack,
use_onednn,
use_xnnpack,
use_ynnpack,
/*onednn_fusion_types=*/
use_onednn ? &fusion_types : &empty_fusion_types,
/*xnn_fusion_types=*/use_xnnpack ? &fusion_types : &empty_fusion_types};
/*xnn_fusion_types=*/use_xnnpack ? &fusion_types : &empty_fusion_types,
/*ynn_fusion_types=*/use_ynnpack ? &fusion_types : &empty_fusion_types,
};
LibraryRewriter rewriter(features.get(), options);
EXPECT_EQ(expected.changed, rewriter.Run(module.get()).value());
if (!expected.changed) {

View File

@ -0,0 +1,115 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_BACKENDS_CPU_TRANSFORMS_YNN_MATCHER_H_
#define XLA_BACKENDS_CPU_TRANSFORMS_YNN_MATCHER_H_
#include <string>
#include "absl/base/no_destructor.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/backends/cpu/codegen/target_machine_features.h"
#include "xla/backends/cpu/transforms/library_matcher.h"
#include "xla/backends/cpu/ynn_support.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "tsl/platform/protobuf.h"
namespace xla::cpu {
class YnnMatcher : public LibraryMatcher {
public:
explicit YnnMatcher(const TargetMachineFeatures* target_machine_features,
const tsl::protobuf::RepeatedField<int>* fusion_types)
: LibraryMatcher(target_machine_features, fusion_types) {}
~YnnMatcher() override = default;
// Returns the set of supported HLO instructions.
absl::flat_hash_set<HloOpcode> SupportedOps() const override {
static const absl::NoDestructor<absl::flat_hash_set<HloOpcode>>
kSupportedOps{[]() {
absl::flat_hash_set<HloOpcode> supported_ops{
HloOpcode::kDot, HloOpcode::kReduce, HloOpcode::kConstant};
for (const auto& [op, _] : GetYnnUnaryOpMap()) {
supported_ops.insert(op);
}
for (const auto& [op, _] : GetYnnBinaryOpMap()) {
supported_ops.insert(op);
}
return supported_ops;
}()};
return *kSupportedOps;
}
// Returns true if the HLO instruction is supported by the library.
absl::StatusOr<bool> IsOpSupported(const HloInstruction* instr) override {
if (instr->opcode() == HloOpcode::kDot) {
return IsDotSupportedByYnn(instr->dot_dimension_numbers(),
instr->operand(0)->shape(),
instr->operand(1)->shape(), instr->shape());
}
if (instr->opcode() == HloOpcode::kReduce) {
return IsReduceOpSupportedByYnn(instr);
}
if (instr->IsConstant()) {
return IsConstantSupportedByYnn(instr);
}
// TODO(b/441837668): Need to get the reduction performance right before
// enabling fusions. Fusions make performance analysis quite challenging.
if (fuse_reduce_) {
return false;
}
if (instr->IsElementwise()) {
return IsElementwiseOpSupportedByYnn(instr);
}
return false;
}
// Returns true if we should start a new fusion containing just the given HLO
// instruction. We control the instructions that can start a fusion with the
// `--xla_cpu_experimental_ynn_fusion_type` flag.
bool ShouldCreateFusion(const HloInstruction* instr) override {
if (fuse_dot_ && instr->opcode() == HloOpcode::kDot) {
return true;
}
if (fuse_reduce_ && instr->opcode() == HloOpcode::kReduce) {
return true;
}
return fuse_eltwise_ && instr->IsElementwise();
}
PrimitiveType LibraryOpOutputType(const HloInstruction* instr) override {
auto out_type = instr->shape().element_type();
if (instr->opcode() != HloOpcode::kDot) {
return out_type;
}
return out_type == BF16 ? F32 : out_type;
}
// Returns a prefix string for the fusion op's name.
std::string fusion_prefix() const override { return "ynn_"; }
// Returns a string for FusionBackendConfig's fusion kind.
absl::string_view fusion_kind() const override { return kYnnFusionKind; }
private:
absl::flat_hash_set<DebugOptions::LibraryFusionType> fusion_types_;
};
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_TRANSFORMS_YNN_MATCHER_H_

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "xla/backends/cpu/runtime/dot_lib.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/backends/cpu/ynn_support.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
@ -206,6 +207,48 @@ static absl::StatusOr<uint32_t> DefineBinaryOp(ynn_subgraph_t subgraph,
return out;
}
static absl::StatusOr<uint32_t> DefineReduceOp(ynn_subgraph_t subgraph,
TensorIdMap& tensor_ids,
const HloInstruction* instr) {
VLOG(3) << absl::StreamFormat("Define tensor value for reduce op: %s",
instr->ToString());
CHECK_EQ(instr->opcode(), HloOpcode::kReduce);
const HloReduceInstruction* reduce_instr = Cast<HloReduceInstruction>(instr);
const HloInstruction* input = instr->operand(0);
const HloInstruction* init = instr->operand(1);
CHECK_EQ(input->shape().element_type(), instr->shape().element_type());
CHECK_EQ(init->shape().element_type(), instr->shape().element_type());
ynn_reduce_operator ynn_reduce_op = ynn_reduce_invalid;
CHECK_EQ(reduce_instr->to_apply()->num_parameters(), 2);
CHECK_EQ(reduce_instr->to_apply()->instruction_count(), 3);
switch (reduce_instr->to_apply()->root_instruction()->opcode()) {
case HloOpcode::kAdd:
ynn_reduce_op = ynn_reduce_sum;
break;
case HloOpcode::kMaximum:
ynn_reduce_op = ynn_reduce_max;
break;
case HloOpcode::kMinimum:
ynn_reduce_op = ynn_reduce_min;
break;
default:
LOG(FATAL) << "Unsupported reduction: " << instr->to_apply()->ToString();
}
const absl::Span<const int64_t> reduce_dims = reduce_instr->dimensions();
const std::vector<int32_t> dims(reduce_dims.begin(), reduce_dims.end());
TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input));
TF_ASSIGN_OR_RETURN(auto init_id, FindTensorValue(tensor_ids, init));
TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr));
YNN_RETURN_IF_ERROR(
ynn_define_reduce(subgraph, ynn_reduce_op, /*num_axes=*/dims.size(),
/*axes=*/dims.data(), in, init_id, &out, /*flags=*/0));
return out;
}
//===----------------------------------------------------------------------===//
// Emit YNNPACK subgraph for the given HLO computation.
//===----------------------------------------------------------------------===//
@ -279,6 +322,11 @@ static absl::StatusOr<YnnSubgraph> EmitYnnSubgraph(
DefineBitcastOp(subgraph.get(), tensor_ids, instr));
} break;
case HloOpcode::kReduce: {
TF_ASSIGN_OR_RETURN(tensor_ids[instr],
DefineReduceOp(subgraph.get(), tensor_ids, instr));
} break;
default: {
return InvalidArgument("Unsupported fusion instruction: %s",
instr->ToString());

View File

@ -27,9 +27,12 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/dot_lib.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout_util.h"
#include "xla/service/pattern_matcher.h"
#include "xla/shape.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
@ -203,4 +206,36 @@ absl::StatusOr<bool> IsDotSupportedByYnn(
return true;
}
bool IsReduceOpSupportedByYnn(const HloInstruction* hlo) {
CHECK_EQ(hlo->opcode(), HloOpcode::kReduce);
if (!YnnType(hlo->shape().element_type()).ok()) {
return false;
}
const HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
CHECK_NE(reduce, nullptr);
// TODO(ashaposhnikov): we can support this edge case,
// planning to come back to this later.
if (reduce->dimensions().empty()) {
return false;
}
HloInstruction* init = reduce->init_values().front();
const PrimitiveType type = init->shape().element_type();
// TODO(ashaposhnikov): The list of supported types can be extended.
if (type != F32) {
return false;
}
if (type != hlo->shape().element_type()) {
return false;
}
const HloComputation* to_apply = reduce->to_apply();
CHECK_NE(to_apply, nullptr);
return Match(to_apply->root_instruction(),
match::AnyOf<HloInstruction>(match::Add(), match::Maximum(),
match::Minimum())
.WithBinaryOperandsAnyOrder(match::Parameter(0),
match::Parameter(1)));
}
} // namespace xla::cpu

View File

@ -62,6 +62,9 @@ absl::StatusOr<bool> IsDotSupportedByYnn(
const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape,
const Shape& rhs_shape, const Shape& out_shape);
// Returns true if the reduce op is supported by YNNPACK.
bool IsReduceOpSupportedByYnn(const HloInstruction* hlo);
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_YNN_SUPPORT_H_

View File

@ -969,14 +969,20 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn(
// XNNPACK ops availability checks depend on the layout information,
// so until another solution is developed the passes creating XNNPACK fusions
// have to run after layout assignment.
const bool use_ynnpack = absl::c_linear_search(
debug_options.xla_cpu_experimental_ynn_fusion_type(),
DebugOptions::LIBRARY_FUSION_TYPE_REDUCE);
LibraryRewriterOptions options = {
/*use_onednn=*/debug_options.xla_cpu_use_onednn(),
/*use_xnnpack=*/debug_options.xla_cpu_use_xnnpack(),
/*use_ynnpack=*/use_ynnpack,
/*onednn_fusion_types=*/
&debug_options.xla_cpu_experimental_onednn_fusion_type(),
/*xnn_fusion_types=*/
&debug_options.xla_cpu_experimental_xnn_fusion_type()};
if (options.use_onednn || options.use_xnnpack) {
&debug_options.xla_cpu_experimental_xnn_fusion_type(),
/*ynn_fusion_types=*/
&debug_options.xla_cpu_experimental_ynn_fusion_type()};
if (options.use_onednn || options.use_xnnpack || options.use_ynnpack) {
HloPassPipeline lib_pipeline("dot-library-passes");
lib_pipeline.AddPass<DotDecomposer>();
lib_pipeline.AddPass<LibraryRewriter>(target_machine_features, options);