pytorch/tools/download_mnist.py
zhouzhuojie cb6841b263 Fix ConnectionError in download_mnist (#61789)
Summary:
Fixes issues like the following error. Note that `ConnectionResetError` is a subclass of `ConnectionError`.

```
+ python tools/download_mnist.py --quiet -d test/cpp/api/mnist
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz ...
Traceback (most recent call last):
  File "tools/download_mnist.py", line 93, in <module>
    main()
  File "tools/download_mnist.py", line 86, in main
    download(path, resource, options.quiet)
  File "tools/download_mnist.py", line 42, in download
    urlretrieve(url, destination_path, reporthook=hook)
  File "/opt/conda/lib/python3.6/urllib/request.py", line 277, in urlretrieve
    block = fp.read(bs)
  File "/opt/conda/lib/python3.6/http/client.py", line 463, in read
    n = self.readinto(b)
  File "/opt/conda/lib/python3.6/http/client.py", line 507, in readinto
    n = self.fp.readinto(b)
  File "/opt/conda/lib/python3.6/socket.py", line 586, in readinto
    return self._sock.recv_into(b)
ConnectionResetError: [Errno 104] Connection reset by peer
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61789

Reviewed By: dreiss

Differential Revision: D29745459

Pulled By: zhouzhuojie

fbshipit-source-id: 2deb668bd74478f32bd01704d4362e8a4d95087b
2021-07-16 17:02:13 -07:00

94 lines
2.8 KiB
Python

import argparse
import gzip
import os
from urllib.error import URLError
from urllib.request import urlretrieve
import sys
MIRRORS = [
'http://yann.lecun.com/exdb/mnist/',
'https://ossci-datasets.s3.amazonaws.com/mnist/',
]
RESOURCES = [
'train-images-idx3-ubyte.gz',
'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz',
]
def report_download_progress(
chunk_number: int,
chunk_size: int,
file_size: int,
) -> None:
if file_size != -1:
percent = min(1, (chunk_number * chunk_size) / file_size)
bar = '#' * int(64 * percent)
sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))
def download(destination_path: str, resource: str, quiet: bool) -> None:
if os.path.exists(destination_path):
if not quiet:
print('{} already exists, skipping ...'.format(destination_path))
else:
for mirror in MIRRORS:
url = mirror + resource
print('Downloading {} ...'.format(url))
try:
hook = None if quiet else report_download_progress
urlretrieve(url, destination_path, reporthook=hook)
except (URLError, ConnectionError) as e:
print('Failed to download (trying next):\n{}'.format(e))
continue
finally:
if not quiet:
# Just a newline.
print()
break
else:
raise RuntimeError('Error downloading resource!')
def unzip(zipped_path: str, quiet: bool) -> None:
unzipped_path = os.path.splitext(zipped_path)[0]
if os.path.exists(unzipped_path):
if not quiet:
print('{} already exists, skipping ... '.format(unzipped_path))
return
with gzip.open(zipped_path, 'rb') as zipped_file:
with open(unzipped_path, 'wb') as unzipped_file:
unzipped_file.write(zipped_file.read())
if not quiet:
print('Unzipped {} ...'.format(zipped_path))
def main() -> None:
parser = argparse.ArgumentParser(
description='Download the MNIST dataset from the internet')
parser.add_argument(
'-d', '--destination', default='.', help='Destination directory')
parser.add_argument(
'-q',
'--quiet',
action='store_true',
help="Don't report about progress")
options = parser.parse_args()
if not os.path.exists(options.destination):
os.makedirs(options.destination)
try:
for resource in RESOURCES:
path = os.path.join(options.destination, resource)
download(path, resource, options.quiet)
unzip(path, options.quiet)
except KeyboardInterrupt:
print('Interrupted')
if __name__ == '__main__':
main()