Add canonical singleprint method to c++/python to easily and uniquely identify a SavedModel.

PiperOrigin-RevId: 516332423
This commit is contained in:
Adam Cogdell 2023-03-13 15:01:15 -07:00 committed by TensorFlower Gardener
parent a0203c1c68
commit 03fd965e1e
11 changed files with 155 additions and 5 deletions

View File

@ -99,6 +99,9 @@
* Introduce class method
`tf.saved_model.experimental.Fingerprint.from_proto(proto)`, which can
be used to construct a `Fingerprint` object directly from a protobuf.
* Introduce member method
`tf.saved_model.experimental.Fingerprint.singleprint()`, which provides
a convenient way to uniquely identify a SavedModel.
## Bug Fixes and Other Changes

View File

@ -395,8 +395,8 @@ cc_library(
"//tensorflow/core/graph/regularization:simple_delete",
"//tensorflow/core/graph/regularization:util",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:types",
"//tensorflow/core/util/tensor_bundle:naming",
"//tensorflow/tsl/platform:types",
"@com_google_protobuf//:protobuf_headers",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/strings",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map>
#include "absl/container/btree_map.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/framework/function.pb.h"
@ -38,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/util/tensor_bundle/naming.h"
#include "tensorflow/tsl/platform/types.h"
namespace tensorflow::saved_model::fingerprinting {
@ -162,4 +164,25 @@ StatusOr<FingerprintDef> ReadSavedModelFingerprint(
return found_pb;
}
std::string Singleprint(uint64 graph_def_program_hash,
uint64 signature_def_hash,
uint64 saved_object_graph_hash,
uint64 checkpoint_hash) {
return std::to_string(graph_def_program_hash) + "/" +
std::to_string(signature_def_hash) + "/" +
std::to_string(saved_object_graph_hash) + "/" +
std::to_string(checkpoint_hash);
}
std::string Singleprint(const FingerprintDef& fingerprint) {
return Singleprint(
fingerprint.graph_def_program_hash(), fingerprint.signature_def_hash(),
fingerprint.saved_object_graph_hash(), fingerprint.checkpoint_hash());
}
std::string Singleprint(absl::string_view export_dir) {
FingerprintDef fingerprint = ReadSavedModelFingerprint(export_dir).value();
return Singleprint(fingerprint);
}
} // namespace tensorflow::saved_model::fingerprinting

View File

@ -36,6 +36,13 @@ StatusOr<FingerprintDef> CreateFingerprintDef(const SavedModel& saved_model,
StatusOr<FingerprintDef> ReadSavedModelFingerprint(
absl::string_view export_dir);
// Canonical fingerprinting ID for a SavedModel.
std::string Singleprint(uint64 graph_def_program_hash,
uint64 signature_def_hash,
uint64 saved_object_graph_hash, uint64 checkpoint_hash);
std::string Singleprint(const FingerprintDef& fingerprint);
std::string Singleprint(absl::string_view export_dir);
} // namespace tensorflow::saved_model::fingerprinting
#endif // TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_H_

View File

@ -139,5 +139,23 @@ TEST(FingerprintingTest, TestReadNonexistentFingerprint) {
EXPECT_FALSE(ReadSavedModelFingerprint(export_dir).ok());
}
TEST(FingerprintingTest, TestSingleprint) {
const std::string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata",
"VarsAndArithmeticObjectGraph");
const std::string const_singleprint =
"706963557435316516/5693392539583495303/12074714563970609759/"
"10788359570789890102";
EXPECT_EQ(Singleprint(export_dir), const_singleprint);
TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_pb,
ReadSavedModelFingerprint(export_dir));
EXPECT_EQ(Singleprint(fingerprint_pb), const_singleprint);
EXPECT_EQ(Singleprint(fingerprint_pb.graph_def_program_hash(),
fingerprint_pb.signature_def_hash(),
fingerprint_pb.saved_object_graph_hash(),
fingerprint_pb.checkpoint_hash()),
const_singleprint);
}
} // namespace
} // namespace tensorflow::saved_model::fingerprinting

View File

@ -76,6 +76,36 @@ class Fingerprint(object):
proto.checkpoint_hash,
proto.version)
def singleprint(self):
"""Canonical fingerprinting ID for a SavedModel.
Uniquely identifies a SavedModel based on the regularized fingerprint
attributes. (saved_model_checksum is sensitive to immaterial changes and
thus non-deterministic.)
Returns:
The string concatenation of `graph_def_program_hash`,
`signature_def_hash`, and `saved_object_graph_hash`
fingerprint attributes (separated by '/').
Raises:
ValueError: If the fingerprint fields cannot be used to construct the
singleprint.
"""
try:
return fingerprinting_pywrap.Singleprint(self.graph_def_program_hash,
self.signature_def_hash,
self.saved_object_graph_hash,
self.checkpoint_hash)
except (TypeError, fingerprinting_pywrap.FingerprintException) as e:
raise ValueError(
f"Encounted invalid fingerprint values when constructing singleprint."
f"graph_def_program_hash: {self.graph_def_program_hash}"
f"signature_def_hash: {self.signature_def_hash}"
f"saved_object_graph_hash: {self.saved_object_graph_hash}"
f"checkpoint_hash: {self.checkpoint_hash}"
f"{e}") from None
@tf_export("saved_model.experimental.read_fingerprint", v1=[])
def read_fingerprint(export_dir):
@ -95,7 +125,7 @@ def read_fingerprint(export_dir):
A `tf.saved_model.experimental.Fingerprint`.
Raises:
FingerprintException: If no or an invalid fingerprint is found.
FileNotFoundError: If no or an invalid fingerprint is found.
"""
try:
fingerprint = fingerprinting_pywrap.ReadSavedModelFingerprint(export_dir)

View File

@ -24,11 +24,12 @@ from tensorflow.core.config import flags
from tensorflow.core.protobuf import fingerprint_pb2
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import fingerprinting
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model.fingerprinting import read_fingerprint
from tensorflow.python.saved_model.pywrap_saved_model import constants
from tensorflow.python.trackable import autotrackable
@ -54,6 +55,14 @@ class FingerprintingTest(test.TestCase):
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
return root
def _create_model_with_data(self):
root = autotrackable.AutoTrackable()
root.x = constant_op.constant(1.0, dtype=dtypes.float32)
root.f = def_function.function(
lambda x: root.x * x,
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
return root
def _read_fingerprint(self, filename):
fingerprint_def = fingerprint_pb2.FingerprintDef()
with file_io.FileIO(filename, "rb") as f:
@ -119,7 +128,7 @@ class FingerprintingTest(test.TestCase):
def test_read_fingerprint_api(self):
save_dir = self._create_saved_model()
fingerprint = read_fingerprint(save_dir)
fingerprint = fingerprinting.read_fingerprint(save_dir)
fingerprint_def = self._read_fingerprint(
file_io.join(save_dir, constants.FINGERPRINT_FILENAME)
@ -149,7 +158,26 @@ class FingerprintingTest(test.TestCase):
def test_read_fingerprint_api_invalid(self):
with self.assertRaisesRegex(FileNotFoundError,
"SavedModel Fingerprint Error"):
read_fingerprint("foo")
fingerprinting.read_fingerprint("foo")
def test_valid_singleprint(self):
save_dir = os.path.join(self.get_temp_dir(), "singleprint_model")
save.save(self._create_model_with_data(), save_dir)
fingerprint = fingerprinting.read_fingerprint(save_dir)
singleprint = fingerprint.singleprint()
# checkpoint_hash is non-deterministic and not included
self.assertRegex(singleprint,
"/".join(["8947653168630125217", # graph_def_program_hash
"13520770727385282311", # signature_def_hash
"1613952301283913051" # saved_object_graph_hash
]))
def test_invalid_singleprint(self):
fingerprint = fingerprinting.Fingerprint()
with self.assertRaisesRegex(ValueError,
"Encounted invalid fingerprint values"):
fingerprint.singleprint()
if __name__ == "__main__":

View File

@ -91,6 +91,24 @@ void DefineFingerprintingModule(py::module main_module) {
py::doc(
"Loads the `fingerprint.pb` from `export_dir`, returns an error if "
"there is none."));
m.def(
"Singleprint",
[](uint64 graph_def_program_hash, uint64 signature_def_hash,
uint64 saved_object_graph_hash, uint64 checkpoint_hash) {
StatusOr<std::string> singleprint = fingerprinting::Singleprint(
graph_def_program_hash, signature_def_hash, saved_object_graph_hash,
checkpoint_hash);
if (singleprint.ok()) {
return py::str(singleprint.value());
}
throw FingerprintException(
std::string("Could not create singleprint from given values.")
.c_str());
},
py::arg("graph_def_program_hash"), py::arg("signature_def_hash"),
py::arg("saved_object_graph_hash"), py::arg("checkpoint_hash"),
py::doc("Canonical fingerprinting ID for a SavedModel."));
}
} // namespace python

View File

@ -60,6 +60,24 @@ class FingerprintingTest(test.TestCase):
pywrap_fingerprinting.ReadSavedModelFingerprint(export_dir)
self.assertRegex(str(excinfo.exception), "Could not read fingerprint.")
def test_read_saved_model_singleprint(self):
export_dir = test.test_src_dir_path(
"cc/saved_model/testdata/VarsAndArithmeticObjectGraph")
fingerprint = fingerprint_pb2.FingerprintDef().FromString(
pywrap_fingerprinting.ReadSavedModelFingerprint(export_dir))
singleprint = pywrap_fingerprinting.Singleprint(
fingerprint.graph_def_program_hash,
fingerprint.signature_def_hash,
fingerprint.saved_object_graph_hash,
fingerprint.checkpoint_hash)
# checkpoint_hash is non-deterministic and not included
self.assertRegex(singleprint,
"/".join([
"706963557435316516", # graph_def_program_hash
"5693392539583495303", # signature_def_hash
"12074714563970609759", # saved_object_graph_hash
]))
if __name__ == "__main__":
test.main()

View File

@ -10,4 +10,8 @@ tf_class {
name: "from_proto"
argspec: "args=[\'cls\', \'proto\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "singleprint"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -459,6 +459,7 @@ tensorflow::metrics::CheckpointSize
[//tensorflow/cc/saved_model:fingerprinting_impl] # SavedModel Fingerprinting
tensorflow::saved_model::fingerprinting::CreateFingerprintDef
tensorflow::saved_model::fingerprinting::ReadSavedModelFingerprint
tensorflow::saved_model::fingerprinting::Singleprint
[//tensorflow/compiler/jit:flags] # tfe