Add support for int8 dots, and allow bf16 to be used on any CPU.

PiperOrigin-RevId: 824272399
This commit is contained in:
A. Unique TensorFlower 2025-10-26 16:06:27 -07:00 committed by TensorFlower Gardener
parent 5edcd28152
commit 32c1551f24
8 changed files with 132 additions and 29 deletions

View File

@ -235,14 +235,11 @@ cc_library(
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/cpu/codegen:target_machine_features",
"//xla/backends/cpu/runtime:dot_lib",
"//xla/backends/cpu/runtime/ynnpack:ynn_interop",
"//xla/hlo/ir:hlo",
"//xla/service:pattern_matcher",
"//xla/tsl/platform:statusor",
"@XNNPACK//ynnpack",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
@ -250,7 +247,6 @@ cc_library(
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
],
)

View File

@ -46,12 +46,22 @@ absl::StatusOr<YnnThreadpool> CreateYnnThreadpool(
absl::StatusOr<ynn_type> YnnType(const PrimitiveType& type) {
switch (type) {
case S4:
return ynn_type_int4;
case U4:
return ynn_type_uint4;
case S8:
return ynn_type_int8;
case U8:
return ynn_type_uint8;
case BF16:
return ynn_type_bf16;
case F16:
return ynn_type_fp16;
case F32:
return ynn_type_fp32;
case S32:
return ynn_type_int32;
default:
return InvalidArgument("Unsupported YNNPACK type: %s",
primitive_util::LowercasePrimitiveTypeName(type));

View File

@ -349,31 +349,26 @@ static absl::StatusOr<YnnSubgraph> EmitYnnDotSubgraph(
std::vector<size_t> rhs_dims = dims(rhs_shape.dimensions());
std::vector<size_t> out_dims = dims(out_shape.dimensions());
PrimitiveType dtype = lhs->shape().element_type();
if (dtype != F32 && dtype != BF16) {
return InvalidArgument("Unsupported input data type for YnnDotThunk: %s",
primitive_util::LowercasePrimitiveTypeName(dtype));
}
ynn_type input_type = (dtype == F32) ? ynn_type_fp32 : ynn_type_bf16;
ynn_type output_type = ynn_type_fp32;
TF_ASSIGN_OR_RETURN(ynn_type ynn_lhs_type, YnnType(lhs_shape.element_type()));
TF_ASSIGN_OR_RETURN(ynn_type ynn_rhs_type, YnnType(rhs_shape.element_type()));
TF_ASSIGN_OR_RETURN(ynn_type ynn_out_type, YnnType(out_shape.element_type()));
const uint32_t input_tensor_flags = YNN_VALUE_FLAG_EXTERNAL_INPUT;
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
subgraph.get(), input_type, lhs_dims.size(), lhs_dims.data(),
subgraph.get(), ynn_lhs_type, lhs_dims.size(), lhs_dims.data(),
/*data=*/nullptr,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID, input_tensor_flags, &lhs_id));
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
subgraph.get(), input_type, rhs_dims.size(), rhs_dims.data(),
subgraph.get(), ynn_rhs_type, rhs_dims.size(), rhs_dims.data(),
capture_rhs ? arguments_buffers[1].opaque() : nullptr,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID, input_tensor_flags, &rhs_id));
const uint32_t output_tensor_flags = YNN_VALUE_FLAG_EXTERNAL_OUTPUT;
YNN_RETURN_IF_ERROR(ynn_define_tensor_value(
subgraph.get(), output_type, out_dims.size(), out_dims.data(),
subgraph.get(), ynn_out_type, out_dims.size(), out_dims.data(),
/*data=*/nullptr,
/*zero_point_id=*/YNN_INVALID_VALUE_ID,
/*scale_id=*/YNN_INVALID_VALUE_ID, output_tensor_flags, &out_id));

View File

@ -16,19 +16,24 @@ limitations under the License.
#include "xla/backends/cpu/ynn_support.h"
#include <algorithm>
#include <tuple>
#include "ynnpack/include/ynnpack.h"
#include "absl/base/no_destructor.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/dot_lib.h"
#include "xla/backends/cpu/runtime/ynnpack/ynn_interop.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout_util.h"
#include "xla/shape.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
namespace xla::cpu {
@ -138,4 +143,64 @@ bool IsElementwiseOpSupportedByYnn(const HloInstruction* hlo) {
}
}
absl::StatusOr<bool> IsDotSupportedByYnn(
const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape,
const Shape& rhs_shape, const Shape& out_shape) {
// Stores tuple of allowed (input, output) dtypes.
static const absl::NoDestructor<absl::flat_hash_set<
std::tuple<PrimitiveType, PrimitiveType, PrimitiveType>>>
kAllowedTypes({
{F32, F32, F32},
// TODO(b/449998002): We don't have fast fp16 kernels yet.
// {F16, F16, F32},
{BF16, BF16, F32},
{S8, S8, S32},
{U8, S8, S32},
// TODO(b/441600372): We don't have fast int4 kernels yet. Even the
// reference kernel might be pretty good though?
// {S8, S4, S32},
});
// Types must be in the allowed set.
PrimitiveType lhs_dtype = lhs_shape.element_type();
PrimitiveType rhs_dtype = rhs_shape.element_type();
PrimitiveType out_dtype = out_shape.element_type();
if (!kAllowedTypes->contains({lhs_dtype, rhs_dtype, out_dtype})) {
return false;
}
if (!IsLayoutSupportedByYnn(lhs_shape) ||
!IsLayoutSupportedByYnn(rhs_shape) ||
!IsLayoutSupportedByYnn(out_shape)) {
return false;
}
// Check shapes.
TF_ASSIGN_OR_RETURN(DotShape dot_shape, GetDotShape(dot_dimensions, lhs_shape,
rhs_shape, out_shape));
TF_ASSIGN_OR_RETURN(DotCanonicalDims dot_canonical_dims,
GetDotCanonicalDims(dot_dimensions, dot_shape));
if (dot_canonical_dims.m == 1 && dot_canonical_dims.n == 1 &&
dot_shape.batch_size > 1) {
// TODO(b/430079105): YNNPACK does not handle batch dimensions that are not
// matrix dimensions. We could handle this case by fully implementing dot
// (b/430079105), but we also could just insert dummy dimensions of size 1
// for the matrix dimensions, so the batch dimensions get handled correctly.
return false;
}
// YNNPACK supports transposing the inputs efficiently if possible (they will
// fuse with dot packing), but we don't currently support generating the
// necessary transposes.
if (!dot_canonical_dims.lhs_canonical ||
dot_canonical_dims.lhs_column_major ||
dot_canonical_dims.rhs_column_major) {
return false;
}
return true;
}
} // namespace xla::cpu

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/shape.h"
namespace xla::cpu {
@ -55,6 +56,12 @@ bool IsConstantSupportedByYnn(const HloInstruction* hlo);
// Returns true if the nonconstant elementwise op is supported by YNNPACK.
bool IsElementwiseOpSupportedByYnn(const HloInstruction* hlo);
// Returns true if the dot operation is supported by YNNPACK. Returns an error
// if the dot operation shape is invalid.
absl::StatusOr<bool> IsDotSupportedByYnn(
const DotDimensionNumbers& dot_dimensions, const Shape& lhs_shape,
const Shape& rhs_shape, const Shape& out_shape);
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_YNN_SUPPORT_H_

View File

@ -406,6 +406,8 @@ cc_library(
":onednn_contraction_rewriter",
":onednn_float_support",
":onednn_ops_rewriter",
]) + if_ynnpack([
"//xla/backends/cpu:ynn_support",
]),
)

View File

@ -257,6 +257,10 @@ limitations under the License.
#include "xla/service/cpu/onednn_ops_rewriter.h"
#endif // XLA_ONEDNN
#ifdef XLA_YNNPACK
#include "xla/backends/cpu/ynn_support.h"
#endif // XLA_YNNPACK
namespace xla {
namespace {
@ -628,15 +632,42 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
if (!call_library_for_dot(*instr)) {
return true;
}
bool use_cost_model = module->config()
.debug_options()
.xla_cpu_experimental_xnn_graph_fusion_mode() !=
DebugOptions::XNN_GRAPH_FUSION_MODE_BYPASS_COST_MODEL;
return !IsDotSupportedByXnn(instr->dot_dimension_numbers(),
instr->operand(0)->shape(),
instr->operand(1)->shape(), instr->shape(),
target_machine_features, use_cost_model)
.value_or(false);
#ifdef XLA_YNNPACK
if (absl::c_linear_search(
module->config()
.debug_options()
.xla_cpu_experimental_ynn_fusion_type(),
DebugOptions::LIBRARY_FUSION_TYPE_INDIVIDUAL_DOT)) {
if (IsDotSupportedByYnn(instr->dot_dimension_numbers(),
instr->operand(0)->shape(),
instr->operand(1)->shape(), instr->shape())
.value_or(false)) {
return false;
}
}
#endif // XLA_YNNPACK
auto xnn_graph_fusion_mode =
module->config()
.debug_options()
.xla_cpu_experimental_xnn_graph_fusion_mode();
if (xnn_graph_fusion_mode != DebugOptions::XNN_GRAPH_FUSION_MODE_DISABLED) {
bool use_cost_model =
module->config()
.debug_options()
.xla_cpu_experimental_xnn_graph_fusion_mode() !=
DebugOptions::XNN_GRAPH_FUSION_MODE_BYPASS_COST_MODEL;
if (IsDotSupportedByXnn(instr->dot_dimension_numbers(),
instr->operand(0)->shape(),
instr->operand(1)->shape(), instr->shape(),
target_machine_features, use_cost_model)
.value_or(false)) {
return false;
}
}
return true;
};
// xla::cpu::GetDotImplementationStrategy (used by call_library_for_dot)

View File

@ -1091,13 +1091,10 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitDotThunk(
.xla_cpu_experimental_ynn_fusion_type(),
DebugOptions::LIBRARY_FUSION_TYPE_INDIVIDUAL_DOT);
if (use_ynn) {
// TODO(ashaposhnikov): Replace IsDotSupportedByXnn with
// IsDotSupportedByYnn.
TF_ASSIGN_OR_RETURN(
auto is_dot_supported,
IsDotSupportedByXnn(dnums, lhs->shape(), rhs->shape(),
instruction->shape(), &target_machine_features_,
/*use_cost_model=*/false));
IsDotSupportedByYnn(dnums, lhs->shape(), rhs->shape(),
instruction->shape()));
if (is_dot_supported) {
return EmitYnnFusionThunk(instruction);
}