mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Added PlaceholderWithDefault to the list of known placeholder types
Use PartialTensorShape instead of TensorShapes to better handle partially known shapes PiperOrigin-RevId: 157657664
This commit is contained in:
parent
0462416f64
commit
b09932d749
|
|
@ -42,7 +42,8 @@ bool IsMerge(const NodeDef& node) {
|
|||
|
||||
bool IsPlaceholder(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Placeholder" || op == "PlaceholderV2";
|
||||
return op == "Placeholder" || op == "PlaceholderV2" ||
|
||||
op == "PlaceholderWithDefault";
|
||||
}
|
||||
|
||||
bool IsTranspose(const NodeDef& node) {
|
||||
|
|
|
|||
|
|
@ -118,10 +118,10 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) {
|
|||
std::vector<OpInfo::TensorProperties> input =
|
||||
properties.GetInputProperties(node.name());
|
||||
CHECK_EQ(1, input.size());
|
||||
const TensorShapeProto shape = input[0].shape();
|
||||
|
||||
const TensorShapeProto shape = input[0].shape();
|
||||
// Materialize the shapes using constants whenever possible.
|
||||
TensorShape shp(shape);
|
||||
PartialTensorShape shp(shape);
|
||||
if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) {
|
||||
bool valid = true;
|
||||
Tensor value(type);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user