Bump up Cast op version op to 7 with bfloat16 runtime kernel support

This CL can resolve the latest bfloat16 TFLite flatbuffers' interpreter executation & quantization. It's required because quantization checks the TFLite float FB validation with interpreter.

PiperOrigin-RevId: 695529144
This commit is contained in:
Jae H. Yoo 2024-11-11 16:58:42 -08:00 committed by TensorFlower Gardener
parent 0634e71305
commit ca0bb02924
16 changed files with 208 additions and 7 deletions

View File

@ -25,7 +25,8 @@
### Major Features and Improvements
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
* `tf.lite`
* `tfl.Cast` op is now supporting `bfloat16` in runtime kernel.
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
### Bug Fixes and Other Changes

View File

@ -363,6 +363,7 @@ typedef union TfLitePtrUnion {
uint64_t* u64;
float* f;
TfLiteFloat16* f16;
TfLiteBFloat16* bf16;
double* f64;
char* raw;
const char* raw_const;

View File

@ -175,7 +175,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version = */ 2);
AddBuiltin(BuiltinOperator_CAST, Register_CAST(),
/* min_version = */ 1,
/* max_version = */ 6);
/* max_version = */ 7);
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
/* min_version = */ 1,
/* max_version = */ 6);

View File

@ -230,6 +230,7 @@ cc_library(
"@com_google_absl//absl/base",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@eigen_archive//:eigen3",
"@flatbuffers",
"@local_tsl//tsl/platform:logging",
],
@ -1551,7 +1552,7 @@ cc_test(
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest",
"@flatbuffers",
"@eigen_archive//:eigen3",
],
)

View File

@ -106,6 +106,14 @@ void copyCast(const Eigen::half* in, std::complex<float>* out,
});
}
template <>
void copyCast(const Eigen::bfloat16* in, std::complex<float>* out,
int num_elements) {
std::transform(in, in + num_elements, out, [](Eigen::bfloat16 a) {
return std::complex<float>(Eigen::bfloat16_impl::bfloat16_to_float(a));
});
}
template <typename FromT>
void copyCastToFloat16(const FromT* in, Eigen::half* out, int num_elements) {
std::transform(in, in + num_elements, out, [](FromT a) {
@ -127,6 +135,50 @@ void copyCastToFloat16(const Eigen::half* in, Eigen::half* out,
std::transform(in, in + num_elements, out, [](Eigen::half a) { return a; });
}
template <>
void copyCastToFloat16(const Eigen::bfloat16* in, Eigen::half* out,
int num_elements) {
// bfloat16 -> float -> half (fp16)
std::transform(in, in + num_elements, out, [](Eigen::bfloat16 a) {
return Eigen::half_impl::float_to_half_rtne(
Eigen::bfloat16_impl::bfloat16_to_float(a));
});
}
template <typename FromT>
void copyCastToBFloat16(const FromT* in, Eigen::bfloat16* out,
int num_elements) {
std::transform(in, in + num_elements, out, [](FromT a) {
return Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(
static_cast<float>(a));
});
}
template <>
void copyCastToBFloat16(const std::complex<float>* in, Eigen::bfloat16* out,
int num_elements) {
std::transform(in, in + num_elements, out, [](std::complex<float> a) {
return Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(std::real(a));
});
}
template <>
void copyCastToBFloat16(const Eigen::bfloat16* in, Eigen::bfloat16* out,
int num_elements) {
std::transform(in, in + num_elements, out,
[](Eigen::bfloat16 a) { return a; });
}
template <>
void copyCastToBFloat16(const Eigen::half* in, Eigen::bfloat16* out,
int num_elements) {
// half (fp16) -> float -> bfloat16
std::transform(in, in + num_elements, out, [](Eigen::half a) {
return Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(
Eigen::half_impl::half_to_float(a));
});
}
TfLiteStatus castInt4ToFloat(TfLiteContext* context, const TfLiteTensor* in,
TfLiteTensor* out, int num_elements) {
const int8_t* in_data = (const int8_t*)in->data.data;
@ -213,6 +265,10 @@ TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
copyCastToFloat16(in, reinterpret_cast<Eigen::half*>(out->data.f16),
num_elements);
break;
case kTfLiteBFloat16:
copyCastToBFloat16(in, reinterpret_cast<Eigen::bfloat16*>(out->data.bf16),
num_elements);
break;
case kTfLiteFloat32:
copyCast(in, GetTensorData<float>(out), num_elements);
break;
@ -254,6 +310,10 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
return copyToTensor(context,
reinterpret_cast<Eigen::half*>(input->data.f16),
output, num_elements);
case kTfLiteBFloat16:
return copyToTensor(context,
reinterpret_cast<Eigen::bfloat16*>(input->data.bf16),
output, num_elements);
case kTfLiteFloat32:
return copyToTensor(context, GetTensorData<float>(input), output,
num_elements);

View File

@ -23,7 +23,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/types/span.h"
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "Eigen/Core" // from @eigen_archive
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/c/c_api_types.h"
#include "tensorflow/lite/kernels/cast_test_common.h"
@ -327,6 +327,60 @@ TEST(CastOpModel, CastInt16ToUInt16) {
ElementsAreArray({10, 20, 30, 40, 50, 60}));
}
TEST(CastOpModel, CastFloatToFloat16) {
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_FLOAT16, {3, 2}});
m.PopulateTensor<float>(m.input(), {100.f, 1.0f, 0.f, 0.4f, 1.999f, 1.1f});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(
m.ExtractVector<Eigen::half>(m.output()),
ElementsAreArray(
{static_cast<Eigen::half>(100.f), static_cast<Eigen::half>(1.0f),
static_cast<Eigen::half>(0.f), static_cast<Eigen::half>(0.4f),
static_cast<Eigen::half>(1.999f), static_cast<Eigen::half>(1.1)}));
}
TEST(CastOpModel, CastFloatToBFloat16) {
CastOpModel m({TensorType_FLOAT32, {3, 2}}, {TensorType_BFLOAT16, {3, 2}});
m.PopulateTensor<float>(m.input(), {100.f, 1.0f, 0.f, 0.4f, 1.999f, 1.1f});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.ExtractVector<Eigen::bfloat16>(m.output()),
ElementsAreArray({static_cast<Eigen::bfloat16>(100.f),
static_cast<Eigen::bfloat16>(1.0f),
static_cast<Eigen::bfloat16>(0.f),
static_cast<Eigen::bfloat16>(0.4f),
static_cast<Eigen::bfloat16>(1.999f),
static_cast<Eigen::bfloat16>(1.1f)}));
}
TEST(CastOpModel, CastFloat16ToFloat) {
CastOpModel m({TensorType_FLOAT16, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
m.PopulateTensor<Eigen::half>(
m.input(),
{static_cast<Eigen::half>(100.f), static_cast<Eigen::half>(1.0f),
static_cast<Eigen::half>(0.f), static_cast<Eigen::half>(0.4f),
static_cast<Eigen::half>(1.999f), static_cast<Eigen::half>(1.1f)});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.ExtractVector<float>(m.output()),
ElementsAreArray(ArrayFloatNear(
{100.f, 1.0f, 0.f, 0.399902344f, 1.99902344f, 1.09960938f},
/*max_abs_err=*/0.05f)));
}
TEST(CastOpModel, CastBFloat16ToFloat) {
CastOpModel m({TensorType_BFLOAT16, {3, 2}}, {TensorType_FLOAT32, {3, 2}});
m.PopulateTensor<Eigen::bfloat16>(
m.input(),
{static_cast<Eigen::bfloat16>(100.f), static_cast<Eigen::bfloat16>(1.0f),
static_cast<Eigen::bfloat16>(0.f), static_cast<Eigen::bfloat16>(0.4f),
static_cast<Eigen::bfloat16>(1.999f),
static_cast<Eigen::bfloat16>(1.1)});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.ExtractVector<float>(m.output()),
ElementsAreArray(ArrayFloatNear(
{100.f, 1.0f, 0.f, 0.400390625f, 2.f, 1.1015625f},
/*max_abs_err=*/0.05f)));
}
TEST(CastOpModel, CastConstInputCachingWorks) {
// This tests the implementation of a performance optimization. If that
// optimization is changed, this test will likely break/need to be updated.

View File

@ -41,6 +41,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/algorithm/container.h"
#include "absl/types/span.h"
#include "Eigen/Core" // from @eigen_archive
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/core/c/common.h"
@ -133,6 +134,11 @@ constexpr TfLiteType typeToTfLiteType<Eigen::half>() {
return kTfLiteFloat16;
}
template <>
constexpr TfLiteType typeToTfLiteType<Eigen::bfloat16>() {
return kTfLiteBFloat16;
}
// A test model that contains a single operator. All operator inputs and
// output are external to the model, so the tests can directly access them.
// Typical usage:
@ -1186,6 +1192,8 @@ TFLITE_TENSOR_TYPE_ASSOC(uint32_t, TensorType_UINT32);
TFLITE_TENSOR_TYPE_ASSOC(uint64_t, TensorType_UINT64);
TFLITE_TENSOR_TYPE_ASSOC(TfLiteFloat16, TensorType_FLOAT16);
TFLITE_TENSOR_TYPE_ASSOC(Eigen::half, TensorType_FLOAT16);
TFLITE_TENSOR_TYPE_ASSOC(TfLiteBFloat16, TensorType_BFLOAT16);
TFLITE_TENSOR_TYPE_ASSOC(Eigen::bfloat16, TensorType_BFLOAT16);
TFLITE_TENSOR_TYPE_ASSOC(float, TensorType_FLOAT32);
TFLITE_TENSOR_TYPE_ASSOC(double, TensorType_FLOAT64);
TFLITE_TENSOR_TYPE_ASSOC(std::string, TensorType_STRING);
@ -1281,6 +1289,26 @@ struct TypeUnion<uint8_t> {
typedef uint8_t ScalarType;
};
template <>
struct TypeUnion<Eigen::half> {
public:
// NOLINTNEXTLINE
static constexpr TensorType tensor_type = TensorType::TensorType_FLOAT16;
// NOLINTNEXTLINE
static constexpr TfLiteType tflite_type = TfLiteType::kTfLiteFloat16;
typedef Eigen::half ScalarType;
};
template <>
struct TypeUnion<Eigen::bfloat16> {
public:
// NOLINTNEXTLINE
static constexpr TensorType tensor_type = TensorType::TensorType_BFLOAT16;
// NOLINTNEXTLINE
static constexpr TfLiteType tflite_type = TfLiteType::kTfLiteBFloat16;
typedef Eigen::bfloat16 ScalarType;
};
class MultiOpModel : public SingleOpModel {
public:
MultiOpModel() : SingleOpModel() {}

View File

@ -72,6 +72,7 @@ MATCH_TYPE_AND_TFLITE_TYPE(unsigned char, kTfLiteUInt8);
MATCH_TYPE_AND_TFLITE_TYPE(int8_t, kTfLiteInt8);
MATCH_TYPE_AND_TFLITE_TYPE(bool, kTfLiteBool);
MATCH_TYPE_AND_TFLITE_TYPE(TfLiteFloat16, kTfLiteFloat16);
MATCH_TYPE_AND_TFLITE_TYPE(TfLiteBFloat16, kTfLiteBFloat16);
MATCH_TYPE_AND_TFLITE_TYPE(double, kTfLiteFloat64);
MATCH_TYPE_AND_TFLITE_TYPE(uint64_t, kTfLiteUInt64);

View File

@ -180,6 +180,7 @@ py_strict_library(
"//tensorflow/python/framework:convert_to_constants",
"//tensorflow/python/saved_model:signature_constants",
"//third_party/py/numpy",
"@ml_dtypes",
],
)

View File

@ -118,6 +118,16 @@ def make_cast_tests(options):
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.bfloat16],
"output_dtype": [tf.float32],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
{
"input_dtype": [tf.float32],
"output_dtype": [tf.bfloat16],
"input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
},
]
def build_graph(parameters):

View File

@ -112,6 +112,8 @@ bool DataExpectation::Check(bool verbose, const TfLiteTensor& tensor) {
return TypedCheck<double, double>(verbose, tensor);
case kTfLiteFloat16:
return TypedCheck<Eigen::half, float>(verbose, tensor);
case kTfLiteBFloat16:
return TypedCheck<Eigen::bfloat16, float>(verbose, tensor);
default:
fprintf(stderr, "Unsupported type %d in Check\n", tensor.type);
return false;

View File

@ -209,6 +209,17 @@ inline std::vector<Eigen::half> Split(const string& s,
return fields;
}
template <>
inline std::vector<Eigen::bfloat16> Split(const string& s,
const string& delimiter) {
std::vector<Eigen::bfloat16> fields;
for (const auto& p : SplitToPos(s, delimiter)) {
fields.push_back(Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(
strtof(s.data() + p.first, nullptr)));
}
return fields;
}
} // namespace testing
} // namespace tflite

View File

@ -174,6 +174,8 @@ void TfLiteDriver::AllocateTensors() {
void TfLiteDriver::LoadModel(const std::string& bin_file_path,
const std::string& signature) {
std::cout << " [ Jae ] LoadModel: " << bin_file_path << std::endl;
std::cout << " [ Jae ] LoadModel: " << signature << std::endl;
if (!IsValid()) return;
model_ = FlatBufferModel::BuildFromFile(GetFullPath(bin_file_path).c_str());
@ -413,6 +415,15 @@ void TfLiteDriver::SetInput(const std::string& name,
SetTensorData(values, tensor->data.raw);
break;
}
case kTfLiteBFloat16: {
const auto& values = testing::Split<Eigen::bfloat16>(csv_values, ",");
for (auto k : values) {
TFLITE_LOG(INFO) << "input" << k;
}
if (!CheckSizes<Eigen::bfloat16>(tensor->bytes, values.size())) return;
SetTensorData(values, tensor->data.raw);
break;
}
default:
Invalidate(absl::StrCat("Unsupported tensor type ",
TfLiteTypeGetName(tensor->type),
@ -493,6 +504,9 @@ void TfLiteDriver::SetExpectation(const std::string& name,
case kTfLiteFloat16:
expected_output_[id]->SetData<Eigen::half>(csv_values);
break;
case kTfLiteBFloat16:
expected_output_[id]->SetData<Eigen::bfloat16>(csv_values);
break;
default:
Invalidate(absl::StrCat("Unsupported tensor type ",
TfLiteTypeGetName(tensor->type),

View File

@ -25,6 +25,7 @@ import tempfile
import traceback
import zipfile
import ml_dtypes
import numpy as np
import tensorflow as tf
@ -130,6 +131,12 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
# Not the best strings, but they will do for some basic testing.
letters = list(string.ascii_uppercase)
return np.random.choice(letters, size=shape).astype(dtype)
elif dtype == tf.bfloat16:
value = (max_value - min_value) * np.random.random_sample(shape) + min_value
# There is no bfloat16 type in numpy. Uses ml_dtypes.bfloat16 for Eigen.
dtype = ml_dtypes.bfloat16
else:
raise ValueError("Unsupported dtype: %s" % dtype)
return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype(
dtype)
@ -149,6 +156,12 @@ def create_scalar_data(dtype, min_value=-100, max_value=100):
elif dtype == np.bytes_:
l = np.random.randint(1, 6)
value = "".join(np.random.choice(list(string.ascii_uppercase), size=l))
elif dtype == tf.bfloat16:
value = (max_value - min_value) * np.random.random() + min_value
# There is no bfloat16 type in numpy. Uses ml_dtypes.bfloat16 for Eigen.
dtype = ml_dtypes.bfloat16
else:
raise ValueError("Unsupported dtype: %s" % dtype)
return np.array(value, dtype=dtype)
@ -170,7 +183,12 @@ def format_result(t):
"""Convert a tensor to a format that can be used in test specs."""
if t.dtype.kind not in [np.dtype(np.bytes_).kind, np.dtype(np.object_).kind]:
# Output 9 digits after the point to ensure the precision is good enough.
values = ["{:.9f}".format(value) for value in list(t.flatten())]
# bfloat16 promotes the value to string, not float. so we need to
# convert it to float explicitly.
if t.dtype == ml_dtypes.bfloat16:
values = ["{:.9f}".format(float(value)) for value in list(t.flatten())]
else:
values = ["{:.9f}".format(value) for value in list(t.flatten())]
return ",".join(values)
else:
# SerializeAsHexString returns bytes in PY3, so decode if appropriate.

View File

@ -16,9 +16,7 @@ limitations under the License.
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/c_api_types.h"
#include "tensorflow/lite/schema/schema_generated.h"

View File

@ -110,6 +110,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_CAST, 4}, "2.9.0"},
{{BuiltinOperator_CAST, 5}, "2.12.0"},
{{BuiltinOperator_CAST, 6}, "2.15.0"},
{{BuiltinOperator_CAST, 7}, "2.17.0"},
{{BuiltinOperator_CONCATENATION, 1}, "1.5.0"},
{{BuiltinOperator_CONCATENATION, 2}, "1.14.0"},
{{BuiltinOperator_CONCATENATION, 3}, "2.3.0"},