mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
Bugfix to LSTMBlockCell and friends: clipping is off by default.
* Rename broken API argu clip_cell boolean to cell_clip value. * Make default no clipping. PiperOrigin-RevId: 170960975
This commit is contained in:
parent
bfaaefa9ec
commit
f9f037c1c4
|
|
@ -65,7 +65,7 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell):
|
|||
|
||||
def __init__(self, num_units, reuse=None):
|
||||
super(CudnnCompatibleLSTMCell, self).__init__(
|
||||
num_units, forget_bias=0, clip_cell=False, use_peephole=False,
|
||||
num_units, forget_bias=0, cell_clip=None, use_peephole=False,
|
||||
reuse=reuse)
|
||||
self._names.update({"scope": "cudnn_compatible_lstm_cell"})
|
||||
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ def _lstm_block_cell(x,
|
|||
wco: A `Tensor`. Must have the same type as `x`.
|
||||
The weight matrix for output gate peephole connection.
|
||||
forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
|
||||
cell_clip: An optional `float`. Defaults to `3`.
|
||||
cell_clip: An optional `float`. Defaults to `-1` (no clipping).
|
||||
Value to clip the 'cs' value to. Disable by setting to negative value.
|
||||
use_peephole: An optional `bool`. Defaults to `False`.
|
||||
Whether to use peephole weights.
|
||||
|
|
@ -130,7 +130,7 @@ def _lstm_block_cell(x,
|
|||
wcf=wcf,
|
||||
b=b,
|
||||
forget_bias=forget_bias,
|
||||
cell_clip=cell_clip,
|
||||
cell_clip=cell_clip if cell_clip is not None else -1,
|
||||
use_peephole=use_peephole,
|
||||
name=name)
|
||||
# pylint: enable=protected-access
|
||||
|
|
@ -162,7 +162,7 @@ def _block_lstm(seq_len_max,
|
|||
wcf: A `Tensor`. Must have the same type as `x`.
|
||||
wco: A `Tensor`. Must have the same type as `x`.
|
||||
forget_bias: An optional `float`. Defaults to `1`.
|
||||
cell_clip: An optional `float`. Defaults to `3`.
|
||||
cell_clip: An optional `float`. Defaults to `-1` (no clipping).
|
||||
use_peephole: An optional `bool`. Defaults to `False`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
|
|
@ -216,7 +216,7 @@ def _block_lstm(seq_len_max,
|
|||
wcf=wcf,
|
||||
b=b,
|
||||
forget_bias=forget_bias,
|
||||
cell_clip=cell_clip,
|
||||
cell_clip=cell_clip if cell_clip is not None else -1,
|
||||
name=name,
|
||||
use_peephole=use_peephole)
|
||||
|
||||
|
|
@ -341,7 +341,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
|
|||
def __init__(self,
|
||||
num_units,
|
||||
forget_bias=1.0,
|
||||
clip_cell=True,
|
||||
cell_clip=None,
|
||||
use_peephole=False,
|
||||
reuse=None):
|
||||
"""Initialize the basic LSTM cell.
|
||||
|
|
@ -349,8 +349,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
|
|||
Args:
|
||||
num_units: int, The number of units in the LSTM cell.
|
||||
forget_bias: float, The bias added to forget gates (see above).
|
||||
clip_cell: boolean, whether to apply cell clipping. See
|
||||
`_lstm_block_cell()` for details.
|
||||
cell_clip: An optional `float`. Defaults to `-1` (no clipping).
|
||||
use_peephole: Whether to use peephole connections or not.
|
||||
reuse: (optional) boolean describing whether to reuse variables in an
|
||||
existing scope. If not `True`, and the existing scope already has the
|
||||
|
|
@ -363,7 +362,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
|
|||
self._num_units = num_units
|
||||
self._forget_bias = forget_bias
|
||||
self._use_peephole = use_peephole
|
||||
self._clip_cell = clip_cell
|
||||
self._cell_clip = cell_clip if cell_clip is not None else -1
|
||||
self._names = {
|
||||
"W": "kernel",
|
||||
"b": "bias",
|
||||
|
|
@ -412,7 +411,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
|
|||
wco=wco,
|
||||
wcf=wcf,
|
||||
forget_bias=self._forget_bias,
|
||||
cell_clip=None if self._clip_cell else -1,
|
||||
cell_clip=self._cell_clip,
|
||||
use_peephole=self._use_peephole)
|
||||
|
||||
new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
|
||||
|
|
@ -594,12 +593,12 @@ class LSTMBlockFusedCell(LSTMBlockWrapper):
|
|||
Args:
|
||||
num_units: int, The number of units in the LSTM cell.
|
||||
forget_bias: float, The bias added to forget gates (see above).
|
||||
cell_clip: clip the cell to this value. Defaults to `3`.
|
||||
cell_clip: clip the cell to this value. Default is no cell clipping.
|
||||
use_peephole: Whether to use peephole connections or not.
|
||||
"""
|
||||
self._num_units = num_units
|
||||
self._forget_bias = forget_bias
|
||||
self._cell_clip = cell_clip
|
||||
self._cell_clip = cell_clip if cell_clip is not None else -1
|
||||
self._use_peephole = use_peephole
|
||||
|
||||
@property
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user