Implement CRF decode (Viterbi decode) for tensor (#12056)

* Implement CRF decoding for tensors

* add test code for tensor version's CRF decoding

* made modifications according to pylint

* add some comments for crf decode

* remove useless code

* add comments at the top comment of crf module and add more comments in crf_test

* capitalize first char of first word in comments

* replace crf_decode test code with a deterministic example
This commit is contained in:
QingYing Chen 2017-08-11 11:31:50 +08:00 committed by Rasmus Munk Larsen
parent 2b374f7d4b
commit 26719d29fb
2 changed files with 215 additions and 3 deletions

View File

@ -23,6 +23,7 @@ import itertools
import numpy as np import numpy as np
from tensorflow.contrib.crf.python.ops import crf from tensorflow.contrib.crf.python.ops import crf
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -199,6 +200,52 @@ class CrfTest(test.TestCase):
self.assertEqual(actual_max_sequence, self.assertEqual(actual_max_sequence,
expected_max_sequence[:sequence_lengths]) expected_max_sequence[:sequence_lengths])
def testCrfDecode(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
with self.test_session() as sess:
all_sequence_scores = []
all_sequences = []
# Compare the dynamic program with brute force computation.
for tag_indices in itertools.product(
range(num_tags), repeat=sequence_lengths):
tag_indices = list(tag_indices)
tag_indices.extend([0] * (num_words - sequence_lengths))
all_sequences.append(tag_indices)
sequence_score = crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
transition_params=constant_op.constant(transition_params))
sequence_score = array_ops.squeeze(sequence_score, [0])
all_sequence_scores.append(sequence_score)
tf_all_sequence_scores = sess.run(all_sequence_scores)
expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
expected_max_sequence = all_sequences[expected_max_sequence_index]
expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
actual_max_sequence, actual_max_score = crf.crf_decode(
array_ops.expand_dims(inputs, 0),
constant_op.constant(transition_params),
array_ops.expand_dims(sequence_lengths, 0))
actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0])
actual_max_score = array_ops.squeeze(actual_max_score, [0])
tf_actual_max_sequence, tf_actual_max_score = sess.run(
[actual_max_sequence, actual_max_score])
self.assertAllClose(tf_actual_max_score, expected_max_score)
self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]),
expected_max_sequence[:sequence_lengths])
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -16,13 +16,24 @@
The following snippet is an example of a CRF layer on top of a batched sequence The following snippet is an example of a CRF layer on top of a batched sequence
of unary scores (logits for every word). This example also decodes the most of unary scores (logits for every word). This example also decodes the most
likely sequence at test time: likely sequence at test time. There are two ways to do decoding. One
is using crf_decode to do decoding in Tensorflow , and the other one is using
viterbi_decode in Numpy.
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood( log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
unary_scores, gold_tags, sequence_lengths) unary_scores, gold_tags, sequence_lengths)
loss = tf.reduce_mean(-log_likelihood) loss = tf.reduce_mean(-log_likelihood)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss) train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
# Decoding in Tensorflow.
viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode(
unary_scores, transition_params, sequence_lengths)
tf_viterbi_sequence, tf_viterbi_score, _ = session.run(
[viterbi_sequence, viterbi_score, train_op])
# Decoding in Numpy.
tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run( tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run(
[unary_scores, sequence_lengths, transition_params, train_op]) [unary_scores, sequence_lengths, transition_params, train_op])
for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores, for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
@ -31,7 +42,7 @@ for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_] tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_]
# Compute the highest score and its tag sequence. # Compute the highest score and its tag sequence.
viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode( tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode(
tf_unary_scores_, tf_transition_params) tf_unary_scores_, tf_transition_params)
""" """
@ -43,6 +54,7 @@ import numpy as np
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import rnn_cell
@ -50,7 +62,9 @@ from tensorflow.python.ops import variable_scope as vs
__all__ = [ __all__ = [
"crf_sequence_score", "crf_log_norm", "crf_log_likelihood", "crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell", "viterbi_decode" "crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
"viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell",
"CrfDecodeBackwardRnnCell"
] ]
@ -310,3 +324,154 @@ def viterbi_decode(score, transition_params):
viterbi_score = np.max(trellis[-1]) viterbi_score = np.max(trellis[-1])
return viterbi, viterbi_score return viterbi, viterbi_score
class CrfDecodeForwardRnnCell(rnn_cell.RNNCell):
"""Computes the forward decoding in a linear-chain CRF.
"""
def __init__(self, transition_params):
"""Initialize the CrfDecodeForwardRnnCell.
Args:
transition_params: A [num_tags, num_tags] matrix of binary
potentials. This matrix is expanded into a
[1, num_tags, num_tags] in preparation for the broadcast
summation occurring within the cell.
"""
self._transition_params = array_ops.expand_dims(transition_params, 0)
self._num_tags = transition_params.get_shape()[0].value
@property
def state_size(self):
return self._num_tags
@property
def output_size(self):
return self._num_tags
def __call__(self, inputs, state, scope=None):
"""Build the CrfDecodeForwardRnnCell.
Args:
inputs: A [batch_size, num_tags] matrix of unary potentials.
state: A [batch_size, num_tags] matrix containing the previous step's
score values.
scope: Unused variable scope of this cell.
Returns:
backpointers: [batch_size, num_tags], containing backpointers.
new_state: [batch_size, num_tags], containing new score values.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
state = array_ops.expand_dims(state, 2) # [B, O, 1]
# This addition op broadcasts self._transitions_params along the zeroth
# dimension and state along the second dimension.
# [B, O, 1] + [1, O, O] -> [B, O, O]
transition_scores = state + self._transition_params # [B, O, O]
new_state = inputs + math_ops.reduce_max(transition_scores, [1]) # [B, O]
backpointers = math_ops.argmax(transition_scores, 1)
backpointers = math_ops.cast(backpointers, dtype=dtypes.int32) # [B, O]
return backpointers, new_state
class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
"""Computes backward decoding in a linear-chain CRF.
"""
def __init__(self, num_tags):
"""Initialize the CrfDecodeBackwardRnnCell.
Args:
num_tags
"""
self._num_tags = num_tags
@property
def state_size(self):
return 1
@property
def output_size(self):
return 1
def __call__(self, inputs, state, scope=None):
"""Build the CrfDecodeBackwardRnnCell.
Args:
inputs: [batch_size, num_tags], backpointer of next step (in time order).
state: [batch_size, 1], next position's tag index.
scope: Unused variable scope of this cell.
Returns:
new_tags, new_tags: A pair of [batch_size, num_tags]
tensors containing the new tag indices.
"""
state = array_ops.squeeze(state, axis=[1]) # [B]
batch_size = array_ops.shape(inputs)[0]
b_indices = math_ops.range(batch_size) # [B]
indices = array_ops.stack([b_indices, state], axis=1) # [B, 2]
new_tags = array_ops.expand_dims(
gen_array_ops.gather_nd(inputs, indices), # [B]
axis=-1) # [B, 1]
return new_tags, new_tags
def crf_decode(potentials, transition_params, sequence_length):
"""Decode the highest scoring sequence of tags in TensorFlow.
This is a function for tensor.
Args:
potentials: A [batch_size, max_seq_len, num_tags] tensor, matrix of
unary potentials.
transition_params: A [num_tags, num_tags] tensor, matrix of
binary potentials.
sequence_length: A [batch_size] tensor, containing sequence lengths.
Returns:
decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
Contains the highest scoring tag indicies.
best_score: A [batch_size] tensor, containing the score of decode_tags.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
num_tags = potentials.get_shape()[2].value
# Computes forward decoding. Get last score and backpointers.
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
backpointers, last_score = rnn.dynamic_rnn(
crf_fwd_cell,
inputs=inputs,
sequence_length=sequence_length - 1,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32) # [B, T - 1, O], [B, O]
backpointers = gen_array_ops.reverse_sequence(
backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O]
# Computes backward decoding. Extract tag indices from backpointers.
crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
dtype=dtypes.int32) # [B]
initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
decode_tags, _ = rnn.dynamic_rnn(
crf_bwd_cell,
inputs=backpointers,
sequence_length=sequence_length - 1,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32) # [B, T - 1, 1]
decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T]
decode_tags = gen_array_ops.reverse_sequence(
decode_tags, sequence_length, seq_dim=1) # [B, T]
best_score = math_ops.reduce_max(last_score, axis=1) # [B]
return decode_tags, best_score