Add Flatten to core layers.

PiperOrigin-RevId: 168254118
This commit is contained in:
Francois Chollet 2017-09-11 11:00:26 -07:00 committed by TensorFlower Gardener
parent a6223c01a6
commit 80ed8afc02
10 changed files with 249 additions and 47 deletions

View File

@ -45,7 +45,7 @@ class ConditioningUtilsTest(test.TestCase):
array_ops.placeholder(dtypes.float32, (5, None)), array_ops.placeholder(dtypes.float32, (5, None)),
array_ops.placeholder(dtypes.float32, (5, 1))) array_ops.placeholder(dtypes.float32, (5, 1)))
with self.assertRaisesRegexp(ValueError, 'must have a least 2 dimensions.'): with self.assertRaisesRegexp(ValueError, 'expected min_ndim=2'):
conditioning_utils.condition_tensor( conditioning_utils.condition_tensor(
array_ops.placeholder(dtypes.float32, (5, 2)), array_ops.placeholder(dtypes.float32, (5, 2)),
array_ops.placeholder(dtypes.float32, (5))) array_ops.placeholder(dtypes.float32, (5)))

View File

@ -1435,30 +1435,7 @@ def flatten(inputs,
""" """
with ops.name_scope(scope, 'Flatten', [inputs]) as sc: with ops.name_scope(scope, 'Flatten', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs) inputs = ops.convert_to_tensor(inputs)
inputs_rank = inputs.get_shape().ndims outputs = core_layers.flatten(inputs)
if (inputs_rank is None) or (inputs_rank < 2):
raise ValueError('Inputs must have a least 2 dimensions.')
inputs_shape = array_ops.shape(inputs)
batch_dim = array_ops.slice(inputs_shape, [0], [1])
spatial_dims = array_ops.slice(inputs_shape, [1], [inputs_rank - 1])
flat_spatial_dim = math_ops.reduce_prod(spatial_dims)
flat_spatial_dim = array_ops.expand_dims(flat_spatial_dim, 0)
flat_shape = array_ops.concat([batch_dim, flat_spatial_dim], 0)
outputs = array_ops.reshape(inputs, flat_shape)
# Attempt to propagate shape information, if it is defined.
input_shape = inputs.get_shape().as_list()
batch_dim, spatial_dims = input_shape[0], input_shape[1:]
if all(spatial_dims):
outputs.set_shape([batch_dim,
functools.reduce(lambda x, y: x * y, spatial_dims)])
else:
outputs.set_shape([batch_dim, None])
return utils.collect_named_outputs(outputs_collections, sc, outputs) return utils.collect_named_outputs(outputs_collections, sc, outputs)

View File

@ -1399,7 +1399,7 @@ class FlattenTest(test.TestCase):
inputs = array_ops.placeholder(dtype=dtypes.float32) inputs = array_ops.placeholder(dtype=dtypes.float32)
inputs.set_shape(tensor_shape.TensorShape((5,))) inputs.set_shape(tensor_shape.TensorShape((5,)))
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
'must have a least 2 dimensions'): 'incompatible with the layer'):
_layers.flatten(inputs) _layers.flatten(inputs)
def testUnknownLastDim(self): def testUnknownLastDim(self):

View File

@ -456,7 +456,7 @@ class Permute(Layer):
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
class Flatten(Layer): class Flatten(tf_core_layers.Flatten, Layer):
"""Flattens the input. Does not affect the batch size. """Flattens the input. Does not affect the batch size.
Example: Example:
@ -472,26 +472,7 @@ class Flatten(Layer):
# now: model.output_shape == (None, 65536) # now: model.output_shape == (None, 65536)
``` ```
""" """
pass
def __init__(self, **kwargs):
super(Flatten, self).__init__(**kwargs)
self.input_spec = InputSpec(min_ndim=3)
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if not all(input_shape[1:]):
raise ValueError('The shape of the input to "Flatten" '
'is not fully defined '
'(got ' + str(input_shape[1:]) + '. '
'Make sure to pass a complete "input_shape" '
'or "batch_input_shape" argument to the first '
'layer in your model.')
return tensor_shape.TensorShape([input_shape[0], np.prod(input_shape[1:])])
def call(self, inputs):
outputs = K.batch_flatten(inputs)
outputs.set_shape(self._compute_output_shape(inputs.get_shape()))
return outputs
class RepeatVector(Layer): class RepeatVector(Layer):

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn from tensorflow.python.ops import nn
from tensorflow.python.ops import standard_ops from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variable_scope as vs
@ -337,6 +338,67 @@ def dropout(inputs,
return layer.apply(inputs, training=training) return layer.apply(inputs, training=training)
class Flatten(base.Layer):
"""Flattens an input tensor while preserving the batch axis (axis 0).
Examples:
```
x = tf.placeholder(shape=(None, 4, 4), dtype='float32')
y = Flatten()(x)
# now `y` has shape `(None, 16)`
x = tf.placeholder(shape=(None, 3, None), dtype='float32')
y = Flatten()(x)
# now `y` has shape `(None, None)`
```
"""
def __init__(self, **kwargs):
super(Flatten, self).__init__(**kwargs)
self.input_spec = base.InputSpec(min_ndim=2)
def call(self, inputs):
outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
outputs.set_shape(self._compute_output_shape(inputs.get_shape()))
return outputs
def _compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
output_shape = [input_shape[0]]
if all(input_shape[1:]):
output_shape += [np.prod(input_shape[1:])]
else:
output_shape += [None]
return tensor_shape.TensorShape(output_shape)
def flatten(inputs, name=None):
"""Flattens an input tensor while preserving the batch axis (axis 0).
Arguments:
inputs: Tensor input.
name: The name of the layer (string).
Returns:
Reshaped tensor.
Examples:
```
x = tf.placeholder(shape=(None, 4, 4), dtype='float32')
y = flatten(x)
# now `y` has shape `(None, 16)`
x = tf.placeholder(shape=(None, 3, None), dtype='float32')
y = flatten(x)
# now `y` has shape `(None, None)`
```
"""
layer = Flatten(name=name)
return layer.apply(inputs)
# Aliases # Aliases
FullyConnected = Dense FullyConnected = Dense

View File

@ -391,5 +391,56 @@ class DropoutTest(test.TestCase):
self.assertAllClose(np.ones((5, 5)), np_output) self.assertAllClose(np.ones((5, 5)), np_output)
class FlattenTest(test.TestCase):
def testCreateFlatten(self):
with self.test_session() as sess:
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((3, 2, 3))})
self.assertEqual(list(np_output.shape), [3, 6])
self.assertEqual(y.get_shape().as_list(), [None, 6])
x = array_ops.placeholder(shape=(1, 2, 3, 2), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((1, 2, 3, 2))})
self.assertEqual(list(np_output.shape), [1, 12])
self.assertEqual(y.get_shape().as_list(), [1, 12])
def testComputeShape(self):
shape = core_layers.Flatten()._compute_output_shape((1, 2, 3, 2))
self.assertEqual(shape.as_list(), [1, 12])
shape = core_layers.Flatten()._compute_output_shape((None, 3, 2))
self.assertEqual(shape.as_list(), [None, 6])
shape = core_layers.Flatten()._compute_output_shape((None, 3, None))
self.assertEqual(shape.as_list(), [None, None])
def testFunctionalFlatten(self):
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
y = core_layers.flatten(x, name='flatten')
self.assertEqual(y.get_shape().as_list(), [None, 6])
def testFlattenValueError(self):
x = array_ops.placeholder(shape=(None,), dtype='float32')
with self.assertRaises(ValueError):
core_layers.Flatten()(x)
def testFlattenUnknownAxes(self):
with self.test_session() as sess:
x = array_ops.placeholder(shape=(5, None, None), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((5, 2, 3))})
self.assertEqual(list(np_output.shape), [5, 6])
self.assertEqual(y.get_shape().as_list(), [5, None])
x = array_ops.placeholder(shape=(5, None, 2), dtype='float32')
y = core_layers.Flatten()(x)
np_output = sess.run(y, feed_dict={x: np.zeros((5, 3, 2))})
self.assertEqual(list(np_output.shape), [5, 6])
self.assertEqual(y.get_shape().as_list(), [5, None])
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -18,6 +18,7 @@
@@Dense @@Dense
@@Dropout @@Dropout
@@Flatten
@@Conv1D @@Conv1D
@@Conv2D @@Conv2D
@@Conv3D @@Conv3D
@ -39,6 +40,7 @@
@@dense @@dense
@@dropout @@dropout
@@flatten
@@conv1d @@conv1d
@@conv2d @@conv2d
@@conv3d @@conv3d
@ -71,9 +73,11 @@ from tensorflow.python.layers.base import InputSpec
# Core layers. # Core layers.
from tensorflow.python.layers.core import Dense from tensorflow.python.layers.core import Dense
from tensorflow.python.layers.core import Dropout from tensorflow.python.layers.core import Dropout
from tensorflow.python.layers.core import Flatten
from tensorflow.python.layers.core import dense from tensorflow.python.layers.core import dense
from tensorflow.python.layers.core import dropout from tensorflow.python.layers.core import dropout
from tensorflow.python.layers.core import flatten
# Convolutional layers. # Convolutional layers.
from tensorflow.python.layers.convolutional import SeparableConv2D from tensorflow.python.layers.convolutional import SeparableConv2D

View File

@ -1,6 +1,7 @@
path: "tensorflow.keras.layers.Flatten" path: "tensorflow.keras.layers.Flatten"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Flatten\'>" is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Flatten\'>"
is_instance: "<class \'tensorflow.python.layers.core.Flatten\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>" is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>" is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"

View File

@ -0,0 +1,118 @@
path: "tensorflow.layers.Flatten"
tf_class {
is_instance: "<class \'tensorflow.python.layers.core.Flatten\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<type \'object\'>"
member {
name: "graph"
mtype: "<type \'property\'>"
}
member {
name: "input"
mtype: "<type \'property\'>"
}
member {
name: "input_shape"
mtype: "<type \'property\'>"
}
member {
name: "losses"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "output"
mtype: "<type \'property\'>"
}
member {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "scope_name"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
name: "variables"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_loss"
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_update"
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_variable"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "build"
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_losses_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -44,6 +44,10 @@ tf_module {
name: "Dropout" name: "Dropout"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "Flatten"
mtype: "<type \'type\'>"
}
member { member {
name: "InputSpec" name: "InputSpec"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
@ -120,6 +124,10 @@ tf_module {
name: "dropout" name: "dropout"
argspec: "args=[\'inputs\', \'rate\', \'noise_shape\', \'seed\', \'training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\', \'None\', \'False\', \'None\'], " argspec: "args=[\'inputs\', \'rate\', \'noise_shape\', \'seed\', \'training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\', \'None\', \'False\', \'None\'], "
} }
member_method {
name: "flatten"
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "max_pooling1d" name: "max_pooling1d"
argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], "