mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add support for int8 dots, and allow bf16 to be used on any CPU.
PiperOrigin-RevId: 824272399
This commit is contained in:
parent
5edcd28152
commit
32c1551f24
4
third_party/xla/xla/backends/cpu/BUILD
vendored
4
third_party/xla/xla/backends/cpu/BUILD
vendored
|
|
@ -235,14 +235,11 @@ cc_library(
|
|||
"//xla:shape_util",
|
||||
"//xla:util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/backends/cpu/codegen:target_machine_features",
|
||||
"//xla/backends/cpu/runtime:dot_lib",
|
||||
"//xla/backends/cpu/runtime/ynnpack:ynn_interop",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/service:pattern_matcher",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@XNNPACK//ynnpack",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:no_destructor",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
|
|
@ -250,7 +247,6 @@ cc_library(
|
|||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -46,12 +46,22 @@ absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
|
|||
|
||||
absl::StatusOr<ynn_type> YnnType(const PrimitiveType& type) {
|
||||
switch (type) {
|
||||
case S4:
|
||||
return ynn_type_int4;
|
||||
case U4:
|
||||
return ynn_type_uint4;
|
||||
case S8:
|
||||
return ynn_type_int8;
|
||||
case U8:
|
||||
return ynn_type_uint8;
|
||||
case BF16:
|
||||
return ynn_type_bf16;
|
||||
case F16:
|
||||
return ynn_type_fp16;
|
||||
case F32:
|
||||
return ynn_type_fp32;
|
||||
case S32:
|
||||
return ynn_type_int32;
|
||||
default:
|
||||
return InvalidArgument("Unsupported YNNPACK type: %s",
|
||||
primitive_util::LowercasePrimitiveTypeName(type));
|
||||
|
|
|
|||
17
third_party/xla/xla/backends/cpu/ynn_emitter.cc
vendored
17
third_party/xla/xla/backends/cpu/ynn_emitter.cc
vendored
|
|
@ -349,31 +349,26 @@ static absl::StatusOr<YnnSubgraph> EmitYnnDotSubgraph(
|
|||
std::vector<size_t> rhs_dims = dims(rhs_shape.dimensions());
|
||||
std::vector<size_t> out_dims = dims(out_shape.dimensions());
|
||||
|
||||
PrimitiveType dtype = lhs->shape().element_type();
|
||||
if (dtype != F32 && dtype != BF16) {
|
||||
return InvalidArgument("Unsupported input data type for YnnDotThunk: %s",
|
||||
primitive_util::LowercasePrimitiveTypeName(dtype));
|
||||
}
|
||||
|
||||
ynn_type input_type = (dtype == F32) ? ynn_type_fp32 : ynn_type_bf16;
|
||||
ynn_type output_type = ynn_type_fp32;
|
||||
TF_ASSIGN_OR_RETURN(ynn_type ynn_lhs_type, YnnType(lhs_shape.element_type()));
|
||||
TF_ASSIGN_OR_RETURN(ynn_type ynn_rhs_type, YnnType(rhs_shape.element_type()));
|
||||
TF_ASSIGN_OR_RETURN(ynn_type ynn_out_type, YnnType(out_shape.element_type()));
|
||||
|
||||
const uint32_t input_tensor_flags = YNN_VALUE_FLAG_EXTERNAL_INPUT;
|
||||
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
|
||||
subgraph.get(), input_type, lhs_dims.size(), lhs_dims.data(),
|
||||
subgraph.get(), ynn_lhs_type, lhs_dims.size(), lhs_dims.data(),
|
||||
/*data=*/nullptr,
|
||||
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
|
||||
/*scale_id=*/YNN_INVALID_VALUE_ID, input_tensor_flags, &lhs_id));
|
||||
|
||||
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
|
||||
subgraph.get(), input_type, rhs_dims.size(), rhs_dims.data(),
|
||||
subgraph.get(), ynn_rhs_type, rhs_dims.size(), rhs_dims.data(),
|
||||
capture_rhs ? arguments_buffers[1].opaque() : nullptr,
|
||||
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
|
||||
/*scale_id=*/YNN_INVALID_VALUE_ID, input_tensor_flags, &rhs_id));
|
||||
|
||||
const uint32_t output_tensor_flags = YNN_VALUE_FLAG_EXTERNAL_OUTPUT;
|
||||
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
|
||||
subgraph.get(), output_type, out_dims.size(), out_dims.data(),
|
||||
subgraph.get(), ynn_out_type, out_dims.size(), out_dims.data(),
|
||||
/*data=*/nullptr,
|
||||
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
|
||||
/*scale_id=*/YNN_INVALID_VALUE_ID, output_tensor_flags, &out_id));
|
||||
|
|
|
|||
65
third_party/xla/xla/backends/cpu/ynn_support.cc
vendored
65
third_party/xla/xla/backends/cpu/ynn_support.cc
vendored
|
|
@ -16,19 +16,24 @@ limitations under the License.
|
|||
#include "xla/backends/cpu/ynn_support.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
|
||||
#include "ynnpack/include/ynnpack.h"
|
||||
#include "absl/base/no_destructor.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/backends/cpu/runtime/dot_lib.h"
|
||||
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/layout_util.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/util.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
|
|
@ -138,4 +143,64 @@ bool IsElementwiseOpSupportedByYnn(const HloInstruction* hlo) {
|
|||
}
|
||||
}
|
||||
|
||||
absl::StatusOr<bool> IsDotSupportedByYnn(
|
||||
const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape,
|
||||
const Shape& rhs_shape, const Shape& out_shape) {
|
||||
// Stores tuple of allowed (input, output) dtypes.
|
||||
static const absl::NoDestructor<absl::flat_hash_set<
|
||||
std::tuple<PrimitiveType, PrimitiveType, PrimitiveType>>>
|
||||
kAllowedTypes({
|
||||
{F32, F32, F32},
|
||||
// TODO(b/449998002): We don't have fast fp16 kernels yet.
|
||||
// {F16, F16, F32},
|
||||
{BF16, BF16, F32},
|
||||
{S8, S8, S32},
|
||||
{U8, S8, S32},
|
||||
// TODO(b/441600372): We don't have fast int4 kernels yet. Even the
|
||||
// reference kernel might be pretty good though?
|
||||
// {S8, S4, S32},
|
||||
});
|
||||
|
||||
// Types must be in the allowed set.
|
||||
PrimitiveType lhs_dtype = lhs_shape.element_type();
|
||||
PrimitiveType rhs_dtype = rhs_shape.element_type();
|
||||
PrimitiveType out_dtype = out_shape.element_type();
|
||||
if (!kAllowedTypes->contains({lhs_dtype, rhs_dtype, out_dtype})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!IsLayoutSupportedByYnn(lhs_shape) ||
|
||||
!IsLayoutSupportedByYnn(rhs_shape) ||
|
||||
!IsLayoutSupportedByYnn(out_shape)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check shapes.
|
||||
TF_ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape,
|
||||
rhs_shape, out_shape));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims,
|
||||
GetDotCanonicalDims(dot_dimensions, dot_shape));
|
||||
|
||||
if (dot_canonical_dims.m == 1 && dot_canonical_dims.n == 1 &&
|
||||
dot_shape.batch_size > 1) {
|
||||
// TODO(b/430079105): YNNPACK does not handle batch dimensions that are not
|
||||
// matrix dimensions. We could handle this case by fully implementing dot
|
||||
// (b/430079105), but we also could just insert dummy dimensions of size 1
|
||||
// for the matrix dimensions, so the batch dimensions get handled correctly.
|
||||
return false;
|
||||
}
|
||||
|
||||
// YNNPACK supports transposing the inputs efficiently if possible (they will
|
||||
// fuse with dot packing), but we don't currently support generating the
|
||||
// necessary transposes.
|
||||
if (!dot_canonical_dims.lhs_canonical ||
|
||||
dot_canonical_dims.lhs_column_major ||
|
||||
dot_canonical_dims.rhs_column_major) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include "absl/strings/string_view.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/shape.h"
|
||||
|
||||
namespace xla::cpu {
|
||||
|
||||
|
|
@ -55,6 +56,12 @@ bool IsConstantSupportedByYnn(const HloInstruction* hlo);
|
|||
// Returns true if the nonconstant elementwise op is supported by YNNPACK.
|
||||
bool IsElementwiseOpSupportedByYnn(const HloInstruction* hlo);
|
||||
|
||||
// Returns true if the dot operation is supported by YNNPACK. Returns an error
|
||||
// if the dot operation shape is invalid.
|
||||
absl::StatusOr<bool> IsDotSupportedByYnn(
|
||||
const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape,
|
||||
const Shape& rhs_shape, const Shape& out_shape);
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
||||
#endif // XLA_BACKENDS_CPU_YNN_SUPPORT_H_
|
||||
|
|
|
|||
2
third_party/xla/xla/service/cpu/BUILD
vendored
2
third_party/xla/xla/service/cpu/BUILD
vendored
|
|
@ -406,6 +406,8 @@ cc_library(
|
|||
":onednn_contraction_rewriter",
|
||||
":onednn_float_support",
|
||||
":onednn_ops_rewriter",
|
||||
]) + if_ynnpack([
|
||||
"//xla/backends/cpu:ynn_support",
|
||||
]),
|
||||
)
|
||||
|
||||
|
|
|
|||
49
third_party/xla/xla/service/cpu/cpu_compiler.cc
vendored
49
third_party/xla/xla/service/cpu/cpu_compiler.cc
vendored
|
|
@ -257,6 +257,10 @@ limitations under the License.
|
|||
#include "xla/service/cpu/onednn_ops_rewriter.h"
|
||||
#endif // XLA_ONEDNN
|
||||
|
||||
#ifdef XLA_YNNPACK
|
||||
#include "xla/backends/cpu/ynn_support.h"
|
||||
#endif // XLA_YNNPACK
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
|
|
@ -628,15 +632,42 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
|||
if (!call_library_for_dot(*instr)) {
|
||||
return true;
|
||||
}
|
||||
bool use_cost_model = module->config()
|
||||
.debug_options()
|
||||
.xla_cpu_experimental_xnn_graph_fusion_mode() !=
|
||||
DebugOptions::XNN_GRAPH_FUSION_MODE_BYPASS_COST_MODEL;
|
||||
return !IsDotSupportedByXnn(instr->dot_dimension_numbers(),
|
||||
instr->operand(0)->shape(),
|
||||
instr->operand(1)->shape(), instr->shape(),
|
||||
target_machine_features, use_cost_model)
|
||||
.value_or(false);
|
||||
|
||||
#ifdef XLA_YNNPACK
|
||||
if (absl::c_linear_search(
|
||||
module->config()
|
||||
.debug_options()
|
||||
.xla_cpu_experimental_ynn_fusion_type(),
|
||||
DebugOptions::LIBRARY_FUSION_TYPE_INDIVIDUAL_DOT)) {
|
||||
if (IsDotSupportedByYnn(instr->dot_dimension_numbers(),
|
||||
instr->operand(0)->shape(),
|
||||
instr->operand(1)->shape(), instr->shape())
|
||||
.value_or(false)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif // XLA_YNNPACK
|
||||
|
||||
auto xnn_graph_fusion_mode =
|
||||
module->config()
|
||||
.debug_options()
|
||||
.xla_cpu_experimental_xnn_graph_fusion_mode();
|
||||
if (xnn_graph_fusion_mode != DebugOptions::XNN_GRAPH_FUSION_MODE_DISABLED) {
|
||||
bool use_cost_model =
|
||||
module->config()
|
||||
.debug_options()
|
||||
.xla_cpu_experimental_xnn_graph_fusion_mode() !=
|
||||
DebugOptions::XNN_GRAPH_FUSION_MODE_BYPASS_COST_MODEL;
|
||||
if (IsDotSupportedByXnn(instr->dot_dimension_numbers(),
|
||||
instr->operand(0)->shape(),
|
||||
instr->operand(1)->shape(), instr->shape(),
|
||||
target_machine_features, use_cost_model)
|
||||
.value_or(false)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
// xla::cpu::GetDotImplementationStrategy (used by call_library_for_dot)
|
||||
|
|
|
|||
|
|
@ -1091,13 +1091,10 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitDotThunk(
|
|||
.xla_cpu_experimental_ynn_fusion_type(),
|
||||
DebugOptions::LIBRARY_FUSION_TYPE_INDIVIDUAL_DOT);
|
||||
if (use_ynn) {
|
||||
// TODO(ashaposhnikov): Replace IsDotSupportedByXnn with
|
||||
// IsDotSupportedByYnn.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto is_dot_supported,
|
||||
IsDotSupportedByXnn(dnums, lhs->shape(), rhs->shape(),
|
||||
instruction->shape(), &target_machine_features_,
|
||||
/*use_cost_model=*/false));
|
||||
IsDotSupportedByYnn(dnums, lhs->shape(), rhs->shape(),
|
||||
instruction->shape()));
|
||||
if (is_dot_supported) {
|
||||
return EmitYnnFusionThunk(instruction);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user