mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Replace GetFingerprintMap() and MaybeReadSavedModelChecksum() with ReadSavedModelFingerprint().
PiperOrigin-RevId: 515432111
This commit is contained in:
parent
b523cef78b
commit
b01db583ac
|
|
@ -90,6 +90,12 @@
|
|||
* `tf.data.Dataset.zip` now supports Python-style zipping, i.e.
|
||||
`Dataset.zip(a, b, c)`.
|
||||
|
||||
* `tf.SavedModel`
|
||||
|
||||
* Introduce class method
|
||||
`tf.saved_model.experimental.Fingerprint.from_proto(proto)`, which can
|
||||
be used to construct a `Fingerprint` object directly from a protobuf.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/util/tensor_bundle/naming.h"
|
||||
#include "tensorflow/tsl/lib/strings/proto_serialization.h"
|
||||
|
||||
namespace tensorflow::saved_model::fingerprinting {
|
||||
|
||||
|
|
@ -71,7 +70,7 @@ uint64 RegularizeAndHashSignatureDefs(
|
|||
|
||||
// The SavedObjectGraph contains two parts: the list of nodes and the map of
|
||||
// concrete functions. Regularization treats these two parts separately.
|
||||
uint64 RegularizeAndHashSavedObjectGraph(
|
||||
StatusOr<uint64> RegularizeAndHashSavedObjectGraph(
|
||||
const SavedObjectGraph& object_graph_def) {
|
||||
// Sort `concrete_functions`, which is an unordered map from function names to
|
||||
// SavedConcreteFunction, using the suffix UID of the function name. Assumes
|
||||
|
|
@ -80,13 +79,9 @@ uint64 RegularizeAndHashSavedObjectGraph(
|
|||
absl::btree_map<int, std::string> uid_to_function_names;
|
||||
for (const auto& [name, concrete_function] :
|
||||
object_graph_def.concrete_functions()) {
|
||||
StatusOr<int> uid = graph_regularization::GetSuffixUID(name);
|
||||
// All valid function names should end in an UID.
|
||||
if (uid.ok()) {
|
||||
uid_to_function_names.insert({*uid, name});
|
||||
} else {
|
||||
LOG(ERROR) << uid.status().error_message();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(int uid, graph_regularization::GetSuffixUID(name));
|
||||
uid_to_function_names.insert({uid, name});
|
||||
}
|
||||
uint64 result_hash = 0;
|
||||
for (const auto& [uid, function_name] : uid_to_function_names) {
|
||||
|
|
@ -115,14 +110,13 @@ uint64 HashCheckpointIndexFile(absl::string_view model_dir) {
|
|||
if (read_status.ok()) {
|
||||
return tensorflow::Fingerprint64(data);
|
||||
} else {
|
||||
LOG(WARNING) << read_status.error_message();
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
FingerprintDef CreateFingerprintDef(const SavedModel& saved_model,
|
||||
StatusOr<FingerprintDef> CreateFingerprintDef(const SavedModel& saved_model,
|
||||
absl::string_view export_dir) {
|
||||
// Create a copy of `metagraph` which will be used and mutated for fingerprint
|
||||
// computation.
|
||||
|
|
@ -138,10 +132,10 @@ FingerprintDef CreateFingerprintDef(const SavedModel& saved_model,
|
|||
fingerprint_def.set_signature_def_hash(
|
||||
RegularizeAndHashSignatureDefs(metagraph_copy.signature_def()));
|
||||
// Set fingerprint field #4.
|
||||
StatusOr<uint64> object_graph_hash =
|
||||
RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def());
|
||||
fingerprint_def.set_saved_object_graph_hash(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
StatusOr<uint64> object_graph_hash,
|
||||
RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def()));
|
||||
fingerprint_def.set_saved_object_graph_hash(object_graph_hash.value());
|
||||
// Set fingerprint field #5.
|
||||
fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir));
|
||||
// Set version of the fingerprint.
|
||||
|
|
@ -168,18 +162,4 @@ StatusOr<FingerprintDef> ReadSavedModelFingerprint(
|
|||
return found_pb;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, uint64_t> MakeFingerprintMap(
|
||||
const FingerprintDef& fingerprint) {
|
||||
std::unordered_map<std::string, uint64_t> fingerprint_map;
|
||||
fingerprint_map["saved_model_checksum"] = fingerprint.saved_model_checksum();
|
||||
fingerprint_map["graph_def_program_hash"] =
|
||||
fingerprint.graph_def_program_hash();
|
||||
fingerprint_map["signature_def_hash"] = fingerprint.signature_def_hash();
|
||||
fingerprint_map["saved_object_graph_hash"] =
|
||||
fingerprint.saved_object_graph_hash();
|
||||
fingerprint_map["checkpoint_hash"] = fingerprint.checkpoint_hash();
|
||||
fingerprint_map["version"] = fingerprint.version().producer();
|
||||
return fingerprint_map;
|
||||
}
|
||||
|
||||
} // namespace tensorflow::saved_model::fingerprinting
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ namespace tensorflow::saved_model::fingerprinting {
|
|||
|
||||
// Creates a FingerprintDef proto from a SavedModel and the checkpoint meta file
|
||||
// (.index) in `export_dir`.
|
||||
FingerprintDef CreateFingerprintDef(const SavedModel& saved_model,
|
||||
StatusOr<FingerprintDef> CreateFingerprintDef(const SavedModel& saved_model,
|
||||
absl::string_view export_dir);
|
||||
|
||||
// Loads the `fingerprint.pb` from `export_dir`, returns an error if there is
|
||||
|
|
@ -36,10 +36,6 @@ FingerprintDef CreateFingerprintDef(const SavedModel& saved_model,
|
|||
StatusOr<FingerprintDef> ReadSavedModelFingerprint(
|
||||
absl::string_view export_dir);
|
||||
|
||||
// Converts the fingerprint into a dictionary mapping field names to values.
|
||||
std::unordered_map<std::string, uint64_t> MakeFingerprintMap(
|
||||
const FingerprintDef& fingerprint);
|
||||
|
||||
} // namespace tensorflow::saved_model::fingerprinting
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_H_
|
||||
|
|
|
|||
|
|
@ -52,8 +52,8 @@ TEST(FingerprintingTest, TestCreateFingerprint) {
|
|||
"VarsAndArithmeticObjectGraph");
|
||||
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
|
||||
ReadSavedModel(export_dir));
|
||||
FingerprintDef fingerprint_def =
|
||||
CreateFingerprintDef(saved_model_pb, export_dir);
|
||||
TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def,
|
||||
CreateFingerprintDef(saved_model_pb, export_dir));
|
||||
|
||||
EXPECT_GT(fingerprint_def.saved_model_checksum(), 0);
|
||||
EXPECT_EQ(fingerprint_def.graph_def_program_hash(), 10127142238652115842U);
|
||||
|
|
@ -72,15 +72,15 @@ TEST(FingerprintingTest, TestCompareFingerprintForTwoModelSavedTwice) {
|
|||
|
||||
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
|
||||
ReadSavedModel(export_dir));
|
||||
FingerprintDef fingerprint_def =
|
||||
CreateFingerprintDef(saved_model_pb, export_dir);
|
||||
TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def,
|
||||
CreateFingerprintDef(saved_model_pb, export_dir));
|
||||
|
||||
const std::string export_dir2 = io::JoinPath(
|
||||
testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert2");
|
||||
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb2,
|
||||
ReadSavedModel(export_dir2));
|
||||
FingerprintDef fingerprint_def2 =
|
||||
CreateFingerprintDef(saved_model_pb2, export_dir2);
|
||||
TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def2,
|
||||
CreateFingerprintDef(saved_model_pb2, export_dir2));
|
||||
|
||||
EXPECT_EQ(fingerprint_def.graph_def_program_hash(),
|
||||
fingerprint_def2.graph_def_program_hash());
|
||||
|
|
@ -95,10 +95,10 @@ TEST(FingerprintingTest, TestFingerprintComputationDoesNotMutateModel) {
|
|||
testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert1");
|
||||
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
|
||||
ReadSavedModel(export_dir));
|
||||
FingerprintDef fingerprint_def =
|
||||
CreateFingerprintDef(saved_model_pb, export_dir);
|
||||
FingerprintDef fingerprint_def2 =
|
||||
CreateFingerprintDef(saved_model_pb, export_dir);
|
||||
TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def,
|
||||
CreateFingerprintDef(saved_model_pb, export_dir));
|
||||
TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def2,
|
||||
CreateFingerprintDef(saved_model_pb, export_dir));
|
||||
|
||||
EXPECT_EQ(fingerprint_def.saved_model_checksum(),
|
||||
fingerprint_def2.saved_model_checksum());
|
||||
|
|
@ -109,8 +109,8 @@ TEST(FingerprintingTest, TestFingerprintHasVersion) {
|
|||
testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert1");
|
||||
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
|
||||
ReadSavedModel(export_dir));
|
||||
FingerprintDef fingerprint_def =
|
||||
CreateFingerprintDef(saved_model_pb, export_dir);
|
||||
TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def,
|
||||
CreateFingerprintDef(saved_model_pb, export_dir));
|
||||
EXPECT_EQ(fingerprint_def.version().producer(), 1);
|
||||
}
|
||||
|
||||
|
|
@ -119,8 +119,8 @@ TEST(FingerprintingTest, TestHashCheckpointForModelWithNoVariables) {
|
|||
testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert1");
|
||||
TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb,
|
||||
ReadSavedModel(export_dir));
|
||||
FingerprintDef fingerprint_def =
|
||||
CreateFingerprintDef(saved_model_pb, export_dir);
|
||||
TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_def,
|
||||
CreateFingerprintDef(saved_model_pb, export_dir));
|
||||
EXPECT_EQ(fingerprint_def.checkpoint_hash(), 0);
|
||||
}
|
||||
|
||||
|
|
@ -139,16 +139,5 @@ TEST(FingerprintingTest, TestReadNonexistentFingerprint) {
|
|||
EXPECT_FALSE(ReadSavedModelFingerprint(export_dir).ok());
|
||||
}
|
||||
|
||||
TEST(FingerprintingTest, TestMakeFingerprintMap) {
|
||||
const std::string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata",
|
||||
"VarsAndArithmeticObjectGraph");
|
||||
TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_pb,
|
||||
ReadSavedModelFingerprint(export_dir));
|
||||
auto fingerprint_map = MakeFingerprintMap(fingerprint_pb);
|
||||
EXPECT_EQ(fingerprint_pb.saved_model_checksum(),
|
||||
fingerprint_map["saved_model_checksum"]);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow::saved_model::fingerprinting
|
||||
|
|
|
|||
|
|
@ -426,6 +426,7 @@ py_library(
|
|||
"ignore_for_dep=third_party.py.keras.optimizers.optimizer_v2",
|
||||
],
|
||||
deps = [
|
||||
":fingerprinting",
|
||||
":function_deserialization",
|
||||
":load_options",
|
||||
":load_v1_in_v2",
|
||||
|
|
@ -466,6 +467,7 @@ py_library(
|
|||
"//tensorflow/python/trackable:trackable_utils",
|
||||
"//tensorflow/python/training/saving:saveable_object_util",
|
||||
"//tensorflow/python/util:tf_export",
|
||||
"@absl_py//absl/logging",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -704,6 +706,7 @@ tf_py_test(
|
|||
name = "metrics_test",
|
||||
srcs = ["metrics_test.py"],
|
||||
deps = [
|
||||
":fingerprinting",
|
||||
":pywrap_saved_model",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
|
|
@ -784,6 +787,7 @@ py_strict_library(
|
|||
srcs = ["fingerprinting.py"],
|
||||
deps = [
|
||||
":pywrap_saved_model",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/util:tf_export",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ This module contains classes and functions for reading the SavedModel
|
|||
fingerprint.
|
||||
"""
|
||||
|
||||
from tensorflow.core.protobuf import fingerprint_pb2
|
||||
from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting as fingerprinting_pywrap
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
|
@ -65,6 +66,16 @@ class Fingerprint(object):
|
|||
self.checkpoint_hash = checkpoint_hash
|
||||
self.version = version
|
||||
|
||||
@classmethod
|
||||
def from_proto(cls, proto):
|
||||
return Fingerprint(
|
||||
proto.saved_model_checksum,
|
||||
proto.graph_def_program_hash,
|
||||
proto.signature_def_hash,
|
||||
proto.saved_object_graph_hash,
|
||||
proto.checkpoint_hash,
|
||||
proto.version)
|
||||
|
||||
|
||||
@tf_export("saved_model.experimental.read_fingerprint", v1=[])
|
||||
def read_fingerprint(export_dir):
|
||||
|
|
@ -73,7 +84,9 @@ def read_fingerprint(export_dir):
|
|||
Returns a `tf.saved_model.experimental.Fingerprint` object that contains
|
||||
the values of the SavedModel fingerprint, which is persisted on disk in the
|
||||
`fingerprint.pb` file in the `export_dir`.
|
||||
TODO(b/265199038): Add link to TensorFlow SavedModel guide.
|
||||
|
||||
Read more about fingerprints in the SavedModel guide at
|
||||
https://www.tensorflow.org/guide/saved_model.
|
||||
|
||||
Args:
|
||||
export_dir: The directory that contains the SavedModel.
|
||||
|
|
@ -82,16 +95,11 @@ def read_fingerprint(export_dir):
|
|||
A `tf.saved_model.experimental.Fingerprint`.
|
||||
|
||||
Raises:
|
||||
ValueError: If no or an invalid fingerprint is found.
|
||||
FingerprintException: If no or an invalid fingerprint is found.
|
||||
"""
|
||||
fingerprint_map = fingerprinting_pywrap.GetFingerprintMap(export_dir)
|
||||
if not fingerprint_map:
|
||||
raise ValueError(f"No or invalid fingerprint found in: {export_dir}.")
|
||||
return Fingerprint(
|
||||
fingerprint_map["saved_model_checksum"],
|
||||
fingerprint_map["graph_def_program_hash"],
|
||||
fingerprint_map["signature_def_hash"],
|
||||
fingerprint_map["saved_object_graph_hash"],
|
||||
fingerprint_map["checkpoint_hash"],
|
||||
fingerprint_map["version"],
|
||||
)
|
||||
try:
|
||||
fingerprint = fingerprinting_pywrap.ReadSavedModelFingerprint(export_dir)
|
||||
except fingerprinting_pywrap.FingerprintException as e:
|
||||
raise FileNotFoundError(f"SavedModel Fingerprint Error: {e}") from None # pylint: disable=raise-missing-from
|
||||
return Fingerprint.from_proto(
|
||||
fingerprint_pb2.FingerprintDef().FromString(fingerprint))
|
||||
|
|
|
|||
|
|
@ -142,10 +142,13 @@ class FingerprintingTest(test.TestCase):
|
|||
self.assertEqual(
|
||||
fingerprint.checkpoint_hash, fingerprint_def.checkpoint_hash
|
||||
)
|
||||
self.assertEqual(fingerprint.version, fingerprint_def.version.producer)
|
||||
self.assertEqual(
|
||||
fingerprint.version.producer, fingerprint_def.version.producer
|
||||
)
|
||||
|
||||
def test_read_fingerprint_api_invalid(self):
|
||||
with self.assertRaisesRegex(ValueError, "No or invalid fingerprint"):
|
||||
with self.assertRaisesRegex(FileNotFoundError,
|
||||
"SavedModel Fingerprint Error"):
|
||||
read_fingerprint("foo")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ import functools
|
|||
import os
|
||||
import sys
|
||||
|
||||
from absl import logging
|
||||
|
||||
from tensorflow.core.function.capture import restore_captures
|
||||
from tensorflow.core.protobuf import graph_debug_info_pb2
|
||||
from tensorflow.python.checkpoint import checkpoint
|
||||
|
|
@ -42,6 +44,7 @@ from tensorflow.python.ops import control_flow_ops
|
|||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.saved_model import fingerprinting
|
||||
from tensorflow.python.saved_model import function_deserialization
|
||||
from tensorflow.python.saved_model import load_options
|
||||
from tensorflow.python.saved_model import load_v1_in_v2
|
||||
|
|
@ -50,7 +53,6 @@ from tensorflow.python.saved_model import path_helpers
|
|||
from tensorflow.python.saved_model import registration
|
||||
from tensorflow.python.saved_model import revived_types
|
||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||
from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting
|
||||
from tensorflow.python.saved_model.pywrap_saved_model import metrics
|
||||
from tensorflow.python.trackable import asset
|
||||
from tensorflow.python.trackable import autotrackable
|
||||
|
|
@ -991,9 +993,14 @@ def load_partial(export_dir, filters, tags=None, options=None):
|
|||
metrics.SetReadPath(saved_model_path=str(export_dir))
|
||||
|
||||
# Read and log SavedModel checksum, if it is nonzero.
|
||||
saved_model_checksum = fingerprinting.MaybeReadSavedModelChecksum(export_dir)
|
||||
if saved_model_checksum != 0:
|
||||
metrics.SetReadFingerprint(saved_model_checksum=str(saved_model_checksum))
|
||||
try:
|
||||
fingerprint = fingerprinting.read_fingerprint(export_dir)
|
||||
if fingerprint.saved_model_checksum != 0:
|
||||
metrics.SetReadFingerprint(
|
||||
saved_model_checksum=str(fingerprint.saved_model_checksum))
|
||||
except FileNotFoundError:
|
||||
logging.error("Unable to load fingerprint when loading saved model.",
|
||||
exc_info=True)
|
||||
|
||||
if filters:
|
||||
return {node_id: loader.get(node_id) for node_id in filters}
|
||||
|
|
|
|||
|
|
@ -25,11 +25,11 @@ from tensorflow.python.framework import constant_op
|
|||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.saved_model import builder
|
||||
from tensorflow.python.saved_model import builder_impl
|
||||
from tensorflow.python.saved_model import fingerprinting
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import load_v1_in_v2
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting
|
||||
from tensorflow.python.saved_model.pywrap_saved_model import metrics
|
||||
from tensorflow.python.trackable import autotrackable
|
||||
|
||||
|
|
@ -114,18 +114,20 @@ class MetricsTests(test.TestCase):
|
|||
|
||||
def test_save_sets_write_fingerprint_metric(self):
|
||||
exported_dir = self._create_save_v2_model()
|
||||
fingerprint = fingerprinting.read_fingerprint(exported_dir)
|
||||
|
||||
self.assertEqual(
|
||||
metrics.GetWriteFingerprint(),
|
||||
str(fingerprinting.MaybeReadSavedModelChecksum(exported_dir)))
|
||||
str(fingerprint.saved_model_checksum))
|
||||
|
||||
def test_load_sets_read_fingerprint_metric(self):
|
||||
exported_dir = self._create_save_v2_model()
|
||||
load.load(exported_dir)
|
||||
fingerprint = fingerprinting.read_fingerprint(exported_dir)
|
||||
|
||||
self.assertEqual(
|
||||
metrics.GetWriteFingerprint(),
|
||||
str(fingerprinting.MaybeReadSavedModelChecksum(exported_dir)))
|
||||
str(fingerprint.saved_model_checksum))
|
||||
|
||||
def test_save_sets_write_path_metric(self):
|
||||
exported_dir = self._create_save_v2_model()
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
|
|
@ -27,11 +28,31 @@ namespace python {
|
|||
|
||||
namespace py = pybind11;
|
||||
|
||||
class FingerprintException : public std::exception {
|
||||
public:
|
||||
explicit FingerprintException(const char *m) : message_{m} {}
|
||||
const char *what() const noexcept override { return message_.c_str(); }
|
||||
|
||||
private:
|
||||
std::string message_ = "";
|
||||
};
|
||||
|
||||
void DefineFingerprintingModule(py::module main_module) {
|
||||
auto m = main_module.def_submodule("fingerprinting");
|
||||
|
||||
m.doc() = "Python bindings for TensorFlow SavedModel Fingerprinting.";
|
||||
|
||||
static py::exception<FingerprintException> ex(m, "FingerprintException");
|
||||
py::register_exception_translator([](std::exception_ptr p) {
|
||||
try {
|
||||
if (p) {
|
||||
std::rethrow_exception(p);
|
||||
}
|
||||
} catch (const FingerprintException &e) {
|
||||
ex(e.what());
|
||||
}
|
||||
});
|
||||
|
||||
m.def(
|
||||
"CreateFingerprintDef",
|
||||
[](std::string serialized_saved_model, std::string export_dir) {
|
||||
|
|
@ -39,42 +60,37 @@ void DefineFingerprintingModule(py::module main_module) {
|
|||
SavedModel saved_model_pb;
|
||||
saved_model_pb.ParseFromString(serialized_saved_model);
|
||||
|
||||
return py::bytes(
|
||||
fingerprinting::CreateFingerprintDef(saved_model_pb, export_dir)
|
||||
.SerializeAsString());
|
||||
StatusOr<FingerprintDef> fingerprint =
|
||||
fingerprinting::CreateFingerprintDef(saved_model_pb, export_dir);
|
||||
if (fingerprint.ok()) {
|
||||
return py::bytes(fingerprint.value().SerializeAsString());
|
||||
}
|
||||
throw FingerprintException(
|
||||
std::string("Could not create fingerprint in directory: " +
|
||||
export_dir)
|
||||
.c_str());
|
||||
},
|
||||
py::arg("saved_model"), py::arg("export_dir"),
|
||||
py::doc(
|
||||
"Returns the serialized FingerprintDef of a serialized SavedModel."));
|
||||
|
||||
m.def(
|
||||
"MaybeReadSavedModelChecksum",
|
||||
"ReadSavedModelFingerprint",
|
||||
[](std::string export_dir) {
|
||||
StatusOr<FingerprintDef> fingerprint =
|
||||
fingerprinting::ReadSavedModelFingerprint(export_dir);
|
||||
if (fingerprint.ok()) {
|
||||
return fingerprint->saved_model_checksum();
|
||||
return py::bytes(fingerprint.value().SerializeAsString());
|
||||
}
|
||||
return (uint64_t)0;
|
||||
throw FingerprintException(
|
||||
std::string("Could not read fingerprint from directory: " +
|
||||
export_dir)
|
||||
.c_str());
|
||||
},
|
||||
py::arg("export_dir"),
|
||||
py::doc(
|
||||
"Reads the fingerprint checksum from SavedModel directory. Returns "
|
||||
"0 if an error occurs."));
|
||||
|
||||
m.def(
|
||||
"GetFingerprintMap",
|
||||
[](std::string export_dir) {
|
||||
StatusOr<FingerprintDef> fingerprint =
|
||||
fingerprinting::ReadSavedModelFingerprint(export_dir);
|
||||
if (fingerprint.ok()) {
|
||||
return fingerprinting::MakeFingerprintMap(*fingerprint);
|
||||
}
|
||||
return std::unordered_map<std::string, uint64_t>();
|
||||
},
|
||||
py::arg("export_dir"),
|
||||
py::doc("Returns the fingerprint protobuf as a dictionary. Returns "
|
||||
"an empty dictionary if invalid fingerprint file."));
|
||||
"Loads the `fingerprint.pb` from `export_dir`, returns an error if "
|
||||
"there is none."));
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
|
|
|
|||
|
|
@ -19,89 +19,46 @@ import os
|
|||
from tensorflow.core.protobuf import fingerprint_pb2
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting
|
||||
from tensorflow.python.saved_model.pywrap_saved_model import fingerprinting as pywrap_fingerprinting
|
||||
|
||||
|
||||
class FingerprintingTest(test.TestCase):
|
||||
|
||||
# Checks that the fingerprint values are preserved when passed from C++ to
|
||||
# Python.
|
||||
def test_fingerprint_def_is_deserialized_correctly(self):
|
||||
def test_create_fingerprint_def(self):
|
||||
export_dir = test.test_src_dir_path(
|
||||
"cc/saved_model/testdata/VarsAndArithmeticObjectGraph")
|
||||
with file_io.FileIO(os.path.join(export_dir, "saved_model.pb"), "rb") as f:
|
||||
file_content = f.read()
|
||||
|
||||
fingerprint_def = fingerprint_pb2.FingerprintDef()
|
||||
fingerprint_def.ParseFromString(
|
||||
fingerprinting.CreateFingerprintDef(file_content, export_dir))
|
||||
fingerprint = fingerprint_pb2.FingerprintDef().FromString(
|
||||
pywrap_fingerprinting.CreateFingerprintDef(file_content, export_dir))
|
||||
# We cannot check the value of the saved_model_checksum due to
|
||||
# non-determinism in serialization.
|
||||
self.assertGreater(fingerprint_def.saved_model_checksum, 0)
|
||||
self.assertEqual(fingerprint_def.graph_def_program_hash,
|
||||
10127142238652115842)
|
||||
self.assertEqual(fingerprint_def.signature_def_hash, 5693392539583495303)
|
||||
self.assertEqual(fingerprint_def.saved_object_graph_hash,
|
||||
3678101440349108924)
|
||||
self.assertGreater(fingerprint.saved_model_checksum, 0)
|
||||
self.assertEqual(fingerprint.graph_def_program_hash, 10127142238652115842)
|
||||
self.assertEqual(fingerprint.signature_def_hash, 5693392539583495303)
|
||||
self.assertEqual(fingerprint.saved_object_graph_hash, 3678101440349108924)
|
||||
# TODO(b/242348400): The checkpoint hash is non-deterministic, so we cannot
|
||||
# check its value here.
|
||||
self.assertGreater(fingerprint_def.checkpoint_hash, 0)
|
||||
self.assertGreater(fingerprint.checkpoint_hash, 0)
|
||||
|
||||
def test_read_fingerprint_from_file(self):
|
||||
def test_read_saved_model_fingerprint(self):
|
||||
export_dir = test.test_src_dir_path(
|
||||
"cc/saved_model/testdata/VarsAndArithmeticObjectGraph")
|
||||
self.assertEqual(
|
||||
fingerprinting.MaybeReadSavedModelChecksum(export_dir),
|
||||
15788619162413586750)
|
||||
fingerprint = fingerprint_pb2.FingerprintDef().FromString(
|
||||
pywrap_fingerprinting.ReadSavedModelFingerprint(export_dir))
|
||||
self.assertGreater(fingerprint.saved_model_checksum, 0)
|
||||
self.assertEqual(fingerprint.graph_def_program_hash, 706963557435316516)
|
||||
self.assertEqual(fingerprint.signature_def_hash, 5693392539583495303)
|
||||
self.assertEqual(fingerprint.saved_object_graph_hash, 12074714563970609759)
|
||||
self.assertGreater(fingerprint.checkpoint_hash, 0)
|
||||
self.assertEqual(fingerprint.version.producer, 1)
|
||||
|
||||
def test_read_nonexistent_fingerprint_from_file(self):
|
||||
def test_read_nonexistent_fingerprint(self):
|
||||
export_dir = test.test_src_dir_path("cc/saved_model/testdata/AssetModule")
|
||||
self.assertEqual(fingerprinting.MaybeReadSavedModelChecksum(export_dir), 0)
|
||||
|
||||
def test_get_fingerprint_map_valid(self):
|
||||
export_dir = test.test_src_dir_path(
|
||||
"cc/saved_model/testdata/VarsAndArithmeticObjectGraph"
|
||||
)
|
||||
fingerprint_map = fingerprinting.GetFingerprintMap(export_dir)
|
||||
|
||||
fingerprint_def = fingerprint_pb2.FingerprintDef()
|
||||
with file_io.FileIO(os.path.join(export_dir, "fingerprint.pb"), "rb") as f:
|
||||
fingerprint_def.ParseFromString(f.read())
|
||||
|
||||
self.assertEqual(
|
||||
fingerprint_map["saved_model_checksum"],
|
||||
fingerprint_def.saved_model_checksum,
|
||||
)
|
||||
self.assertEqual(
|
||||
fingerprint_map["graph_def_program_hash"],
|
||||
fingerprint_def.graph_def_program_hash,
|
||||
)
|
||||
self.assertEqual(
|
||||
fingerprint_map["signature_def_hash"],
|
||||
fingerprint_def.signature_def_hash,
|
||||
)
|
||||
self.assertEqual(
|
||||
fingerprint_map["saved_object_graph_hash"],
|
||||
fingerprint_def.saved_object_graph_hash,
|
||||
)
|
||||
self.assertEqual(
|
||||
fingerprint_map["checkpoint_hash"], fingerprint_def.checkpoint_hash
|
||||
)
|
||||
self.assertEqual(
|
||||
fingerprint_map["version"], fingerprint_def.version.producer
|
||||
)
|
||||
|
||||
|
||||
def test_get_fingerprint_map_nonexistent(self):
|
||||
export_dir = test.test_src_dir_path("cc/saved_model/testdata/AssetModule")
|
||||
fingerprint_map = fingerprinting.GetFingerprintMap(export_dir)
|
||||
self.assertEmpty(fingerprint_map)
|
||||
|
||||
|
||||
def test_get_fingerprint_map_invalid_saved_model(self):
|
||||
export_dir = test.test_src_dir_path("not_a_saved_model")
|
||||
fingerprint_map = fingerprinting.GetFingerprintMap(export_dir)
|
||||
self.assertEmpty(fingerprint_map)
|
||||
with self.assertRaises(
|
||||
pywrap_fingerprinting.FingerprintException) as excinfo:
|
||||
pywrap_fingerprinting.ReadSavedModelFingerprint(export_dir)
|
||||
self.assertRegex(str(excinfo.exception), "Could not read fingerprint.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -6,4 +6,8 @@ tf_class {
|
|||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'saved_model_checksum\', \'graph_def_program_hash\', \'signature_def_hash\', \'saved_object_graph_hash\', \'checkpoint_hash\', \'version\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_proto"
|
||||
argspec: "args=[\'cls\', \'proto\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -459,7 +459,6 @@ tensorflow::metrics::CheckpointSize
|
|||
[//tensorflow/cc/saved_model:fingerprinting_impl] # SavedModel Fingerprinting
|
||||
tensorflow::saved_model::fingerprinting::CreateFingerprintDef
|
||||
tensorflow::saved_model::fingerprinting::ReadSavedModelFingerprint
|
||||
tensorflow::saved_model::fingerprinting::MakeFingerprintMap
|
||||
|
||||
|
||||
[//tensorflow/compiler/jit:flags] # tfe
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user