From b01db583ac0da99e343863f30c0fcbd53c516eb4 Mon Sep 17 00:00:00 2001 From: Adam Cogdell Date: Thu, 9 Mar 2023 13:54:28 -0800 Subject: [PATCH] Replace GetFingerprintMap() and MaybeReadSavedModelChecksum() with ReadSavedModelFingerprint(). PiperOrigin-RevId: 515432111 --- RELEASE.md | 6 ++ tensorflow/cc/saved_model/fingerprinting.cc | 36 ++------ tensorflow/cc/saved_model/fingerprinting.h | 8 +- .../cc/saved_model/fingerprinting_test.cc | 39 +++----- tensorflow/python/saved_model/BUILD | 4 + .../python/saved_model/fingerprinting.py | 34 ++++--- .../python/saved_model/fingerprinting_test.py | 7 +- tensorflow/python/saved_model/load.py | 15 +++- tensorflow/python/saved_model/metrics_test.py | 8 +- .../pywrap_saved_model_fingerprinting.cc | 60 ++++++++----- .../pywrap_saved_model_fingerprinting_test.py | 89 +++++-------------- ...aved_model.experimental.-fingerprint.pbtxt | 4 + .../tools/def_file_filter/symbols_pybind.txt | 1 - 13 files changed, 141 insertions(+), 170 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 379bcba08bc..4ff7a5e3a4d 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -90,6 +90,12 @@ * `tf.data.Dataset.zip` now supports Python-style zipping, i.e. `Dataset.zip(a, b, c)`. +* `tf.SavedModel` + + * Introduce class method + `tf.saved_model.experimental.Fingerprint.from_proto(proto)`, which can + be used to construct a `Fingerprint` object directly from a protobuf. + ## Bug Fixes and Other Changes * diff --git a/tensorflow/cc/saved_model/fingerprinting.cc b/tensorflow/cc/saved_model/fingerprinting.cc index 7d2893f6199..d7f6aeabaee 100644 --- a/tensorflow/cc/saved_model/fingerprinting.cc +++ b/tensorflow/cc/saved_model/fingerprinting.cc @@ -38,7 +38,6 @@ limitations under the License. #include "tensorflow/core/protobuf/saved_model.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/util/tensor_bundle/naming.h" -#include "tensorflow/tsl/lib/strings/proto_serialization.h" namespace tensorflow::saved_model::fingerprinting { @@ -71,7 +70,7 @@ uint64 RegularizeAndHashSignatureDefs( // The SavedObjectGraph contains two parts: the list of nodes and the map of // concrete functions. Regularization treats these two parts separately. -uint64 RegularizeAndHashSavedObjectGraph( +StatusOr RegularizeAndHashSavedObjectGraph( const SavedObjectGraph& object_graph_def) { // Sort `concrete_functions`, which is an unordered map from function names to // SavedConcreteFunction, using the suffix UID of the function name. Assumes @@ -80,13 +79,9 @@ uint64 RegularizeAndHashSavedObjectGraph( absl::btree_map uid_to_function_names; for (const auto& [name, concrete_function] : object_graph_def.concrete_functions()) { - StatusOr uid = graph_regularization::GetSuffixUID(name); // All valid function names should end in an UID. - if (uid.ok()) { - uid_to_function_names.insert({*uid, name}); - } else { - LOG(ERROR) << uid.status().error_message(); - } + TF_ASSIGN_OR_RETURN(int uid, graph_regularization::GetSuffixUID(name)); + uid_to_function_names.insert({uid, name}); } uint64 result_hash = 0; for (const auto& [uid, function_name] : uid_to_function_names) { @@ -115,15 +110,14 @@ uint64 HashCheckpointIndexFile(absl::string_view model_dir) { if (read_status.ok()) { return tensorflow::Fingerprint64(data); } else { - LOG(WARNING) << read_status.error_message(); return 0; } } } // namespace -FingerprintDef CreateFingerprintDef(const SavedModel& saved_model, - absl::string_view export_dir) { +StatusOr CreateFingerprintDef(const SavedModel& saved_model, + absl::string_view export_dir) { // Create a copy of `metagraph` which will be used and mutated for fingerprint // computation. MetaGraphDef metagraph_copy = saved_model.meta_graphs(0); @@ -138,10 +132,10 @@ FingerprintDef CreateFingerprintDef(const SavedModel& saved_model, fingerprint_def.set_signature_def_hash( RegularizeAndHashSignatureDefs(metagraph_copy.signature_def())); // Set fingerprint field #4. - StatusOr object_graph_hash = - RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def()); - fingerprint_def.set_saved_object_graph_hash( + TF_ASSIGN_OR_RETURN( + StatusOr object_graph_hash, RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def())); + fingerprint_def.set_saved_object_graph_hash(object_graph_hash.value()); // Set fingerprint field #5. fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir)); // Set version of the fingerprint. @@ -168,18 +162,4 @@ StatusOr ReadSavedModelFingerprint( return found_pb; } -std::unordered_map MakeFingerprintMap( - const FingerprintDef& fingerprint) { - std::unordered_map fingerprint_map; - fingerprint_map["saved_model_checksum"] = fingerprint.saved_model_checksum(); - fingerprint_map["graph_def_program_hash"] = - fingerprint.graph_def_program_hash(); - fingerprint_map["signature_def_hash"] = fingerprint.signature_def_hash(); - fingerprint_map["saved_object_graph_hash"] = - fingerprint.saved_object_graph_hash(); - fingerprint_map["checkpoint_hash"] = fingerprint.checkpoint_hash(); - fingerprint_map["version"] = fingerprint.version().producer(); - return fingerprint_map; -} - } // namespace tensorflow::saved_model::fingerprinting diff --git a/tensorflow/cc/saved_model/fingerprinting.h b/tensorflow/cc/saved_model/fingerprinting.h index 15790ed61e9..da0dd6b4a01 100644 --- a/tensorflow/cc/saved_model/fingerprinting.h +++ b/tensorflow/cc/saved_model/fingerprinting.h @@ -28,18 +28,14 @@ namespace tensorflow::saved_model::fingerprinting { // Creates a FingerprintDef proto from a SavedModel and the checkpoint meta file // (.index) in `export_dir`. -FingerprintDef CreateFingerprintDef(const SavedModel& saved_model, - absl::string_view export_dir); +StatusOr CreateFingerprintDef(const SavedModel& saved_model, + absl::string_view export_dir); // Loads the `fingerprint.pb` from `export_dir`, returns an error if there is // none. StatusOr ReadSavedModelFingerprint( absl::string_view export_dir); -// Converts the fingerprint into a dictionary mapping field names to values. -std::unordered_map MakeFingerprintMap( - const FingerprintDef& fingerprint); - } // namespace tensorflow::saved_model::fingerprinting #endif // TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_H_ diff --git a/tensorflow/cc/saved_model/fingerprinting_test.cc b/tensorflow/cc/saved_model/fingerprinting_test.cc index 0db31bbf17a..6453cac143a 100644 --- a/tensorflow/cc/saved_model/fingerprinting_test.cc +++ b/tensorflow/cc/saved_model/fingerprinting_test.cc @@ -52,8 +52,8 @@ TEST(FingerprintingTest, TestCreateFingerprint) { "VarsAndArithmeticObjectGraph"); TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, ReadSavedModel(export_dir)); - FingerprintDef fingerprint_def = - CreateFingerprintDef(saved_model_pb, export_dir); + TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def, + CreateFingerprintDef(saved_model_pb, export_dir)); EXPECT_GT(fingerprint_def.saved_model_checksum(), 0); EXPECT_EQ(fingerprint_def.graph_def_program_hash(), 10127142238652115842U); @@ -72,15 +72,15 @@ TEST(FingerprintingTest, TestCompareFingerprintForTwoModelSavedTwice) { TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, ReadSavedModel(export_dir)); - FingerprintDef fingerprint_def = - CreateFingerprintDef(saved_model_pb, export_dir); + TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def, + CreateFingerprintDef(saved_model_pb, export_dir)); const std::string export_dir2 = io::JoinPath( testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert2"); TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb2, ReadSavedModel(export_dir2)); - FingerprintDef fingerprint_def2 = - CreateFingerprintDef(saved_model_pb2, export_dir2); + TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def2, + CreateFingerprintDef(saved_model_pb2, export_dir2)); EXPECT_EQ(fingerprint_def.graph_def_program_hash(), fingerprint_def2.graph_def_program_hash()); @@ -95,10 +95,10 @@ TEST(FingerprintingTest, TestFingerprintComputationDoesNotMutateModel) { testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert1"); TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, ReadSavedModel(export_dir)); - FingerprintDef fingerprint_def = - CreateFingerprintDef(saved_model_pb, export_dir); - FingerprintDef fingerprint_def2 = - CreateFingerprintDef(saved_model_pb, export_dir); + TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def, + CreateFingerprintDef(saved_model_pb, export_dir)); + TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def2, + CreateFingerprintDef(saved_model_pb, export_dir)); EXPECT_EQ(fingerprint_def.saved_model_checksum(), fingerprint_def2.saved_model_checksum()); @@ -109,8 +109,8 @@ TEST(FingerprintingTest, TestFingerprintHasVersion) { testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert1"); TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, ReadSavedModel(export_dir)); - FingerprintDef fingerprint_def = - CreateFingerprintDef(saved_model_pb, export_dir); + TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def, + CreateFingerprintDef(saved_model_pb, export_dir)); EXPECT_EQ(fingerprint_def.version().producer(), 1); } @@ -119,8 +119,8 @@ TEST(FingerprintingTest, TestHashCheckpointForModelWithNoVariables) { testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert1"); TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, ReadSavedModel(export_dir)); - FingerprintDef fingerprint_def = - CreateFingerprintDef(saved_model_pb, export_dir); + TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def, + CreateFingerprintDef(saved_model_pb, export_dir)); EXPECT_EQ(fingerprint_def.checkpoint_hash(), 0); } @@ -139,16 +139,5 @@ TEST(FingerprintingTest, TestReadNonexistentFingerprint) { EXPECT_FALSE(ReadSavedModelFingerprint(export_dir).ok()); } -TEST(FingerprintingTest, TestMakeFingerprintMap) { - const std::string export_dir = - io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", - "VarsAndArithmeticObjectGraph"); - TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_pb, - ReadSavedModelFingerprint(export_dir)); - auto fingerprint_map = MakeFingerprintMap(fingerprint_pb); - EXPECT_EQ(fingerprint_pb.saved_model_checksum(), - fingerprint_map["saved_model_checksum"]); -} - } // namespace } // namespace tensorflow::saved_model::fingerprinting diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 9c17f37bd85..1868630dd62 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -426,6 +426,7 @@ py_library( "ignore_for_dep=third_party.py.keras.optimizers.optimizer_v2", ], deps = [ + ":fingerprinting", ":function_deserialization", ":load_options", ":load_v1_in_v2", @@ -466,6 +467,7 @@ py_library( "//tensorflow/python/trackable:trackable_utils", "//tensorflow/python/training/saving:saveable_object_util", "//tensorflow/python/util:tf_export", + "@absl_py//absl/logging", ], ) @@ -704,6 +706,7 @@ tf_py_test( name = "metrics_test", srcs = ["metrics_test.py"], deps = [ + ":fingerprinting", ":pywrap_saved_model", "//tensorflow/python/eager:test", ], @@ -784,6 +787,7 @@ py_strict_library( srcs = ["fingerprinting.py"], deps = [ ":pywrap_saved_model", + "//tensorflow/core:protos_all_py", "//tensorflow/python/util:tf_export", ], ) diff --git a/tensorflow/python/saved_model/fingerprinting.py b/tensorflow/python/saved_model/fingerprinting.py index a727d6c523b..7a435d6a602 100644 --- a/tensorflow/python/saved_model/fingerprinting.py +++ b/tensorflow/python/saved_model/fingerprinting.py @@ -18,6 +18,7 @@ This module contains classes and functions for reading the SavedModel fingerprint. """ +from tensorflow.core.protobuf import fingerprint_pb2 from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting as fingerprinting_pywrap from tensorflow.python.util.tf_export import tf_export @@ -65,6 +66,16 @@ class Fingerprint(object): self.checkpoint_hash = checkpoint_hash self.version = version + @classmethod + def from_proto(cls, proto): + return Fingerprint( + proto.saved_model_checksum, + proto.graph_def_program_hash, + proto.signature_def_hash, + proto.saved_object_graph_hash, + proto.checkpoint_hash, + proto.version) + @tf_export("saved_model.experimental.read_fingerprint", v1=[]) def read_fingerprint(export_dir): @@ -73,7 +84,9 @@ def read_fingerprint(export_dir): Returns a `tf.saved_model.experimental.Fingerprint` object that contains the values of the SavedModel fingerprint, which is persisted on disk in the `fingerprint.pb` file in the `export_dir`. - TODO(b/265199038): Add link to TensorFlow SavedModel guide. + + Read more about fingerprints in the SavedModel guide at + https://www.tensorflow.org/guide/saved_model. Args: export_dir: The directory that contains the SavedModel. @@ -82,16 +95,11 @@ def read_fingerprint(export_dir): A `tf.saved_model.experimental.Fingerprint`. Raises: - ValueError: If no or an invalid fingerprint is found. + FingerprintException: If no or an invalid fingerprint is found. """ - fingerprint_map = fingerprinting_pywrap.GetFingerprintMap(export_dir) - if not fingerprint_map: - raise ValueError(f"No or invalid fingerprint found in: {export_dir}.") - return Fingerprint( - fingerprint_map["saved_model_checksum"], - fingerprint_map["graph_def_program_hash"], - fingerprint_map["signature_def_hash"], - fingerprint_map["saved_object_graph_hash"], - fingerprint_map["checkpoint_hash"], - fingerprint_map["version"], - ) + try: + fingerprint = fingerprinting_pywrap.ReadSavedModelFingerprint(export_dir) + except fingerprinting_pywrap.FingerprintException as e: + raise FileNotFoundError(f"SavedModel Fingerprint Error: {e}") from None # pylint: disable=raise-missing-from + return Fingerprint.from_proto( + fingerprint_pb2.FingerprintDef().FromString(fingerprint)) diff --git a/tensorflow/python/saved_model/fingerprinting_test.py b/tensorflow/python/saved_model/fingerprinting_test.py index 7f98cca3430..121fd7cdd8b 100644 --- a/tensorflow/python/saved_model/fingerprinting_test.py +++ b/tensorflow/python/saved_model/fingerprinting_test.py @@ -142,10 +142,13 @@ class FingerprintingTest(test.TestCase): self.assertEqual( fingerprint.checkpoint_hash, fingerprint_def.checkpoint_hash ) - self.assertEqual(fingerprint.version, fingerprint_def.version.producer) + self.assertEqual( + fingerprint.version.producer, fingerprint_def.version.producer + ) def test_read_fingerprint_api_invalid(self): - with self.assertRaisesRegex(ValueError, "No or invalid fingerprint"): + with self.assertRaisesRegex(FileNotFoundError, + "SavedModel Fingerprint Error"): read_fingerprint("foo") diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index 612e8258971..90203c0a972 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -19,6 +19,8 @@ import functools import os import sys +from absl import logging + from tensorflow.core.function.capture import restore_captures from tensorflow.core.protobuf import graph_debug_info_pb2 from tensorflow.python.checkpoint import checkpoint @@ -42,6 +44,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables +from tensorflow.python.saved_model import fingerprinting from tensorflow.python.saved_model import function_deserialization from tensorflow.python.saved_model import load_options from tensorflow.python.saved_model import load_v1_in_v2 @@ -50,7 +53,6 @@ from tensorflow.python.saved_model import path_helpers from tensorflow.python.saved_model import registration from tensorflow.python.saved_model import revived_types from tensorflow.python.saved_model import utils_impl as saved_model_utils -from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting from tensorflow.python.saved_model.pywrap_saved_model import metrics from tensorflow.python.trackable import asset from tensorflow.python.trackable import autotrackable @@ -991,9 +993,14 @@ def load_partial(export_dir, filters, tags=None, options=None): metrics.SetReadPath(saved_model_path=str(export_dir)) # Read and log SavedModel checksum, if it is nonzero. - saved_model_checksum = fingerprinting.MaybeReadSavedModelChecksum(export_dir) - if saved_model_checksum != 0: - metrics.SetReadFingerprint(saved_model_checksum=str(saved_model_checksum)) + try: + fingerprint = fingerprinting.read_fingerprint(export_dir) + if fingerprint.saved_model_checksum != 0: + metrics.SetReadFingerprint( + saved_model_checksum=str(fingerprint.saved_model_checksum)) + except FileNotFoundError: + logging.error("Unable to load fingerprint when loading saved model.", + exc_info=True) if filters: return {node_id: loader.get(node_id) for node_id in filters} diff --git a/tensorflow/python/saved_model/metrics_test.py b/tensorflow/python/saved_model/metrics_test.py index 8f9b50938f3..5cb53b4c3a6 100644 --- a/tensorflow/python/saved_model/metrics_test.py +++ b/tensorflow/python/saved_model/metrics_test.py @@ -25,11 +25,11 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.saved_model import builder from tensorflow.python.saved_model import builder_impl +from tensorflow.python.saved_model import fingerprinting from tensorflow.python.saved_model import load from tensorflow.python.saved_model import load_v1_in_v2 from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import save -from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting from tensorflow.python.saved_model.pywrap_saved_model import metrics from tensorflow.python.trackable import autotrackable @@ -114,18 +114,20 @@ class MetricsTests(test.TestCase): def test_save_sets_write_fingerprint_metric(self): exported_dir = self._create_save_v2_model() + fingerprint = fingerprinting.read_fingerprint(exported_dir) self.assertEqual( metrics.GetWriteFingerprint(), - str(fingerprinting.MaybeReadSavedModelChecksum(exported_dir))) + str(fingerprint.saved_model_checksum)) def test_load_sets_read_fingerprint_metric(self): exported_dir = self._create_save_v2_model() load.load(exported_dir) + fingerprint = fingerprinting.read_fingerprint(exported_dir) self.assertEqual( metrics.GetWriteFingerprint(), - str(fingerprinting.MaybeReadSavedModelChecksum(exported_dir))) + str(fingerprint.saved_model_checksum)) def test_save_sets_write_path_metric(self): exported_dir = self._create_save_v2_model() diff --git a/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting.cc b/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting.cc index f925fc2392d..e412729a7c9 100644 --- a/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting.cc +++ b/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "absl/strings/string_view.h" @@ -27,11 +28,31 @@ namespace python { namespace py = pybind11; +class FingerprintException : public std::exception { + public: + explicit FingerprintException(const char *m) : message_{m} {} + const char *what() const noexcept override { return message_.c_str(); } + + private: + std::string message_ = ""; +}; + void DefineFingerprintingModule(py::module main_module) { auto m = main_module.def_submodule("fingerprinting"); m.doc() = "Python bindings for TensorFlow SavedModel Fingerprinting."; + static py::exception ex(m, "FingerprintException"); + py::register_exception_translator([](std::exception_ptr p) { + try { + if (p) { + std::rethrow_exception(p); + } + } catch (const FingerprintException &e) { + ex(e.what()); + } + }); + m.def( "CreateFingerprintDef", [](std::string serialized_saved_model, std::string export_dir) { @@ -39,42 +60,37 @@ void DefineFingerprintingModule(py::module main_module) { SavedModel saved_model_pb; saved_model_pb.ParseFromString(serialized_saved_model); - return py::bytes( - fingerprinting::CreateFingerprintDef(saved_model_pb, export_dir) - .SerializeAsString()); + StatusOr fingerprint = + fingerprinting::CreateFingerprintDef(saved_model_pb, export_dir); + if (fingerprint.ok()) { + return py::bytes(fingerprint.value().SerializeAsString()); + } + throw FingerprintException( + std::string("Could not create fingerprint in directory: " + + export_dir) + .c_str()); }, py::arg("saved_model"), py::arg("export_dir"), py::doc( "Returns the serialized FingerprintDef of a serialized SavedModel.")); m.def( - "MaybeReadSavedModelChecksum", + "ReadSavedModelFingerprint", [](std::string export_dir) { StatusOr fingerprint = fingerprinting::ReadSavedModelFingerprint(export_dir); if (fingerprint.ok()) { - return fingerprint->saved_model_checksum(); + return py::bytes(fingerprint.value().SerializeAsString()); } - return (uint64_t)0; + throw FingerprintException( + std::string("Could not read fingerprint from directory: " + + export_dir) + .c_str()); }, py::arg("export_dir"), py::doc( - "Reads the fingerprint checksum from SavedModel directory. Returns " - "0 if an error occurs.")); - - m.def( - "GetFingerprintMap", - [](std::string export_dir) { - StatusOr fingerprint = - fingerprinting::ReadSavedModelFingerprint(export_dir); - if (fingerprint.ok()) { - return fingerprinting::MakeFingerprintMap(*fingerprint); - } - return std::unordered_map(); - }, - py::arg("export_dir"), - py::doc("Returns the fingerprint protobuf as a dictionary. Returns " - "an empty dictionary if invalid fingerprint file.")); + "Loads the `fingerprint.pb` from `export_dir`, returns an error if " + "there is none.")); } } // namespace python diff --git a/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py b/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py index 289787b0648..41b6a6ed287 100644 --- a/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py +++ b/tensorflow/python/saved_model/pywrap_saved_model_fingerprinting_test.py @@ -19,89 +19,46 @@ import os from tensorflow.core.protobuf import fingerprint_pb2 from tensorflow.python.lib.io import file_io from tensorflow.python.platform import test -from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting +from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting as pywrap_fingerprinting class FingerprintingTest(test.TestCase): - - # Checks that the fingerprint values are preserved when passed from C++ to - # Python. - def test_fingerprint_def_is_deserialized_correctly(self): + def test_create_fingerprint_def(self): export_dir = test.test_src_dir_path( "cc/saved_model/testdata/VarsAndArithmeticObjectGraph") with file_io.FileIO(os.path.join(export_dir, "saved_model.pb"), "rb") as f: file_content = f.read() - fingerprint_def = fingerprint_pb2.FingerprintDef() - fingerprint_def.ParseFromString( - fingerprinting.CreateFingerprintDef(file_content, export_dir)) + fingerprint = fingerprint_pb2.FingerprintDef().FromString( + pywrap_fingerprinting.CreateFingerprintDef(file_content, export_dir)) # We cannot check the value of the saved_model_checksum due to # non-determinism in serialization. - self.assertGreater(fingerprint_def.saved_model_checksum, 0) - self.assertEqual(fingerprint_def.graph_def_program_hash, - 10127142238652115842) - self.assertEqual(fingerprint_def.signature_def_hash, 5693392539583495303) - self.assertEqual(fingerprint_def.saved_object_graph_hash, - 3678101440349108924) + self.assertGreater(fingerprint.saved_model_checksum, 0) + self.assertEqual(fingerprint.graph_def_program_hash, 10127142238652115842) + self.assertEqual(fingerprint.signature_def_hash, 5693392539583495303) + self.assertEqual(fingerprint.saved_object_graph_hash, 3678101440349108924) # TODO(b/242348400): The checkpoint hash is non-deterministic, so we cannot # check its value here. - self.assertGreater(fingerprint_def.checkpoint_hash, 0) + self.assertGreater(fingerprint.checkpoint_hash, 0) - def test_read_fingerprint_from_file(self): + def test_read_saved_model_fingerprint(self): export_dir = test.test_src_dir_path( "cc/saved_model/testdata/VarsAndArithmeticObjectGraph") - self.assertEqual( - fingerprinting.MaybeReadSavedModelChecksum(export_dir), - 15788619162413586750) + fingerprint = fingerprint_pb2.FingerprintDef().FromString( + pywrap_fingerprinting.ReadSavedModelFingerprint(export_dir)) + self.assertGreater(fingerprint.saved_model_checksum, 0) + self.assertEqual(fingerprint.graph_def_program_hash, 706963557435316516) + self.assertEqual(fingerprint.signature_def_hash, 5693392539583495303) + self.assertEqual(fingerprint.saved_object_graph_hash, 12074714563970609759) + self.assertGreater(fingerprint.checkpoint_hash, 0) + self.assertEqual(fingerprint.version.producer, 1) - def test_read_nonexistent_fingerprint_from_file(self): + def test_read_nonexistent_fingerprint(self): export_dir = test.test_src_dir_path("cc/saved_model/testdata/AssetModule") - self.assertEqual(fingerprinting.MaybeReadSavedModelChecksum(export_dir), 0) - - def test_get_fingerprint_map_valid(self): - export_dir = test.test_src_dir_path( - "cc/saved_model/testdata/VarsAndArithmeticObjectGraph" - ) - fingerprint_map = fingerprinting.GetFingerprintMap(export_dir) - - fingerprint_def = fingerprint_pb2.FingerprintDef() - with file_io.FileIO(os.path.join(export_dir, "fingerprint.pb"), "rb") as f: - fingerprint_def.ParseFromString(f.read()) - - self.assertEqual( - fingerprint_map["saved_model_checksum"], - fingerprint_def.saved_model_checksum, - ) - self.assertEqual( - fingerprint_map["graph_def_program_hash"], - fingerprint_def.graph_def_program_hash, - ) - self.assertEqual( - fingerprint_map["signature_def_hash"], - fingerprint_def.signature_def_hash, - ) - self.assertEqual( - fingerprint_map["saved_object_graph_hash"], - fingerprint_def.saved_object_graph_hash, - ) - self.assertEqual( - fingerprint_map["checkpoint_hash"], fingerprint_def.checkpoint_hash - ) - self.assertEqual( - fingerprint_map["version"], fingerprint_def.version.producer - ) - - -def test_get_fingerprint_map_nonexistent(self): - export_dir = test.test_src_dir_path("cc/saved_model/testdata/AssetModule") - fingerprint_map = fingerprinting.GetFingerprintMap(export_dir) - self.assertEmpty(fingerprint_map) - - -def test_get_fingerprint_map_invalid_saved_model(self): - export_dir = test.test_src_dir_path("not_a_saved_model") - fingerprint_map = fingerprinting.GetFingerprintMap(export_dir) - self.assertEmpty(fingerprint_map) + with self.assertRaises( + pywrap_fingerprinting.FingerprintException) as excinfo: + pywrap_fingerprinting.ReadSavedModelFingerprint(export_dir) + self.assertRegex(str(excinfo.exception), "Could not read fingerprint.") if __name__ == "__main__": diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.experimental.-fingerprint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.experimental.-fingerprint.pbtxt index 33475723ae7..d9d4eb9b024 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.experimental.-fingerprint.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.experimental.-fingerprint.pbtxt @@ -6,4 +6,8 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'saved_model_checksum\', \'graph_def_program_hash\', \'signature_def_hash\', \'saved_object_graph_hash\', \'checkpoint_hash\', \'version\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "from_proto" + argspec: "args=[\'cls\', \'proto\'], varargs=None, keywords=None, defaults=None" + } } diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt index d6679ee0a9a..3db515cfbce 100644 --- a/tensorflow/tools/def_file_filter/symbols_pybind.txt +++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt @@ -459,7 +459,6 @@ tensorflow::metrics::CheckpointSize [//tensorflow/cc/saved_model:fingerprinting_impl] # SavedModel Fingerprinting tensorflow::saved_model::fingerprinting::CreateFingerprintDef tensorflow::saved_model::fingerprinting::ReadSavedModelFingerprint -tensorflow::saved_model::fingerprinting::MakeFingerprintMap [//tensorflow/compiler/jit:flags] # tfe