Implement CRF decode (Viterbi decode) for tensor (#12056)

* Implement CRF decoding for tensors

* add test code for tensor version's CRF decoding

* made modifications according to pylint

* add some comments for crf decode

* remove useless code

* add comments at the top comment of crf module and add more comments in crf_test

* capitalize first char of first word in comments

* replace crf_decode test code with a deterministic example
This commit is contained in:
QingYing Chen 2017-08-11 11:31:50 +08:00 committed by Rasmus Munk Larsen
parent 2b374f7d4b
commit 26719d29fb
2 changed files with 215 additions and 3 deletions

View File

@ -23,6 +23,7 @@ import itertools
import numpy as np
from tensorflow.contrib.crf.python.ops import crf
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -199,6 +200,52 @@ class CrfTest(test.TestCase):
self.assertEqual(actual_max_sequence,
expected_max_sequence[:sequence_lengths])
def testCrfDecode(self):
inputs = np.array(
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
with self.test_session() as sess:
all_sequence_scores = []
all_sequences = []
# Compare the dynamic program with brute force computation.
for tag_indices in itertools.product(
range(num_tags), repeat=sequence_lengths):
tag_indices = list(tag_indices)
tag_indices.extend([0] * (num_words - sequence_lengths))
all_sequences.append(tag_indices)
sequence_score = crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
transition_params=constant_op.constant(transition_params))
sequence_score = array_ops.squeeze(sequence_score, [0])
all_sequence_scores.append(sequence_score)
tf_all_sequence_scores = sess.run(all_sequence_scores)
expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
expected_max_sequence = all_sequences[expected_max_sequence_index]
expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
actual_max_sequence, actual_max_score = crf.crf_decode(
array_ops.expand_dims(inputs, 0),
constant_op.constant(transition_params),
array_ops.expand_dims(sequence_lengths, 0))
actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0])
actual_max_score = array_ops.squeeze(actual_max_score, [0])
tf_actual_max_sequence, tf_actual_max_score = sess.run(
[actual_max_sequence, actual_max_score])
self.assertAllClose(tf_actual_max_score, expected_max_score)
self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]),
expected_max_sequence[:sequence_lengths])
if __name__ == "__main__":
test.main()

View File

@ -16,13 +16,24 @@
The following snippet is an example of a CRF layer on top of a batched sequence
of unary scores (logits for every word). This example also decodes the most
likely sequence at test time:
likely sequence at test time. There are two ways to do decoding. One
is using crf_decode to do decoding in Tensorflow , and the other one is using
viterbi_decode in Numpy.
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
unary_scores, gold_tags, sequence_lengths)
loss = tf.reduce_mean(-log_likelihood)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
# Decoding in Tensorflow.
viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode(
unary_scores, transition_params, sequence_lengths)
tf_viterbi_sequence, tf_viterbi_score, _ = session.run(
[viterbi_sequence, viterbi_score, train_op])
# Decoding in Numpy.
tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run(
[unary_scores, sequence_lengths, transition_params, train_op])
for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
@ -31,7 +42,7 @@ for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_]
# Compute the highest score and its tag sequence.
viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(
tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode(
tf_unary_scores_, tf_transition_params)
"""
@ -43,6 +54,7 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
@ -50,7 +62,9 @@ from tensorflow.python.ops import variable_scope as vs
__all__ = [
"crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell", "viterbi_decode"
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
"viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell",
"CrfDecodeBackwardRnnCell"
]
@ -310,3 +324,154 @@ def viterbi_decode(score, transition_params):
viterbi_score = np.max(trellis[-1])
return viterbi, viterbi_score
class CrfDecodeForwardRnnCell(rnn_cell.RNNCell):
"""Computes the forward decoding in a linear-chain CRF.
"""
def __init__(self, transition_params):
"""Initialize the CrfDecodeForwardRnnCell.
Args:
transition_params: A [num_tags, num_tags] matrix of binary
potentials. This matrix is expanded into a
[1, num_tags, num_tags] in preparation for the broadcast
summation occurring within the cell.
"""
self._transition_params = array_ops.expand_dims(transition_params, 0)
self._num_tags = transition_params.get_shape()[0].value
@property
def state_size(self):
return self._num_tags
@property
def output_size(self):
return self._num_tags
def __call__(self, inputs, state, scope=None):
"""Build the CrfDecodeForwardRnnCell.
Args:
inputs: A [batch_size, num_tags] matrix of unary potentials.
state: A [batch_size, num_tags] matrix containing the previous step's
score values.
scope: Unused variable scope of this cell.
Returns:
backpointers: [batch_size, num_tags], containing backpointers.
new_state: [batch_size, num_tags], containing new score values.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
state = array_ops.expand_dims(state, 2) # [B, O, 1]
# This addition op broadcasts self._transitions_params along the zeroth
# dimension and state along the second dimension.
# [B, O, 1] + [1, O, O] -> [B, O, O]
transition_scores = state + self._transition_params # [B, O, O]
new_state = inputs + math_ops.reduce_max(transition_scores, [1]) # [B, O]
backpointers = math_ops.argmax(transition_scores, 1)
backpointers = math_ops.cast(backpointers, dtype=dtypes.int32) # [B, O]
return backpointers, new_state
class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
"""Computes backward decoding in a linear-chain CRF.
"""
def __init__(self, num_tags):
"""Initialize the CrfDecodeBackwardRnnCell.
Args:
num_tags
"""
self._num_tags = num_tags
@property
def state_size(self):
return 1
@property
def output_size(self):
return 1
def __call__(self, inputs, state, scope=None):
"""Build the CrfDecodeBackwardRnnCell.
Args:
inputs: [batch_size, num_tags], backpointer of next step (in time order).
state: [batch_size, 1], next position's tag index.
scope: Unused variable scope of this cell.
Returns:
new_tags, new_tags: A pair of [batch_size, num_tags]
tensors containing the new tag indices.
"""
state = array_ops.squeeze(state, axis=[1]) # [B]
batch_size = array_ops.shape(inputs)[0]
b_indices = math_ops.range(batch_size) # [B]
indices = array_ops.stack([b_indices, state], axis=1) # [B, 2]
new_tags = array_ops.expand_dims(
gen_array_ops.gather_nd(inputs, indices), # [B]
axis=-1) # [B, 1]
return new_tags, new_tags
def crf_decode(potentials, transition_params, sequence_length):
"""Decode the highest scoring sequence of tags in TensorFlow.
This is a function for tensor.
Args:
potentials: A [batch_size, max_seq_len, num_tags] tensor, matrix of
unary potentials.
transition_params: A [num_tags, num_tags] tensor, matrix of
binary potentials.
sequence_length: A [batch_size] tensor, containing sequence lengths.
Returns:
decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
Contains the highest scoring tag indicies.
best_score: A [batch_size] tensor, containing the score of decode_tags.
"""
# For simplicity, in shape comments, denote:
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
num_tags = potentials.get_shape()[2].value
# Computes forward decoding. Get last score and backpointers.
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
backpointers, last_score = rnn.dynamic_rnn(
crf_fwd_cell,
inputs=inputs,
sequence_length=sequence_length - 1,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32) # [B, T - 1, O], [B, O]
backpointers = gen_array_ops.reverse_sequence(
backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O]
# Computes backward decoding. Extract tag indices from backpointers.
crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
dtype=dtypes.int32) # [B]
initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
decode_tags, _ = rnn.dynamic_rnn(
crf_bwd_cell,
inputs=backpointers,
sequence_length=sequence_length - 1,
initial_state=initial_state,
time_major=False,
dtype=dtypes.int32) # [B, T - 1, 1]
decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T]
decode_tags = gen_array_ops.reverse_sequence(
decode_tags, sequence_length, seq_dim=1) # [B, T]
best_score = math_ops.reduce_max(last_score, axis=1) # [B]
return decode_tags, best_score