mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Add SavedModel to StableHLO Converter to TensorFlow pip package
PiperOrigin-RevId: 635568823
This commit is contained in:
parent
98a7f0b73c
commit
651e9b8c8a
|
|
@ -31,6 +31,7 @@
|
|||
been added to TF binary distributions (Python wheels).
|
||||
* Replace `DebuggerOptions` of TensorFlow Quantizer, and migrate to
|
||||
`DebuggerConfig` of StableHLO Quantizer.
|
||||
* Add TensorFlow to StableHLO converter to TensorFlow pip package.
|
||||
|
||||
## Keras
|
||||
|
||||
|
|
|
|||
|
|
@ -1382,6 +1382,7 @@ tf_cc_shared_library(
|
|||
"//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config",
|
||||
"//tensorflow/compiler/mlir/lite/sparsity:sparsify_model",
|
||||
"//tensorflow/compiler/mlir/quantization/stablehlo/python:pywrap_quantization_lib_impl",
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo_lib_impl",
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op",
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl",
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow:passes",
|
||||
|
|
|
|||
|
|
@ -6,12 +6,10 @@ load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allo
|
|||
package_group(
|
||||
name = "internal_visibility_allowlist_package",
|
||||
packages = [
|
||||
"//learning/brain/mlir/quantization/stablehlo/python/integration_test/...",
|
||||
"//tensorflow/compiler/mlir/lite/...",
|
||||
"//tensorflow/compiler/mlir/quantization/...",
|
||||
"//tensorflow/compiler/mlir/tf2xla/transforms/...",
|
||||
"//tensorflow/lite/...",
|
||||
"//third_party/cloud_tpu/inference_converter/...", # TPU Inference Converter V1
|
||||
] + internal_visibility_allowlist(),
|
||||
)
|
||||
|
||||
|
|
@ -80,7 +78,6 @@ tf_cc_binary(
|
|||
glob_lit_tests(
|
||||
name = "all_tests",
|
||||
data = [":test_utilities"],
|
||||
# TODO: b/288344501 - Enable OSS tests again when stable-quant-opt works well.
|
||||
default_tags = [
|
||||
"no_oss",
|
||||
"no_pip",
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
## Tensorflow SavedModel to StableHLO (tf-to-stablehlo-translate)
|
||||
# Tensorflow SavedModel to StableHLO (tf-to-stablehlo-translate)
|
||||
|
||||
### Description
|
||||
|
||||
This tool converts TensorFlow models (SavedModel or MLIR module) to StableHLO
|
||||
MLIR modules, preserving model structure and signatures. It enables seamless
|
||||
Converts TensorFlow models (SavedModel or MLIR module) to StableHLO MLIR
|
||||
modules, preserving model structure and signatures. It enables seamless
|
||||
integration of TensorFlow models into MLIR-based compiler frameworks for further
|
||||
optimization and deployment.
|
||||
|
||||
### Usage
|
||||
## C++ APIs
|
||||
|
||||
```bash
|
||||
tf-to-stablehlo-translate \
|
||||
|
|
@ -58,3 +56,68 @@ tf-to-stablehlo-translate <saved-model-path> --input-arg-shapes=1,12:1,12:1,12
|
|||
* TensorFlow
|
||||
* MLIR
|
||||
* Abseil (absl)
|
||||
|
||||
## Python APIs
|
||||
|
||||
|
||||
### `savedmodel_to_stablehlo`
|
||||
|
||||
Converts a TensorFlow SavedModel into StableHLO bytecode.
|
||||
|
||||
```Python
|
||||
from tensorflow.compiler.mlir.quantization.tensorflow_to_stablehlo.python import pywrap_tensorflow_to_stablehlo as tf2shlo
|
||||
|
||||
stablehlo_bytes = tf2shlo.savedmodel_to_stablehlo(
|
||||
input_path="/path/to/your/savedmodel",
|
||||
exported_model_signatures=["serving_default"],
|
||||
tag_names=["serve"],
|
||||
input_arg_shapes_str="1,28,28,3::32"
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
#### Arguments:
|
||||
|
||||
* `input_path` (required): Path to your SavedModel directory.
|
||||
* `exported_model_signatures` (optional): List of signature names to convert.
|
||||
Defaults to ["serving_default"].
|
||||
* `tag_names` (optional): List of tags associated with the SavedModel. Defaults
|
||||
to ["serve"].
|
||||
* `input_arg_shapes_str` (optional): A string representation of input argument
|
||||
shapes for 'main' entry-point, separating
|
||||
tensors with ':', dimension with ',', and
|
||||
using '?' for unknown sizes. For example,
|
||||
`input-arg-shapes=1,2::1,?` expresses
|
||||
argument shapes `[1,2], [] and [1,?]`.
|
||||
|
||||
#### Error Handling
|
||||
|
||||
An exception will be raised with details about the error.
|
||||
|
||||
### `tensorflow_module_to_stablehlo`
|
||||
|
||||
Converts a TensorFlow MLIR module string into StableHLO bytecode.
|
||||
|
||||
```Python
|
||||
from tensorflow.compiler.mlir.quantization.tensorflow_to_stablehlo.python import pywrap_tensorflow_to_stablehlo as tf2shlo
|
||||
|
||||
stablehlo_bytes = tf2shlo.tensorflow_module_to_stablehlo(
|
||||
module_op_str="your_tensorflow_mlir_module_string",
|
||||
input_arg_shapes_str="1,28,28,3::32"
|
||||
)
|
||||
```
|
||||
|
||||
#### Arguments:
|
||||
|
||||
* `module_op_str` (required): String containing the TensorFlow MLIR module.
|
||||
* `input_arg_shapes_str` (optional): A string representation of input argument
|
||||
shapes for 'main' entry-point, separating
|
||||
tensors with ':', dimension with ',', and
|
||||
using '?' for unknown sizes. For example,
|
||||
`input-arg-shapes=1,2::1,?` expresses
|
||||
argument shapes `[1,2], [] and [1,?]`.
|
||||
|
||||
#### Error Handling
|
||||
|
||||
Return `py::none()` (equivalent to Python's `None`) if there's an error. An
|
||||
exception will be raised with details about the error.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
load(
|
||||
"//tensorflow:tensorflow.default.bzl",
|
||||
"get_compatible_with_portable",
|
||||
"tf_py_strict_test",
|
||||
"tf_python_pybind_extension",
|
||||
)
|
||||
|
|
@ -20,6 +21,7 @@ package(
|
|||
default_visibility = [
|
||||
":internal_visibility_allowlist_package",
|
||||
"//tensorflow:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
licenses = ["notice"],
|
||||
)
|
||||
|
|
@ -45,20 +47,60 @@ package(
|
|||
# )
|
||||
# copybara:uncomment_end
|
||||
|
||||
# This is a header-only target. The purpose of `pywrap_tensorflow_to_stablehlo_lib_*` targets is to expose only
|
||||
# the symbols that are required by `pywrap_tensorflow_to_stablehlo` that translates them to python functions.
|
||||
# The only intended use case of this library is by `pywrap_tensorflow_to_stablehlo`. Not letting
|
||||
# `pywrap_tensorflow_to_stablehlo` directly depend on sub-libraries like `static_range_srq` and instead haiving
|
||||
# a consolidated impl library `pywrap_tensorflow_to_stablehlo_lib_impl` allows the maintainers to avoid
|
||||
# declaring multiple impl libraries to `libtensorflow_cc` and `lib_pywrap_tensorflow_internal`,
|
||||
# which is required to avoid ODR violations.
|
||||
cc_library(
|
||||
name = "pywrap_tensorflow_to_stablehlo_lib_header_only",
|
||||
srcs = [],
|
||||
hdrs = ["pywrap_tensorflow_to_stablehlo_lib.h"],
|
||||
compatible_with = get_compatible_with_portable(),
|
||||
visibility = ["//visibility:private"], # ONLY for `pywrap_tensorflow_to_stablehlo`.
|
||||
deps = [
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
],
|
||||
)
|
||||
|
||||
# See the comments for `pywrap_tensorflow_to_stablehlo_lib_header_only`.
|
||||
cc_library(
|
||||
name = "pywrap_tensorflow_to_stablehlo_lib_impl",
|
||||
srcs = ["pywrap_tensorflow_to_stablehlo_lib.cc"],
|
||||
hdrs = ["pywrap_tensorflow_to_stablehlo_lib.h"],
|
||||
compatible_with = get_compatible_with_portable(),
|
||||
visibility = [
|
||||
"//tensorflow:__pkg__", # For libtensorflow_cc.so.
|
||||
"//tensorflow/python:__pkg__", # For lib_pywrap_tensorflow_internal.so.
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo:tf_to_stablehlo",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:BytecodeWriter",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "pywrap_tensorflow_to_stablehlo",
|
||||
srcs = ["pywrap_tensorflow_to_stablehlo.cc"],
|
||||
pytype_srcs = ["pywrap_tensorflow_to_stablehlo.pyi"],
|
||||
# Each dependency MUST be either header-only or exclusive.
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow/python:type_casters",
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo:tf_to_stablehlo",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Support",
|
||||
":pywrap_tensorflow_to_stablehlo_lib_header_only",
|
||||
"//third_party/python_runtime:headers",
|
||||
"@pybind11",
|
||||
"@pybind11_abseil//pybind11_abseil:absl_casters",
|
||||
"@pybind11_abseil//pybind11_abseil:status_casters",
|
||||
|
|
|
|||
|
|
@ -12,159 +12,79 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <Python.h>
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
|
||||
#include "mlir/Parser/Parser.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "pybind11/pybind11.h" // from @pybind11
|
||||
#include "pybind11/pytypes.h" // from @pybind11
|
||||
#include "pybind11/stl.h" // from @pybind11 // IWYU pragma: keep
|
||||
#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep
|
||||
#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep
|
||||
#include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep
|
||||
#include "tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mlir::pywrap {
|
||||
namespace {
|
||||
|
||||
absl::StatusOr<std::string> ModuleToBytecode(ModuleOp module) {
|
||||
std::string bytecode;
|
||||
llvm::raw_string_ostream os(bytecode);
|
||||
mlir::BytecodeWriterConfig config;
|
||||
if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) {
|
||||
return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed");
|
||||
}
|
||||
return bytecode;
|
||||
}
|
||||
using mlir::tensorflow_to_stablehlo::pywrap::PywrapSavedModelToStablehlo;
|
||||
using mlir::tensorflow_to_stablehlo::pywrap::PywrapTfModuleToStablehlo;
|
||||
|
||||
absl::StatusOr<std::string> ExportModule(ModuleOp module) {
|
||||
auto output_filename = absl::StrFormat(
|
||||
"%s/tf_module.mlir", std::filesystem::temp_directory_path());
|
||||
|
||||
std::string error_msg;
|
||||
auto output = openOutputFile(output_filename, &error_msg);
|
||||
if (output == nullptr) {
|
||||
return absl::AbortedError(
|
||||
absl::StrCat("Unable to open output path: ", error_msg));
|
||||
}
|
||||
|
||||
std::string result;
|
||||
llvm::raw_string_ostream os(result);
|
||||
OpPrintingFlags printing_flags;
|
||||
module.print(os, printing_flags);
|
||||
|
||||
output->os() << result;
|
||||
output->keep();
|
||||
|
||||
return output_filename;
|
||||
}
|
||||
|
||||
py::bytes PywrapSavedModelToStablehlo(
|
||||
absl::string_view input_path,
|
||||
const std::vector<std::string>& exported_model_signatures =
|
||||
{"serving_default"},
|
||||
const std::vector<std::string>& tag_names = {"serve"},
|
||||
absl::string_view input_arg_shapes_str = "") {
|
||||
mlir::DialectRegistry registry;
|
||||
RegisterAllTensorFlowDialects(registry);
|
||||
mlir::MLIRContext context(registry);
|
||||
context.loadAllAvailableDialects();
|
||||
|
||||
auto module =
|
||||
TfToStablehlo(input_path, &context, exported_model_signatures, tag_names,
|
||||
input_arg_shapes_str, /*is_input_mlir_module=*/false);
|
||||
|
||||
if (!module.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"failed to converted TensorFlow to StableHLO");
|
||||
return {};
|
||||
}
|
||||
|
||||
auto bytecode = ModuleToBytecode(module.value().get());
|
||||
if (!bytecode.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, "failed to write module to bytecode");
|
||||
return {};
|
||||
}
|
||||
|
||||
return bytecode.value();
|
||||
}
|
||||
|
||||
py::bytes PywrapTfModuleToStablehlo(
|
||||
absl::string_view module_op_str,
|
||||
absl::string_view input_arg_shapes_str = "") {
|
||||
mlir::DialectRegistry registry;
|
||||
RegisterAllTensorFlowDialects(registry);
|
||||
mlir::MLIRContext context(registry);
|
||||
context.loadAllAvailableDialects();
|
||||
|
||||
auto tf_module = mlir::parseSourceString<ModuleOp>(module_op_str, &context);
|
||||
if (!tf_module) {
|
||||
PyErr_SetString(PyExc_ValueError, "failed to parse TF module string");
|
||||
return {};
|
||||
}
|
||||
|
||||
auto mlir_file_path = ExportModule(*tf_module);
|
||||
if (!mlir_file_path.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"failed to write TF module to a temporary file");
|
||||
return {};
|
||||
}
|
||||
|
||||
auto module = TfToStablehlo(
|
||||
*mlir_file_path, &context, /*exported_model_signatures=*/{},
|
||||
/*tag_names=*/{}, input_arg_shapes_str, /*is_input_mlir_module=*/true);
|
||||
|
||||
if (!module.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"failed to converted TensorFlow to StableHLO");
|
||||
return {};
|
||||
}
|
||||
|
||||
auto bytecode = ModuleToBytecode(module.value().get());
|
||||
if (!bytecode.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, "failed to write module to bytecode");
|
||||
return {};
|
||||
}
|
||||
|
||||
return bytecode.value();
|
||||
}
|
||||
|
||||
} // namespace mlir::pywrap
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(pywrap_tensorflow_to_stablehlo, m) {
|
||||
m.doc() = "TensorFlow to StableHLO APIs.";
|
||||
|
||||
// LINT.IfChange(savedmodel_to_stablehlo)
|
||||
m.def("savedmodel_to_stablehlo", &mlir::pywrap::PywrapSavedModelToStablehlo,
|
||||
R"pbdoc(
|
||||
This tool converts TensorFlow SavedModel to StableHLO.
|
||||
m.def(
|
||||
"savedmodel_to_stablehlo",
|
||||
[](absl::string_view input_path,
|
||||
const std::vector<std::string>& exported_model_signatures =
|
||||
{"serving_default"},
|
||||
const std::vector<std::string>& tag_names = {"serve"},
|
||||
absl::string_view input_arg_shapes_str = "") -> py::bytes {
|
||||
auto module_bytecode =
|
||||
PywrapSavedModelToStablehlo(input_path, exported_model_signatures,
|
||||
tag_names, input_arg_shapes_str);
|
||||
if (!module_bytecode.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
module_bytecode.status().ToString().c_str());
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return py::bytes(module_bytecode.value());
|
||||
},
|
||||
R"pbdoc(
|
||||
Converts a TensorFlow SavedModel into StableHLO bytecode.
|
||||
|
||||
* input-path: The path to the input TensorFlow SavedModel.
|
||||
* exported-model-signatures: Comma-separated list of exported model
|
||||
signatures to convert. Ignored for MLIR input.
|
||||
* tag_names: Comma-separated list of tags for loading SavedModel. Ignored for MLIR
|
||||
input.
|
||||
signatures to convert.
|
||||
* tag_names: Comma-separated list of tags for loading SavedModel.
|
||||
* input-arg-shapes: A string representation of input argument shapes for
|
||||
'main' entry-point, separating tensors with ':', dimension with ',', and
|
||||
using '?' for unknown sizes. For example, 'input-arg-shapes=1,2::1,?'
|
||||
expresses argument shapes [1,2], [] and [1,?].
|
||||
)pbdoc",
|
||||
py::arg("input_path"),
|
||||
py::arg("exported_model_signatures") =
|
||||
std::vector<std::string>{"serving_default"},
|
||||
py::arg("tag_names") = std::vector<std::string>{"serve"},
|
||||
py::arg("input_arg_shapes_str") = "");
|
||||
// LINT.ThenChange(pywrap_tensorflow_to_stablehlo.pyi:tensorflow_to_stablehlo)
|
||||
py::arg("input_path"),
|
||||
py::arg("exported_model_signatures") =
|
||||
std::vector<std::string>{"serving_default"},
|
||||
py::arg("tag_names") = std::vector<std::string>{"serve"},
|
||||
py::arg("input_arg_shapes_str") = "");
|
||||
// LINT.ThenChange(pywrap_tensorflow_to_stablehlo.pyi:savedmodel_to_stablehlo)
|
||||
//
|
||||
// LINT.IfChange(tensorflow_mlir_to_stablehlo)
|
||||
m.def("tensorflow_module_to_stablehlo",
|
||||
&mlir::pywrap::PywrapTfModuleToStablehlo,
|
||||
R"pbdoc(
|
||||
This tool converts TensorFlow mlir module string to StableHLO.
|
||||
// LINT.IfChange(tensorflow_module_to_stablehlo)
|
||||
m.def(
|
||||
"tensorflow_module_to_stablehlo",
|
||||
[](absl::string_view module_op_str,
|
||||
absl::string_view input_arg_shapes_str) -> py::bytes {
|
||||
auto module_bytecode =
|
||||
PywrapTfModuleToStablehlo(module_op_str, input_arg_shapes_str);
|
||||
if (!module_bytecode.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
module_bytecode.status().ToString().c_str());
|
||||
throw py::error_already_set();
|
||||
}
|
||||
return py::bytes(module_bytecode.value());
|
||||
},
|
||||
R"pbdoc(
|
||||
Converts a TensorFlow MLIR module string into StableHLO bytecode.
|
||||
|
||||
* module: TensorFlow MLIR module string.
|
||||
* input-arg-shapes: A string representation of input argument shapes for
|
||||
|
|
@ -172,6 +92,6 @@ PYBIND11_MODULE(pywrap_tensorflow_to_stablehlo, m) {
|
|||
using '?' for unknown sizes. For example, 'input-arg-shapes=1,2::1,?'
|
||||
expresses argument shapes [1,2], [] and [1,?].
|
||||
)pbdoc",
|
||||
py::arg("module"), py::arg("input_arg_shapes_str") = "");
|
||||
// LINT.ThenChange(pywrap_tensorflow_to_stablehlo.pyi:tensorflow_mlir_to_stablehlo)
|
||||
py::arg("module"), py::arg("input_arg_shapes_str") = "");
|
||||
// LINT.ThenChange(pywrap_tensorflow_to_stablehlo.pyi:tensorflow_module_to_stablehlo)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,141 @@
|
|||
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/DialectRegistry.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/Parser/Parser.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/tf_to_stablehlo.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
|
||||
namespace mlir::tensorflow_to_stablehlo::pywrap {
|
||||
|
||||
absl::StatusOr<std::string> ModuleToBytecode(ModuleOp module) {
|
||||
std::string bytecode;
|
||||
llvm::raw_string_ostream os(bytecode);
|
||||
mlir::BytecodeWriterConfig config;
|
||||
if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) {
|
||||
return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed");
|
||||
}
|
||||
return bytecode;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::string> ExportModule(ModuleOp module) {
|
||||
const std::string output_filename = tensorflow::io::GetTempFilename(".mlir");
|
||||
std::string error_msg;
|
||||
auto output = openOutputFile(output_filename, &error_msg);
|
||||
if (output == nullptr) {
|
||||
return absl::UnknownError(
|
||||
absl::StrCat("Unable to open output path: ", error_msg));
|
||||
}
|
||||
|
||||
std::string result;
|
||||
llvm::raw_string_ostream os(result);
|
||||
OpPrintingFlags printing_flags;
|
||||
module.print(os, printing_flags);
|
||||
|
||||
output->os() << result;
|
||||
output->keep();
|
||||
|
||||
return output_filename;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::string> PywrapSavedModelToStablehlo(
|
||||
absl::string_view input_path,
|
||||
const std::vector<std::string>& exported_model_signatures,
|
||||
const std::vector<std::string>& tag_names,
|
||||
absl::string_view input_arg_shapes_str) {
|
||||
mlir::DialectRegistry registry;
|
||||
RegisterAllTensorFlowDialects(registry);
|
||||
mlir::MLIRContext context(registry);
|
||||
context.loadAllAvailableDialects();
|
||||
|
||||
auto module =
|
||||
TfToStablehlo(input_path, &context, exported_model_signatures, tag_names,
|
||||
input_arg_shapes_str, /*is_input_mlir_module=*/false);
|
||||
|
||||
if (!module.ok()) {
|
||||
return absl::UnknownError(
|
||||
absl::StrCat("Failed to convert SavedModel to StableHLO: ",
|
||||
module.status().message()));
|
||||
}
|
||||
|
||||
auto bytecode = ModuleToBytecode(module.value().get());
|
||||
if (!bytecode.ok()) {
|
||||
return absl::UnknownError(
|
||||
absl::StrCat("Failed to serialize MLIR module to bytecode: ",
|
||||
bytecode.status().message()));
|
||||
}
|
||||
|
||||
return bytecode.value();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::string> PywrapTfModuleToStablehlo(
|
||||
absl::string_view module_op_str, absl::string_view input_arg_shapes_str) {
|
||||
mlir::DialectRegistry registry;
|
||||
RegisterAllTensorFlowDialects(registry);
|
||||
mlir::MLIRContext context(registry);
|
||||
context.loadAllAvailableDialects();
|
||||
|
||||
auto tf_module = mlir::parseSourceString<ModuleOp>(module_op_str, &context);
|
||||
if (!tf_module) {
|
||||
return absl::UnknownError("Failed to parse MLIR module");
|
||||
}
|
||||
|
||||
auto mlir_file_path = ExportModule(*tf_module);
|
||||
if (!mlir_file_path.ok()) {
|
||||
return absl::UnknownError(
|
||||
absl::StrCat("Failed to write MLIR module to file.",
|
||||
mlir_file_path.status().message()));
|
||||
}
|
||||
|
||||
auto module = TfToStablehlo(*mlir_file_path, &context,
|
||||
/*exported_model_signatures=*/{},
|
||||
/*tag_names=*/{}, input_arg_shapes_str,
|
||||
/*is_input_mlir_module=*/true);
|
||||
|
||||
if (!module.ok()) {
|
||||
return absl::UnknownError(
|
||||
absl::StrCat(" Failed to convert SavedModel to StableHLO: ",
|
||||
module.status().message()));
|
||||
}
|
||||
|
||||
auto bytecode = ModuleToBytecode(module.value().get());
|
||||
if (!bytecode.ok()) {
|
||||
return absl::UnknownError(
|
||||
absl::StrCat("Failed to serialize MLIR module to bytecode: ",
|
||||
bytecode.status().message()));
|
||||
}
|
||||
|
||||
return bytecode.value();
|
||||
}
|
||||
|
||||
} // namespace mlir::tensorflow_to_stablehlo::pywrap
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TO_STABLEHLO_PYTHON_PYWRAP_TENSORFLOW_TO_STABLEHLO_LIB_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TO_STABLEHLO_PYTHON_PYWRAP_TENSORFLOW_TO_STABLEHLO_LIB_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
|
||||
namespace mlir::tensorflow_to_stablehlo::pywrap {
|
||||
|
||||
// Converts a TensorFlow SavedModel to a StableHLO MLIR module and serializes it
|
||||
// to bytecode.
|
||||
//
|
||||
// Args:
|
||||
// input_path: The path to the SavedModel directory.
|
||||
// exported_model_signatures: Comma-separated list of exported model
|
||||
// signatures to convert. tag_names: Comma-separated list of tags for loading
|
||||
// SavedModel.
|
||||
// input_arg_shapes_str: A string representation of input argument
|
||||
// shapes for 'main' entry-point, separating tensors with ':', dimension
|
||||
// with ',', and using '?' for unknown sizes. For example,
|
||||
// 'input-arg-shapes=1,2::1,?' expresses argument shapes [1,2], [] and [1,?].
|
||||
//
|
||||
// Returns:
|
||||
// An absl::StatusOr containing the serialized bytecode of the StableHLO
|
||||
// module on success, or an error status on failure.
|
||||
absl::StatusOr<std::string> PywrapSavedModelToStablehlo(
|
||||
absl::string_view input_path,
|
||||
const std::vector<std::string>& exported_model_signatures,
|
||||
const std::vector<std::string>& tag_names,
|
||||
absl::string_view input_arg_shapes_str);
|
||||
|
||||
// Converts a TensorFlow MLIR module string to a StableHLO MLIR module and
|
||||
// serializes it to bytecode.
|
||||
//
|
||||
// Args:
|
||||
// module_op_str: TensorFlow MLIR module string.
|
||||
// input_arg_shapes_str: A string representation of input argument
|
||||
// shapes for 'main' entry-point, separating tensors with ':', dimension
|
||||
// with ',', and using '?' for unknown sizes. For example,
|
||||
// 'input-arg-shapes=1,2::1,?' expresses argument shapes [1,2], [] and [1,?].
|
||||
//
|
||||
// Returns:
|
||||
// An absl::StatusOr containing the serialized bytecode of the StableHLO
|
||||
// module on success, or an error status on failure.
|
||||
absl::StatusOr<std::string> PywrapTfModuleToStablehlo(
|
||||
absl::string_view module_op_str, absl::string_view input_arg_shapes_str);
|
||||
|
||||
} // namespace mlir::tensorflow_to_stablehlo::pywrap
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TO_STABLEHLO_PYTHON_PYWRAP_TENSORFLOW_TO_STABLEHLO_LIB_H_
|
||||
|
|
@ -134,6 +134,7 @@ py_strict_library(
|
|||
":pywrap_tensorflow",
|
||||
":pywrap_tfe",
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model",
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/dtensor/python:dtensor",
|
||||
"//tensorflow/python/autograph",
|
||||
|
|
@ -768,6 +769,7 @@ pywrap_tensorflow_macro(
|
|||
"//tensorflow/cc/saved_model:metrics_impl",
|
||||
"//tensorflow/compiler/mlir/quantization/stablehlo/python:pywrap_quantization_lib_impl",
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl",
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo_lib_impl",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/compiler/tf2tensorrt:op_converter_registry_impl",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_opset",
|
||||
|
|
@ -883,6 +885,7 @@ filegroup(
|
|||
"//tensorflow/compiler/jit:flags", # tfe
|
||||
"//tensorflow/compiler/jit:get_compiler_ir", # tfe
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl", # quantization
|
||||
"//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo_lib_impl", # tensorflow_to_stablehlo
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_opset", # pywrap_xla_ops
|
||||
"//tensorflow/core:framework_internal_impl", # op_def_registry
|
||||
"//tensorflow/core:lib_internal_impl", # device_lib
|
||||
|
|
|
|||
|
|
@ -566,6 +566,10 @@ tensorflow::quantization::QuantizeStaticRangePtq
|
|||
tensorflow::quantization::QuantizeDynamicRangePtq
|
||||
tensorflow::quantization::QuantizeWeightOnly
|
||||
|
||||
[//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo_lib_impl] # tensorflow_to_stablehlo
|
||||
mlir::tensorflow_to_stablehlo::pywrap::PywrapSavedModelToStablehlo
|
||||
mlir::tensorflow_to_stablehlo::pywrap::PywrapTfModuleToStablehlo
|
||||
|
||||
[//tensorflow/dtensor/cc:dtensor_device_cc] # DTensor
|
||||
tensorflow::dtensor::AllocateDTensorDevice
|
||||
tensorflow::dtensor::AddMesh
|
||||
|
|
|
|||
|
|
@ -566,6 +566,10 @@ tensorflow::quantization::QuantizeStaticRangePtq
|
|||
tensorflow::quantization::QuantizeDynamicRangePtq
|
||||
tensorflow::quantization::QuantizeWeightOnly
|
||||
|
||||
[//tensorflow/compiler/mlir/quantization/tensorflow_to_stablehlo/python:pywrap_tensorflow_to_stablehlo_lib_impl] # tensorflow_to_stablehlo
|
||||
mlir::tensorflow_to_stablehlo::pywrap::PywrapSavedModelToStablehlo
|
||||
mlir::tensorflow_to_stablehlo::pywrap::PywrapTfModuleToStablehlo
|
||||
|
||||
[//tensorflow/dtensor/cc:dtensor_device_cc] # DTensor
|
||||
tensorflow::dtensor::AllocateDTensorDevice
|
||||
tensorflow::dtensor::AddMesh
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user