Fix: add GDN to __init__. Also put it in alphabetical order.

PiperOrigin-RevId: 163842410
This commit is contained in:
A. Unique TensorFlower 2017-08-01 09:44:30 -07:00 committed by TensorFlower Gardener
parent db0e1c6c8e
commit 1a44996072
2 changed files with 314 additions and 294 deletions

View File

@ -32,6 +32,8 @@ See the @{$python/contrib.layers} guide.
@@embedding_lookup_unique
@@flatten
@@fully_connected
@@GDN
@@gdn
@@layer_norm
@@linear
@@max_pool2d

View File

@ -71,6 +71,8 @@ __all__ = ['avg_pool2d',
'elu',
'flatten',
'fully_connected',
'GDN',
'gdn',
'layer_norm',
'linear',
'pool',
@ -1682,6 +1684,316 @@ def fully_connected(inputs,
outputs_collections, sc.original_name_scope, outputs)
class GDN(base.Layer):
"""Generalized divisive normalization layer.
Based on the papers:
"Density Modeling of Images using a Generalized Normalization
Transformation"
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
https://arxiv.org/abs/1511.06281
"End-to-end Optimized Image Compression"
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
https://arxiv.org/abs/1611.01704
Implements an activation function that is essentially a multivariate
generalization of a particular sigmoid-type function:
```
y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))
```
where `i` and `j` run over channels. This implementation never sums across
spatial dimensions. It is similar to local response normalization, but much
more flexible, as `beta` and `gamma` are trainable parameters.
Arguments:
inverse: If `False` (default), compute GDN response. If `True`, compute IGDN
response (one step of fixed point iteration to invert GDN; the division
is replaced by multiplication).
beta_min: Lower bound for beta, to prevent numerical error from causing
square root of zero or negative values.
gamma_init: The gamma matrix will be initialized as the identity matrix
multiplied with this value. If set to zero, the layer is effectively
initialized to the identity operation, since beta is initialized as one.
A good default setting is somewhere between 0 and 0.5.
reparam_offset: Offset added to the reparameterization of beta and gamma.
The reparameterization of beta and gamma as their square roots lets the
training slow down when their values are close to zero, which is desirable
as small values in the denominator can lead to a situation where gradient
noise on beta/gamma leads to extreme amounts of noise in the GDN
activations. However, without the offset, we would get zero gradients if
any elements of beta or gamma were exactly zero, and thus the training
could get stuck. To prevent this, we add this small constant. The default
value was empirically determined as a good starting point. Making it
bigger potentially leads to more gradient noise on the activations, making
it too small may lead to numerical precision issues.
data_format: Format of input tensor. Currently supports `'channels_first'`
and `'channels_last'`.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True`, also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require `reuse=True` in such
cases.
Properties:
inverse: Boolean, whether GDN is computed (`True`) or IGDN (`False`).
data_format: Format of input tensor. Currently supports `'channels_first'`
and `'channels_last'`.
beta: The beta parameter as defined above (1D `Tensor`).
gamma: The gamma parameter as defined above (2D `Tensor`).
"""
def __init__(self,
inverse=False,
beta_min=1e-6,
gamma_init=.1,
reparam_offset=2 ** -18,
data_format='channels_last',
activity_regularizer=None,
trainable=True,
name=None,
**kwargs):
super(GDN, self).__init__(trainable=trainable, name=name, **kwargs)
self.inverse = inverse
self._beta_min = beta_min
self._gamma_init = gamma_init
self._reparam_offset = reparam_offset
self.data_format = data_format
self.activity_regularizer = activity_regularizer
self._channel_axis() # trigger ValueError early
self.input_spec = base.InputSpec(min_ndim=3, max_ndim=5)
def _channel_axis(self):
try:
return {'channels_first': 1, 'channels_last': -1}[self.data_format]
except KeyError:
raise ValueError('Unsupported `data_format` for GDN layer: {}.'.format(
self.data_format))
@staticmethod
def _lower_bound(inputs, bound, name=None):
"""Same as tf.maximum, but with helpful gradient for inputs < bound.
The gradient is overwritten so that it is passed through if the input is not
hitting the bound. If it is, only gradients that push `inputs` higher than
the bound are passed through. No gradients are passed through to the bound.
Args:
inputs: input tensor
bound: lower bound for the input tensor
name: name for this op
Returns:
tf.maximum(inputs, bound)
"""
with ops.name_scope(name, 'GDNLowerBound', [inputs, bound]) as scope:
inputs = ops.convert_to_tensor(inputs, name='inputs')
bound = ops.convert_to_tensor(bound, name='bound')
with ops.get_default_graph().gradient_override_map(
{'Maximum': 'GDNLowerBound'}):
return math_ops.maximum(inputs, bound, name=scope)
@staticmethod
def _lower_bound_grad(op, grad):
"""Gradient for `_lower_bound`.
Args:
op: the tensorflow op for which to calculate a gradient
grad: gradient with respect to the output of the op
Returns:
gradients with respect to the inputs of the op
"""
inputs = op.inputs[0]
bound = op.inputs[1]
pass_through_if = math_ops.logical_or(inputs >= bound, grad < 0)
return [math_ops.cast(pass_through_if, grad.dtype) * grad, None]
def build(self, input_shape):
channel_axis = self._channel_axis()
input_shape = tensor_shape.TensorShape(input_shape)
num_channels = input_shape[channel_axis].value
if num_channels is None:
raise ValueError('The channel dimension of the inputs to `GDN` '
'must be defined.')
self._input_rank = input_shape.ndims
self.input_spec = base.InputSpec(ndim=input_shape.ndims,
axes={channel_axis: num_channels})
pedestal = array_ops.constant(self._reparam_offset ** 2, dtype=self.dtype)
beta_bound = array_ops.constant(
(self._beta_min + self._reparam_offset ** 2) ** .5, dtype=self.dtype)
gamma_bound = array_ops.constant(self._reparam_offset, dtype=self.dtype)
def beta_initializer(shape, dtype=None, partition_info=None):
del partition_info # unused
return math_ops.sqrt(array_ops.ones(shape, dtype=dtype) + pedestal)
def gamma_initializer(shape, dtype=None, partition_info=None):
del partition_info # unused
assert len(shape) == 2
assert shape[0] == shape[1]
eye = linalg_ops.eye(shape[0], dtype=dtype)
return math_ops.sqrt(self._gamma_init * eye + pedestal)
beta = self.add_variable('reparam_beta',
shape=[num_channels],
initializer=beta_initializer,
dtype=self.dtype,
trainable=True)
beta = self._lower_bound(beta, beta_bound)
self.beta = math_ops.square(beta) - pedestal
gamma = self.add_variable('reparam_gamma',
shape=[num_channels, num_channels],
initializer=gamma_initializer,
dtype=self.dtype,
trainable=True)
gamma = self._lower_bound(gamma, gamma_bound)
self.gamma = math_ops.square(gamma) - pedestal
self.built = True
def call(self, inputs):
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
ndim = self._input_rank
shape = self.gamma.get_shape().as_list()
gamma = array_ops.reshape(self.gamma, (ndim - 2) * [1] + shape)
# Compute normalization pool.
if self.data_format == 'channels_first':
norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID',
data_format='NC' + 'DHW'[-(ndim - 2):])
if ndim == 3:
norm_pool = array_ops.expand_dims(norm_pool, 2)
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
norm_pool = array_ops.squeeze(norm_pool, [2])
elif ndim == 5:
shape = array_ops.shape(norm_pool)
norm_pool = array_ops.reshape(norm_pool, shape[:3] + [-1])
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
norm_pool = array_ops.reshape(norm_pool, shape)
else: # ndim == 4
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
else: # channels_last
norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID')
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NHWC')
norm_pool = math_ops.sqrt(norm_pool)
if self.inverse:
outputs = inputs * norm_pool
else:
outputs = inputs / norm_pool
outputs.set_shape(inputs.get_shape())
return outputs
def _compute_output_shape(self, input_shape):
channel_axis = self._channel_axis()
input_shape = tensor_shape.TensorShape(input_shape)
if not 3 <= input_shape.ndim <= 5:
raise ValueError('`input_shape` must be of rank 3 to 5, inclusive.')
if input_shape[channel_axis].value is None:
raise ValueError(
'The channel dimension of `input_shape` must be defined.')
return input_shape
ops.RegisterGradient('GDNLowerBound')(GDN._lower_bound_grad) # pylint:disable=protected-access
def gdn(inputs,
inverse=False,
beta_min=1e-6,
gamma_init=.1,
reparam_offset=2 ** -18,
data_format='channels_last',
activity_regularizer=None,
trainable=True,
name=None,
reuse=None):
"""Functional interface for GDN layer.
Based on the papers:
"Density Modeling of Images using a Generalized Normalization
Transformation"
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
https://arxiv.org/abs/1511.06281
"End-to-end Optimized Image Compression"
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
https://arxiv.org/abs/1611.01704
Implements an activation function that is essentially a multivariate
generalization of a particular sigmoid-type function:
```
y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))
```
where `i` and `j` run over channels. This implementation never sums across
spatial dimensions. It is similar to local response normalization, but much
more flexible, as `beta` and `gamma` are trainable parameters.
Arguments:
inputs: Tensor input.
inverse: If `False` (default), compute GDN response. If `True`, compute IGDN
response (one step of fixed point iteration to invert GDN; the division
is replaced by multiplication).
beta_min: Lower bound for beta, to prevent numerical error from causing
square root of zero or negative values.
gamma_init: The gamma matrix will be initialized as the identity matrix
multiplied with this value. If set to zero, the layer is effectively
initialized to the identity operation, since beta is initialized as one.
A good default setting is somewhere between 0 and 0.5.
reparam_offset: Offset added to the reparameterization of beta and gamma.
The reparameterization of beta and gamma as their square roots lets the
training slow down when their values are close to zero, which is desirable
as small values in the denominator can lead to a situation where gradient
noise on beta/gamma leads to extreme amounts of noise in the GDN
activations. However, without the offset, we would get zero gradients if
any elements of beta or gamma were exactly zero, and thus the training
could get stuck. To prevent this, we add this small constant. The default
value was empirically determined as a good starting point. Making it
bigger potentially leads to more gradient noise on the activations, making
it too small may lead to numerical precision issues.
data_format: Format of input tensor. Currently supports `'channels_first'`
and `'channels_last'`.
activity_regularizer: Regularizer function for the output.
trainable: Boolean, if `True`, also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require `reuse=True` in such
cases.
reuse: Boolean, whether to reuse the weights of a previous layer by the same
name.
Returns:
Output tensor.
"""
layer = GDN(inverse=inverse,
beta_min=beta_min,
gamma_init=gamma_init,
reparam_offset=reparam_offset,
data_format=data_format,
activity_regularizer=activity_regularizer,
trainable=trainable,
name=name,
dtype=inputs.dtype.base_dtype,
_scope=name,
_reuse=reuse)
return layer.apply(inputs)
@add_arg_scope
def layer_norm(inputs,
center=True,
@ -1812,300 +2124,6 @@ def layer_norm(inputs,
outputs)
class GDN(base.Layer):
"""Generalized divisive normalization layer.
Based on the papers:
"Density Modeling of Images using a Generalized Normalization
Transformation"
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
https://arxiv.org/abs/1511.06281
"End-to-end Optimized Image Compression"
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
https://arxiv.org/abs/1611.01704
Implements an activation function that is essentially a multivariate
generalization of a particular sigmoid-type function:
y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))
where i and j run over channels. This implementation never sums across spatial
dimensions. It is similar to local response normalization, but more powerful,
as beta and gamma are trainable parameters.
Arguments:
inverse: If False (default), compute GDN response. If True, compute IGDN
response (one step of fixed point iteration to invert GDN; the division
is replaced by multiplication).
beta_min: Lower bound for beta, to prevent numerical error from causing
square root of zero or negative values.
gamma_init: The gamma matrix will be initialized as the identity matrix
multiplied with this value. If set to zero, the layer is effectively
initialized to the identity operation, since beta is initialized as one.
A good default setting is somewhere between 0 and 0.5.
reparam_offset: Offset added to the reparameterization of beta and gamma.
The reparameterization of beta and gamma as their square roots lets the
training slow down when their values are close to zero, which is desirable
as small values in the denominator can lead to a situation where gradient
noise on beta/gamma leads to extreme amounts of noise in the GDN
activations. However, without the offset, we would get zero gradients if
any elements of beta or gamma were exactly zero, and thus the training
could get stuck. To prevent this, we add this small constant. The default
value was empirically determined as a good starting point. Making it
bigger potentially leads to more gradient noise on the activations, making
it too small may lead to numerical precision issues.
data_format: Format of input tensor. Currently supports 'channels_first' and
'channels_last'.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such cases.
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Properties:
inverse: Boolean, whether GDN is computed (True) or IGDN (False).
data_format: Format of input tensor. Currently supports 'channels_first' and
'channels_last'.
beta: The beta parameter as defined above (1D TensorFlow tensor).
gamma: The gamma parameter as defined above (2D TensorFlow tensor).
"""
def __init__(self,
inverse=False,
beta_min=1e-6,
gamma_init=.1,
reparam_offset=2 ** -18,
data_format='channels_last',
trainable=True,
name=None,
**kwargs):
super(GDN, self).__init__(trainable=trainable, name=name, **kwargs)
self.inverse = inverse
self._beta_min = beta_min
self._gamma_init = gamma_init
self._reparam_offset = reparam_offset
self.data_format = data_format
self._channel_axis() # trigger ValueError early
self.input_spec = base.InputSpec(min_ndim=3, max_ndim=5)
def _channel_axis(self):
try:
return {'channels_first': 1, 'channels_last': -1}[self.data_format]
except KeyError:
raise ValueError('Unsupported `data_format` for GDN layer: {}.'.format(
self.data_format))
@staticmethod
def _lower_bound(inputs, bound, name=None):
"""Same as tf.maximum, but with helpful gradient for inputs < bound.
The gradient is overwritten so that it is passed through if the input is not
hitting the bound. If it is, only gradients that push `inputs` higher than
the bound are passed through. No gradients are passed through to the bound.
Args:
inputs: input tensor
bound: lower bound for the input tensor
name: name for this op
Returns:
tf.maximum(inputs, bound)
"""
with ops.name_scope(name, 'GDNLowerBound', [inputs, bound]) as scope:
inputs = ops.convert_to_tensor(inputs, name='inputs')
bound = ops.convert_to_tensor(bound, name='bound')
with ops.get_default_graph().gradient_override_map(
{'Maximum': 'GDNLowerBound'}):
return math_ops.maximum(inputs, bound, name=scope)
@ops.RegisterGradient('GDNLowerBound')
@staticmethod
def _lower_bound_grad(op, grad):
"""Gradient for `_lower_bound`.
Args:
op: the tensorflow op for which to calculate a gradient
grad: gradient with respect to the output of the op
Returns:
gradients with respect to the inputs of the op
"""
inputs = op.inputs[0]
bound = op.inputs[1]
pass_through_if = math_ops.logical_or(inputs >= bound, grad < 0)
return [math_ops.cast(pass_through_if, grad.dtype) * grad, None]
def build(self, input_shape):
channel_axis = self._channel_axis()
input_shape = tensor_shape.TensorShape(input_shape)
num_channels = input_shape[channel_axis].value
if num_channels is None:
raise ValueError('The channel dimension of the inputs to `GDN` '
'must be defined.')
self._input_rank = input_shape.ndims
self.input_spec = base.InputSpec(ndim=input_shape.ndims,
axes={channel_axis: num_channels})
pedestal = array_ops.constant(self._reparam_offset ** 2, dtype=self.dtype)
beta_bound = array_ops.constant(
(self._beta_min + self._reparam_offset ** 2) ** .5, dtype=self.dtype)
gamma_bound = array_ops.constant(self._reparam_offset, dtype=self.dtype)
def beta_initializer(shape, dtype=None, partition_info=None):
del partition_info # unused
return math_ops.sqrt(array_ops.ones(shape, dtype=dtype) + pedestal)
def gamma_initializer(shape, dtype=None, partition_info=None):
del partition_info # unused
assert len(shape) == 2
assert shape[0] == shape[1]
eye = linalg_ops.eye(shape[0], dtype=dtype)
return math_ops.sqrt(self._gamma_init * eye + pedestal)
beta = self.add_variable('reparam_beta',
shape=[num_channels],
initializer=beta_initializer,
dtype=self.dtype,
trainable=True)
beta = self._lower_bound(beta, beta_bound)
self.beta = math_ops.square(beta) - pedestal
gamma = self.add_variable('reparam_gamma',
shape=[num_channels, num_channels],
initializer=gamma_initializer,
dtype=self.dtype,
trainable=True)
gamma = self._lower_bound(gamma, gamma_bound)
self.gamma = math_ops.square(gamma) - pedestal
self.built = True
def call(self, inputs):
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
ndim = self._input_rank
shape = self.gamma.get_shape().as_list()
gamma = array_ops.reshape(self.gamma, (ndim - 2) * [1] + shape)
# Compute normalization pool.
if self.data_format == 'channels_first':
norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID',
data_format='NC' + 'DHW'[-(ndim - 2):])
if ndim == 3:
norm_pool = array_ops.expand_dims(norm_pool, 2)
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
norm_pool = array_ops.squeeze(norm_pool, [2])
elif ndim == 5:
shape = array_ops.shape(norm_pool)
norm_pool = array_ops.reshape(norm_pool, shape[:3] + [-1])
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
norm_pool = array_ops.reshape(norm_pool, shape)
else: # ndim == 4
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW')
else: # channels_last
norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID')
norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NHWC')
norm_pool = math_ops.sqrt(norm_pool)
if self.inverse:
outputs = inputs * norm_pool
else:
outputs = inputs / norm_pool
outputs.set_shape(inputs.get_shape())
return outputs
def _compute_output_shape(self, input_shape):
channel_axis = self._channel_axis()
input_shape = tensor_shape.TensorShape(input_shape)
if not 3 <= input_shape.ndim <= 5:
raise ValueError('`input_shape` must be of rank 3 to 5, inclusive.')
if input_shape[channel_axis].value is None:
raise ValueError(
'The channel dimension of `input_shape` must be defined.')
return input_shape
def gdn(inputs,
inverse=False,
beta_min=1e-6,
gamma_init=.1,
reparam_offset=2 ** -18,
data_format='channels_last',
trainable=True,
name=None,
reuse=None):
"""Functional interface for GDN layer.
Based on the papers:
"Density Modeling of Images using a Generalized Normalization
Transformation"
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
https://arxiv.org/abs/1511.06281
"End-to-end Optimized Image Compression"
Johannes Ballé, Valero Laparra, Eero P. Simoncelli
https://arxiv.org/abs/1611.01704
Implements an activation function that is essentially a multivariate
generalization of a particular sigmoid-type function:
y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j]))
where i and j run over channels. This implementation never sums across spatial
dimensions. It is similar to local response normalization, but more powerful,
as beta and gamma are trainable parameters.
Arguments:
inputs: Tensor input.
inverse: If False (default), compute GDN response. If True, compute IGDN
response (one step of fixed point iteration to invert GDN; the division
is replaced by multiplication).
beta_min: Lower bound for beta, to prevent numerical error from causing
square root of zero or negative values.
gamma_init: The gamma matrix will be initialized as the identity matrix
multiplied with this value. If set to zero, the layer is effectively
initialized to the identity operation, since beta is initialized as one.
A good default setting is somewhere between 0 and 0.5.
reparam_offset: Offset added to the reparameterization of beta and gamma.
The reparameterization of beta and gamma as their square roots lets the
training slow down when their values are close to zero, which is desirable
as small values in the denominator can lead to a situation where gradient
noise on beta/gamma leads to extreme amounts of noise in the GDN
activations. However, without the offset, we would get zero gradients if
any elements of beta or gamma were exactly zero, and thus the training
could get stuck. To prevent this, we add this small constant. The default
value was empirically determined as a good starting point. Making it
bigger potentially leads to more gradient noise on the activations, making
it too small may lead to numerical precision issues.
data_format: Format of input tensor. Currently supports 'channels_first' and
'channels_last'.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such cases.
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Returns:
Output tensor.
"""
layer = GDN(inverse=inverse,
beta_min=beta_min,
gamma_init=gamma_init,
reparam_offset=reparam_offset,
data_format=data_format,
trainable=trainable,
name=name,
dtype=inputs.dtype.base_dtype,
_scope=name,
_reuse=reuse)
return layer.apply(inputs)
@add_arg_scope
def max_pool2d(inputs,
kernel_size,