Add SavedModel to StableHLO Converter to TensorFlow pip package

PiperOrigin-RevId: 635568823
This commit is contained in:
Sandeep Dasgupta 2024-05-20 14:19:45 -07:00 committed by TensorFlower Gardener
parent 98a7f0b73c
commit 651e9b8c8a
11 changed files with 392 additions and 149 deletions

View File

@ -31,6 +31,7 @@
been added to TF binary distributions (Python wheels).
* Replace `DebuggerOptions` of TensorFlow Quantizer, and migrate to
`DebuggerConfig` of StableHLO Quantizer.
* Add TensorFlow to StableHLO converter to TensorFlow pip package.
## Keras

View File

@ -1382,6 +1382,7 @@ tf_cc_shared_library(
"//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config",
"//tensorflow/compiler/mlir/lite/sparsity:sparsify_model",
"//tensorflow/compiler/mlir/quantization/stablehlo/python:pywrap_quantization_lib_impl",
"//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo_lib_impl",
"//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op",
"//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl",
"//tensorflow/compiler/mlir/quantization/tensorflow:passes",

View File

@ -6,12 +6,10 @@ load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allo
package_group(
name = "internal_visibility_allowlist_package",
packages = [
"//learning/brain/mlir/quantization/stablehlo/python/integration_test/...",
"//tensorflow/compiler/mlir/lite/...",
"//tensorflow/compiler/mlir/quantization/...",
"//tensorflow/compiler/mlir/tf2xla/transforms/...",
"//tensorflow/lite/...",
"//third_party/cloud_tpu/inference_converter/...", # TPU Inference Converter V1
] + internal_visibility_allowlist(),
)
@ -80,7 +78,6 @@ tf_cc_binary(
glob_lit_tests(
name = "all_tests",
data = [":test_utilities"],
# TODO: b/288344501 - Enable OSS tests again when stable-quant-opt works well.
default_tags = [
"no_oss",
"no_pip",

View File

@ -1,13 +1,11 @@
## Tensorflow SavedModel to StableHLO (tf-to-stablehlo-translate)
# Tensorflow SavedModel to StableHLO (tf-to-stablehlo-translate)
### Description
This tool converts TensorFlow models (SavedModel or MLIR module) to StableHLO
MLIR modules, preserving model structure and signatures. It enables seamless
Converts TensorFlow models (SavedModel or MLIR module) to StableHLO MLIR
modules, preserving model structure and signatures. It enables seamless
integration of TensorFlow models into MLIR-based compiler frameworks for further
optimization and deployment.
### Usage
## C++ APIs
```bash
tf-to-stablehlo-translate \
@ -58,3 +56,68 @@ tf-to-stablehlo-translate <saved-model-path> --input-arg-shapes=1,12:1,12:1,12
* TensorFlow
* MLIR
* Abseil (absl)
## Python APIs
### `savedmodel_to_stablehlo`
Converts a TensorFlow SavedModel into StableHLO bytecode.
```Python
from tensorflow.compiler.mlir.quantization.tensorflow_to_stablehlo.python import pywrap_tensorflow_to_stablehlo as tf2shlo
stablehlo_bytes = tf2shlo.savedmodel_to_stablehlo(
input_path="/path/to/your/savedmodel",
exported_model_signatures=["serving_default"],
tag_names=["serve"],
input_arg_shapes_str="1,28,28,3::32"
)
```
#### Arguments:
* `input_path` (required): Path to your SavedModel directory.
* `exported_model_signatures` (optional): List of signature names to convert.
Defaults to ["serving_default"].
* `tag_names` (optional): List of tags associated with the SavedModel. Defaults
to ["serve"].
* `input_arg_shapes_str` (optional): A string representation of input argument
shapes for 'main' entry-point, separating
tensors with ':', dimension with ',', and
using '?' for unknown sizes. For example,
`input-arg-shapes=1,2::1,?` expresses
argument shapes `[1,2], [] and [1,?]`.
#### Error Handling
An exception will be raised with details about the error.
### `tensorflow_module_to_stablehlo`
Converts a TensorFlow MLIR module string into StableHLO bytecode.
```Python
from tensorflow.compiler.mlir.quantization.tensorflow_to_stablehlo.python import pywrap_tensorflow_to_stablehlo as tf2shlo
stablehlo_bytes = tf2shlo.tensorflow_module_to_stablehlo(
module_op_str="your_tensorflow_mlir_module_string",
input_arg_shapes_str="1,28,28,3::32"
)
```
#### Arguments:
* `module_op_str` (required): String containing the TensorFlow MLIR module.
* `input_arg_shapes_str` (optional): A string representation of input argument
shapes for 'main' entry-point, separating
tensors with ':', dimension with ',', and
using '?' for unknown sizes. For example,
`input-arg-shapes=1,2::1,?` expresses
argument shapes `[1,2], [] and [1,?]`.
#### Error Handling
Return `py::none()` (equivalent to Python's `None`) if there's an error. An
exception will be raised with details about the error.

View File

@ -1,5 +1,6 @@
load(
"//tensorflow:tensorflow.default.bzl",
"get_compatible_with_portable",
"tf_py_strict_test",
"tf_python_pybind_extension",
)
@ -20,6 +21,7 @@ package(
default_visibility = [
":internal_visibility_allowlist_package",
"//tensorflow:__pkg__",
"//tensorflow/python:__pkg__",
],
licenses = ["notice"],
)
@ -45,20 +47,60 @@ package(
# )
# copybara:uncomment_end
# This is a header-only target. The purpose of `pywrap_tensorflow_to_stablehlo_lib_*` targets is to expose only
# the symbols that are required by `pywrap_tensorflow_to_stablehlo` that translates them to python functions.
# The only intended use case of this library is by `pywrap_tensorflow_to_stablehlo`. Not letting
# `pywrap_tensorflow_to_stablehlo` directly depend on sub-libraries like `static_range_srq` and instead haiving
# a consolidated impl library `pywrap_tensorflow_to_stablehlo_lib_impl` allows the maintainers to avoid
# declaring multiple impl libraries to `libtensorflow_cc` and `lib_pywrap_tensorflow_internal`,
# which is required to avoid ODR violations.
cc_library(
name = "pywrap_tensorflow_to_stablehlo_lib_header_only",
srcs = [],
hdrs = ["pywrap_tensorflow_to_stablehlo_lib.h"],
compatible_with = get_compatible_with_portable(),
visibility = ["//visibility:private"], # ONLY for `pywrap_tensorflow_to_stablehlo`.
deps = [
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
)
# See the comments for `pywrap_tensorflow_to_stablehlo_lib_header_only`.
cc_library(
name = "pywrap_tensorflow_to_stablehlo_lib_impl",
srcs = ["pywrap_tensorflow_to_stablehlo_lib.cc"],
hdrs = ["pywrap_tensorflow_to_stablehlo_lib.h"],
compatible_with = get_compatible_with_portable(),
visibility = [
"//tensorflow:__pkg__", # For libtensorflow_cc.so.
"//tensorflow/python:__pkg__", # For lib_pywrap_tensorflow_internal.so.
],
deps = [
"//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo:tf_to_stablehlo",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:lib",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
],
)
tf_python_pybind_extension(
name = "pywrap_tensorflow_to_stablehlo",
srcs = ["pywrap_tensorflow_to_stablehlo.cc"],
pytype_srcs = ["pywrap_tensorflow_to_stablehlo.pyi"],
# Each dependency MUST be either header-only or exclusive.
deps = [
"//tensorflow/compiler/mlir/quantization/tensorflow/python:type_casters",
"//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo:tf_to_stablehlo",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/strings:str_format",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Support",
":pywrap_tensorflow_to_stablehlo_lib_header_only",
"//third_party/python_runtime:headers",
"@pybind11",
"@pybind11_abseil//pybind11_abseil:absl_casters",
"@pybind11_abseil//pybind11_abseil:status_casters",

View File

@ -12,142 +12,51 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <Python.h>
#include "absl/strings/str_format.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "pybind11/pybind11.h" // from @pybind11
#include "pybind11/pytypes.h" // from @pybind11
#include "pybind11/stl.h" // from @pybind11 // IWYU pragma: keep
#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep
#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep
#include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep
#include "tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.h"
namespace py = pybind11;
namespace mlir::pywrap {
namespace {
absl::StatusOr<std::string> ModuleToBytecode(ModuleOp module) {
std::string bytecode;
llvm::raw_string_ostream os(bytecode);
mlir::BytecodeWriterConfig config;
if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) {
return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed");
}
return bytecode;
}
using mlir::tensorflow_to_stablehlo::pywrap::PywrapSavedModelToStablehlo;
using mlir::tensorflow_to_stablehlo::pywrap::PywrapTfModuleToStablehlo;
absl::StatusOr<std::string> ExportModule(ModuleOp module) {
auto output_filename = absl::StrFormat(
"%s/tf_module.mlir", std::filesystem::temp_directory_path());
std::string error_msg;
auto output = openOutputFile(output_filename, &error_msg);
if (output == nullptr) {
return absl::AbortedError(
absl::StrCat("Unable to open output path: ", error_msg));
}
std::string result;
llvm::raw_string_ostream os(result);
OpPrintingFlags printing_flags;
module.print(os, printing_flags);
output->os() << result;
output->keep();
return output_filename;
}
py::bytes PywrapSavedModelToStablehlo(
absl::string_view input_path,
const std::vector<std::string>& exported_model_signatures =
{"serving_default"},
const std::vector<std::string>& tag_names = {"serve"},
absl::string_view input_arg_shapes_str = "") {
mlir::DialectRegistry registry;
RegisterAllTensorFlowDialects(registry);
mlir::MLIRContext context(registry);
context.loadAllAvailableDialects();
auto module =
TfToStablehlo(input_path, &context, exported_model_signatures, tag_names,
input_arg_shapes_str, /*is_input_mlir_module=*/false);
if (!module.ok()) {
PyErr_SetString(PyExc_ValueError,
"failed to converted TensorFlow to StableHLO");
return {};
}
auto bytecode = ModuleToBytecode(module.value().get());
if (!bytecode.ok()) {
PyErr_SetString(PyExc_ValueError, "failed to write module to bytecode");
return {};
}
return bytecode.value();
}
py::bytes PywrapTfModuleToStablehlo(
absl::string_view module_op_str,
absl::string_view input_arg_shapes_str = "") {
mlir::DialectRegistry registry;
RegisterAllTensorFlowDialects(registry);
mlir::MLIRContext context(registry);
context.loadAllAvailableDialects();
auto tf_module = mlir::parseSourceString<ModuleOp>(module_op_str, &context);
if (!tf_module) {
PyErr_SetString(PyExc_ValueError, "failed to parse TF module string");
return {};
}
auto mlir_file_path = ExportModule(*tf_module);
if (!mlir_file_path.ok()) {
PyErr_SetString(PyExc_ValueError,
"failed to write TF module to a temporary file");
return {};
}
auto module = TfToStablehlo(
*mlir_file_path, &context, /*exported_model_signatures=*/{},
/*tag_names=*/{}, input_arg_shapes_str, /*is_input_mlir_module=*/true);
if (!module.ok()) {
PyErr_SetString(PyExc_ValueError,
"failed to converted TensorFlow to StableHLO");
return {};
}
auto bytecode = ModuleToBytecode(module.value().get());
if (!bytecode.ok()) {
PyErr_SetString(PyExc_ValueError, "failed to write module to bytecode");
return {};
}
return bytecode.value();
}
} // namespace mlir::pywrap
} // namespace
PYBIND11_MODULE(pywrap_tensorflow_to_stablehlo, m) {
m.doc() = "TensorFlow to StableHLO APIs.";
// LINT.IfChange(savedmodel_to_stablehlo)
m.def("savedmodel_to_stablehlo", &mlir::pywrap::PywrapSavedModelToStablehlo,
m.def(
"savedmodel_to_stablehlo",
[](absl::string_view input_path,
const std::vector<std::string>& exported_model_signatures =
{"serving_default"},
const std::vector<std::string>& tag_names = {"serve"},
absl::string_view input_arg_shapes_str = "") -> py::bytes {
auto module_bytecode =
PywrapSavedModelToStablehlo(input_path, exported_model_signatures,
tag_names, input_arg_shapes_str);
if (!module_bytecode.ok()) {
PyErr_SetString(PyExc_ValueError,
module_bytecode.status().ToString().c_str());
throw py::error_already_set();
}
return py::bytes(module_bytecode.value());
},
R"pbdoc(
This tool converts TensorFlow SavedModel to StableHLO.
Converts a TensorFlow SavedModel into StableHLO bytecode.
* input-path: The path to the input TensorFlow SavedModel.
* exported-model-signatures: Comma-separated list of exported model
signatures to convert. Ignored for MLIR input.
* tag_names: Comma-separated list of tags for loading SavedModel. Ignored for MLIR
input.
signatures to convert.
* tag_names: Comma-separated list of tags for loading SavedModel.
* input-arg-shapes: A string representation of input argument shapes for
'main' entry-point, separating tensors with ':', dimension with ',', and
using '?' for unknown sizes. For example, 'input-arg-shapes=1,2::1,?'
@ -158,13 +67,24 @@ PYBIND11_MODULE(pywrap_tensorflow_to_stablehlo, m) {
std::vector<std::string>{"serving_default"},
py::arg("tag_names") = std::vector<std::string>{"serve"},
py::arg("input_arg_shapes_str") = "");
// LINT.ThenChange(pywrap_tensorflow_to_stablehlo.pyi:tensorflow_to_stablehlo)
// LINT.ThenChange(pywrap_tensorflow_to_stablehlo.pyi:savedmodel_to_stablehlo)
//
// LINT.IfChange(tensorflow_mlir_to_stablehlo)
m.def("tensorflow_module_to_stablehlo",
&mlir::pywrap::PywrapTfModuleToStablehlo,
// LINT.IfChange(tensorflow_module_to_stablehlo)
m.def(
"tensorflow_module_to_stablehlo",
[](absl::string_view module_op_str,
absl::string_view input_arg_shapes_str) -> py::bytes {
auto module_bytecode =
PywrapTfModuleToStablehlo(module_op_str, input_arg_shapes_str);
if (!module_bytecode.ok()) {
PyErr_SetString(PyExc_ValueError,
module_bytecode.status().ToString().c_str());
throw py::error_already_set();
}
return py::bytes(module_bytecode.value());
},
R"pbdoc(
This tool converts TensorFlow mlir module string to StableHLO.
Converts a TensorFlow MLIR module string into StableHLO bytecode.
* module: TensorFlow MLIR module string.
* input-arg-shapes: A string representation of input argument shapes for
@ -173,5 +93,5 @@ PYBIND11_MODULE(pywrap_tensorflow_to_stablehlo, m) {
expresses argument shapes [1,2], [] and [1,?].
)pbdoc",
py::arg("module"), py::arg("input_arg_shapes_str") = "");
// LINT.ThenChange(pywrap_tensorflow_to_stablehlo.pyi:tensorflow_mlir_to_stablehlo)
// LINT.ThenChange(pywrap_tensorflow_to_stablehlo.pyi:tensorflow_module_to_stablehlo)
}

View File

@ -0,0 +1,141 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.h"
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/DialectRegistry.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/core/platform/path.h"
namespace mlir::tensorflow_to_stablehlo::pywrap {
absl::StatusOr<std::string> ModuleToBytecode(ModuleOp module) {
std::string bytecode;
llvm::raw_string_ostream os(bytecode);
mlir::BytecodeWriterConfig config;
if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) {
return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed");
}
return bytecode;
}
absl::StatusOr<std::string> ExportModule(ModuleOp module) {
const std::string output_filename = tensorflow::io::GetTempFilename(".mlir");
std::string error_msg;
auto output = openOutputFile(output_filename, &error_msg);
if (output == nullptr) {
return absl::UnknownError(
absl::StrCat("Unable to open output path: ", error_msg));
}
std::string result;
llvm::raw_string_ostream os(result);
OpPrintingFlags printing_flags;
module.print(os, printing_flags);
output->os() << result;
output->keep();
return output_filename;
}
absl::StatusOr<std::string> PywrapSavedModelToStablehlo(
absl::string_view input_path,
const std::vector<std::string>& exported_model_signatures,
const std::vector<std::string>& tag_names,
absl::string_view input_arg_shapes_str) {
mlir::DialectRegistry registry;
RegisterAllTensorFlowDialects(registry);
mlir::MLIRContext context(registry);
context.loadAllAvailableDialects();
auto module =
TfToStablehlo(input_path, &context, exported_model_signatures, tag_names,
input_arg_shapes_str, /*is_input_mlir_module=*/false);
if (!module.ok()) {
return absl::UnknownError(
absl::StrCat("Failed to convert SavedModel to StableHLO: ",
module.status().message()));
}
auto bytecode = ModuleToBytecode(module.value().get());
if (!bytecode.ok()) {
return absl::UnknownError(
absl::StrCat("Failed to serialize MLIR module to bytecode: ",
bytecode.status().message()));
}
return bytecode.value();
}
absl::StatusOr<std::string> PywrapTfModuleToStablehlo(
absl::string_view module_op_str, absl::string_view input_arg_shapes_str) {
mlir::DialectRegistry registry;
RegisterAllTensorFlowDialects(registry);
mlir::MLIRContext context(registry);
context.loadAllAvailableDialects();
auto tf_module = mlir::parseSourceString<ModuleOp>(module_op_str, &context);
if (!tf_module) {
return absl::UnknownError("Failed to parse MLIR module");
}
auto mlir_file_path = ExportModule(*tf_module);
if (!mlir_file_path.ok()) {
return absl::UnknownError(
absl::StrCat("Failed to write MLIR module to file.",
mlir_file_path.status().message()));
}
auto module = TfToStablehlo(*mlir_file_path, &context,
/*exported_model_signatures=*/{},
/*tag_names=*/{}, input_arg_shapes_str,
/*is_input_mlir_module=*/true);
if (!module.ok()) {
return absl::UnknownError(
absl::StrCat(" Failed to convert SavedModel to StableHLO: ",
module.status().message()));
}
auto bytecode = ModuleToBytecode(module.value().get());
if (!bytecode.ok()) {
return absl::UnknownError(
absl::StrCat("Failed to serialize MLIR module to bytecode: ",
bytecode.status().message()));
}
return bytecode.value();
}
} // namespace mlir::tensorflow_to_stablehlo::pywrap

View File

@ -0,0 +1,67 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TO_STABLEHLO_PYTHON_PYWRAP_TENSORFLOW_TO_STABLEHLO_LIB_H_
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TO_STABLEHLO_PYTHON_PYWRAP_TENSORFLOW_TO_STABLEHLO_LIB_H_
#include <string>
#include <vector>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
namespace mlir::tensorflow_to_stablehlo::pywrap {
// Converts a TensorFlow SavedModel to a StableHLO MLIR module and serializes it
// to bytecode.
//
// Args:
// input_path: The path to the SavedModel directory.
// exported_model_signatures: Comma-separated list of exported model
// signatures to convert. tag_names: Comma-separated list of tags for loading
// SavedModel.
// input_arg_shapes_str: A string representation of input argument
// shapes for 'main' entry-point, separating tensors with ':', dimension
// with ',', and using '?' for unknown sizes. For example,
// 'input-arg-shapes=1,2::1,?' expresses argument shapes [1,2], [] and [1,?].
//
// Returns:
// An absl::StatusOr containing the serialized bytecode of the StableHLO
// module on success, or an error status on failure.
absl::StatusOr<std::string> PywrapSavedModelToStablehlo(
absl::string_view input_path,
const std::vector<std::string>& exported_model_signatures,
const std::vector<std::string>& tag_names,
absl::string_view input_arg_shapes_str);
// Converts a TensorFlow MLIR module string to a StableHLO MLIR module and
// serializes it to bytecode.
//
// Args:
// module_op_str: TensorFlow MLIR module string.
// input_arg_shapes_str: A string representation of input argument
// shapes for 'main' entry-point, separating tensors with ':', dimension
// with ',', and using '?' for unknown sizes. For example,
// 'input-arg-shapes=1,2::1,?' expresses argument shapes [1,2], [] and [1,?].
//
// Returns:
// An absl::StatusOr containing the serialized bytecode of the StableHLO
// module on success, or an error status on failure.
absl::StatusOr<std::string> PywrapTfModuleToStablehlo(
absl::string_view module_op_str, absl::string_view input_arg_shapes_str);
} // namespace mlir::tensorflow_to_stablehlo::pywrap
#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TO_STABLEHLO_PYTHON_PYWRAP_TENSORFLOW_TO_STABLEHLO_LIB_H_

View File

@ -134,6 +134,7 @@ py_strict_library(
":pywrap_tensorflow",
":pywrap_tfe",
"//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model",
"//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo",
"//tensorflow/core:protos_all_py",
"//tensorflow/dtensor/python:dtensor",
"//tensorflow/python/autograph",
@ -768,6 +769,7 @@ pywrap_tensorflow_macro(
"//tensorflow/cc/saved_model:metrics_impl",
"//tensorflow/compiler/mlir/quantization/stablehlo/python:pywrap_quantization_lib_impl",
"//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl",
"//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo_lib_impl",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/compiler/tf2tensorrt:op_converter_registry_impl",
"//tensorflow/compiler/tf2xla:tf2xla_opset",
@ -883,6 +885,7 @@ filegroup(
"//tensorflow/compiler/jit:flags", # tfe
"//tensorflow/compiler/jit:get_compiler_ir", # tfe
"//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl", # quantization
"//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo_lib_impl", # tensorflow_to_stablehlo
"//tensorflow/compiler/tf2xla:tf2xla_opset", # pywrap_xla_ops
"//tensorflow/core:framework_internal_impl", # op_def_registry
"//tensorflow/core:lib_internal_impl", # device_lib

View File

@ -566,6 +566,10 @@ tensorflow::quantization::QuantizeStaticRangePtq
tensorflow::quantization::QuantizeDynamicRangePtq
tensorflow::quantization::QuantizeWeightOnly
[//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo_lib_impl] # tensorflow_to_stablehlo
mlir::tensorflow_to_stablehlo::pywrap::PywrapSavedModelToStablehlo
mlir::tensorflow_to_stablehlo::pywrap::PywrapTfModuleToStablehlo
[//tensorflow/dtensor/cc:dtensor_device_cc] # DTensor
tensorflow::dtensor::AllocateDTensorDevice
tensorflow::dtensor::AddMesh

View File

@ -566,6 +566,10 @@ tensorflow::quantization::QuantizeStaticRangePtq
tensorflow::quantization::QuantizeDynamicRangePtq
tensorflow::quantization::QuantizeWeightOnly
[//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo_lib_impl] # tensorflow_to_stablehlo
mlir::tensorflow_to_stablehlo::pywrap::PywrapSavedModelToStablehlo
mlir::tensorflow_to_stablehlo::pywrap::PywrapTfModuleToStablehlo
[//tensorflow/dtensor/cc:dtensor_device_cc] # DTensor
tensorflow::dtensor::AllocateDTensorDevice
tensorflow::dtensor::AddMesh