mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add initial bits to support reductions.
PiperOrigin-RevId: 825292462
This commit is contained in:
parent
e3549cef96
commit
b81ecb432f
1
third_party/xla/xla/backends/cpu/BUILD
vendored
1
third_party/xla/xla/backends/cpu/BUILD
vendored
|
|
@ -239,6 +239,7 @@ cc_library(
|
|||
"//xla/backends/cpu/runtime:dot_lib",
|
||||
"//xla/backends/cpu/runtime/ynnpack:ynn_interop",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/service:pattern_matcher",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@XNNPACK//ynnpack",
|
||||
"@com_google_absl//absl/base:no_destructor",
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ load("//xla/tsl:tsl.bzl", "internal_visibility")
|
|||
load("//xla/tsl/mkl:build_defs.bzl", "if_graph_api")
|
||||
load("//xla/tsl/mkl:graph.bzl", "onednn_graph_cc_library")
|
||||
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
|
||||
load("//xla/tsl/xnnpack:build_defs.bzl", "if_ynnpack")
|
||||
|
||||
package(
|
||||
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
|
||||
|
|
@ -42,7 +43,7 @@ cc_library(
|
|||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@local_tsl//tsl/platform:protobuf",
|
||||
],
|
||||
] + if_ynnpack([":ynn_matcher"]),
|
||||
)
|
||||
|
||||
xla_cc_test(
|
||||
|
|
@ -115,6 +116,21 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ynn_matcher",
|
||||
hdrs = ["ynn_matcher.h"],
|
||||
deps = [
|
||||
":library_matcher",
|
||||
"//xla/backends/cpu/codegen:target_machine_features",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"@com_google_absl//absl/base:no_destructor",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@local_tsl//tsl/platform:protobuf",
|
||||
] + if_ynnpack(["//xla/backends/cpu:ynn_support"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xnn_graph_fusion",
|
||||
srcs = ["xnn_graph_fusion.cc"],
|
||||
|
|
|
|||
|
|
@ -39,6 +39,10 @@ limitations under the License.
|
|||
#include "xla/backends/cpu/transforms/onednn_matcher.h"
|
||||
#endif // XLA_ONEDNN_USE_GRAPH_API
|
||||
|
||||
#ifdef XLA_YNNPACK
|
||||
#include "xla/backends/cpu/transforms/ynn_matcher.h"
|
||||
#endif
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
enum class FusionDirection {
|
||||
|
|
@ -50,8 +54,10 @@ enum class FusionDirection {
|
|||
struct LibraryRewriterOptions {
|
||||
bool use_onednn = false;
|
||||
bool use_xnnpack = false;
|
||||
bool use_ynnpack = false;
|
||||
const tsl::protobuf::RepeatedField<int>* onednn_fusion_types = nullptr;
|
||||
const tsl::protobuf::RepeatedField<int>* xnn_fusion_types = nullptr;
|
||||
const tsl::protobuf::RepeatedField<int>* ynn_fusion_types = nullptr;
|
||||
};
|
||||
|
||||
// Rewrites suitable Dot operations into library fusions.
|
||||
|
|
@ -74,6 +80,14 @@ class LibraryRewriter : public HloModulePass {
|
|||
libs_.push_back(std::make_unique<XnnMatcher>(target_machine_features_,
|
||||
options_.xnn_fusion_types));
|
||||
}
|
||||
#ifdef XLA_YNNPACK
|
||||
if (options_.use_ynnpack && options_.ynn_fusion_types != nullptr &&
|
||||
!options_.ynn_fusion_types->empty()) {
|
||||
libs_.push_back(std::make_unique<YnnMatcher>(target_machine_features_,
|
||||
options_.ynn_fusion_types));
|
||||
}
|
||||
#endif // XLA_YNNPACK
|
||||
|
||||
for (std::unique_ptr<LibraryMatcher>& lib : libs_) {
|
||||
supported_ops_.merge(lib->SupportedOps());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -101,11 +101,16 @@ class CpuLibraryTest : public TargetMachineTestBase {
|
|||
tsl::protobuf::RepeatedField<int> empty_fusion_types;
|
||||
bool use_onednn = spec.lib == "onednn";
|
||||
bool use_xnnpack = spec.lib == "xnn";
|
||||
bool use_ynnpack = spec.lib == "ynn";
|
||||
LibraryRewriterOptions options = {
|
||||
use_onednn, use_xnnpack,
|
||||
use_onednn,
|
||||
use_xnnpack,
|
||||
use_ynnpack,
|
||||
/*onednn_fusion_types=*/
|
||||
use_onednn ? &fusion_types : &empty_fusion_types,
|
||||
/*xnn_fusion_types=*/use_xnnpack ? &fusion_types : &empty_fusion_types};
|
||||
/*xnn_fusion_types=*/use_xnnpack ? &fusion_types : &empty_fusion_types,
|
||||
/*ynn_fusion_types=*/use_ynnpack ? &fusion_types : &empty_fusion_types,
|
||||
};
|
||||
LibraryRewriter rewriter(features.get(), options);
|
||||
EXPECT_EQ(expected.changed, rewriter.Run(module.get()).value());
|
||||
if (!expected.changed) {
|
||||
|
|
|
|||
115
third_party/xla/xla/backends/cpu/transforms/ynn_matcher.h
vendored
Normal file
115
third_party/xla/xla/backends/cpu/transforms/ynn_matcher.h
vendored
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_BACKENDS_CPU_TRANSFORMS_YNN_MATCHER_H_
|
||||
#define XLA_BACKENDS_CPU_TRANSFORMS_YNN_MATCHER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/base/no_destructor.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "xla/backends/cpu/codegen/target_machine_features.h"
|
||||
#include "xla/backends/cpu/transforms/library_matcher.h"
|
||||
#include "xla/backends/cpu/ynn_support.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "tsl/platform/protobuf.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
class YnnMatcher : public LibraryMatcher {
|
||||
public:
|
||||
explicit YnnMatcher(const TargetMachineFeatures* target_machine_features,
|
||||
const tsl::protobuf::RepeatedField<int>* fusion_types)
|
||||
: LibraryMatcher(target_machine_features, fusion_types) {}
|
||||
~YnnMatcher() override = default;
|
||||
|
||||
// Returns the set of supported HLO instructions.
|
||||
absl::flat_hash_set<HloOpcode> SupportedOps() const override {
|
||||
static const absl::NoDestructor<absl::flat_hash_set<HloOpcode>>
|
||||
kSupportedOps{[]() {
|
||||
absl::flat_hash_set<HloOpcode> supported_ops{
|
||||
HloOpcode::kDot, HloOpcode::kReduce, HloOpcode::kConstant};
|
||||
for (const auto& [op, _] : GetYnnUnaryOpMap()) {
|
||||
supported_ops.insert(op);
|
||||
}
|
||||
for (const auto& [op, _] : GetYnnBinaryOpMap()) {
|
||||
supported_ops.insert(op);
|
||||
}
|
||||
return supported_ops;
|
||||
}()};
|
||||
return *kSupportedOps;
|
||||
}
|
||||
|
||||
// Returns true if the HLO instruction is supported by the library.
|
||||
absl::StatusOr<bool> IsOpSupported(const HloInstruction* instr) override {
|
||||
if (instr->opcode() == HloOpcode::kDot) {
|
||||
return IsDotSupportedByYnn(instr->dot_dimension_numbers(),
|
||||
instr->operand(0)->shape(),
|
||||
instr->operand(1)->shape(), instr->shape());
|
||||
}
|
||||
if (instr->opcode() == HloOpcode::kReduce) {
|
||||
return IsReduceOpSupportedByYnn(instr);
|
||||
}
|
||||
if (instr->IsConstant()) {
|
||||
return IsConstantSupportedByYnn(instr);
|
||||
}
|
||||
// TODO(b/441837668): Need to get the reduction performance right before
|
||||
// enabling fusions. Fusions make performance analysis quite challenging.
|
||||
if (fuse_reduce_) {
|
||||
return false;
|
||||
}
|
||||
if (instr->IsElementwise()) {
|
||||
return IsElementwiseOpSupportedByYnn(instr);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns true if we should start a new fusion containing just the given HLO
|
||||
// instruction. We control the instructions that can start a fusion with the
|
||||
// `--xla_cpu_experimental_ynn_fusion_type` flag.
|
||||
bool ShouldCreateFusion(const HloInstruction* instr) override {
|
||||
if (fuse_dot_ && instr->opcode() == HloOpcode::kDot) {
|
||||
return true;
|
||||
}
|
||||
if (fuse_reduce_ && instr->opcode() == HloOpcode::kReduce) {
|
||||
return true;
|
||||
}
|
||||
return fuse_eltwise_ && instr->IsElementwise();
|
||||
}
|
||||
|
||||
PrimitiveType LibraryOpOutputType(const HloInstruction* instr) override {
|
||||
auto out_type = instr->shape().element_type();
|
||||
if (instr->opcode() != HloOpcode::kDot) {
|
||||
return out_type;
|
||||
}
|
||||
return out_type == BF16 ? F32 : out_type;
|
||||
}
|
||||
|
||||
// Returns a prefix string for the fusion op's name.
|
||||
std::string fusion_prefix() const override { return "ynn_"; }
|
||||
|
||||
// Returns a string for FusionBackendConfig's fusion kind.
|
||||
absl::string_view fusion_kind() const override { return kYnnFusionKind; }
|
||||
|
||||
private:
|
||||
absl::flat_hash_set<DebugOptions::LibraryFusionType> fusion_types_;
|
||||
};
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
#endif // XLA_BACKENDS_CPU_TRANSFORMS_YNN_MATCHER_H_
|
||||
48
third_party/xla/xla/backends/cpu/ynn_emitter.cc
vendored
48
third_party/xla/xla/backends/cpu/ynn_emitter.cc
vendored
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||
#include "xla/backends/cpu/runtime/dot_lib.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
#include "xla/backends/cpu/ynn_support.h"
|
||||
#include "xla/hlo/ir/hlo_casting_utils.h"
|
||||
#include "xla/hlo/ir/hlo_computation.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
|
|
@ -206,6 +207,48 @@ static absl::StatusOr<uint32_t> DefineBinaryOp(ynn_subgraph_t subgraph,
|
|||
return out;
|
||||
}
|
||||
|
||||
static absl::StatusOr<uint32_t> DefineReduceOp(ynn_subgraph_t subgraph,
|
||||
TensorIdMap& tensor_ids,
|
||||
const HloInstruction* instr) {
|
||||
VLOG(3) << absl::StreamFormat("Define tensor value for reduce op: %s",
|
||||
instr->ToString());
|
||||
CHECK_EQ(instr->opcode(), HloOpcode::kReduce);
|
||||
const HloReduceInstruction* reduce_instr = Cast<HloReduceInstruction>(instr);
|
||||
const HloInstruction* input = instr->operand(0);
|
||||
const HloInstruction* init = instr->operand(1);
|
||||
CHECK_EQ(input->shape().element_type(), instr->shape().element_type());
|
||||
CHECK_EQ(init->shape().element_type(), instr->shape().element_type());
|
||||
|
||||
ynn_reduce_operator ynn_reduce_op = ynn_reduce_invalid;
|
||||
CHECK_EQ(reduce_instr->to_apply()->num_parameters(), 2);
|
||||
CHECK_EQ(reduce_instr->to_apply()->instruction_count(), 3);
|
||||
|
||||
switch (reduce_instr->to_apply()->root_instruction()->opcode()) {
|
||||
case HloOpcode::kAdd:
|
||||
ynn_reduce_op = ynn_reduce_sum;
|
||||
break;
|
||||
case HloOpcode::kMaximum:
|
||||
ynn_reduce_op = ynn_reduce_max;
|
||||
break;
|
||||
case HloOpcode::kMinimum:
|
||||
ynn_reduce_op = ynn_reduce_min;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported reduction: " << instr->to_apply()->ToString();
|
||||
}
|
||||
|
||||
const absl::Span<const int64_t> reduce_dims = reduce_instr->dimensions();
|
||||
const std::vector<int32_t> dims(reduce_dims.begin(), reduce_dims.end());
|
||||
TF_ASSIGN_OR_RETURN(auto in, FindTensorValue(tensor_ids, input));
|
||||
TF_ASSIGN_OR_RETURN(auto init_id, FindTensorValue(tensor_ids, init));
|
||||
TF_ASSIGN_OR_RETURN(auto out, DefineTensorValue(subgraph, instr));
|
||||
|
||||
YNN_RETURN_IF_ERROR(
|
||||
ynn_define_reduce(subgraph, ynn_reduce_op, /*num_axes=*/dims.size(),
|
||||
/*axes=*/dims.data(), in, init_id, &out, /*flags=*/0));
|
||||
return out;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Emit YNNPACK subgraph for the given HLO computation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
@ -279,6 +322,11 @@ static absl::StatusOr<YnnSubgraph> EmitYnnSubgraph(
|
|||
DefineBitcastOp(subgraph.get(), tensor_ids, instr));
|
||||
} break;
|
||||
|
||||
case HloOpcode::kReduce: {
|
||||
TF_ASSIGN_OR_RETURN(tensor_ids[instr],
|
||||
DefineReduceOp(subgraph.get(), tensor_ids, instr));
|
||||
} break;
|
||||
|
||||
default: {
|
||||
return InvalidArgument("Unsupported fusion instruction: %s",
|
||||
instr->ToString());
|
||||
|
|
|
|||
35
third_party/xla/xla/backends/cpu/ynn_support.cc
vendored
35
third_party/xla/xla/backends/cpu/ynn_support.cc
vendored
|
|
@ -27,9 +27,12 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "xla/backends/cpu/runtime/dot_lib.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
#include "xla/hlo/ir/hlo_casting_utils.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/layout_util.h"
|
||||
#include "xla/service/pattern_matcher.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/util.h"
|
||||
|
|
@ -203,4 +206,36 @@ absl::StatusOr<bool> IsDotSupportedByYnn(
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IsReduceOpSupportedByYnn(const HloInstruction* hlo) {
|
||||
CHECK_EQ(hlo->opcode(), HloOpcode::kReduce);
|
||||
if (!YnnType(hlo->shape().element_type()).ok()) {
|
||||
return false;
|
||||
}
|
||||
const HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
|
||||
CHECK_NE(reduce, nullptr);
|
||||
// TODO(ashaposhnikov): we can support this edge case,
|
||||
// planning to come back to this later.
|
||||
if (reduce->dimensions().empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
HloInstruction* init = reduce->init_values().front();
|
||||
const PrimitiveType type = init->shape().element_type();
|
||||
// TODO(ashaposhnikov): The list of supported types can be extended.
|
||||
if (type != F32) {
|
||||
return false;
|
||||
}
|
||||
if (type != hlo->shape().element_type()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const HloComputation* to_apply = reduce->to_apply();
|
||||
CHECK_NE(to_apply, nullptr);
|
||||
return Match(to_apply->root_instruction(),
|
||||
match::AnyOf<HloInstruction>(match::Add(), match::Maximum(),
|
||||
match::Minimum())
|
||||
.WithBinaryOperandsAnyOrder(match::Parameter(0),
|
||||
match::Parameter(1)));
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
|
|
@ -62,6 +62,9 @@ absl::StatusOr<bool> IsDotSupportedByYnn(
|
|||
const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape,
|
||||
const Shape& rhs_shape, const Shape& out_shape);
|
||||
|
||||
// Returns true if the reduce op is supported by YNNPACK.
|
||||
bool IsReduceOpSupportedByYnn(const HloInstruction* hlo);
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
#endif // XLA_BACKENDS_CPU_YNN_SUPPORT_H_
|
||||
|
|
|
|||
10
third_party/xla/xla/service/cpu/cpu_compiler.cc
vendored
10
third_party/xla/xla/service/cpu/cpu_compiler.cc
vendored
|
|
@ -969,14 +969,20 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn(
|
|||
// XNNPACK ops availability checks depend on the layout information,
|
||||
// so until another solution is developed the passes creating XNNPACK fusions
|
||||
// have to run after layout assignment.
|
||||
const bool use_ynnpack = absl::c_linear_search(
|
||||
debug_options.xla_cpu_experimental_ynn_fusion_type(),
|
||||
DebugOptions::LIBRARY_FUSION_TYPE_REDUCE);
|
||||
LibraryRewriterOptions options = {
|
||||
/*use_onednn=*/debug_options.xla_cpu_use_onednn(),
|
||||
/*use_xnnpack=*/debug_options.xla_cpu_use_xnnpack(),
|
||||
/*use_ynnpack=*/use_ynnpack,
|
||||
/*onednn_fusion_types=*/
|
||||
&debug_options.xla_cpu_experimental_onednn_fusion_type(),
|
||||
/*xnn_fusion_types=*/
|
||||
&debug_options.xla_cpu_experimental_xnn_fusion_type()};
|
||||
if (options.use_onednn || options.use_xnnpack) {
|
||||
&debug_options.xla_cpu_experimental_xnn_fusion_type(),
|
||||
/*ynn_fusion_types=*/
|
||||
&debug_options.xla_cpu_experimental_ynn_fusion_type()};
|
||||
if (options.use_onednn || options.use_xnnpack || options.use_ynnpack) {
|
||||
HloPassPipeline lib_pipeline("dot-library-passes");
|
||||
lib_pipeline.AddPass<DotDecomposer>();
|
||||
lib_pipeline.AddPass<LibraryRewriter>(target_machine_features, options);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user