mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Add Flatten to core layers.
PiperOrigin-RevId: 168254118
This commit is contained in:
parent
a6223c01a6
commit
80ed8afc02
|
|
@ -45,7 +45,7 @@ class ConditioningUtilsTest(test.TestCase):
|
||||||
array_ops.placeholder(dtypes.float32, (5, None)),
|
array_ops.placeholder(dtypes.float32, (5, None)),
|
||||||
array_ops.placeholder(dtypes.float32, (5, 1)))
|
array_ops.placeholder(dtypes.float32, (5, 1)))
|
||||||
|
|
||||||
with self.assertRaisesRegexp(ValueError, 'must have a least 2 dimensions.'):
|
with self.assertRaisesRegexp(ValueError, 'expected min_ndim=2'):
|
||||||
conditioning_utils.condition_tensor(
|
conditioning_utils.condition_tensor(
|
||||||
array_ops.placeholder(dtypes.float32, (5, 2)),
|
array_ops.placeholder(dtypes.float32, (5, 2)),
|
||||||
array_ops.placeholder(dtypes.float32, (5)))
|
array_ops.placeholder(dtypes.float32, (5)))
|
||||||
|
|
|
||||||
|
|
@ -1435,30 +1435,7 @@ def flatten(inputs,
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(scope, 'Flatten', [inputs]) as sc:
|
with ops.name_scope(scope, 'Flatten', [inputs]) as sc:
|
||||||
inputs = ops.convert_to_tensor(inputs)
|
inputs = ops.convert_to_tensor(inputs)
|
||||||
inputs_rank = inputs.get_shape().ndims
|
outputs = core_layers.flatten(inputs)
|
||||||
if (inputs_rank is None) or (inputs_rank < 2):
|
|
||||||
raise ValueError('Inputs must have a least 2 dimensions.')
|
|
||||||
|
|
||||||
inputs_shape = array_ops.shape(inputs)
|
|
||||||
|
|
||||||
batch_dim = array_ops.slice(inputs_shape, [0], [1])
|
|
||||||
spatial_dims = array_ops.slice(inputs_shape, [1], [inputs_rank - 1])
|
|
||||||
|
|
||||||
flat_spatial_dim = math_ops.reduce_prod(spatial_dims)
|
|
||||||
flat_spatial_dim = array_ops.expand_dims(flat_spatial_dim, 0)
|
|
||||||
flat_shape = array_ops.concat([batch_dim, flat_spatial_dim], 0)
|
|
||||||
|
|
||||||
outputs = array_ops.reshape(inputs, flat_shape)
|
|
||||||
|
|
||||||
# Attempt to propagate shape information, if it is defined.
|
|
||||||
input_shape = inputs.get_shape().as_list()
|
|
||||||
batch_dim, spatial_dims = input_shape[0], input_shape[1:]
|
|
||||||
if all(spatial_dims):
|
|
||||||
outputs.set_shape([batch_dim,
|
|
||||||
functools.reduce(lambda x, y: x * y, spatial_dims)])
|
|
||||||
else:
|
|
||||||
outputs.set_shape([batch_dim, None])
|
|
||||||
|
|
||||||
return utils.collect_named_outputs(outputs_collections, sc, outputs)
|
return utils.collect_named_outputs(outputs_collections, sc, outputs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1399,7 +1399,7 @@ class FlattenTest(test.TestCase):
|
||||||
inputs = array_ops.placeholder(dtype=dtypes.float32)
|
inputs = array_ops.placeholder(dtype=dtypes.float32)
|
||||||
inputs.set_shape(tensor_shape.TensorShape((5,)))
|
inputs.set_shape(tensor_shape.TensorShape((5,)))
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(ValueError,
|
||||||
'must have a least 2 dimensions'):
|
'incompatible with the layer'):
|
||||||
_layers.flatten(inputs)
|
_layers.flatten(inputs)
|
||||||
|
|
||||||
def testUnknownLastDim(self):
|
def testUnknownLastDim(self):
|
||||||
|
|
|
||||||
|
|
@ -456,7 +456,7 @@ class Permute(Layer):
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
|
||||||
class Flatten(Layer):
|
class Flatten(tf_core_layers.Flatten, Layer):
|
||||||
"""Flattens the input. Does not affect the batch size.
|
"""Flattens the input. Does not affect the batch size.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
@ -472,26 +472,7 @@ class Flatten(Layer):
|
||||||
# now: model.output_shape == (None, 65536)
|
# now: model.output_shape == (None, 65536)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
pass
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(Flatten, self).__init__(**kwargs)
|
|
||||||
self.input_spec = InputSpec(min_ndim=3)
|
|
||||||
|
|
||||||
def _compute_output_shape(self, input_shape):
|
|
||||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
|
||||||
if not all(input_shape[1:]):
|
|
||||||
raise ValueError('The shape of the input to "Flatten" '
|
|
||||||
'is not fully defined '
|
|
||||||
'(got ' + str(input_shape[1:]) + '. '
|
|
||||||
'Make sure to pass a complete "input_shape" '
|
|
||||||
'or "batch_input_shape" argument to the first '
|
|
||||||
'layer in your model.')
|
|
||||||
return tensor_shape.TensorShape([input_shape[0], np.prod(input_shape[1:])])
|
|
||||||
|
|
||||||
def call(self, inputs):
|
|
||||||
outputs = K.batch_flatten(inputs)
|
|
||||||
outputs.set_shape(self._compute_output_shape(inputs.get_shape()))
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class RepeatVector(Layer):
|
class RepeatVector(Layer):
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
from tensorflow.python.ops import standard_ops
|
from tensorflow.python.ops import standard_ops
|
||||||
from tensorflow.python.ops import variable_scope as vs
|
from tensorflow.python.ops import variable_scope as vs
|
||||||
|
|
@ -337,6 +338,67 @@ def dropout(inputs,
|
||||||
return layer.apply(inputs, training=training)
|
return layer.apply(inputs, training=training)
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(base.Layer):
|
||||||
|
"""Flattens an input tensor while preserving the batch axis (axis 0).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```
|
||||||
|
x = tf.placeholder(shape=(None, 4, 4), dtype='float32')
|
||||||
|
y = Flatten()(x)
|
||||||
|
# now `y` has shape `(None, 16)`
|
||||||
|
|
||||||
|
x = tf.placeholder(shape=(None, 3, None), dtype='float32')
|
||||||
|
y = Flatten()(x)
|
||||||
|
# now `y` has shape `(None, None)`
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(Flatten, self).__init__(**kwargs)
|
||||||
|
self.input_spec = base.InputSpec(min_ndim=2)
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
|
||||||
|
outputs.set_shape(self._compute_output_shape(inputs.get_shape()))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _compute_output_shape(self, input_shape):
|
||||||
|
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||||
|
output_shape = [input_shape[0]]
|
||||||
|
if all(input_shape[1:]):
|
||||||
|
output_shape += [np.prod(input_shape[1:])]
|
||||||
|
else:
|
||||||
|
output_shape += [None]
|
||||||
|
return tensor_shape.TensorShape(output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(inputs, name=None):
|
||||||
|
"""Flattens an input tensor while preserving the batch axis (axis 0).
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
inputs: Tensor input.
|
||||||
|
name: The name of the layer (string).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reshaped tensor.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```
|
||||||
|
x = tf.placeholder(shape=(None, 4, 4), dtype='float32')
|
||||||
|
y = flatten(x)
|
||||||
|
# now `y` has shape `(None, 16)`
|
||||||
|
|
||||||
|
x = tf.placeholder(shape=(None, 3, None), dtype='float32')
|
||||||
|
y = flatten(x)
|
||||||
|
# now `y` has shape `(None, None)`
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
layer = Flatten(name=name)
|
||||||
|
return layer.apply(inputs)
|
||||||
|
|
||||||
|
|
||||||
# Aliases
|
# Aliases
|
||||||
|
|
||||||
FullyConnected = Dense
|
FullyConnected = Dense
|
||||||
|
|
|
||||||
|
|
@ -391,5 +391,56 @@ class DropoutTest(test.TestCase):
|
||||||
self.assertAllClose(np.ones((5, 5)), np_output)
|
self.assertAllClose(np.ones((5, 5)), np_output)
|
||||||
|
|
||||||
|
|
||||||
|
class FlattenTest(test.TestCase):
|
||||||
|
|
||||||
|
def testCreateFlatten(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
|
||||||
|
y = core_layers.Flatten()(x)
|
||||||
|
np_output = sess.run(y, feed_dict={x: np.zeros((3, 2, 3))})
|
||||||
|
self.assertEqual(list(np_output.shape), [3, 6])
|
||||||
|
self.assertEqual(y.get_shape().as_list(), [None, 6])
|
||||||
|
|
||||||
|
x = array_ops.placeholder(shape=(1, 2, 3, 2), dtype='float32')
|
||||||
|
y = core_layers.Flatten()(x)
|
||||||
|
np_output = sess.run(y, feed_dict={x: np.zeros((1, 2, 3, 2))})
|
||||||
|
self.assertEqual(list(np_output.shape), [1, 12])
|
||||||
|
self.assertEqual(y.get_shape().as_list(), [1, 12])
|
||||||
|
|
||||||
|
def testComputeShape(self):
|
||||||
|
shape = core_layers.Flatten()._compute_output_shape((1, 2, 3, 2))
|
||||||
|
self.assertEqual(shape.as_list(), [1, 12])
|
||||||
|
|
||||||
|
shape = core_layers.Flatten()._compute_output_shape((None, 3, 2))
|
||||||
|
self.assertEqual(shape.as_list(), [None, 6])
|
||||||
|
|
||||||
|
shape = core_layers.Flatten()._compute_output_shape((None, 3, None))
|
||||||
|
self.assertEqual(shape.as_list(), [None, None])
|
||||||
|
|
||||||
|
def testFunctionalFlatten(self):
|
||||||
|
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
|
||||||
|
y = core_layers.flatten(x, name='flatten')
|
||||||
|
self.assertEqual(y.get_shape().as_list(), [None, 6])
|
||||||
|
|
||||||
|
def testFlattenValueError(self):
|
||||||
|
x = array_ops.placeholder(shape=(None,), dtype='float32')
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
core_layers.Flatten()(x)
|
||||||
|
|
||||||
|
def testFlattenUnknownAxes(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
x = array_ops.placeholder(shape=(5, None, None), dtype='float32')
|
||||||
|
y = core_layers.Flatten()(x)
|
||||||
|
np_output = sess.run(y, feed_dict={x: np.zeros((5, 2, 3))})
|
||||||
|
self.assertEqual(list(np_output.shape), [5, 6])
|
||||||
|
self.assertEqual(y.get_shape().as_list(), [5, None])
|
||||||
|
|
||||||
|
x = array_ops.placeholder(shape=(5, None, 2), dtype='float32')
|
||||||
|
y = core_layers.Flatten()(x)
|
||||||
|
np_output = sess.run(y, feed_dict={x: np.zeros((5, 3, 2))})
|
||||||
|
self.assertEqual(list(np_output.shape), [5, 6])
|
||||||
|
self.assertEqual(y.get_shape().as_list(), [5, None])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
@@Dense
|
@@Dense
|
||||||
@@Dropout
|
@@Dropout
|
||||||
|
@@Flatten
|
||||||
@@Conv1D
|
@@Conv1D
|
||||||
@@Conv2D
|
@@Conv2D
|
||||||
@@Conv3D
|
@@Conv3D
|
||||||
|
|
@ -39,6 +40,7 @@
|
||||||
|
|
||||||
@@dense
|
@@dense
|
||||||
@@dropout
|
@@dropout
|
||||||
|
@@flatten
|
||||||
@@conv1d
|
@@conv1d
|
||||||
@@conv2d
|
@@conv2d
|
||||||
@@conv3d
|
@@conv3d
|
||||||
|
|
@ -71,9 +73,11 @@ from tensorflow.python.layers.base import InputSpec
|
||||||
# Core layers.
|
# Core layers.
|
||||||
from tensorflow.python.layers.core import Dense
|
from tensorflow.python.layers.core import Dense
|
||||||
from tensorflow.python.layers.core import Dropout
|
from tensorflow.python.layers.core import Dropout
|
||||||
|
from tensorflow.python.layers.core import Flatten
|
||||||
|
|
||||||
from tensorflow.python.layers.core import dense
|
from tensorflow.python.layers.core import dense
|
||||||
from tensorflow.python.layers.core import dropout
|
from tensorflow.python.layers.core import dropout
|
||||||
|
from tensorflow.python.layers.core import flatten
|
||||||
|
|
||||||
# Convolutional layers.
|
# Convolutional layers.
|
||||||
from tensorflow.python.layers.convolutional import SeparableConv2D
|
from tensorflow.python.layers.convolutional import SeparableConv2D
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
path: "tensorflow.keras.layers.Flatten"
|
path: "tensorflow.keras.layers.Flatten"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Flatten\'>"
|
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.core.Flatten\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.layers.core.Flatten\'>"
|
||||||
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
|
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
|
||||||
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
|
|
||||||
118
tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
Normal file
118
tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
path: "tensorflow.layers.Flatten"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.layers.core.Flatten\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "input"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "input_shape"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "losses"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "non_trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "output"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "output_shape"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "scope_name"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "trainable_weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "updates"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "variables"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "weights"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_loss"
|
||||||
|
argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_update"
|
||||||
|
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "add_variable"
|
||||||
|
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "apply"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "build"
|
||||||
|
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "count_params"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_input_at"
|
||||||
|
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_input_shape_at"
|
||||||
|
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_losses_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_output_at"
|
||||||
|
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_output_shape_at"
|
||||||
|
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_updates_for"
|
||||||
|
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -44,6 +44,10 @@ tf_module {
|
||||||
name: "Dropout"
|
name: "Dropout"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "Flatten"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "InputSpec"
|
name: "InputSpec"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
|
@ -120,6 +124,10 @@ tf_module {
|
||||||
name: "dropout"
|
name: "dropout"
|
||||||
argspec: "args=[\'inputs\', \'rate\', \'noise_shape\', \'seed\', \'training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\', \'None\', \'False\', \'None\'], "
|
argspec: "args=[\'inputs\', \'rate\', \'noise_shape\', \'seed\', \'training\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\', \'None\', \'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "flatten"
|
||||||
|
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "max_pooling1d"
|
name: "max_pooling1d"
|
||||||
argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], "
|
argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], "
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user