from caffe2.python import workspace, crf from caffe2.python.cnn import CNNModelHelper from caffe2.python.crf_predict import crf_update_predictions from caffe2.python.test_util import TestCase import hypothesis.strategies as st from hypothesis import given, settings import numpy as np class TestCrfDecode(TestCase): @given(num_tags=st.integers(2, 4), num_words=st.integers(2, 15)) @settings(deadline=2000) def test_crf_viterbi(self, num_tags, num_words): model = CNNModelHelper(name='external') predictions = np.random.randn(num_words, num_tags).astype(np.float32) transitions = np.random.uniform( low=-1, high=1, size=(num_tags + 2, num_tags + 2) ).astype(np.float32) predictions_blob, transitions_blob = ( model.net.AddExternalInputs('predictions', 'crf_transitions') ) workspace.FeedBlob(str(transitions_blob), transitions) workspace.FeedBlob(str(predictions_blob), predictions) crf_layer = crf.CRFWithLoss(model, num_tags, transitions_blob) updated_predictions = crf_update_predictions( model, crf_layer, predictions_blob ) ref_predictions = crf_layer.update_predictions(predictions_blob) workspace.RunNetOnce(model.param_init_net) workspace.RunNetOnce(model.net) updated_predictions = workspace.FetchBlob(str(updated_predictions)) ref_predictions = workspace.FetchBlob(str(ref_predictions)) np.testing.assert_allclose( updated_predictions, ref_predictions, atol=1e-4, rtol=1e-4, err_msg='Mismatch in CRF predictions' )