mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Fix GRUBlockCell parameter naming inconsistency (#13153)
* Fix GRUBlockCell parameter naming inconsistency This fix tries to fix the issue in 13137 where parameter `cell_size` is used instead of `num_units`. This is inconsistent with other RNN cells. This fix adds support of `num_units` while at the same time maintains backward compatiblility for `cell_size`. This fix fixes 13137. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add `@deprecated_args` for 'cell_size' in `GRUBlockCell` This commit adds `@deprecated_args` for 'cell_size' in `GRUBlockCell` Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Address review comment Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
02a2eba057
commit
e0501bc4d0
|
|
@ -27,6 +27,7 @@ from tensorflow.python.ops import nn_ops
|
|||
from tensorflow.python.ops import rnn_cell_impl
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.platform import resource_loader
|
||||
from tensorflow.python.util.deprecation import deprecated_args
|
||||
|
||||
_gru_ops_so = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_gru_ops.so"))
|
||||
|
|
@ -129,13 +130,24 @@ class GRUBlockCell(rnn_cell_impl.RNNCell):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, cell_size):
|
||||
@deprecated_args(None, "cell_size is deprecated, use num_units instead",
|
||||
"cell_size")
|
||||
def __init__(self, num_units=None, cell_size=None):
|
||||
"""Initialize the Block GRU cell.
|
||||
|
||||
Args:
|
||||
cell_size: int, GRU cell size.
|
||||
num_units: int, The number of units in the GRU cell.
|
||||
cell_size: int, The old (deprecated) name for `num_units`.
|
||||
|
||||
Raises:
|
||||
ValueError: if both cell_size and num_units are not None;
|
||||
or both are None.
|
||||
"""
|
||||
self._cell_size = cell_size
|
||||
if (cell_size is None) == (num_units is None):
|
||||
raise ValueError("Exactly one of num_units or cell_size must be provided.")
|
||||
if num_units is None:
|
||||
num_units = cell_size
|
||||
self._cell_size = num_units
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user