mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Implement CRF decode (Viterbi decode) for tensor (#12056)
* Implement CRF decoding for tensors * add test code for tensor version's CRF decoding * made modifications according to pylint * add some comments for crf decode * remove useless code * add comments at the top comment of crf module and add more comments in crf_test * capitalize first char of first word in comments * replace crf_decode test code with a deterministic example
This commit is contained in:
parent
2b374f7d4b
commit
26719d29fb
|
|
@ -23,6 +23,7 @@ import itertools
|
|||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.crf.python.ops import crf
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
|
@ -199,6 +200,52 @@ class CrfTest(test.TestCase):
|
|||
self.assertEqual(actual_max_sequence,
|
||||
expected_max_sequence[:sequence_lengths])
|
||||
|
||||
def testCrfDecode(self):
|
||||
inputs = np.array(
|
||||
[[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
|
||||
transition_params = np.array(
|
||||
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
|
||||
sequence_lengths = np.array(3, dtype=np.int32)
|
||||
num_words = inputs.shape[0]
|
||||
num_tags = inputs.shape[1]
|
||||
|
||||
with self.test_session() as sess:
|
||||
all_sequence_scores = []
|
||||
all_sequences = []
|
||||
|
||||
# Compare the dynamic program with brute force computation.
|
||||
for tag_indices in itertools.product(
|
||||
range(num_tags), repeat=sequence_lengths):
|
||||
tag_indices = list(tag_indices)
|
||||
tag_indices.extend([0] * (num_words - sequence_lengths))
|
||||
all_sequences.append(tag_indices)
|
||||
sequence_score = crf.crf_sequence_score(
|
||||
inputs=array_ops.expand_dims(inputs, 0),
|
||||
tag_indices=array_ops.expand_dims(tag_indices, 0),
|
||||
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
|
||||
transition_params=constant_op.constant(transition_params))
|
||||
sequence_score = array_ops.squeeze(sequence_score, [0])
|
||||
all_sequence_scores.append(sequence_score)
|
||||
|
||||
tf_all_sequence_scores = sess.run(all_sequence_scores)
|
||||
|
||||
expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
|
||||
expected_max_sequence = all_sequences[expected_max_sequence_index]
|
||||
expected_max_score = tf_all_sequence_scores[expected_max_sequence_index]
|
||||
|
||||
actual_max_sequence, actual_max_score = crf.crf_decode(
|
||||
array_ops.expand_dims(inputs, 0),
|
||||
constant_op.constant(transition_params),
|
||||
array_ops.expand_dims(sequence_lengths, 0))
|
||||
actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0])
|
||||
actual_max_score = array_ops.squeeze(actual_max_score, [0])
|
||||
tf_actual_max_sequence, tf_actual_max_score = sess.run(
|
||||
[actual_max_sequence, actual_max_score])
|
||||
|
||||
self.assertAllClose(tf_actual_max_score, expected_max_score)
|
||||
self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]),
|
||||
expected_max_sequence[:sequence_lengths])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
|||
|
|
@ -16,13 +16,24 @@
|
|||
|
||||
The following snippet is an example of a CRF layer on top of a batched sequence
|
||||
of unary scores (logits for every word). This example also decodes the most
|
||||
likely sequence at test time:
|
||||
likely sequence at test time. There are two ways to do decoding. One
|
||||
is using crf_decode to do decoding in Tensorflow , and the other one is using
|
||||
viterbi_decode in Numpy.
|
||||
|
||||
log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(
|
||||
unary_scores, gold_tags, sequence_lengths)
|
||||
|
||||
loss = tf.reduce_mean(-log_likelihood)
|
||||
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
|
||||
|
||||
# Decoding in Tensorflow.
|
||||
viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode(
|
||||
unary_scores, transition_params, sequence_lengths)
|
||||
|
||||
tf_viterbi_sequence, tf_viterbi_score, _ = session.run(
|
||||
[viterbi_sequence, viterbi_score, train_op])
|
||||
|
||||
# Decoding in Numpy.
|
||||
tf_unary_scores, tf_sequence_lengths, tf_transition_params, _ = session.run(
|
||||
[unary_scores, sequence_lengths, transition_params, train_op])
|
||||
for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
|
||||
|
|
@ -31,7 +42,7 @@ for tf_unary_scores_, tf_sequence_length_ in zip(tf_unary_scores,
|
|||
tf_unary_scores_ = tf_unary_scores_[:tf_sequence_length_]
|
||||
|
||||
# Compute the highest score and its tag sequence.
|
||||
viterbi_sequence, viterbi_score = tf.contrib.crf.viterbi_decode(
|
||||
tf_viterbi_sequence, tf_viterbi_score = tf.contrib.crf.viterbi_decode(
|
||||
tf_unary_scores_, tf_transition_params)
|
||||
"""
|
||||
|
||||
|
|
@ -43,6 +54,7 @@ import numpy as np
|
|||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn
|
||||
from tensorflow.python.ops import rnn_cell
|
||||
|
|
@ -50,7 +62,9 @@ from tensorflow.python.ops import variable_scope as vs
|
|||
|
||||
__all__ = [
|
||||
"crf_sequence_score", "crf_log_norm", "crf_log_likelihood",
|
||||
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell", "viterbi_decode"
|
||||
"crf_unary_score", "crf_binary_score", "CrfForwardRnnCell",
|
||||
"viterbi_decode", "crf_decode", "CrfDecodeForwardRnnCell",
|
||||
"CrfDecodeBackwardRnnCell"
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -310,3 +324,154 @@ def viterbi_decode(score, transition_params):
|
|||
|
||||
viterbi_score = np.max(trellis[-1])
|
||||
return viterbi, viterbi_score
|
||||
|
||||
|
||||
class CrfDecodeForwardRnnCell(rnn_cell.RNNCell):
|
||||
"""Computes the forward decoding in a linear-chain CRF.
|
||||
"""
|
||||
|
||||
def __init__(self, transition_params):
|
||||
"""Initialize the CrfDecodeForwardRnnCell.
|
||||
|
||||
Args:
|
||||
transition_params: A [num_tags, num_tags] matrix of binary
|
||||
potentials. This matrix is expanded into a
|
||||
[1, num_tags, num_tags] in preparation for the broadcast
|
||||
summation occurring within the cell.
|
||||
"""
|
||||
self._transition_params = array_ops.expand_dims(transition_params, 0)
|
||||
self._num_tags = transition_params.get_shape()[0].value
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._num_tags
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._num_tags
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
"""Build the CrfDecodeForwardRnnCell.
|
||||
|
||||
Args:
|
||||
inputs: A [batch_size, num_tags] matrix of unary potentials.
|
||||
state: A [batch_size, num_tags] matrix containing the previous step's
|
||||
score values.
|
||||
scope: Unused variable scope of this cell.
|
||||
|
||||
Returns:
|
||||
backpointers: [batch_size, num_tags], containing backpointers.
|
||||
new_state: [batch_size, num_tags], containing new score values.
|
||||
"""
|
||||
# For simplicity, in shape comments, denote:
|
||||
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
|
||||
state = array_ops.expand_dims(state, 2) # [B, O, 1]
|
||||
|
||||
# This addition op broadcasts self._transitions_params along the zeroth
|
||||
# dimension and state along the second dimension.
|
||||
# [B, O, 1] + [1, O, O] -> [B, O, O]
|
||||
transition_scores = state + self._transition_params # [B, O, O]
|
||||
new_state = inputs + math_ops.reduce_max(transition_scores, [1]) # [B, O]
|
||||
backpointers = math_ops.argmax(transition_scores, 1)
|
||||
backpointers = math_ops.cast(backpointers, dtype=dtypes.int32) # [B, O]
|
||||
return backpointers, new_state
|
||||
|
||||
|
||||
class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
|
||||
"""Computes backward decoding in a linear-chain CRF.
|
||||
"""
|
||||
|
||||
def __init__(self, num_tags):
|
||||
"""Initialize the CrfDecodeBackwardRnnCell.
|
||||
|
||||
Args:
|
||||
num_tags
|
||||
"""
|
||||
self._num_tags = num_tags
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return 1
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return 1
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
"""Build the CrfDecodeBackwardRnnCell.
|
||||
|
||||
Args:
|
||||
inputs: [batch_size, num_tags], backpointer of next step (in time order).
|
||||
state: [batch_size, 1], next position's tag index.
|
||||
scope: Unused variable scope of this cell.
|
||||
|
||||
Returns:
|
||||
new_tags, new_tags: A pair of [batch_size, num_tags]
|
||||
tensors containing the new tag indices.
|
||||
"""
|
||||
state = array_ops.squeeze(state, axis=[1]) # [B]
|
||||
batch_size = array_ops.shape(inputs)[0]
|
||||
b_indices = math_ops.range(batch_size) # [B]
|
||||
indices = array_ops.stack([b_indices, state], axis=1) # [B, 2]
|
||||
new_tags = array_ops.expand_dims(
|
||||
gen_array_ops.gather_nd(inputs, indices), # [B]
|
||||
axis=-1) # [B, 1]
|
||||
|
||||
return new_tags, new_tags
|
||||
|
||||
|
||||
def crf_decode(potentials, transition_params, sequence_length):
|
||||
"""Decode the highest scoring sequence of tags in TensorFlow.
|
||||
|
||||
This is a function for tensor.
|
||||
|
||||
Args:
|
||||
potentials: A [batch_size, max_seq_len, num_tags] tensor, matrix of
|
||||
unary potentials.
|
||||
transition_params: A [num_tags, num_tags] tensor, matrix of
|
||||
binary potentials.
|
||||
sequence_length: A [batch_size] tensor, containing sequence lengths.
|
||||
|
||||
Returns:
|
||||
decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
|
||||
Contains the highest scoring tag indicies.
|
||||
best_score: A [batch_size] tensor, containing the score of decode_tags.
|
||||
"""
|
||||
# For simplicity, in shape comments, denote:
|
||||
# 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
|
||||
num_tags = potentials.get_shape()[2].value
|
||||
|
||||
# Computes forward decoding. Get last score and backpointers.
|
||||
crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
|
||||
initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
|
||||
initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O]
|
||||
inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O]
|
||||
backpointers, last_score = rnn.dynamic_rnn(
|
||||
crf_fwd_cell,
|
||||
inputs=inputs,
|
||||
sequence_length=sequence_length - 1,
|
||||
initial_state=initial_state,
|
||||
time_major=False,
|
||||
dtype=dtypes.int32) # [B, T - 1, O], [B, O]
|
||||
backpointers = gen_array_ops.reverse_sequence(
|
||||
backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O]
|
||||
|
||||
# Computes backward decoding. Extract tag indices from backpointers.
|
||||
crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
|
||||
initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
|
||||
dtype=dtypes.int32) # [B]
|
||||
initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1]
|
||||
decode_tags, _ = rnn.dynamic_rnn(
|
||||
crf_bwd_cell,
|
||||
inputs=backpointers,
|
||||
sequence_length=sequence_length - 1,
|
||||
initial_state=initial_state,
|
||||
time_major=False,
|
||||
dtype=dtypes.int32) # [B, T - 1, 1]
|
||||
decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1]
|
||||
decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T]
|
||||
decode_tags = gen_array_ops.reverse_sequence(
|
||||
decode_tags, sequence_length, seq_dim=1) # [B, T]
|
||||
|
||||
best_score = math_ops.reduce_max(last_score, axis=1) # [B]
|
||||
return decode_tags, best_score
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user