mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Bugfix to LSTMBlockCell and friends: clipping is off by default.
* Rename broken API argu clip_cell boolean to cell_clip value. * Make default no clipping. PiperOrigin-RevId: 170960975
This commit is contained in:
parent
bfaaefa9ec
commit
f9f037c1c4
|
|
@ -65,7 +65,7 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell):
|
||||||
|
|
||||||
def __init__(self, num_units, reuse=None):
|
def __init__(self, num_units, reuse=None):
|
||||||
super(CudnnCompatibleLSTMCell, self).__init__(
|
super(CudnnCompatibleLSTMCell, self).__init__(
|
||||||
num_units, forget_bias=0, clip_cell=False, use_peephole=False,
|
num_units, forget_bias=0, cell_clip=None, use_peephole=False,
|
||||||
reuse=reuse)
|
reuse=reuse)
|
||||||
self._names.update({"scope": "cudnn_compatible_lstm_cell"})
|
self._names.update({"scope": "cudnn_compatible_lstm_cell"})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ def _lstm_block_cell(x,
|
||||||
wco: A `Tensor`. Must have the same type as `x`.
|
wco: A `Tensor`. Must have the same type as `x`.
|
||||||
The weight matrix for output gate peephole connection.
|
The weight matrix for output gate peephole connection.
|
||||||
forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
|
forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
|
||||||
cell_clip: An optional `float`. Defaults to `3`.
|
cell_clip: An optional `float`. Defaults to `-1` (no clipping).
|
||||||
Value to clip the 'cs' value to. Disable by setting to negative value.
|
Value to clip the 'cs' value to. Disable by setting to negative value.
|
||||||
use_peephole: An optional `bool`. Defaults to `False`.
|
use_peephole: An optional `bool`. Defaults to `False`.
|
||||||
Whether to use peephole weights.
|
Whether to use peephole weights.
|
||||||
|
|
@ -130,7 +130,7 @@ def _lstm_block_cell(x,
|
||||||
wcf=wcf,
|
wcf=wcf,
|
||||||
b=b,
|
b=b,
|
||||||
forget_bias=forget_bias,
|
forget_bias=forget_bias,
|
||||||
cell_clip=cell_clip,
|
cell_clip=cell_clip if cell_clip is not None else -1,
|
||||||
use_peephole=use_peephole,
|
use_peephole=use_peephole,
|
||||||
name=name)
|
name=name)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
@ -162,7 +162,7 @@ def _block_lstm(seq_len_max,
|
||||||
wcf: A `Tensor`. Must have the same type as `x`.
|
wcf: A `Tensor`. Must have the same type as `x`.
|
||||||
wco: A `Tensor`. Must have the same type as `x`.
|
wco: A `Tensor`. Must have the same type as `x`.
|
||||||
forget_bias: An optional `float`. Defaults to `1`.
|
forget_bias: An optional `float`. Defaults to `1`.
|
||||||
cell_clip: An optional `float`. Defaults to `3`.
|
cell_clip: An optional `float`. Defaults to `-1` (no clipping).
|
||||||
use_peephole: An optional `bool`. Defaults to `False`.
|
use_peephole: An optional `bool`. Defaults to `False`.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
|
@ -216,7 +216,7 @@ def _block_lstm(seq_len_max,
|
||||||
wcf=wcf,
|
wcf=wcf,
|
||||||
b=b,
|
b=b,
|
||||||
forget_bias=forget_bias,
|
forget_bias=forget_bias,
|
||||||
cell_clip=cell_clip,
|
cell_clip=cell_clip if cell_clip is not None else -1,
|
||||||
name=name,
|
name=name,
|
||||||
use_peephole=use_peephole)
|
use_peephole=use_peephole)
|
||||||
|
|
||||||
|
|
@ -341,7 +341,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_units,
|
num_units,
|
||||||
forget_bias=1.0,
|
forget_bias=1.0,
|
||||||
clip_cell=True,
|
cell_clip=None,
|
||||||
use_peephole=False,
|
use_peephole=False,
|
||||||
reuse=None):
|
reuse=None):
|
||||||
"""Initialize the basic LSTM cell.
|
"""Initialize the basic LSTM cell.
|
||||||
|
|
@ -349,8 +349,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
|
||||||
Args:
|
Args:
|
||||||
num_units: int, The number of units in the LSTM cell.
|
num_units: int, The number of units in the LSTM cell.
|
||||||
forget_bias: float, The bias added to forget gates (see above).
|
forget_bias: float, The bias added to forget gates (see above).
|
||||||
clip_cell: boolean, whether to apply cell clipping. See
|
cell_clip: An optional `float`. Defaults to `-1` (no clipping).
|
||||||
`_lstm_block_cell()` for details.
|
|
||||||
use_peephole: Whether to use peephole connections or not.
|
use_peephole: Whether to use peephole connections or not.
|
||||||
reuse: (optional) boolean describing whether to reuse variables in an
|
reuse: (optional) boolean describing whether to reuse variables in an
|
||||||
existing scope. If not `True`, and the existing scope already has the
|
existing scope. If not `True`, and the existing scope already has the
|
||||||
|
|
@ -363,7 +362,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
|
||||||
self._num_units = num_units
|
self._num_units = num_units
|
||||||
self._forget_bias = forget_bias
|
self._forget_bias = forget_bias
|
||||||
self._use_peephole = use_peephole
|
self._use_peephole = use_peephole
|
||||||
self._clip_cell = clip_cell
|
self._cell_clip = cell_clip if cell_clip is not None else -1
|
||||||
self._names = {
|
self._names = {
|
||||||
"W": "kernel",
|
"W": "kernel",
|
||||||
"b": "bias",
|
"b": "bias",
|
||||||
|
|
@ -412,7 +411,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
|
||||||
wco=wco,
|
wco=wco,
|
||||||
wcf=wcf,
|
wcf=wcf,
|
||||||
forget_bias=self._forget_bias,
|
forget_bias=self._forget_bias,
|
||||||
cell_clip=None if self._clip_cell else -1,
|
cell_clip=self._cell_clip,
|
||||||
use_peephole=self._use_peephole)
|
use_peephole=self._use_peephole)
|
||||||
|
|
||||||
new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
|
new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
|
||||||
|
|
@ -594,12 +593,12 @@ class LSTMBlockFusedCell(LSTMBlockWrapper):
|
||||||
Args:
|
Args:
|
||||||
num_units: int, The number of units in the LSTM cell.
|
num_units: int, The number of units in the LSTM cell.
|
||||||
forget_bias: float, The bias added to forget gates (see above).
|
forget_bias: float, The bias added to forget gates (see above).
|
||||||
cell_clip: clip the cell to this value. Defaults to `3`.
|
cell_clip: clip the cell to this value. Default is no cell clipping.
|
||||||
use_peephole: Whether to use peephole connections or not.
|
use_peephole: Whether to use peephole connections or not.
|
||||||
"""
|
"""
|
||||||
self._num_units = num_units
|
self._num_units = num_units
|
||||||
self._forget_bias = forget_bias
|
self._forget_bias = forget_bias
|
||||||
self._cell_clip = cell_clip
|
self._cell_clip = cell_clip if cell_clip is not None else -1
|
||||||
self._use_peephole = use_peephole
|
self._use_peephole = use_peephole
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user