mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Merge pull request #49879 from geetachavan1/cherrypicks_4KIML
Prevent memory overflow in ParseAttrValue from nested tensors.
This commit is contained in:
commit
b6e130be09
|
|
@ -38,6 +38,9 @@ namespace {
|
||||||
// Do not construct large tensors to compute their hash or compare for equality.
|
// Do not construct large tensors to compute their hash or compare for equality.
|
||||||
constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024; // 32mb
|
constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024; // 32mb
|
||||||
|
|
||||||
|
// Limit nesting of tensors to 100 deep to prevent memory overflow.
|
||||||
|
constexpr int kMaxTensorNestDepth = 100;
|
||||||
|
|
||||||
// Return the size of the tensor represented by this TensorProto. If shape is
|
// Return the size of the tensor represented by this TensorProto. If shape is
|
||||||
// not fully defined return -1.
|
// not fully defined return -1.
|
||||||
int64 TensorByteSize(const TensorProto& t) {
|
int64 TensorByteSize(const TensorProto& t) {
|
||||||
|
|
@ -224,6 +227,54 @@ string SummarizeFunc(const NameAttrList& func) {
|
||||||
return strings::StrCat(func.name(), "[", absl::StrJoin(entries, ", "), "]");
|
return strings::StrCat(func.name(), "[", absl::StrJoin(entries, ", "), "]");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ParseAttrValueHelper_TensorNestsUnderLimit(int limit, string to_parse) {
|
||||||
|
int nests = 0;
|
||||||
|
int maxed_out = to_parse.length();
|
||||||
|
int open_curly = to_parse.find('{');
|
||||||
|
int open_bracket = to_parse.find('<');
|
||||||
|
int close_curly = to_parse.find('}');
|
||||||
|
int close_bracket = to_parse.find('>');
|
||||||
|
if (open_curly == -1) {
|
||||||
|
open_curly = maxed_out;
|
||||||
|
}
|
||||||
|
if (open_bracket == -1) {
|
||||||
|
open_bracket = maxed_out;
|
||||||
|
}
|
||||||
|
int min = std::min(open_curly, open_bracket);
|
||||||
|
do {
|
||||||
|
if (open_curly == maxed_out && open_bracket == maxed_out) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (min == open_curly) {
|
||||||
|
nests += 1;
|
||||||
|
open_curly = to_parse.find('{', open_curly + 1);
|
||||||
|
if (open_curly == -1) {
|
||||||
|
open_curly = maxed_out;
|
||||||
|
}
|
||||||
|
} else if (min == open_bracket) {
|
||||||
|
nests += 1;
|
||||||
|
open_bracket = to_parse.find('<', open_bracket + 1);
|
||||||
|
if (open_bracket == -1) {
|
||||||
|
open_bracket = maxed_out;
|
||||||
|
}
|
||||||
|
} else if (min == close_curly) {
|
||||||
|
nests -= 1;
|
||||||
|
close_curly = to_parse.find('}', close_curly + 1);
|
||||||
|
if (close_curly == -1) {
|
||||||
|
close_curly = maxed_out;
|
||||||
|
}
|
||||||
|
} else if (min == close_bracket) {
|
||||||
|
nests -= 1;
|
||||||
|
close_bracket = to_parse.find('>', close_bracket + 1);
|
||||||
|
if (close_bracket == -1) {
|
||||||
|
close_bracket = maxed_out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
min = std::min({open_curly, open_bracket, close_curly, close_bracket});
|
||||||
|
} while (nests < 100);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
string SummarizeAttrValue(const AttrValue& attr_value) {
|
string SummarizeAttrValue(const AttrValue& attr_value) {
|
||||||
|
|
@ -448,7 +499,12 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
|
||||||
} else {
|
} else {
|
||||||
to_parse = strings::StrCat(field_name, ": ", text);
|
to_parse = strings::StrCat(field_name, ": ", text);
|
||||||
}
|
}
|
||||||
|
if (field_name == "tensor") {
|
||||||
|
if (!ParseAttrValueHelper_TensorNestsUnderLimit(kMaxTensorNestDepth,
|
||||||
|
to_parse)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
return ProtoParseFromString(to_parse, out);
|
return ProtoParseFromString(to_parse, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user