Replace GetFingerprintMap() and MaybeReadSavedModelChecksum() with ReadSavedModelFingerprint().

PiperOrigin-RevId: 515432111
This commit is contained in:
Adam Cogdell 2023-03-09 13:54:28 -08:00 committed by TensorFlower Gardener
parent b523cef78b
commit b01db583ac
13 changed files with 141 additions and 170 deletions

View File

@ -90,6 +90,12 @@
* `tf.data.Dataset.zip` now supports Python-style zipping, i.e. * `tf.data.Dataset.zip` now supports Python-style zipping, i.e.
`Dataset.zip(a, b, c)`. `Dataset.zip(a, b, c)`.
* `tf.SavedModel`
* Introduce class method
`tf.saved_model.experimental.Fingerprint.from_proto(proto)`, which can
be used to construct a `Fingerprint` object directly from a protobuf.
## Bug Fixes and Other Changes ## Bug Fixes and Other Changes
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES> * <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>

View File

@ -38,7 +38,6 @@ limitations under the License.
#include "tensorflow/core/protobuf/saved_model.pb.h" #include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/util/tensor_bundle/naming.h" #include "tensorflow/core/util/tensor_bundle/naming.h"
#include "tensorflow/tsl/lib/strings/proto_serialization.h"
namespace tensorflow::saved_model::fingerprinting { namespace tensorflow::saved_model::fingerprinting {
@ -71,7 +70,7 @@ uint64 RegularizeAndHashSignatureDefs(
// The SavedObjectGraph contains two parts: the list of nodes and the map of // The SavedObjectGraph contains two parts: the list of nodes and the map of
// concrete functions. Regularization treats these two parts separately. // concrete functions. Regularization treats these two parts separately.
uint64 RegularizeAndHashSavedObjectGraph( StatusOr<uint64> RegularizeAndHashSavedObjectGraph(
const SavedObjectGraph& object_graph_def) { const SavedObjectGraph& object_graph_def) {
// Sort `concrete_functions`, which is an unordered map from function names to // Sort `concrete_functions`, which is an unordered map from function names to
// SavedConcreteFunction, using the suffix UID of the function name. Assumes // SavedConcreteFunction, using the suffix UID of the function name. Assumes
@ -80,13 +79,9 @@ uint64 RegularizeAndHashSavedObjectGraph(
absl::btree_map<int, std::string> uid_to_function_names; absl::btree_map<int, std::string> uid_to_function_names;
for (const auto& [name, concrete_function] : for (const auto& [name, concrete_function] :
object_graph_def.concrete_functions()) { object_graph_def.concrete_functions()) {
StatusOr<int> uid = graph_regularization::GetSuffixUID(name);
// All valid function names should end in an UID. // All valid function names should end in an UID.
if (uid.ok()) { TF_ASSIGN_OR_RETURN(int uid, graph_regularization::GetSuffixUID(name));
uid_to_function_names.insert({*uid, name}); uid_to_function_names.insert({uid, name});
} else {
LOG(ERROR) << uid.status().error_message();
}
} }
uint64 result_hash = 0; uint64 result_hash = 0;
for (const auto& [uid, function_name] : uid_to_function_names) { for (const auto& [uid, function_name] : uid_to_function_names) {
@ -115,15 +110,14 @@ uint64 HashCheckpointIndexFile(absl::string_view model_dir) {
if (read_status.ok()) { if (read_status.ok()) {
return tensorflow::Fingerprint64(data); return tensorflow::Fingerprint64(data);
} else { } else {
LOG(WARNING) << read_status.error_message();
return 0; return 0;
} }
} }
} // namespace } // namespace
FingerprintDef CreateFingerprintDef(const SavedModel& saved_model, StatusOr<FingerprintDef> CreateFingerprintDef(const SavedModel& saved_model,
absl::string_view export_dir) { absl::string_view export_dir) {
// Create a copy of `metagraph` which will be used and mutated for fingerprint // Create a copy of `metagraph` which will be used and mutated for fingerprint
// computation. // computation.
MetaGraphDef metagraph_copy = saved_model.meta_graphs(0); MetaGraphDef metagraph_copy = saved_model.meta_graphs(0);
@ -138,10 +132,10 @@ FingerprintDef CreateFingerprintDef(const SavedModel& saved_model,
fingerprint_def.set_signature_def_hash( fingerprint_def.set_signature_def_hash(
RegularizeAndHashSignatureDefs(metagraph_copy.signature_def())); RegularizeAndHashSignatureDefs(metagraph_copy.signature_def()));
// Set fingerprint field #4. // Set fingerprint field #4.
StatusOr<uint64> object_graph_hash = TF_ASSIGN_OR_RETURN(
RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def()); StatusOr<uint64> object_graph_hash,
fingerprint_def.set_saved_object_graph_hash(
RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def())); RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def()));
fingerprint_def.set_saved_object_graph_hash(object_graph_hash.value());
// Set fingerprint field #5. // Set fingerprint field #5.
fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir)); fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir));
// Set version of the fingerprint. // Set version of the fingerprint.
@ -168,18 +162,4 @@ StatusOr<FingerprintDef> ReadSavedModelFingerprint(
return found_pb; return found_pb;
} }
std::unordered_map<std::string, uint64_t> MakeFingerprintMap(
const FingerprintDef& fingerprint) {
std::unordered_map<std::string, uint64_t> fingerprint_map;
fingerprint_map["saved_model_checksum"] = fingerprint.saved_model_checksum();
fingerprint_map["graph_def_program_hash"] =
fingerprint.graph_def_program_hash();
fingerprint_map["signature_def_hash"] = fingerprint.signature_def_hash();
fingerprint_map["saved_object_graph_hash"] =
fingerprint.saved_object_graph_hash();
fingerprint_map["checkpoint_hash"] = fingerprint.checkpoint_hash();
fingerprint_map["version"] = fingerprint.version().producer();
return fingerprint_map;
}
} // namespace tensorflow::saved_model::fingerprinting } // namespace tensorflow::saved_model::fingerprinting

View File

@ -28,18 +28,14 @@ namespace tensorflow::saved_model::fingerprinting {
// Creates a FingerprintDef proto from a SavedModel and the checkpoint meta file // Creates a FingerprintDef proto from a SavedModel and the checkpoint meta file
// (.index) in `export_dir`. // (.index) in `export_dir`.
FingerprintDef CreateFingerprintDef(const SavedModel& saved_model, StatusOr<FingerprintDef> CreateFingerprintDef(const SavedModel& saved_model,
absl::string_view export_dir); absl::string_view export_dir);
// Loads the `fingerprint.pb` from `export_dir`, returns an error if there is // Loads the `fingerprint.pb` from `export_dir`, returns an error if there is
// none. // none.
StatusOr<FingerprintDef> ReadSavedModelFingerprint( StatusOr<FingerprintDef> ReadSavedModelFingerprint(
absl::string_view export_dir); absl::string_view export_dir);
// Converts the fingerprint into a dictionary mapping field names to values.
std::unordered_map<std::string, uint64_t> MakeFingerprintMap(
const FingerprintDef& fingerprint);
} // namespace tensorflow::saved_model::fingerprinting } // namespace tensorflow::saved_model::fingerprinting
#endif // TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_H_ #endif // TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_H_

View File

@ -52,8 +52,8 @@ TEST(FingerprintingTest, TestCreateFingerprint) {
"VarsAndArithmeticObjectGraph"); "VarsAndArithmeticObjectGraph");
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
ReadSavedModel(export_dir)); ReadSavedModel(export_dir));
FingerprintDef fingerprint_def = TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def,
CreateFingerprintDef(saved_model_pb, export_dir); CreateFingerprintDef(saved_model_pb, export_dir));
EXPECT_GT(fingerprint_def.saved_model_checksum(), 0); EXPECT_GT(fingerprint_def.saved_model_checksum(), 0);
EXPECT_EQ(fingerprint_def.graph_def_program_hash(), 10127142238652115842U); EXPECT_EQ(fingerprint_def.graph_def_program_hash(), 10127142238652115842U);
@ -72,15 +72,15 @@ TEST(FingerprintingTest, TestCompareFingerprintForTwoModelSavedTwice) {
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
ReadSavedModel(export_dir)); ReadSavedModel(export_dir));
FingerprintDef fingerprint_def = TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def,
CreateFingerprintDef(saved_model_pb, export_dir); CreateFingerprintDef(saved_model_pb, export_dir));
const std::string export_dir2 = io::JoinPath( const std::string export_dir2 = io::JoinPath(
testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert2"); testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert2");
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb2, TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb2,
ReadSavedModel(export_dir2)); ReadSavedModel(export_dir2));
FingerprintDef fingerprint_def2 = TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def2,
CreateFingerprintDef(saved_model_pb2, export_dir2); CreateFingerprintDef(saved_model_pb2, export_dir2));
EXPECT_EQ(fingerprint_def.graph_def_program_hash(), EXPECT_EQ(fingerprint_def.graph_def_program_hash(),
fingerprint_def2.graph_def_program_hash()); fingerprint_def2.graph_def_program_hash());
@ -95,10 +95,10 @@ TEST(FingerprintingTest, TestFingerprintComputationDoesNotMutateModel) {
testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert1"); testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert1");
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
ReadSavedModel(export_dir)); ReadSavedModel(export_dir));
FingerprintDef fingerprint_def = TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def,
CreateFingerprintDef(saved_model_pb, export_dir); CreateFingerprintDef(saved_model_pb, export_dir));
FingerprintDef fingerprint_def2 = TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def2,
CreateFingerprintDef(saved_model_pb, export_dir); CreateFingerprintDef(saved_model_pb, export_dir));
EXPECT_EQ(fingerprint_def.saved_model_checksum(), EXPECT_EQ(fingerprint_def.saved_model_checksum(),
fingerprint_def2.saved_model_checksum()); fingerprint_def2.saved_model_checksum());
@ -109,8 +109,8 @@ TEST(FingerprintingTest, TestFingerprintHasVersion) {
testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert1"); testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert1");
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
ReadSavedModel(export_dir)); ReadSavedModel(export_dir));
FingerprintDef fingerprint_def = TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def,
CreateFingerprintDef(saved_model_pb, export_dir); CreateFingerprintDef(saved_model_pb, export_dir));
EXPECT_EQ(fingerprint_def.version().producer(), 1); EXPECT_EQ(fingerprint_def.version().producer(), 1);
} }
@ -119,8 +119,8 @@ TEST(FingerprintingTest, TestHashCheckpointForModelWithNoVariables) {
testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert1"); testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert1");
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
ReadSavedModel(export_dir)); ReadSavedModel(export_dir));
FingerprintDef fingerprint_def = TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def,
CreateFingerprintDef(saved_model_pb, export_dir); CreateFingerprintDef(saved_model_pb, export_dir));
EXPECT_EQ(fingerprint_def.checkpoint_hash(), 0); EXPECT_EQ(fingerprint_def.checkpoint_hash(), 0);
} }
@ -139,16 +139,5 @@ TEST(FingerprintingTest, TestReadNonexistentFingerprint) {
EXPECT_FALSE(ReadSavedModelFingerprint(export_dir).ok()); EXPECT_FALSE(ReadSavedModelFingerprint(export_dir).ok());
} }
TEST(FingerprintingTest, TestMakeFingerprintMap) {
const std::string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata",
"VarsAndArithmeticObjectGraph");
TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_pb,
ReadSavedModelFingerprint(export_dir));
auto fingerprint_map = MakeFingerprintMap(fingerprint_pb);
EXPECT_EQ(fingerprint_pb.saved_model_checksum(),
fingerprint_map["saved_model_checksum"]);
}
} // namespace } // namespace
} // namespace tensorflow::saved_model::fingerprinting } // namespace tensorflow::saved_model::fingerprinting

View File

@ -426,6 +426,7 @@ py_library(
"ignore_for_dep=third_party.py.keras.optimizers.optimizer_v2", "ignore_for_dep=third_party.py.keras.optimizers.optimizer_v2",
], ],
deps = [ deps = [
":fingerprinting",
":function_deserialization", ":function_deserialization",
":load_options", ":load_options",
":load_v1_in_v2", ":load_v1_in_v2",
@ -466,6 +467,7 @@ py_library(
"//tensorflow/python/trackable:trackable_utils", "//tensorflow/python/trackable:trackable_utils",
"//tensorflow/python/training/saving:saveable_object_util", "//tensorflow/python/training/saving:saveable_object_util",
"//tensorflow/python/util:tf_export", "//tensorflow/python/util:tf_export",
"@absl_py//absl/logging",
], ],
) )
@ -704,6 +706,7 @@ tf_py_test(
name = "metrics_test", name = "metrics_test",
srcs = ["metrics_test.py"], srcs = ["metrics_test.py"],
deps = [ deps = [
":fingerprinting",
":pywrap_saved_model", ":pywrap_saved_model",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
], ],
@ -784,6 +787,7 @@ py_strict_library(
srcs = ["fingerprinting.py"], srcs = ["fingerprinting.py"],
deps = [ deps = [
":pywrap_saved_model", ":pywrap_saved_model",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/util:tf_export", "//tensorflow/python/util:tf_export",
], ],
) )

View File

@ -18,6 +18,7 @@ This module contains classes and functions for reading the SavedModel
fingerprint. fingerprint.
""" """
from tensorflow.core.protobuf import fingerprint_pb2
from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting as fingerprinting_pywrap from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting as fingerprinting_pywrap
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -65,6 +66,16 @@ class Fingerprint(object):
self.checkpoint_hash = checkpoint_hash self.checkpoint_hash = checkpoint_hash
self.version = version self.version = version
@classmethod
def from_proto(cls, proto):
return Fingerprint(
proto.saved_model_checksum,
proto.graph_def_program_hash,
proto.signature_def_hash,
proto.saved_object_graph_hash,
proto.checkpoint_hash,
proto.version)
@tf_export("saved_model.experimental.read_fingerprint", v1=[]) @tf_export("saved_model.experimental.read_fingerprint", v1=[])
def read_fingerprint(export_dir): def read_fingerprint(export_dir):
@ -73,7 +84,9 @@ def read_fingerprint(export_dir):
Returns a `tf.saved_model.experimental.Fingerprint` object that contains Returns a `tf.saved_model.experimental.Fingerprint` object that contains
the values of the SavedModel fingerprint, which is persisted on disk in the the values of the SavedModel fingerprint, which is persisted on disk in the
`fingerprint.pb` file in the `export_dir`. `fingerprint.pb` file in the `export_dir`.
TODO(b/265199038): Add link to TensorFlow SavedModel guide.
Read more about fingerprints in the SavedModel guide at
https://www.tensorflow.org/guide/saved_model.
Args: Args:
export_dir: The directory that contains the SavedModel. export_dir: The directory that contains the SavedModel.
@ -82,16 +95,11 @@ def read_fingerprint(export_dir):
A `tf.saved_model.experimental.Fingerprint`. A `tf.saved_model.experimental.Fingerprint`.
Raises: Raises:
ValueError: If no or an invalid fingerprint is found. FingerprintException: If no or an invalid fingerprint is found.
""" """
fingerprint_map = fingerprinting_pywrap.GetFingerprintMap(export_dir) try:
if not fingerprint_map: fingerprint = fingerprinting_pywrap.ReadSavedModelFingerprint(export_dir)
raise ValueError(f"No or invalid fingerprint found in: {export_dir}.") except fingerprinting_pywrap.FingerprintException as e:
return Fingerprint( raise FileNotFoundError(f"SavedModel Fingerprint Error: {e}") from None # pylint: disable=raise-missing-from
fingerprint_map["saved_model_checksum"], return Fingerprint.from_proto(
fingerprint_map["graph_def_program_hash"], fingerprint_pb2.FingerprintDef().FromString(fingerprint))
fingerprint_map["signature_def_hash"],
fingerprint_map["saved_object_graph_hash"],
fingerprint_map["checkpoint_hash"],
fingerprint_map["version"],
)

View File

@ -142,10 +142,13 @@ class FingerprintingTest(test.TestCase):
self.assertEqual( self.assertEqual(
fingerprint.checkpoint_hash, fingerprint_def.checkpoint_hash fingerprint.checkpoint_hash, fingerprint_def.checkpoint_hash
) )
self.assertEqual(fingerprint.version, fingerprint_def.version.producer) self.assertEqual(
fingerprint.version.producer, fingerprint_def.version.producer
)
def test_read_fingerprint_api_invalid(self): def test_read_fingerprint_api_invalid(self):
with self.assertRaisesRegex(ValueError, "No or invalid fingerprint"): with self.assertRaisesRegex(FileNotFoundError,
"SavedModel Fingerprint Error"):
read_fingerprint("foo") read_fingerprint("foo")

View File

@ -19,6 +19,8 @@ import functools
import os import os
import sys import sys
from absl import logging
from tensorflow.core.function.capture import restore_captures from tensorflow.core.function.capture import restore_captures
from tensorflow.core.protobuf import graph_debug_info_pb2 from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.python.checkpoint import checkpoint from tensorflow.python.checkpoint import checkpoint
@ -42,6 +44,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.saved_model import fingerprinting
from tensorflow.python.saved_model import function_deserialization from tensorflow.python.saved_model import function_deserialization
from tensorflow.python.saved_model import load_options from tensorflow.python.saved_model import load_options
from tensorflow.python.saved_model import load_v1_in_v2 from tensorflow.python.saved_model import load_v1_in_v2
@ -50,7 +53,6 @@ from tensorflow.python.saved_model import path_helpers
from tensorflow.python.saved_model import registration from tensorflow.python.saved_model import registration
from tensorflow.python.saved_model import revived_types from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting
from tensorflow.python.saved_model.pywrap_saved_model import metrics from tensorflow.python.saved_model.pywrap_saved_model import metrics
from tensorflow.python.trackable import asset from tensorflow.python.trackable import asset
from tensorflow.python.trackable import autotrackable from tensorflow.python.trackable import autotrackable
@ -991,9 +993,14 @@ def load_partial(export_dir, filters, tags=None, options=None):
metrics.SetReadPath(saved_model_path=str(export_dir)) metrics.SetReadPath(saved_model_path=str(export_dir))
# Read and log SavedModel checksum, if it is nonzero. # Read and log SavedModel checksum, if it is nonzero.
saved_model_checksum = fingerprinting.MaybeReadSavedModelChecksum(export_dir) try:
if saved_model_checksum != 0: fingerprint = fingerprinting.read_fingerprint(export_dir)
metrics.SetReadFingerprint(saved_model_checksum=str(saved_model_checksum)) if fingerprint.saved_model_checksum != 0:
metrics.SetReadFingerprint(
saved_model_checksum=str(fingerprint.saved_model_checksum))
except FileNotFoundError:
logging.error("Unable to load fingerprint when loading saved model.",
exc_info=True)
if filters: if filters:
return {node_id: loader.get(node_id) for node_id in filters} return {node_id: loader.get(node_id) for node_id in filters}

View File

@ -25,11 +25,11 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.saved_model import builder from tensorflow.python.saved_model import builder
from tensorflow.python.saved_model import builder_impl from tensorflow.python.saved_model import builder_impl
from tensorflow.python.saved_model import fingerprinting
from tensorflow.python.saved_model import load from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import load_v1_in_v2 from tensorflow.python.saved_model import load_v1_in_v2
from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import save from tensorflow.python.saved_model import save
from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting
from tensorflow.python.saved_model.pywrap_saved_model import metrics from tensorflow.python.saved_model.pywrap_saved_model import metrics
from tensorflow.python.trackable import autotrackable from tensorflow.python.trackable import autotrackable
@ -114,18 +114,20 @@ class MetricsTests(test.TestCase):
def test_save_sets_write_fingerprint_metric(self): def test_save_sets_write_fingerprint_metric(self):
exported_dir = self._create_save_v2_model() exported_dir = self._create_save_v2_model()
fingerprint = fingerprinting.read_fingerprint(exported_dir)
self.assertEqual( self.assertEqual(
metrics.GetWriteFingerprint(), metrics.GetWriteFingerprint(),
str(fingerprinting.MaybeReadSavedModelChecksum(exported_dir))) str(fingerprint.saved_model_checksum))
def test_load_sets_read_fingerprint_metric(self): def test_load_sets_read_fingerprint_metric(self):
exported_dir = self._create_save_v2_model() exported_dir = self._create_save_v2_model()
load.load(exported_dir) load.load(exported_dir)
fingerprint = fingerprinting.read_fingerprint(exported_dir)
self.assertEqual( self.assertEqual(
metrics.GetWriteFingerprint(), metrics.GetWriteFingerprint(),
str(fingerprinting.MaybeReadSavedModelChecksum(exported_dir))) str(fingerprint.saved_model_checksum))
def test_save_sets_write_path_metric(self): def test_save_sets_write_path_metric(self):
exported_dir = self._create_save_v2_model() exported_dir = self._create_save_v2_model()

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <exception>
#include <string> #include <string>
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
@ -27,11 +28,31 @@ namespace python {
namespace py = pybind11; namespace py = pybind11;
class FingerprintException : public std::exception {
public:
explicit FingerprintException(const char *m) : message_{m} {}
const char *what() const noexcept override { return message_.c_str(); }
private:
std::string message_ = "";
};
void DefineFingerprintingModule(py::module main_module) { void DefineFingerprintingModule(py::module main_module) {
auto m = main_module.def_submodule("fingerprinting"); auto m = main_module.def_submodule("fingerprinting");
m.doc() = "Python bindings for TensorFlow SavedModel Fingerprinting."; m.doc() = "Python bindings for TensorFlow SavedModel Fingerprinting.";
static py::exception<FingerprintException> ex(m, "FingerprintException");
py::register_exception_translator([](std::exception_ptr p) {
try {
if (p) {
std::rethrow_exception(p);
}
} catch (const FingerprintException &e) {
ex(e.what());
}
});
m.def( m.def(
"CreateFingerprintDef", "CreateFingerprintDef",
[](std::string serialized_saved_model, std::string export_dir) { [](std::string serialized_saved_model, std::string export_dir) {
@ -39,42 +60,37 @@ void DefineFingerprintingModule(py::module main_module) {
SavedModel saved_model_pb; SavedModel saved_model_pb;
saved_model_pb.ParseFromString(serialized_saved_model); saved_model_pb.ParseFromString(serialized_saved_model);
return py::bytes( StatusOr<FingerprintDef> fingerprint =
fingerprinting::CreateFingerprintDef(saved_model_pb, export_dir) fingerprinting::CreateFingerprintDef(saved_model_pb, export_dir);
.SerializeAsString()); if (fingerprint.ok()) {
return py::bytes(fingerprint.value().SerializeAsString());
}
throw FingerprintException(
std::string("Could not create fingerprint in directory: " +
export_dir)
.c_str());
}, },
py::arg("saved_model"), py::arg("export_dir"), py::arg("saved_model"), py::arg("export_dir"),
py::doc( py::doc(
"Returns the serialized FingerprintDef of a serialized SavedModel.")); "Returns the serialized FingerprintDef of a serialized SavedModel."));
m.def( m.def(
"MaybeReadSavedModelChecksum", "ReadSavedModelFingerprint",
[](std::string export_dir) { [](std::string export_dir) {
StatusOr<FingerprintDef> fingerprint = StatusOr<FingerprintDef> fingerprint =
fingerprinting::ReadSavedModelFingerprint(export_dir); fingerprinting::ReadSavedModelFingerprint(export_dir);
if (fingerprint.ok()) { if (fingerprint.ok()) {
return fingerprint->saved_model_checksum(); return py::bytes(fingerprint.value().SerializeAsString());
} }
return (uint64_t)0; throw FingerprintException(
std::string("Could not read fingerprint from directory: " +
export_dir)
.c_str());
}, },
py::arg("export_dir"), py::arg("export_dir"),
py::doc( py::doc(
"Reads the fingerprint checksum from SavedModel directory. Returns " "Loads the `fingerprint.pb` from `export_dir`, returns an error if "
"0 if an error occurs.")); "there is none."));
m.def(
"GetFingerprintMap",
[](std::string export_dir) {
StatusOr<FingerprintDef> fingerprint =
fingerprinting::ReadSavedModelFingerprint(export_dir);
if (fingerprint.ok()) {
return fingerprinting::MakeFingerprintMap(*fingerprint);
}
return std::unordered_map<std::string, uint64_t>();
},
py::arg("export_dir"),
py::doc("Returns the fingerprint protobuf as a dictionary. Returns "
"an empty dictionary if invalid fingerprint file."));
} }
} // namespace python } // namespace python

View File

@ -19,89 +19,46 @@ import os
from tensorflow.core.protobuf import fingerprint_pb2 from tensorflow.core.protobuf import fingerprint_pb2
from tensorflow.python.lib.io import file_io from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting as pywrap_fingerprinting
class FingerprintingTest(test.TestCase): class FingerprintingTest(test.TestCase):
def test_create_fingerprint_def(self):
# Checks that the fingerprint values are preserved when passed from C++ to
# Python.
def test_fingerprint_def_is_deserialized_correctly(self):
export_dir = test.test_src_dir_path( export_dir = test.test_src_dir_path(
"cc/saved_model/testdata/VarsAndArithmeticObjectGraph") "cc/saved_model/testdata/VarsAndArithmeticObjectGraph")
with file_io.FileIO(os.path.join(export_dir, "saved_model.pb"), "rb") as f: with file_io.FileIO(os.path.join(export_dir, "saved_model.pb"), "rb") as f:
file_content = f.read() file_content = f.read()
fingerprint_def = fingerprint_pb2.FingerprintDef() fingerprint = fingerprint_pb2.FingerprintDef().FromString(
fingerprint_def.ParseFromString( pywrap_fingerprinting.CreateFingerprintDef(file_content, export_dir))
fingerprinting.CreateFingerprintDef(file_content, export_dir))
# We cannot check the value of the saved_model_checksum due to # We cannot check the value of the saved_model_checksum due to
# non-determinism in serialization. # non-determinism in serialization.
self.assertGreater(fingerprint_def.saved_model_checksum, 0) self.assertGreater(fingerprint.saved_model_checksum, 0)
self.assertEqual(fingerprint_def.graph_def_program_hash, self.assertEqual(fingerprint.graph_def_program_hash, 10127142238652115842)
10127142238652115842) self.assertEqual(fingerprint.signature_def_hash, 5693392539583495303)
self.assertEqual(fingerprint_def.signature_def_hash, 5693392539583495303) self.assertEqual(fingerprint.saved_object_graph_hash, 3678101440349108924)
self.assertEqual(fingerprint_def.saved_object_graph_hash,
3678101440349108924)
# TODO(b/242348400): The checkpoint hash is non-deterministic, so we cannot # TODO(b/242348400): The checkpoint hash is non-deterministic, so we cannot
# check its value here. # check its value here.
self.assertGreater(fingerprint_def.checkpoint_hash, 0) self.assertGreater(fingerprint.checkpoint_hash, 0)
def test_read_fingerprint_from_file(self): def test_read_saved_model_fingerprint(self):
export_dir = test.test_src_dir_path( export_dir = test.test_src_dir_path(
"cc/saved_model/testdata/VarsAndArithmeticObjectGraph") "cc/saved_model/testdata/VarsAndArithmeticObjectGraph")
self.assertEqual( fingerprint = fingerprint_pb2.FingerprintDef().FromString(
fingerprinting.MaybeReadSavedModelChecksum(export_dir), pywrap_fingerprinting.ReadSavedModelFingerprint(export_dir))
15788619162413586750) self.assertGreater(fingerprint.saved_model_checksum, 0)
self.assertEqual(fingerprint.graph_def_program_hash, 706963557435316516)
self.assertEqual(fingerprint.signature_def_hash, 5693392539583495303)
self.assertEqual(fingerprint.saved_object_graph_hash, 12074714563970609759)
self.assertGreater(fingerprint.checkpoint_hash, 0)
self.assertEqual(fingerprint.version.producer, 1)
def test_read_nonexistent_fingerprint_from_file(self): def test_read_nonexistent_fingerprint(self):
export_dir = test.test_src_dir_path("cc/saved_model/testdata/AssetModule") export_dir = test.test_src_dir_path("cc/saved_model/testdata/AssetModule")
self.assertEqual(fingerprinting.MaybeReadSavedModelChecksum(export_dir), 0) with self.assertRaises(
pywrap_fingerprinting.FingerprintException) as excinfo:
def test_get_fingerprint_map_valid(self): pywrap_fingerprinting.ReadSavedModelFingerprint(export_dir)
export_dir = test.test_src_dir_path( self.assertRegex(str(excinfo.exception), "Could not read fingerprint.")
"cc/saved_model/testdata/VarsAndArithmeticObjectGraph"
)
fingerprint_map = fingerprinting.GetFingerprintMap(export_dir)
fingerprint_def = fingerprint_pb2.FingerprintDef()
with file_io.FileIO(os.path.join(export_dir, "fingerprint.pb"), "rb") as f:
fingerprint_def.ParseFromString(f.read())
self.assertEqual(
fingerprint_map["saved_model_checksum"],
fingerprint_def.saved_model_checksum,
)
self.assertEqual(
fingerprint_map["graph_def_program_hash"],
fingerprint_def.graph_def_program_hash,
)
self.assertEqual(
fingerprint_map["signature_def_hash"],
fingerprint_def.signature_def_hash,
)
self.assertEqual(
fingerprint_map["saved_object_graph_hash"],
fingerprint_def.saved_object_graph_hash,
)
self.assertEqual(
fingerprint_map["checkpoint_hash"], fingerprint_def.checkpoint_hash
)
self.assertEqual(
fingerprint_map["version"], fingerprint_def.version.producer
)
def test_get_fingerprint_map_nonexistent(self):
export_dir = test.test_src_dir_path("cc/saved_model/testdata/AssetModule")
fingerprint_map = fingerprinting.GetFingerprintMap(export_dir)
self.assertEmpty(fingerprint_map)
def test_get_fingerprint_map_invalid_saved_model(self):
export_dir = test.test_src_dir_path("not_a_saved_model")
fingerprint_map = fingerprinting.GetFingerprintMap(export_dir)
self.assertEmpty(fingerprint_map)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -6,4 +6,8 @@ tf_class {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'saved_model_checksum\', \'graph_def_program_hash\', \'signature_def_hash\', \'saved_object_graph_hash\', \'checkpoint_hash\', \'version\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'saved_model_checksum\', \'graph_def_program_hash\', \'signature_def_hash\', \'saved_object_graph_hash\', \'checkpoint_hash\', \'version\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
} }
member_method {
name: "from_proto"
argspec: "args=[\'cls\', \'proto\'], varargs=None, keywords=None, defaults=None"
}
} }

View File

@ -459,7 +459,6 @@ tensorflow::metrics::CheckpointSize
[//tensorflow/cc/saved_model:fingerprinting_impl] # SavedModel Fingerprinting [//tensorflow/cc/saved_model:fingerprinting_impl] # SavedModel Fingerprinting
tensorflow::saved_model::fingerprinting::CreateFingerprintDef tensorflow::saved_model::fingerprinting::CreateFingerprintDef
tensorflow::saved_model::fingerprinting::ReadSavedModelFingerprint tensorflow::saved_model::fingerprinting::ReadSavedModelFingerprint
tensorflow::saved_model::fingerprinting::MakeFingerprintMap
[//tensorflow/compiler/jit:flags] # tfe [//tensorflow/compiler/jit:flags] # tfe