mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Switch hub to use requests because of SSL (#25083)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25083 I missed this in the last PR Test Plan: Imported from OSS Differential Revision: D17005372 Pulled By: jamesr66a fbshipit-source-id: 1200a6cd88fb9051aed8baf3162a9f8ffbf65189
This commit is contained in:
parent
85bca16a61
commit
f71ddd4292
|
|
@ -106,6 +106,11 @@ if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then
|
|||
export LC_ALL=C.UTF-8
|
||||
export LANG=C.UTF-8
|
||||
fi
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *py2* ]]; then
|
||||
pip install --user requests
|
||||
fi
|
||||
|
||||
pip install --user pytest-sugar
|
||||
"$PYTHON" \
|
||||
-m pytest \
|
||||
|
|
@ -137,3 +142,4 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
|
|||
fi
|
||||
"$ROOT_DIR/scripts/onnx/test.sh"
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,10 @@ if [[ "$BUILD_ENVIRONMENT" != *ppc64le* ]]; then
|
|||
pip_install --user mypy || true
|
||||
fi
|
||||
|
||||
if [[ $PYTHON_VERSION == "2" ]]; then
|
||||
pip_install --user requests
|
||||
fi
|
||||
|
||||
# faulthandler become built-in since 3.3
|
||||
if [[ ! $(python -c "import sys; print(int(sys.version_info >= (3, 3)))") == "1" ]]; then
|
||||
pip_install --user faulthandler
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -352,6 +352,9 @@ install_requires = []
|
|||
if sys.version_info <= (2, 7):
|
||||
install_requires += ['future']
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
install_requires += ['requests']
|
||||
|
||||
missing_pydep = '''
|
||||
Missing build dependency: Unable to `import {importname}`.
|
||||
Please install it via `conda install {module}` or `pip install {module}`
|
||||
|
|
|
|||
42
torch/hub.py
42
torch/hub.py
|
|
@ -4,7 +4,6 @@ import hashlib
|
|||
import os
|
||||
import re
|
||||
import shutil
|
||||
import ssl
|
||||
import sys
|
||||
import tempfile
|
||||
import torch
|
||||
|
|
@ -13,7 +12,7 @@ import zipfile
|
|||
|
||||
if sys.version_info[0] == 2:
|
||||
from urlparse import urlparse
|
||||
from urllib2 import urlopen # noqa f811
|
||||
import requests
|
||||
else:
|
||||
from urllib.request import urlopen
|
||||
from urllib.parse import urlparse # noqa: F401
|
||||
|
|
@ -94,19 +93,12 @@ def _git_archive_link(repo_owner, repo_name, branch):
|
|||
|
||||
def _download_archive_zip(url, filename):
|
||||
sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, filename))
|
||||
# TODO: This is to get around CA issues on Python2, where urllib can't
|
||||
# verify the cert from the github server. Another solution is to do:
|
||||
#
|
||||
# import certifi
|
||||
# ...
|
||||
# urlopen(url, cafile=certifi.where())
|
||||
#
|
||||
# But it requires adding a dependency on the `certifi` package
|
||||
# We use a different API for python2 since urllib(2) doesn't recognize the CA
|
||||
# certificates in older Python
|
||||
if sys.version_info[0] == 2:
|
||||
context = ssl._create_unverified_context()
|
||||
response = requests.get(url, stream=True).raw
|
||||
else:
|
||||
context = None
|
||||
response = urlopen(url, context=context)
|
||||
response = urlopen(url)
|
||||
with open(filename, 'wb') as f:
|
||||
while True:
|
||||
data = response.read(READ_DATA_CHUNK)
|
||||
|
|
@ -382,14 +374,24 @@ def load(github, model, *args, **kwargs):
|
|||
|
||||
def _download_url_to_file(url, dst, hash_prefix, progress):
|
||||
file_size = None
|
||||
u = urlopen(url)
|
||||
meta = u.info()
|
||||
if hasattr(meta, 'getheaders'):
|
||||
content_length = meta.getheaders("Content-Length")
|
||||
# We use a different API for python2 since urllib(2) doesn't recognize the CA
|
||||
# certificates in older Python
|
||||
if sys.version_info[0] == 2:
|
||||
response = requests.get(url, stream=True)
|
||||
|
||||
content_length = response.headers['Content-Length']
|
||||
file_size = content_length
|
||||
u = response.raw
|
||||
else:
|
||||
content_length = meta.get_all("Content-Length")
|
||||
if content_length is not None and len(content_length) > 0:
|
||||
file_size = int(content_length[0])
|
||||
u = urlopen(url)
|
||||
|
||||
meta = u.info()
|
||||
if hasattr(meta, 'getheaders'):
|
||||
content_length = meta.getheaders("Content-Length")
|
||||
else:
|
||||
content_length = meta.get_all("Content-Length")
|
||||
if content_length is not None and len(content_length) > 0:
|
||||
file_size = int(content_length[0])
|
||||
|
||||
# We deliberately save it in a temp file and move it after
|
||||
# download is complete. This prevents a local working checkpoint
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user