mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add canonical singleprint method to c++/python to easily and uniquely identify a SavedModel.
PiperOrigin-RevId: 516332423
This commit is contained in:
parent
a0203c1c68
commit
03fd965e1e
|
|
@ -99,6 +99,9 @@
|
|||
* Introduce class method
|
||||
`tf.saved_model.experimental.Fingerprint.from_proto(proto)`, which can
|
||||
be used to construct a `Fingerprint` object directly from a protobuf.
|
||||
* Introduce member method
|
||||
`tf.saved_model.experimental.Fingerprint.singleprint()`, which provides
|
||||
a convenient way to uniquely identify a SavedModel.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
|
|
|
|||
|
|
@ -395,8 +395,8 @@ cc_library(
|
|||
"//tensorflow/core/graph/regularization:simple_delete",
|
||||
"//tensorflow/core/graph/regularization:util",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/core/util/tensor_bundle:naming",
|
||||
"//tensorflow/tsl/platform:types",
|
||||
"@com_google_protobuf//:protobuf_headers",
|
||||
"@com_google_absl//absl/container:btree",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include <unordered_map>
|
||||
|
||||
#include "absl/container/btree_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/strip.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
|
|
@ -38,6 +39,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/util/tensor_bundle/naming.h"
|
||||
#include "tensorflow/tsl/platform/types.h"
|
||||
|
||||
namespace tensorflow::saved_model::fingerprinting {
|
||||
|
||||
|
|
@ -162,4 +164,25 @@ StatusOr<FingerprintDef> ReadSavedModelFingerprint(
|
|||
return found_pb;
|
||||
}
|
||||
|
||||
std::string Singleprint(uint64 graph_def_program_hash,
|
||||
uint64 signature_def_hash,
|
||||
uint64 saved_object_graph_hash,
|
||||
uint64 checkpoint_hash) {
|
||||
return std::to_string(graph_def_program_hash) + "/" +
|
||||
std::to_string(signature_def_hash) + "/" +
|
||||
std::to_string(saved_object_graph_hash) + "/" +
|
||||
std::to_string(checkpoint_hash);
|
||||
}
|
||||
|
||||
std::string Singleprint(const FingerprintDef& fingerprint) {
|
||||
return Singleprint(
|
||||
fingerprint.graph_def_program_hash(), fingerprint.signature_def_hash(),
|
||||
fingerprint.saved_object_graph_hash(), fingerprint.checkpoint_hash());
|
||||
}
|
||||
|
||||
std::string Singleprint(absl::string_view export_dir) {
|
||||
FingerprintDef fingerprint = ReadSavedModelFingerprint(export_dir).value();
|
||||
return Singleprint(fingerprint);
|
||||
}
|
||||
|
||||
} // namespace tensorflow::saved_model::fingerprinting
|
||||
|
|
|
|||
|
|
@ -36,6 +36,13 @@ StatusOr<FingerprintDef> CreateFingerprintDef(const SavedModel& saved_model,
|
|||
StatusOr<FingerprintDef> ReadSavedModelFingerprint(
|
||||
absl::string_view export_dir);
|
||||
|
||||
// Canonical fingerprinting ID for a SavedModel.
|
||||
std::string Singleprint(uint64 graph_def_program_hash,
|
||||
uint64 signature_def_hash,
|
||||
uint64 saved_object_graph_hash, uint64 checkpoint_hash);
|
||||
std::string Singleprint(const FingerprintDef& fingerprint);
|
||||
std::string Singleprint(absl::string_view export_dir);
|
||||
|
||||
} // namespace tensorflow::saved_model::fingerprinting
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_H_
|
||||
|
|
|
|||
|
|
@ -139,5 +139,23 @@ TEST(FingerprintingTest, TestReadNonexistentFingerprint) {
|
|||
EXPECT_FALSE(ReadSavedModelFingerprint(export_dir).ok());
|
||||
}
|
||||
|
||||
TEST(FingerprintingTest, TestSingleprint) {
|
||||
const std::string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata",
|
||||
"VarsAndArithmeticObjectGraph");
|
||||
const std::string const_singleprint =
|
||||
"706963557435316516/5693392539583495303/12074714563970609759/"
|
||||
"10788359570789890102";
|
||||
EXPECT_EQ(Singleprint(export_dir), const_singleprint);
|
||||
TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_pb,
|
||||
ReadSavedModelFingerprint(export_dir));
|
||||
EXPECT_EQ(Singleprint(fingerprint_pb), const_singleprint);
|
||||
EXPECT_EQ(Singleprint(fingerprint_pb.graph_def_program_hash(),
|
||||
fingerprint_pb.signature_def_hash(),
|
||||
fingerprint_pb.saved_object_graph_hash(),
|
||||
fingerprint_pb.checkpoint_hash()),
|
||||
const_singleprint);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow::saved_model::fingerprinting
|
||||
|
|
|
|||
|
|
@ -76,6 +76,36 @@ class Fingerprint(object):
|
|||
proto.checkpoint_hash,
|
||||
proto.version)
|
||||
|
||||
def singleprint(self):
|
||||
"""Canonical fingerprinting ID for a SavedModel.
|
||||
|
||||
Uniquely identifies a SavedModel based on the regularized fingerprint
|
||||
attributes. (saved_model_checksum is sensitive to immaterial changes and
|
||||
thus non-deterministic.)
|
||||
|
||||
Returns:
|
||||
The string concatenation of `graph_def_program_hash`,
|
||||
`signature_def_hash`, and `saved_object_graph_hash`
|
||||
fingerprint attributes (separated by '/').
|
||||
|
||||
Raises:
|
||||
ValueError: If the fingerprint fields cannot be used to construct the
|
||||
singleprint.
|
||||
"""
|
||||
try:
|
||||
return fingerprinting_pywrap.Singleprint(self.graph_def_program_hash,
|
||||
self.signature_def_hash,
|
||||
self.saved_object_graph_hash,
|
||||
self.checkpoint_hash)
|
||||
except (TypeError, fingerprinting_pywrap.FingerprintException) as e:
|
||||
raise ValueError(
|
||||
f"Encounted invalid fingerprint values when constructing singleprint."
|
||||
f"graph_def_program_hash: {self.graph_def_program_hash}"
|
||||
f"signature_def_hash: {self.signature_def_hash}"
|
||||
f"saved_object_graph_hash: {self.saved_object_graph_hash}"
|
||||
f"checkpoint_hash: {self.checkpoint_hash}"
|
||||
f"{e}") from None
|
||||
|
||||
|
||||
@tf_export("saved_model.experimental.read_fingerprint", v1=[])
|
||||
def read_fingerprint(export_dir):
|
||||
|
|
@ -95,7 +125,7 @@ def read_fingerprint(export_dir):
|
|||
A `tf.saved_model.experimental.Fingerprint`.
|
||||
|
||||
Raises:
|
||||
FingerprintException: If no or an invalid fingerprint is found.
|
||||
FileNotFoundError: If no or an invalid fingerprint is found.
|
||||
"""
|
||||
try:
|
||||
fingerprint = fingerprinting_pywrap.ReadSavedModelFingerprint(export_dir)
|
||||
|
|
|
|||
|
|
@ -24,11 +24,12 @@ from tensorflow.core.config import flags
|
|||
from tensorflow.core.protobuf import fingerprint_pb2
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.saved_model import fingerprinting
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.saved_model.fingerprinting import read_fingerprint
|
||||
from tensorflow.python.saved_model.pywrap_saved_model import constants
|
||||
from tensorflow.python.trackable import autotrackable
|
||||
|
||||
|
|
@ -54,6 +55,14 @@ class FingerprintingTest(test.TestCase):
|
|||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
return root
|
||||
|
||||
def _create_model_with_data(self):
|
||||
root = autotrackable.AutoTrackable()
|
||||
root.x = constant_op.constant(1.0, dtype=dtypes.float32)
|
||||
root.f = def_function.function(
|
||||
lambda x: root.x * x,
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
|
||||
return root
|
||||
|
||||
def _read_fingerprint(self, filename):
|
||||
fingerprint_def = fingerprint_pb2.FingerprintDef()
|
||||
with file_io.FileIO(filename, "rb") as f:
|
||||
|
|
@ -119,7 +128,7 @@ class FingerprintingTest(test.TestCase):
|
|||
|
||||
def test_read_fingerprint_api(self):
|
||||
save_dir = self._create_saved_model()
|
||||
fingerprint = read_fingerprint(save_dir)
|
||||
fingerprint = fingerprinting.read_fingerprint(save_dir)
|
||||
|
||||
fingerprint_def = self._read_fingerprint(
|
||||
file_io.join(save_dir, constants.FINGERPRINT_FILENAME)
|
||||
|
|
@ -149,7 +158,26 @@ class FingerprintingTest(test.TestCase):
|
|||
def test_read_fingerprint_api_invalid(self):
|
||||
with self.assertRaisesRegex(FileNotFoundError,
|
||||
"SavedModel Fingerprint Error"):
|
||||
read_fingerprint("foo")
|
||||
fingerprinting.read_fingerprint("foo")
|
||||
|
||||
def test_valid_singleprint(self):
|
||||
save_dir = os.path.join(self.get_temp_dir(), "singleprint_model")
|
||||
save.save(self._create_model_with_data(), save_dir)
|
||||
fingerprint = fingerprinting.read_fingerprint(save_dir)
|
||||
singleprint = fingerprint.singleprint()
|
||||
|
||||
# checkpoint_hash is non-deterministic and not included
|
||||
self.assertRegex(singleprint,
|
||||
"/".join(["8947653168630125217", # graph_def_program_hash
|
||||
"13520770727385282311", # signature_def_hash
|
||||
"1613952301283913051" # saved_object_graph_hash
|
||||
]))
|
||||
|
||||
def test_invalid_singleprint(self):
|
||||
fingerprint = fingerprinting.Fingerprint()
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"Encounted invalid fingerprint values"):
|
||||
fingerprint.singleprint()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -91,6 +91,24 @@ void DefineFingerprintingModule(py::module main_module) {
|
|||
py::doc(
|
||||
"Loads the `fingerprint.pb` from `export_dir`, returns an error if "
|
||||
"there is none."));
|
||||
|
||||
m.def(
|
||||
"Singleprint",
|
||||
[](uint64 graph_def_program_hash, uint64 signature_def_hash,
|
||||
uint64 saved_object_graph_hash, uint64 checkpoint_hash) {
|
||||
StatusOr<std::string> singleprint = fingerprinting::Singleprint(
|
||||
graph_def_program_hash, signature_def_hash, saved_object_graph_hash,
|
||||
checkpoint_hash);
|
||||
if (singleprint.ok()) {
|
||||
return py::str(singleprint.value());
|
||||
}
|
||||
throw FingerprintException(
|
||||
std::string("Could not create singleprint from given values.")
|
||||
.c_str());
|
||||
},
|
||||
py::arg("graph_def_program_hash"), py::arg("signature_def_hash"),
|
||||
py::arg("saved_object_graph_hash"), py::arg("checkpoint_hash"),
|
||||
py::doc("Canonical fingerprinting ID for a SavedModel."));
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
|
|
|
|||
|
|
@ -60,6 +60,24 @@ class FingerprintingTest(test.TestCase):
|
|||
pywrap_fingerprinting.ReadSavedModelFingerprint(export_dir)
|
||||
self.assertRegex(str(excinfo.exception), "Could not read fingerprint.")
|
||||
|
||||
def test_read_saved_model_singleprint(self):
|
||||
export_dir = test.test_src_dir_path(
|
||||
"cc/saved_model/testdata/VarsAndArithmeticObjectGraph")
|
||||
fingerprint = fingerprint_pb2.FingerprintDef().FromString(
|
||||
pywrap_fingerprinting.ReadSavedModelFingerprint(export_dir))
|
||||
singleprint = pywrap_fingerprinting.Singleprint(
|
||||
fingerprint.graph_def_program_hash,
|
||||
fingerprint.signature_def_hash,
|
||||
fingerprint.saved_object_graph_hash,
|
||||
fingerprint.checkpoint_hash)
|
||||
# checkpoint_hash is non-deterministic and not included
|
||||
self.assertRegex(singleprint,
|
||||
"/".join([
|
||||
"706963557435316516", # graph_def_program_hash
|
||||
"5693392539583495303", # signature_def_hash
|
||||
"12074714563970609759", # saved_object_graph_hash
|
||||
]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -10,4 +10,8 @@ tf_class {
|
|||
name: "from_proto"
|
||||
argspec: "args=[\'cls\', \'proto\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "singleprint"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -459,6 +459,7 @@ tensorflow::metrics::CheckpointSize
|
|||
[//tensorflow/cc/saved_model:fingerprinting_impl] # SavedModel Fingerprinting
|
||||
tensorflow::saved_model::fingerprinting::CreateFingerprintDef
|
||||
tensorflow::saved_model::fingerprinting::ReadSavedModelFingerprint
|
||||
tensorflow::saved_model::fingerprinting::Singleprint
|
||||
|
||||
|
||||
[//tensorflow/compiler/jit:flags] # tfe
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user