#! /usr/bin/env python3 import os import subprocess import sys import tarfile import tempfile from urllib.request import urlretrieve from caffe2.python.models.download import downloadFromURLToFile, getURLFromName, deleteDirectory class SomeClass: # largely copied from # https://github.com/onnx/onnx-caffe2/blob/master/tests/caffe2_ref_test.py def _download(self, model): model_dir = self._caffe2_model_dir(model) assert not os.path.exists(model_dir) os.makedirs(model_dir) for f in ['predict_net.pb', 'init_net.pb', 'value_info.json']: url = getURLFromName(model, f) dest = os.path.join(model_dir, f) try: try: downloadFromURLToFile(url, dest, show_progress=False) except TypeError: # show_progress not supported prior to # Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1 # (Sep 17, 2017) downloadFromURLToFile(url, dest) except Exception as e: print("Abort: {reason}".format(reason=e)) print("Cleaning up...") deleteDirectory(model_dir) exit(1) def _caffe2_model_dir(self, model): caffe2_home = os.path.expanduser('~/.caffe2') models_dir = os.path.join(caffe2_home, 'models') return os.path.join(models_dir, model) def _onnx_model_dir(self, model): onnx_home = os.path.expanduser('~/.onnx') models_dir = os.path.join(onnx_home, 'models') model_dir = os.path.join(models_dir, model) return model_dir, os.path.dirname(model_dir) # largely copied from # https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py def _prepare_model_data(self, model): model_dir, models_dir = self._onnx_model_dir(model) if os.path.exists(model_dir): return os.makedirs(model_dir) url = 'https://s3.amazonaws.com/download.onnx/models/{}.tar.gz'.format(model) # On Windows, NamedTemporaryFile cannot be opened for a # second time download_file = tempfile.NamedTemporaryFile(delete=False) try: download_file.close() print('Start downloading model {} from {}'.format(model, url)) urlretrieve(url, download_file.name) print('Done') with tarfile.open(download_file.name) as t: t.extractall(models_dir) except Exception as e: print('Failed to prepare data for model {}: {}'.format(model, e)) raise finally: os.remove(download_file.name) models = [ 'bvlc_alexnet', 'densenet121', 'inception_v1', 'inception_v2', 'resnet50', # TODO currently onnx can't translate squeezenet :( # 'squeezenet', 'vgg16', # TODO currently vgg19 doesn't work in the CI environment, # possibly due to OOM # 'vgg19' ] def download_models(): sc = SomeClass() for model in models: print('update-caffe2-models.py: downloading', model) caffe2_model_dir = sc._caffe2_model_dir(model) onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) if not os.path.exists(caffe2_model_dir): sc._download(model) if not os.path.exists(onnx_model_dir): sc._prepare_model_data(model) def generate_models(): sc = SomeClass() for model in models: print('update-caffe2-models.py: generating', model) caffe2_model_dir = sc._caffe2_model_dir(model) onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) subprocess.check_call(['echo', model]) with open(os.path.join(caffe2_model_dir, 'value_info.json'), 'r') as f: value_info = f.read() subprocess.check_call([ 'convert-caffe2-to-onnx', '--caffe2-net-name', model, '--caffe2-init-net', os.path.join(caffe2_model_dir, 'init_net.pb'), '--value-info', value_info, '-o', os.path.join(onnx_model_dir, 'model.pb'), os.path.join(caffe2_model_dir, 'predict_net.pb') ]) subprocess.check_call([ 'tar', '-czf', model + '.tar.gz', model ], cwd=onnx_models_dir) def upload_models(): sc = SomeClass() for model in models: print('update-caffe2-models.py: uploading', model) onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) subprocess.check_call([ 'aws', 's3', 'cp', model + '.tar.gz', "s3://download.onnx/models/{}.tar.gz".format(model), '--acl', 'public-read' ], cwd=onnx_models_dir) def cleanup(): sc = SomeClass() for model in models: onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model) os.remove(os.path.join(os.path.dirname(onnx_model_dir), model + '.tar.gz')) if __name__ == '__main__': try: subprocess.check_call(['aws', 'sts', 'get-caller-identity']) except: print('update-caffe2-models.py: please run `aws configure` manually to set up credentials') sys.exit(1) if sys.argv[1] == 'download': download_models() if sys.argv[1] == 'generate': generate_models() elif sys.argv[1] == 'upload': upload_models() elif sys.argv[1] == 'cleanup': cleanup()