mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Added PlaceholderWithDefault to the list of known placeholder types
Use PartialTensorShape instead of TensorShapes to better handle partially known shapes PiperOrigin-RevId: 157657664
This commit is contained in:
parent
0462416f64
commit
b09932d749
|
|
@ -42,7 +42,8 @@ bool IsMerge(const NodeDef& node) {
|
||||||
|
|
||||||
bool IsPlaceholder(const NodeDef& node) {
|
bool IsPlaceholder(const NodeDef& node) {
|
||||||
const auto op = node.op();
|
const auto op = node.op();
|
||||||
return op == "Placeholder" || op == "PlaceholderV2";
|
return op == "Placeholder" || op == "PlaceholderV2" ||
|
||||||
|
op == "PlaceholderWithDefault";
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsTranspose(const NodeDef& node) {
|
bool IsTranspose(const NodeDef& node) {
|
||||||
|
|
|
||||||
|
|
@ -118,10 +118,10 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) {
|
||||||
std::vector<OpInfo::TensorProperties> input =
|
std::vector<OpInfo::TensorProperties> input =
|
||||||
properties.GetInputProperties(node.name());
|
properties.GetInputProperties(node.name());
|
||||||
CHECK_EQ(1, input.size());
|
CHECK_EQ(1, input.size());
|
||||||
const TensorShapeProto shape = input[0].shape();
|
|
||||||
|
|
||||||
|
const TensorShapeProto shape = input[0].shape();
|
||||||
// Materialize the shapes using constants whenever possible.
|
// Materialize the shapes using constants whenever possible.
|
||||||
TensorShape shp(shape);
|
PartialTensorShape shp(shape);
|
||||||
if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
|
if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
|
||||||
bool valid = true;
|
bool valid = true;
|
||||||
Tensor value(type);
|
Tensor value(type);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user