mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enables two ruff rules derived from pylint: * PLR1722 replaces any exit() calls with sys.exit(). exit() is only designed to be used in repl contexts as may not always be imported by default. This always use the version in the sys module which is better * PLW3301 replaces nested min / max calls with simplified versions (ie. `min(a, min(b, c))` => `min(a, b. c)`). The new version is more idiomatic and more efficient. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109461 Approved by: https://github.com/ezyang
176 lines
5.5 KiB
Python
Executable File
176 lines
5.5 KiB
Python
Executable File
#! /usr/bin/env python3
|
|
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import tarfile
|
|
import tempfile
|
|
|
|
from urllib.request import urlretrieve
|
|
|
|
from caffe2.python.models.download import (
|
|
deleteDirectory,
|
|
downloadFromURLToFile,
|
|
getURLFromName,
|
|
)
|
|
|
|
|
|
class SomeClass:
|
|
# largely copied from
|
|
# https://github.com/onnx/onnx-caffe2/blob/master/tests/caffe2_ref_test.py
|
|
def _download(self, model):
|
|
model_dir = self._caffe2_model_dir(model)
|
|
assert not os.path.exists(model_dir)
|
|
os.makedirs(model_dir)
|
|
for f in ["predict_net.pb", "init_net.pb", "value_info.json"]:
|
|
url = getURLFromName(model, f)
|
|
dest = os.path.join(model_dir, f)
|
|
try:
|
|
try:
|
|
downloadFromURLToFile(url, dest, show_progress=False)
|
|
except TypeError:
|
|
# show_progress not supported prior to
|
|
# Caffe2 78c014e752a374d905ecfb465d44fa16e02a28f1
|
|
# (Sep 17, 2017)
|
|
downloadFromURLToFile(url, dest)
|
|
except Exception as e:
|
|
print(f"Abort: {e}")
|
|
print("Cleaning up...")
|
|
deleteDirectory(model_dir)
|
|
sys.exit(1)
|
|
|
|
def _caffe2_model_dir(self, model):
|
|
caffe2_home = os.path.expanduser("~/.caffe2")
|
|
models_dir = os.path.join(caffe2_home, "models")
|
|
return os.path.join(models_dir, model)
|
|
|
|
def _onnx_model_dir(self, model):
|
|
onnx_home = os.path.expanduser("~/.onnx")
|
|
models_dir = os.path.join(onnx_home, "models")
|
|
model_dir = os.path.join(models_dir, model)
|
|
return model_dir, os.path.dirname(model_dir)
|
|
|
|
# largely copied from
|
|
# https://github.com/onnx/onnx/blob/master/onnx/backend/test/runner/__init__.py
|
|
def _prepare_model_data(self, model):
|
|
model_dir, models_dir = self._onnx_model_dir(model)
|
|
if os.path.exists(model_dir):
|
|
return
|
|
os.makedirs(model_dir)
|
|
url = f"https://s3.amazonaws.com/download.onnx/models/{model}.tar.gz"
|
|
|
|
# On Windows, NamedTemporaryFile cannot be opened for a
|
|
# second time
|
|
download_file = tempfile.NamedTemporaryFile(delete=False)
|
|
try:
|
|
download_file.close()
|
|
print(f"Start downloading model {model} from {url}")
|
|
urlretrieve(url, download_file.name)
|
|
print("Done")
|
|
with tarfile.open(download_file.name) as t:
|
|
t.extractall(models_dir)
|
|
except Exception as e:
|
|
print(f"Failed to prepare data for model {model}: {e}")
|
|
raise
|
|
finally:
|
|
os.remove(download_file.name)
|
|
|
|
|
|
models = [
|
|
"bvlc_alexnet",
|
|
"densenet121",
|
|
"inception_v1",
|
|
"inception_v2",
|
|
"resnet50",
|
|
# TODO currently onnx can't translate squeezenet :(
|
|
# 'squeezenet',
|
|
"vgg16",
|
|
# TODO currently vgg19 doesn't work in the CI environment,
|
|
# possibly due to OOM
|
|
# 'vgg19'
|
|
]
|
|
|
|
|
|
def download_models():
|
|
sc = SomeClass()
|
|
for model in models:
|
|
print("update-caffe2-models.py: downloading", model)
|
|
caffe2_model_dir = sc._caffe2_model_dir(model)
|
|
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
|
|
if not os.path.exists(caffe2_model_dir):
|
|
sc._download(model)
|
|
if not os.path.exists(onnx_model_dir):
|
|
sc._prepare_model_data(model)
|
|
|
|
|
|
def generate_models():
|
|
sc = SomeClass()
|
|
for model in models:
|
|
print("update-caffe2-models.py: generating", model)
|
|
caffe2_model_dir = sc._caffe2_model_dir(model)
|
|
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
|
|
subprocess.check_call(["echo", model])
|
|
with open(os.path.join(caffe2_model_dir, "value_info.json"), "r") as f:
|
|
value_info = f.read()
|
|
subprocess.check_call(
|
|
[
|
|
"convert-caffe2-to-onnx",
|
|
"--caffe2-net-name",
|
|
model,
|
|
"--caffe2-init-net",
|
|
os.path.join(caffe2_model_dir, "init_net.pb"),
|
|
"--value-info",
|
|
value_info,
|
|
"-o",
|
|
os.path.join(onnx_model_dir, "model.pb"),
|
|
os.path.join(caffe2_model_dir, "predict_net.pb"),
|
|
]
|
|
)
|
|
subprocess.check_call(
|
|
["tar", "-czf", model + ".tar.gz", model], cwd=onnx_models_dir
|
|
)
|
|
|
|
|
|
def upload_models():
|
|
sc = SomeClass()
|
|
for model in models:
|
|
print("update-caffe2-models.py: uploading", model)
|
|
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
|
|
subprocess.check_call(
|
|
[
|
|
"aws",
|
|
"s3",
|
|
"cp",
|
|
model + ".tar.gz",
|
|
f"s3://download.onnx/models/{model}.tar.gz",
|
|
"--acl",
|
|
"public-read",
|
|
],
|
|
cwd=onnx_models_dir,
|
|
)
|
|
|
|
|
|
def cleanup():
|
|
sc = SomeClass()
|
|
for model in models:
|
|
onnx_model_dir, onnx_models_dir = sc._onnx_model_dir(model)
|
|
os.remove(os.path.join(os.path.dirname(onnx_model_dir), model + ".tar.gz"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
subprocess.check_call(["aws", "sts", "get-caller-identity"])
|
|
except:
|
|
print(
|
|
"update-caffe2-models.py: please run `aws configure` manually to set up credentials"
|
|
)
|
|
sys.exit(1)
|
|
if sys.argv[1] == "download":
|
|
download_models()
|
|
if sys.argv[1] == "generate":
|
|
generate_models()
|
|
elif sys.argv[1] == "upload":
|
|
upload_models()
|
|
elif sys.argv[1] == "cleanup":
|
|
cleanup()
|