mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Automated Code Change
PiperOrigin-RevId: 826298597
This commit is contained in:
parent
b2334ac330
commit
e61bac51b1
|
|
@ -43,26 +43,27 @@ namespace {
|
|||
|
||||
constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt";
|
||||
|
||||
string DefaultApiDefDir() {
|
||||
std::string DefaultApiDefDir() {
|
||||
return GetDataDependencyFilepath(
|
||||
io::JoinPath("tensorflow", "core", "api_def", "base_api"));
|
||||
}
|
||||
|
||||
string PythonApiDefDir() {
|
||||
std::string PythonApiDefDir() {
|
||||
return GetDataDependencyFilepath(
|
||||
io::JoinPath("tensorflow", "core", "api_def", "python_api"));
|
||||
}
|
||||
|
||||
// Reads golden ApiDef files and returns a map from file name to ApiDef file
|
||||
// contents.
|
||||
void GetGoldenApiDefs(Env* env, const string& api_files_dir,
|
||||
std::unordered_map<string, ApiDef>* name_to_api_def) {
|
||||
std::vector<string> matching_paths;
|
||||
void GetGoldenApiDefs(
|
||||
Env* env, const std::string& api_files_dir,
|
||||
std::unordered_map<std::string, ApiDef>* name_to_api_def) {
|
||||
std::vector<std::string> matching_paths;
|
||||
TF_CHECK_OK(env->GetMatchingPaths(
|
||||
io::JoinPath(api_files_dir, kApiDefFilePattern), &matching_paths));
|
||||
|
||||
for (auto& file_path : matching_paths) {
|
||||
string file_contents;
|
||||
std::string file_contents;
|
||||
TF_CHECK_OK(ReadFileToString(env, file_path, &file_contents));
|
||||
file_contents = PBTxtFromMultiline(file_contents);
|
||||
|
||||
|
|
@ -76,8 +77,9 @@ void GetGoldenApiDefs(Env* env, const string& api_files_dir,
|
|||
}
|
||||
|
||||
void TestAllApiDefsHaveCorrespondingOp(
|
||||
const OpList& ops, const std::unordered_map<string, ApiDef>& api_defs_map) {
|
||||
std::unordered_set<string> op_names;
|
||||
const OpList& ops,
|
||||
const std::unordered_map<std::string, ApiDef>& api_defs_map) {
|
||||
std::unordered_set<std::string> op_names;
|
||||
for (const auto& op : ops.op()) {
|
||||
op_names.insert(op.name());
|
||||
}
|
||||
|
|
@ -89,7 +91,8 @@ void TestAllApiDefsHaveCorrespondingOp(
|
|||
}
|
||||
|
||||
void TestAllApiDefInputArgsAreValid(
|
||||
const OpList& ops, const std::unordered_map<string, ApiDef>& api_defs_map) {
|
||||
const OpList& ops,
|
||||
const std::unordered_map<std::string, ApiDef>& api_defs_map) {
|
||||
for (const auto& op : ops.op()) {
|
||||
const auto api_def_iter = api_defs_map.find(op.name());
|
||||
if (api_def_iter == api_defs_map.end()) {
|
||||
|
|
@ -113,7 +116,8 @@ void TestAllApiDefInputArgsAreValid(
|
|||
}
|
||||
|
||||
void TestAllApiDefOutputArgsAreValid(
|
||||
const OpList& ops, const std::unordered_map<string, ApiDef>& api_defs_map) {
|
||||
const OpList& ops,
|
||||
const std::unordered_map<std::string, ApiDef>& api_defs_map) {
|
||||
for (const auto& op : ops.op()) {
|
||||
const auto api_def_iter = api_defs_map.find(op.name());
|
||||
if (api_def_iter == api_defs_map.end()) {
|
||||
|
|
@ -137,7 +141,8 @@ void TestAllApiDefOutputArgsAreValid(
|
|||
}
|
||||
|
||||
void TestAllApiDefAttributeNamesAreValid(
|
||||
const OpList& ops, const std::unordered_map<string, ApiDef>& api_defs_map) {
|
||||
const OpList& ops,
|
||||
const std::unordered_map<std::string, ApiDef>& api_defs_map) {
|
||||
for (const auto& op : ops.op()) {
|
||||
const auto api_def_iter = api_defs_map.find(op.name());
|
||||
if (api_def_iter == api_defs_map.end()) {
|
||||
|
|
@ -159,7 +164,7 @@ void TestAllApiDefAttributeNamesAreValid(
|
|||
}
|
||||
|
||||
void TestDeprecatedAttributesSetCorrectly(
|
||||
const std::unordered_map<string, ApiDef>& api_defs_map) {
|
||||
const std::unordered_map<std::string, ApiDef>& api_defs_map) {
|
||||
for (const auto& name_and_api_def : api_defs_map) {
|
||||
int num_deprecated_endpoints = 0;
|
||||
const auto& api_def = name_and_api_def.second;
|
||||
|
|
@ -186,7 +191,7 @@ void TestDeprecatedAttributesSetCorrectly(
|
|||
}
|
||||
|
||||
void TestDeprecationVersionSetCorrectly(
|
||||
const std::unordered_map<string, ApiDef>& api_defs_map) {
|
||||
const std::unordered_map<std::string, ApiDef>& api_defs_map) {
|
||||
for (const auto& name_and_api_def : api_defs_map) {
|
||||
const auto& name = name_and_api_def.first;
|
||||
const auto& api_def = name_and_api_def.second;
|
||||
|
|
@ -205,13 +210,13 @@ class BaseApiTest : public ::testing::Test {
|
|||
protected:
|
||||
BaseApiTest() {
|
||||
OpRegistry::Global()->Export(false, &ops_);
|
||||
const std::vector<string> multi_line_fields = {"description"};
|
||||
const std::vector<std::string> multi_line_fields = {"description"};
|
||||
|
||||
Env* env = Env::Default();
|
||||
GetGoldenApiDefs(env, DefaultApiDefDir(), &api_defs_map_);
|
||||
}
|
||||
OpList ops_;
|
||||
std::unordered_map<string, ApiDef> api_defs_map_;
|
||||
std::unordered_map<std::string, ApiDef> api_defs_map_;
|
||||
};
|
||||
|
||||
// Check that all ops have an ApiDef.
|
||||
|
|
@ -233,7 +238,7 @@ TEST_F(BaseApiTest, AllApiDefsHaveCorrespondingOp) {
|
|||
TestAllApiDefsHaveCorrespondingOp(ops_, api_defs_map_);
|
||||
}
|
||||
|
||||
string GetOpDefHasDocStringError(const string& op_name) {
|
||||
std::string GetOpDefHasDocStringError(const std::string& op_name) {
|
||||
return strings::Printf(
|
||||
"OpDef for %s has a doc string. "
|
||||
"Doc strings must be defined in ApiDef instead of OpDef. "
|
||||
|
|
@ -301,13 +306,13 @@ class PythonApiTest : public ::testing::Test {
|
|||
protected:
|
||||
PythonApiTest() {
|
||||
OpRegistry::Global()->Export(false, &ops_);
|
||||
const std::vector<string> multi_line_fields = {"description"};
|
||||
const std::vector<std::string> multi_line_fields = {"description"};
|
||||
|
||||
Env* env = Env::Default();
|
||||
GetGoldenApiDefs(env, PythonApiDefDir(), &api_defs_map_);
|
||||
}
|
||||
OpList ops_;
|
||||
std::unordered_map<string, ApiDef> api_defs_map_;
|
||||
std::unordered_map<std::string, ApiDef> api_defs_map_;
|
||||
};
|
||||
|
||||
// Check that ApiDefs have a corresponding op.
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ bool CheckDocsMatch(const OpDef& op1, const OpDef& op2) {
|
|||
|
||||
// Returns true if descriptions and summaries in op match a
|
||||
// given single doc-string.
|
||||
bool ValidateOpDocs(const OpDef& op, const string& doc) {
|
||||
bool ValidateOpDocs(const OpDef& op, const std::string& doc) {
|
||||
OpDefBuilder b(op.name());
|
||||
// We don't really care about type we use for arguments and
|
||||
// attributes. We just want to make sure attribute and argument names
|
||||
|
|
@ -146,28 +146,28 @@ bool ValidateOpDocs(const OpDef& op, const string& doc) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
string RemoveDoc(const OpDef& op, const string& file_contents,
|
||||
size_t start_location) {
|
||||
std::string RemoveDoc(const OpDef& op, const std::string& file_contents,
|
||||
size_t start_location) {
|
||||
// Look for a line starting with .Doc( after the REGISTER_OP.
|
||||
const auto doc_start_location = file_contents.find(kDocStart, start_location);
|
||||
const string format_error = strings::Printf(
|
||||
const std::string format_error = strings::Printf(
|
||||
"Could not find %s doc for removal. Make sure the doc is defined with "
|
||||
"'%s' prefix and '%s' suffix or remove the doc manually.",
|
||||
op.name().c_str(), kDocStart, kDocEnd);
|
||||
if (doc_start_location == string::npos) {
|
||||
if (doc_start_location == std::string::npos) {
|
||||
std::cerr << format_error << std::endl;
|
||||
LOG(ERROR) << "Didn't find doc start";
|
||||
return file_contents;
|
||||
}
|
||||
const auto doc_end_location = file_contents.find(kDocEnd, doc_start_location);
|
||||
if (doc_end_location == string::npos) {
|
||||
if (doc_end_location == std::string::npos) {
|
||||
LOG(ERROR) << "Didn't find doc start";
|
||||
std::cerr << format_error << std::endl;
|
||||
return file_contents;
|
||||
}
|
||||
|
||||
const auto doc_start_size = sizeof(kDocStart) - 1;
|
||||
string doc_text = file_contents.substr(
|
||||
std::string doc_text = file_contents.substr(
|
||||
doc_start_location + doc_start_size,
|
||||
doc_end_location - doc_start_location - doc_start_size);
|
||||
|
||||
|
|
@ -189,12 +189,12 @@ namespace {
|
|||
// Remove .Doc calls that follow REGISTER_OP calls for the given ops.
|
||||
// We search for REGISTER_OP calls in the given op_files list.
|
||||
void RemoveDocs(const std::vector<const OpDef*>& ops,
|
||||
const std::vector<string>& op_files) {
|
||||
const std::vector<std::string>& op_files) {
|
||||
// Set of ops that we already found REGISTER_OP calls for.
|
||||
std::set<string> processed_ops;
|
||||
std::set<std::string> processed_ops;
|
||||
|
||||
for (const auto& file : op_files) {
|
||||
string file_contents;
|
||||
std::string file_contents;
|
||||
bool file_contents_updated = false;
|
||||
TF_CHECK_OK(ReadFileToString(Env::Default(), file, &file_contents));
|
||||
|
||||
|
|
@ -203,11 +203,11 @@ void RemoveDocs(const std::vector<const OpDef*>& ops,
|
|||
// We already found REGISTER_OP call for this op in another file.
|
||||
continue;
|
||||
}
|
||||
string register_call =
|
||||
std::string register_call =
|
||||
strings::Printf("REGISTER_OP(\"%s\")", op->name().c_str());
|
||||
const auto register_call_location = file_contents.find(register_call);
|
||||
// Find REGISTER_OP(OpName) call.
|
||||
if (register_call_location == string::npos) {
|
||||
if (register_call_location == std::string::npos) {
|
||||
continue;
|
||||
}
|
||||
std::cout << "Removing .Doc call for " << op->name() << " from " << file
|
||||
|
|
@ -228,11 +228,11 @@ void RemoveDocs(const std::vector<const OpDef*>& ops,
|
|||
|
||||
// Returns ApiDefs text representation in multi-line format
|
||||
// constructed based on the given op.
|
||||
string CreateApiDef(const OpDef& op) {
|
||||
std::string CreateApiDef(const OpDef& op) {
|
||||
ApiDefs api_defs;
|
||||
FillBaseApiDef(api_defs.add_op(), op);
|
||||
|
||||
const std::vector<string> multi_line_fields = {"description"};
|
||||
const std::vector<std::string> multi_line_fields = {"description"};
|
||||
std::string new_api_defs_str;
|
||||
::tensorflow::protobuf::TextFormat::PrintToString(api_defs,
|
||||
&new_api_defs_str);
|
||||
|
|
@ -242,8 +242,8 @@ string CreateApiDef(const OpDef& op) {
|
|||
// Creates ApiDef files for any new ops.
|
||||
// If op_file_pattern is not empty, then also removes .Doc calls from
|
||||
// new op registrations in these files.
|
||||
void CreateApiDefs(const OpList& ops, const string& api_def_dir,
|
||||
const string& op_file_pattern) {
|
||||
void CreateApiDefs(const OpList& ops, const std::string& api_def_dir,
|
||||
const std::string& op_file_pattern) {
|
||||
auto* excluded_ops = GetExcludedOps();
|
||||
std::vector<const OpDef*> new_ops_with_docs;
|
||||
|
||||
|
|
@ -252,8 +252,8 @@ void CreateApiDefs(const OpList& ops, const string& api_def_dir,
|
|||
continue;
|
||||
}
|
||||
// Form the expected ApiDef path.
|
||||
string file_path =
|
||||
io::JoinPath(tensorflow::string(api_def_dir), kApiDefFileFormat);
|
||||
std::string file_path =
|
||||
io::JoinPath(std::string(api_def_dir), kApiDefFileFormat);
|
||||
file_path = strings::Printf(file_path.c_str(), op.name().c_str());
|
||||
|
||||
// Create ApiDef if it doesn't exist.
|
||||
|
|
@ -268,7 +268,7 @@ void CreateApiDefs(const OpList& ops, const string& api_def_dir,
|
|||
}
|
||||
}
|
||||
if (!op_file_pattern.empty()) {
|
||||
std::vector<string> op_files;
|
||||
std::vector<std::string> op_files;
|
||||
TF_CHECK_OK(Env::Default()->GetMatchingPaths(op_file_pattern, &op_files));
|
||||
RemoveDocs(new_ops_with_docs, op_files);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,14 +23,14 @@ namespace tensorflow {
|
|||
|
||||
// Returns ApiDefs text representation in multi-line format
|
||||
// constructed based on the given op.
|
||||
string CreateApiDef(const OpDef& op);
|
||||
std::string CreateApiDef(const OpDef& op);
|
||||
|
||||
// Removes .Doc call for the given op.
|
||||
// If unsuccessful, returns original file_contents and prints an error.
|
||||
// start_location - We search for .Doc call starting at this location
|
||||
// in file_contents.
|
||||
string RemoveDoc(const OpDef& op, const string& file_contents,
|
||||
size_t start_location);
|
||||
std::string RemoveDoc(const OpDef& op, const std::string& file_contents,
|
||||
size_t start_location);
|
||||
|
||||
// Creates api_def_*.pbtxt files for any new ops (i.e. ops that don't have an
|
||||
// api_def_*.pbtxt file yet).
|
||||
|
|
@ -38,8 +38,8 @@ string RemoveDoc(const OpDef& op, const string& file_contents,
|
|||
// look for a REGISTER_OP call for the new ops and removes corresponding
|
||||
// .Doc() calls since the newly generated api_def_*.pbtxt files will
|
||||
// store the doc strings.
|
||||
void CreateApiDefs(const OpList& ops, const string& api_def_dir,
|
||||
const string& op_file_pattern);
|
||||
void CreateApiDefs(const OpList& ops, const std::string& api_def_dir,
|
||||
const std::string& op_file_pattern);
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ limitations under the License.
|
|||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::string api_files_dir;
|
||||
tensorflow::string op_file_pattern;
|
||||
std::string api_files_dir;
|
||||
std::string op_file_pattern;
|
||||
std::vector<tensorflow::Flag> flag_list = {
|
||||
tensorflow::Flag("api_def_dir", &api_files_dir,
|
||||
"Base directory of api_def*.pbtxt files."),
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ namespace tensorflow {
|
|||
namespace {
|
||||
|
||||
TEST(UpdateApiDefTest, TestRemoveDocSingleOp) {
|
||||
const string op_def_text = R"opdef(
|
||||
const std::string op_def_text = R"opdef(
|
||||
REGISTER_OP("Op1")
|
||||
.Input("a: T")
|
||||
.Output("output: T")
|
||||
|
|
@ -32,7 +32,7 @@ REGISTER_OP("Op1")
|
|||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
)opdef";
|
||||
|
||||
const string op_def_text_with_doc = R"opdef(
|
||||
const std::string op_def_text_with_doc = R"opdef(
|
||||
REGISTER_OP("Op1")
|
||||
.Input("a: T")
|
||||
.Output("output: T")
|
||||
|
|
@ -50,7 +50,7 @@ output: Description for output.
|
|||
)doc");
|
||||
)opdef";
|
||||
|
||||
const string op_text = R"(
|
||||
const std::string op_text = R"(
|
||||
name: "Op1"
|
||||
input_arg {
|
||||
name: "a"
|
||||
|
|
@ -75,7 +75,7 @@ description: "Description\nfor Op1."
|
|||
}
|
||||
|
||||
TEST(UpdateApiDefTest, TestRemoveDocMultipleOps) {
|
||||
const string op_def_text = R"opdef(
|
||||
const std::string op_def_text = R"opdef(
|
||||
REGISTER_OP("Op1")
|
||||
.Input("a: T")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
|
@ -89,7 +89,7 @@ REGISTER_OP("Op3")
|
|||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
)opdef";
|
||||
|
||||
const string op_def_text_with_doc = R"opdef(
|
||||
const std::string op_def_text_with_doc = R"opdef(
|
||||
REGISTER_OP("Op1")
|
||||
.Input("a: T")
|
||||
.Doc(R"doc(
|
||||
|
|
@ -112,21 +112,21 @@ Summary for Op3.
|
|||
)doc");
|
||||
)opdef";
|
||||
|
||||
const string op1_text = R"(
|
||||
const std::string op1_text = R"(
|
||||
name: "Op1"
|
||||
input_arg {
|
||||
name: "a"
|
||||
}
|
||||
summary: "Summary for Op1."
|
||||
)";
|
||||
const string op2_text = R"(
|
||||
const std::string op2_text = R"(
|
||||
name: "Op2"
|
||||
input_arg {
|
||||
name: "a"
|
||||
}
|
||||
summary: "Summary for Op2."
|
||||
)";
|
||||
const string op3_text = R"(
|
||||
const std::string op3_text = R"(
|
||||
name: "Op3"
|
||||
input_arg {
|
||||
name: "c"
|
||||
|
|
@ -138,12 +138,12 @@ summary: "Summary for Op3."
|
|||
protobuf::TextFormat::ParseFromString(op2_text, &op2); // NOLINT
|
||||
protobuf::TextFormat::ParseFromString(op3_text, &op3); // NOLINT
|
||||
|
||||
string updated_text =
|
||||
std::string updated_text =
|
||||
RemoveDoc(op2, op_def_text_with_doc,
|
||||
op_def_text_with_doc.find("Op2") /* start_location */);
|
||||
EXPECT_EQ(string::npos, updated_text.find("Summary for Op2"));
|
||||
EXPECT_NE(string::npos, updated_text.find("Summary for Op1"));
|
||||
EXPECT_NE(string::npos, updated_text.find("Summary for Op3"));
|
||||
EXPECT_EQ(std::string::npos, updated_text.find("Summary for Op2"));
|
||||
EXPECT_NE(std::string::npos, updated_text.find("Summary for Op1"));
|
||||
EXPECT_NE(std::string::npos, updated_text.find("Summary for Op3"));
|
||||
|
||||
updated_text = RemoveDoc(op3, updated_text,
|
||||
updated_text.find("Op3") /* start_location */);
|
||||
|
|
@ -153,7 +153,7 @@ summary: "Summary for Op3."
|
|||
}
|
||||
|
||||
TEST(UpdateApiDefTest, TestCreateApiDef) {
|
||||
const string op_text = R"(
|
||||
const std::string op_text = R"(
|
||||
name: "Op1"
|
||||
input_arg {
|
||||
name: "a"
|
||||
|
|
@ -173,7 +173,7 @@ description: "Description\nfor Op1."
|
|||
OpDef op;
|
||||
protobuf::TextFormat::ParseFromString(op_text, &op); // NOLINT
|
||||
|
||||
const string expected_api_def = R"(op {
|
||||
const std::string expected_api_def = R"(op {
|
||||
graph_op_name: "Op1"
|
||||
in_arg {
|
||||
name: "a"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user