pytorch/scripts/model_zoo/update-caffe2-models.py
Aaron Gokaslan 6d725e7d66 [BE]: enable ruff rules PLR1722 and PLW3301 (#109461)
Enables two ruff rules derived from pylint:
* PLR1722 replaces any exit() calls with sys.exit(). exit() is only designed to be used in repl contexts as may not always be imported by default. This always use the version in the sys module which is better
* PLW3301 replaces nested min / max calls with simplified versions (ie. `min(a, min(b, c))` => `min(a, b. c)`). The new version is more idiomatic and more efficient.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109461
Approved by: https://github.com/ezyang
2023-09-18 02:07:21 +00:00

176 lines
5.5 KiB
Python
Executable File

#! /usr/bin/env python3
import os
import subprocess
import sys
import tarfile
import tempfile
from urllib.request import urlretrieve
from caffe2.python.models.download import (
deleteDirectory,
downloadFromURLToFile,
getURLFromName,
)
class SomeClass:
# largely copied from
# https://github.com/onnx/onnx-caffe2/blob/master/tests/caffe2_ref_test.py
def _download(self, model):
model_dir = self._caffe2_model_dir(model)
assert not os.path.exists(model_dir)
os.makedirs(model_dir)
for f in ["predict_net.pb", "init_net.pb", "value_info.json"]:
url = getURLFromName(model, f)
dest = os.path.join(model_dir, f)
try:
try:
downloadFromURLToFile(url, dest, show_progress=False)
except TypeError:
# show_progress not supported prior to
# Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
# (Sep 17, 2017)
downloadFromURLToFile(url, dest)
except Exception as e:
print(f"Abort: {e}")
print("Cleaning up...")
deleteDirectory(model_dir)
sys.exit(1)
def _caffe2_model_dir(self, model):
caffe2_home = os.path.expanduser("~/.caffe2")
models_dir = os.path.join(caffe2_home, "models")
return os.path.join(models_dir, model)
def _onnx_model_dir(self, model):
onnx_home = os.path.expanduser("~/.onnx")
models_dir = os.path.join(onnx_home, "models")
model_dir = os.path.join(models_dir, model)
return model_dir, os.path.dirname(model_dir)
# largely copied from
# https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py
def _prepare_model_data(self, model):
model_dir, models_dir = self._onnx_model_dir(model)
if os.path.exists(model_dir):
return
os.makedirs(model_dir)
url = f"https://s3.amazonaws.com/download.onnx/models/{model}.tar.gz"
# On Windows, NamedTemporaryFile cannot be opened for a
# second time
download_file = tempfile.NamedTemporaryFile(delete=False)
try:
download_file.close()
print(f"Start downloading model {model} from {url}")
urlretrieve(url, download_file.name)
print("Done")
with tarfile.open(download_file.name) as t:
t.extractall(models_dir)
except Exception as e:
print(f"Failed to prepare data for model {model}: {e}")
raise
finally:
os.remove(download_file.name)
models = [
"bvlc_alexnet",
"densenet121",
"inception_v1",
"inception_v2",
"resnet50",
# TODO currently onnx can't translate squeezenet :(
# 'squeezenet',
"vgg16",
# TODO currently vgg19 doesn't work in the CI environment,
# possibly due to OOM
# 'vgg19'
]
def download_models():
sc = SomeClass()
for model in models:
print("update-caffe2-models.py: downloading", model)
caffe2_model_dir = sc._caffe2_model_dir(model)
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
if not os.path.exists(caffe2_model_dir):
sc._download(model)
if not os.path.exists(onnx_model_dir):
sc._prepare_model_data(model)
def generate_models():
sc = SomeClass()
for model in models:
print("update-caffe2-models.py: generating", model)
caffe2_model_dir = sc._caffe2_model_dir(model)
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
subprocess.check_call(["echo", model])
with open(os.path.join(caffe2_model_dir, "value_info.json"), "r") as f:
value_info = f.read()
subprocess.check_call(
[
"convert-caffe2-to-onnx",
"--caffe2-net-name",
model,
"--caffe2-init-net",
os.path.join(caffe2_model_dir, "init_net.pb"),
"--value-info",
value_info,
"-o",
os.path.join(onnx_model_dir, "model.pb"),
os.path.join(caffe2_model_dir, "predict_net.pb"),
]
)
subprocess.check_call(
["tar", "-czf", model + ".tar.gz", model], cwd=onnx_models_dir
)
def upload_models():
sc = SomeClass()
for model in models:
print("update-caffe2-models.py: uploading", model)
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
subprocess.check_call(
[
"aws",
"s3",
"cp",
model + ".tar.gz",
f"s3://download.onnx/models/{model}.tar.gz",
"--acl",
"public-read",
],
cwd=onnx_models_dir,
)
def cleanup():
sc = SomeClass()
for model in models:
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
os.remove(os.path.join(os.path.dirname(onnx_model_dir), model + ".tar.gz"))
if __name__ == "__main__":
try:
subprocess.check_call(["aws", "sts", "get-caller-identity"])
except:
print(
"update-caffe2-models.py: please run `aws configure` manually to set up credentials"
)
sys.exit(1)
if sys.argv[1] == "download":
download_models()
if sys.argv[1] == "generate":
generate_models()
elif sys.argv[1] == "upload":
upload_models()
elif sys.argv[1] == "cleanup":
cleanup()