# This a large test that goes through the translation of the bvlc caffenet # model, runs an example through the whole model, and verifies numerically # that all the results look right. In default, it is disabled unless you # explicitly want to run it. from google.protobuf import text_format import numpy as np import os import sys CAFFE_FOUND = False try: from caffe.proto import caffe_pb2 from caffe2.python import caffe_translator CAFFE_FOUND = True except Exception as e: # Safeguard so that we only catch the caffe module not found exception. if ("'caffe'" in str(e)): print( "PyTorch/Caffe2 now requires a separate installation of caffe. " "Right now, this is not found, so we will skip the caffe " "translator test.") from caffe2.python import utils, workspace, test_util import unittest def setUpModule(): # Do nothing if caffe and test data is not found if not (CAFFE_FOUND and os.path.exists('data/testdata/caffe_translator')): return # We will do all the computation stuff in the global space. caffenet = caffe_pb2.NetParameter() caffenet_pretrained = caffe_pb2.NetParameter() with open('data/testdata/caffe_translator/deploy.prototxt') as f: text_format.Merge(f.read(), caffenet) with open('data/testdata/caffe_translator/' 'bvlc_reference_caffenet.caffemodel') as f: caffenet_pretrained.ParseFromString(f.read()) for remove_legacy_pad in [True, False]: net, pretrained_params = caffe_translator.TranslateModel( caffenet, caffenet_pretrained, is_test=True, remove_legacy_pad=remove_legacy_pad ) with open('data/testdata/caffe_translator/' 'bvlc_reference_caffenet.translatedmodel', 'w') as fid: fid.write(str(net)) for param in pretrained_params.protos: workspace.FeedBlob(param.name, utils.Caffe2TensorToNumpyArray(param)) # Let's also feed in the data from the Caffe test code. data = np.load('data/testdata/caffe_translator/data_dump.npy').astype( np.float32) workspace.FeedBlob('data', data) # Actually running the test. workspace.RunNetOnce(net.SerializeToString()) @unittest.skipIf(not CAFFE_FOUND, 'No Caffe installation found.') @unittest.skipIf(not os.path.exists('data/testdata/caffe_translator'), 'No testdata existing for the caffe translator test. Exiting.') class TestNumericalEquivalence(test_util.TestCase): def testBlobs(self): names = [ "conv1", "pool1", "norm1", "conv2", "pool2", "norm2", "conv3", "conv4", "conv5", "pool5", "fc6", "fc7", "fc8", "prob" ] for name in names: print('Verifying {}'.format(name)) caffe2_result = workspace.FetchBlob(name) reference = np.load( 'data/testdata/caffe_translator/' + name + '_dump.npy' ) self.assertEqual(caffe2_result.shape, reference.shape) scale = np.max(caffe2_result) np.testing.assert_almost_equal( caffe2_result / scale, reference / scale, decimal=5 ) if __name__ == '__main__': if len(sys.argv) == 1: print( 'If you do not explicitly ask to run this test, I will not run it. ' 'Pass in any argument to have the test run for you.' ) sys.exit(0) unittest.main()