mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[tf contrib seq2seq] Bugfixes to BeamSearchDecoder
Implementation by Cinjon Resnick. He can't push this since he's traveling. I just copied the fix and added some small syntax tweaks to make the unit tests pass. More comprehensive unit tests will come in the near future. Fixes at least part of #9904. BeamSearchDecoder: 1. Fix the bug where we don't pass the next cell state through. 2. Gather the cell state (and attention if that's a part of the model as an AttentionWrapper on the cell) according to the next_beam_ids. PiperOrigin-RevId: 157415564
This commit is contained in:
parent
f7ae1461c2
commit
34dcd5b493
|
|
@ -141,6 +141,7 @@ class TestBeamStep(test.TestCase):
|
|||
outputs, next_beam_state = beam_search_decoder._beam_search_step(
|
||||
time=2,
|
||||
logits=logits,
|
||||
next_cell_state=dummy_cell_state,
|
||||
beam_state=beam_state,
|
||||
batch_size=ops.convert_to_tensor(self.batch_size),
|
||||
beam_width=self.beam_width,
|
||||
|
|
@ -195,6 +196,7 @@ class TestBeamStep(test.TestCase):
|
|||
outputs, next_beam_state = beam_search_decoder._beam_search_step(
|
||||
time=2,
|
||||
logits=logits,
|
||||
next_cell_state=dummy_cell_state,
|
||||
beam_state=beam_state,
|
||||
batch_size=ops.convert_to_tensor(self.batch_size),
|
||||
beam_width=self.beam_width,
|
||||
|
|
|
|||
|
|
@ -110,6 +110,15 @@ def tile_batch(t, multiplier, name=None):
|
|||
return tiled
|
||||
|
||||
|
||||
def _check_maybe(t):
|
||||
if isinstance(t, tensor_array_ops.TensorArray):
|
||||
raise TypeError(
|
||||
"TensorArray state is not supported by BeamSearchDecoder: %s" % t.name)
|
||||
if t.shape.ndims is None:
|
||||
raise ValueError(
|
||||
"Expected tensor (%s) to have known rank, but ndims == None." % t)
|
||||
|
||||
|
||||
class BeamSearchDecoder(decoder.Decoder):
|
||||
"""BeamSearch sampling decoder."""
|
||||
|
||||
|
|
@ -351,13 +360,7 @@ class BeamSearchDecoder(decoder.Decoder):
|
|||
TypeError: If t is an instance of TensorArray.
|
||||
ValueError: If the rank of t is not statically known.
|
||||
"""
|
||||
if isinstance(t, tensor_array_ops.TensorArray):
|
||||
raise TypeError(
|
||||
"TensorArray state is not supported by BeamSearchDecoder: %s"
|
||||
% t.name)
|
||||
if t.shape.ndims is None:
|
||||
raise ValueError(
|
||||
"Expected tensor (%s) to have known rank, but ndims == None." % t)
|
||||
_check_maybe(t)
|
||||
if t.shape.ndims >= 1:
|
||||
return self._split_batch_beams(t, s)
|
||||
else:
|
||||
|
|
@ -380,13 +383,7 @@ class BeamSearchDecoder(decoder.Decoder):
|
|||
TypeError: If t is an instance of TensorArray.
|
||||
ValueError: If the rank of t is not statically known.
|
||||
"""
|
||||
if isinstance(t, tensor_array_ops.TensorArray):
|
||||
raise TypeError(
|
||||
"TensorArray state is not supported by BeamSearchDecoder: %s"
|
||||
% t.name)
|
||||
if t.shape.ndims is None:
|
||||
raise ValueError(
|
||||
"Expected tensor (%s) to have known rank, but ndims == None." % t)
|
||||
_check_maybe(t)
|
||||
if t.shape.ndims >= 2:
|
||||
return self._merge_batch_beams(t, s)
|
||||
else:
|
||||
|
|
@ -417,7 +414,6 @@ class BeamSearchDecoder(decoder.Decoder):
|
|||
self._maybe_merge_batch_beams,
|
||||
cell_state, self._cell.state_size)
|
||||
cell_outputs, next_cell_state = self._cell(inputs, cell_state)
|
||||
|
||||
cell_outputs = nest.map_structure(
|
||||
lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
|
||||
next_cell_state = nest.map_structure(
|
||||
|
|
@ -430,11 +426,13 @@ class BeamSearchDecoder(decoder.Decoder):
|
|||
beam_search_output, beam_search_state = _beam_search_step(
|
||||
time=time,
|
||||
logits=cell_outputs,
|
||||
next_cell_state=next_cell_state,
|
||||
beam_state=state,
|
||||
batch_size=batch_size,
|
||||
beam_width=beam_width,
|
||||
end_token=end_token,
|
||||
length_penalty_weight=length_penalty_weight)
|
||||
|
||||
finished = beam_search_state.finished
|
||||
sample_ids = beam_search_output.predicted_ids
|
||||
next_inputs = control_flow_ops.cond(
|
||||
|
|
@ -444,8 +442,8 @@ class BeamSearchDecoder(decoder.Decoder):
|
|||
return (beam_search_output, beam_search_state, next_inputs, finished)
|
||||
|
||||
|
||||
def _beam_search_step(time, logits, beam_state, batch_size, beam_width,
|
||||
end_token, length_penalty_weight):
|
||||
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
|
||||
beam_width, end_token, length_penalty_weight):
|
||||
"""Performs a single step of Beam Search Decoding.
|
||||
|
||||
Args:
|
||||
|
|
@ -454,6 +452,8 @@ def _beam_search_step(time, logits, beam_state, batch_size, beam_width,
|
|||
continuations.
|
||||
logits: Logits at the current time step. A tensor of shape
|
||||
`[batch_size, beam_width, vocab_size]`
|
||||
next_cell_state: The next state from the cell, e.g. an instance of
|
||||
AttentionWrapperState if the cell is attentional.
|
||||
beam_state: Current state of the beam search.
|
||||
An instance of `BeamSearchDecoderState`.
|
||||
batch_size: The batch size for this input.
|
||||
|
|
@ -520,8 +520,7 @@ def _beam_search_step(time, logits, beam_state, batch_size, beam_width,
|
|||
gather_from=total_probs,
|
||||
range_input=batch_size,
|
||||
range_size=beam_width * vocab_size,
|
||||
final_shape=[static_batch_size, beam_width])
|
||||
|
||||
gather_shape=[-1])
|
||||
next_word_ids = math_ops.to_int32(word_indices % vocab_size)
|
||||
next_beam_ids = math_ops.to_int32(word_indices / vocab_size)
|
||||
|
||||
|
|
@ -531,7 +530,7 @@ def _beam_search_step(time, logits, beam_state, batch_size, beam_width,
|
|||
gather_from=previously_finished,
|
||||
range_input=batch_size,
|
||||
range_size=beam_width,
|
||||
final_shape=[static_batch_size, beam_width])
|
||||
gather_shape=[-1])
|
||||
next_finished = math_ops.logical_or(previously_finished,
|
||||
math_ops.equal(next_word_ids, end_token))
|
||||
|
||||
|
|
@ -547,11 +546,26 @@ def _beam_search_step(time, logits, beam_state, batch_size, beam_width,
|
|||
gather_from=beam_state.lengths,
|
||||
range_input=batch_size,
|
||||
range_size=beam_width,
|
||||
final_shape=[static_batch_size, beam_width])
|
||||
gather_shape=[-1])
|
||||
next_prediction_len += lengths_to_add
|
||||
|
||||
# Pick out the cell_states according to the next_beam_ids. We use a
|
||||
# different gather_shape here because the cell_state tensors, i.e.
|
||||
# the tensors that would be gathered from, all have dimension
|
||||
# greater than two and we need to preserve those dimensions.
|
||||
# pylint: disable=g-long-lambda
|
||||
next_cell_state = nest.map_structure(
|
||||
lambda gather_from: _maybe_tensor_gather_helper(
|
||||
gather_indices=next_beam_ids,
|
||||
gather_from=gather_from,
|
||||
range_input=batch_size,
|
||||
range_size=beam_width,
|
||||
gather_shape=[batch_size * beam_width, -1]),
|
||||
next_cell_state)
|
||||
# pylint: enable=g-long-lambda
|
||||
|
||||
next_state = BeamSearchDecoderState(
|
||||
cell_state=beam_state.cell_state,
|
||||
cell_state=next_cell_state,
|
||||
log_probs=next_beam_probs,
|
||||
lengths=next_prediction_len,
|
||||
finished=next_finished)
|
||||
|
|
@ -637,12 +651,67 @@ def _mask_probs(probs, eos_token, finished):
|
|||
return finished_examples + non_finished_examples
|
||||
|
||||
|
||||
def _maybe_tensor_gather_helper(gather_indices, gather_from, range_input,
|
||||
range_size, gather_shape):
|
||||
"""Maybe applies _tensor_gather_helper.
|
||||
|
||||
This applies _tensor_gather_helper when the gather_from dims is at least as
|
||||
big as the length of gather_shape. This is used in conjunction with nest so
|
||||
that we don't apply _tensor_gather_helper to inapplicable values like scalars.
|
||||
|
||||
Args:
|
||||
gather_indices: The tensor indices that we use to gather.
|
||||
gather_from: The tensor that we are gathering from.
|
||||
range_input: The input range to use. Likely equal to batch_size.
|
||||
range_size: The number of values in each range. Likely equal to beam_width.
|
||||
gather_shape: What we should reshape gather_from to in order to preserve the
|
||||
correct values. An example is when gather_from is the attention from an
|
||||
AttentionWrapperState with shape [batch_size, beam_width, attention_size].
|
||||
There, we want to preserve the attention_size elements, so gather_shape is
|
||||
[batch_size * beam_width, -1]. Then, upon reshape, we still have the
|
||||
attention_size as desired.
|
||||
|
||||
Returns:
|
||||
output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
|
||||
or the original tensor if its dimensions are too small.
|
||||
"""
|
||||
_check_maybe(gather_from)
|
||||
if gather_from.shape.ndims >= len(gather_shape):
|
||||
return _tensor_gather_helper(gather_indices, gather_from, range_input,
|
||||
range_size, gather_shape)
|
||||
else:
|
||||
return gather_from
|
||||
|
||||
|
||||
def _tensor_gather_helper(gather_indices, gather_from, range_input, range_size,
|
||||
final_shape):
|
||||
gather_shape):
|
||||
"""Helper for gathering the right indices from the tensor.
|
||||
|
||||
This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
|
||||
gathering from that according to the gather_indices, which are offset by
|
||||
the right amounts in order to preserve the batch order.
|
||||
|
||||
Args:
|
||||
gather_indices: The tensor indices that we use to gather.
|
||||
gather_from: The tensor that we are gathering from.
|
||||
range_input: The input range to use. Likely equal to batch_size.
|
||||
range_size: The number of values in each range. Likely equal to beam_width.
|
||||
gather_shape: What we should reshape gather_from to in order to preserve the
|
||||
correct values. An example is when gather_from is the attention from an
|
||||
AttentionWrapperState with shape [batch_size, beam_width, attention_size].
|
||||
There, we want to preserve the attention_size elements, so gather_shape is
|
||||
[batch_size * beam_width, -1]. Then, upon reshape, we still have the
|
||||
attention_size as desired.
|
||||
|
||||
Returns:
|
||||
output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
|
||||
"""
|
||||
range_ = array_ops.expand_dims(math_ops.range(range_input) * range_size, 1)
|
||||
gather_indices = array_ops.reshape(gather_indices + range_, [-1])
|
||||
output = array_ops.gather(
|
||||
array_ops.reshape(gather_from, [-1]), gather_indices)
|
||||
array_ops.reshape(gather_from, gather_shape), gather_indices)
|
||||
final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)]
|
||||
final_static_shape = gather_from.shape[:1 + len(gather_shape)]
|
||||
output = array_ops.reshape(output, final_shape)
|
||||
output.set_shape(final_shape)
|
||||
output.set_shape(final_static_shape)
|
||||
return output
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user