[tf contrib seq2seq] Bugfixes to BeamSearchDecoder

Implementation by Cinjon Resnick.  He can't push this since he's traveling.
I just copied the fix and added some small syntax tweaks to make the unit
tests pass.  More comprehensive unit tests will come in the near future.

Fixes at least part of #9904.

BeamSearchDecoder:
1. Fix the bug where we don't pass the next cell state through.
2. Gather the cell state (and attention if that's a part of the model
   as an AttentionWrapper on the cell) according to the next_beam_ids.
PiperOrigin-RevId: 157415564
This commit is contained in:
Eugene Brevdo 2017-05-29 15:14:46 -07:00 committed by TensorFlower Gardener
parent f7ae1461c2
commit 34dcd5b493
2 changed files with 96 additions and 25 deletions

View File

@ -141,6 +141,7 @@ class TestBeamStep(test.TestCase):
outputs, next_beam_state = beam_search_decoder._beam_search_step(
time=2,
logits=logits,
next_cell_state=dummy_cell_state,
beam_state=beam_state,
batch_size=ops.convert_to_tensor(self.batch_size),
beam_width=self.beam_width,
@ -195,6 +196,7 @@ class TestBeamStep(test.TestCase):
outputs, next_beam_state = beam_search_decoder._beam_search_step(
time=2,
logits=logits,
next_cell_state=dummy_cell_state,
beam_state=beam_state,
batch_size=ops.convert_to_tensor(self.batch_size),
beam_width=self.beam_width,

View File

@ -110,6 +110,15 @@ def tile_batch(t, multiplier, name=None):
return tiled
def _check_maybe(t):
if isinstance(t, tensor_array_ops.TensorArray):
raise TypeError(
"TensorArray state is not supported by BeamSearchDecoder: %s" % t.name)
if t.shape.ndims is None:
raise ValueError(
"Expected tensor (%s) to have known rank, but ndims == None." % t)
class BeamSearchDecoder(decoder.Decoder):
"""BeamSearch sampling decoder."""
@ -351,13 +360,7 @@ class BeamSearchDecoder(decoder.Decoder):
TypeError: If t is an instance of TensorArray.
ValueError: If the rank of t is not statically known.
"""
if isinstance(t, tensor_array_ops.TensorArray):
raise TypeError(
"TensorArray state is not supported by BeamSearchDecoder: %s"
% t.name)
if t.shape.ndims is None:
raise ValueError(
"Expected tensor (%s) to have known rank, but ndims == None." % t)
_check_maybe(t)
if t.shape.ndims >= 1:
return self._split_batch_beams(t, s)
else:
@ -380,13 +383,7 @@ class BeamSearchDecoder(decoder.Decoder):
TypeError: If t is an instance of TensorArray.
ValueError: If the rank of t is not statically known.
"""
if isinstance(t, tensor_array_ops.TensorArray):
raise TypeError(
"TensorArray state is not supported by BeamSearchDecoder: %s"
% t.name)
if t.shape.ndims is None:
raise ValueError(
"Expected tensor (%s) to have known rank, but ndims == None." % t)
_check_maybe(t)
if t.shape.ndims >= 2:
return self._merge_batch_beams(t, s)
else:
@ -417,7 +414,6 @@ class BeamSearchDecoder(decoder.Decoder):
self._maybe_merge_batch_beams,
cell_state, self._cell.state_size)
cell_outputs, next_cell_state = self._cell(inputs, cell_state)
cell_outputs = nest.map_structure(
lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
next_cell_state = nest.map_structure(
@ -430,11 +426,13 @@ class BeamSearchDecoder(decoder.Decoder):
beam_search_output, beam_search_state = _beam_search_step(
time=time,
logits=cell_outputs,
next_cell_state=next_cell_state,
beam_state=state,
batch_size=batch_size,
beam_width=beam_width,
end_token=end_token,
length_penalty_weight=length_penalty_weight)
finished = beam_search_state.finished
sample_ids = beam_search_output.predicted_ids
next_inputs = control_flow_ops.cond(
@ -444,8 +442,8 @@ class BeamSearchDecoder(decoder.Decoder):
return (beam_search_output, beam_search_state, next_inputs, finished)
def _beam_search_step(time, logits, beam_state, batch_size, beam_width,
end_token, length_penalty_weight):
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
beam_width, end_token, length_penalty_weight):
"""Performs a single step of Beam Search Decoding.
Args:
@ -454,6 +452,8 @@ def _beam_search_step(time, logits, beam_state, batch_size, beam_width,
continuations.
logits: Logits at the current time step. A tensor of shape
`[batch_size, beam_width, vocab_size]`
next_cell_state: The next state from the cell, e.g. an instance of
AttentionWrapperState if the cell is attentional.
beam_state: Current state of the beam search.
An instance of `BeamSearchDecoderState`.
batch_size: The batch size for this input.
@ -520,8 +520,7 @@ def _beam_search_step(time, logits, beam_state, batch_size, beam_width,
gather_from=total_probs,
range_input=batch_size,
range_size=beam_width * vocab_size,
final_shape=[static_batch_size, beam_width])
gather_shape=[-1])
next_word_ids = math_ops.to_int32(word_indices % vocab_size)
next_beam_ids = math_ops.to_int32(word_indices / vocab_size)
@ -531,7 +530,7 @@ def _beam_search_step(time, logits, beam_state, batch_size, beam_width,
gather_from=previously_finished,
range_input=batch_size,
range_size=beam_width,
final_shape=[static_batch_size, beam_width])
gather_shape=[-1])
next_finished = math_ops.logical_or(previously_finished,
math_ops.equal(next_word_ids, end_token))
@ -547,11 +546,26 @@ def _beam_search_step(time, logits, beam_state, batch_size, beam_width,
gather_from=beam_state.lengths,
range_input=batch_size,
range_size=beam_width,
final_shape=[static_batch_size, beam_width])
gather_shape=[-1])
next_prediction_len += lengths_to_add
# Pick out the cell_states according to the next_beam_ids. We use a
# different gather_shape here because the cell_state tensors, i.e.
# the tensors that would be gathered from, all have dimension
# greater than two and we need to preserve those dimensions.
# pylint: disable=g-long-lambda
next_cell_state = nest.map_structure(
lambda gather_from: _maybe_tensor_gather_helper(
gather_indices=next_beam_ids,
gather_from=gather_from,
range_input=batch_size,
range_size=beam_width,
gather_shape=[batch_size * beam_width, -1]),
next_cell_state)
# pylint: enable=g-long-lambda
next_state = BeamSearchDecoderState(
cell_state=beam_state.cell_state,
cell_state=next_cell_state,
log_probs=next_beam_probs,
lengths=next_prediction_len,
finished=next_finished)
@ -637,12 +651,67 @@ def _mask_probs(probs, eos_token, finished):
return finished_examples + non_finished_examples
def _maybe_tensor_gather_helper(gather_indices, gather_from, range_input,
range_size, gather_shape):
"""Maybe applies _tensor_gather_helper.
This applies _tensor_gather_helper when the gather_from dims is at least as
big as the length of gather_shape. This is used in conjunction with nest so
that we don't apply _tensor_gather_helper to inapplicable values like scalars.
Args:
gather_indices: The tensor indices that we use to gather.
gather_from: The tensor that we are gathering from.
range_input: The input range to use. Likely equal to batch_size.
range_size: The number of values in each range. Likely equal to beam_width.
gather_shape: What we should reshape gather_from to in order to preserve the
correct values. An example is when gather_from is the attention from an
AttentionWrapperState with shape [batch_size, beam_width, attention_size].
There, we want to preserve the attention_size elements, so gather_shape is
[batch_size * beam_width, -1]. Then, upon reshape, we still have the
attention_size as desired.
Returns:
output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
or the original tensor if its dimensions are too small.
"""
_check_maybe(gather_from)
if gather_from.shape.ndims >= len(gather_shape):
return _tensor_gather_helper(gather_indices, gather_from, range_input,
range_size, gather_shape)
else:
return gather_from
def _tensor_gather_helper(gather_indices, gather_from, range_input, range_size,
final_shape):
gather_shape):
"""Helper for gathering the right indices from the tensor.
This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
gathering from that according to the gather_indices, which are offset by
the right amounts in order to preserve the batch order.
Args:
gather_indices: The tensor indices that we use to gather.
gather_from: The tensor that we are gathering from.
range_input: The input range to use. Likely equal to batch_size.
range_size: The number of values in each range. Likely equal to beam_width.
gather_shape: What we should reshape gather_from to in order to preserve the
correct values. An example is when gather_from is the attention from an
AttentionWrapperState with shape [batch_size, beam_width, attention_size].
There, we want to preserve the attention_size elements, so gather_shape is
[batch_size * beam_width, -1]. Then, upon reshape, we still have the
attention_size as desired.
Returns:
output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
"""
range_ = array_ops.expand_dims(math_ops.range(range_input) * range_size, 1)
gather_indices = array_ops.reshape(gather_indices + range_, [-1])
output = array_ops.gather(
array_ops.reshape(gather_from, [-1]), gather_indices)
array_ops.reshape(gather_from, gather_shape), gather_indices)
final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)]
final_static_shape = gather_from.shape[:1 + len(gather_shape)]
output = array_ops.reshape(output, final_shape)
output.set_shape(final_shape)
output.set_shape(final_static_shape)
return output