pytorch/tools/download_mnist.py
Huy Do 347b036350 Apply ufmt linter to all py files under tools (#81285)
With ufmt in place https://github.com/pytorch/pytorch/pull/81157, we can now use it to gradually format all files. I'm breaking this down into multiple smaller batches to avoid too many merge conflicts later on.

This batch (as copied from the current BLACK linter config):
* `tools/**/*.py`

Upcoming batchs:
* `torchgen/**/*.py`
* `torch/package/**/*.py`
* `torch/onnx/**/*.py`
* `torch/_refs/**/*.py`
* `torch/_prims/**/*.py`
* `torch/_meta_registrations.py`
* `torch/_decomp/**/*.py`
* `test/onnx/**/*.py`

Once they are all formatted, BLACK linter will be removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81285
Approved by: https://github.com/suo
2022-07-13 07:59:22 +00:00

94 lines
2.8 KiB
Python

import argparse
import gzip
import os
import sys
from urllib.error import URLError
from urllib.request import urlretrieve
MIRRORS = [
"http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/",
]
RESOURCES = [
"train-images-idx3-ubyte.gz",
"train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz",
]
def report_download_progress(
chunk_number: int,
chunk_size: int,
file_size: int,
) -> None:
if file_size != -1:
percent = min(1, (chunk_number * chunk_size) / file_size)
bar = "#" * int(64 * percent)
sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100)))
def download(destination_path: str, resource: str, quiet: bool) -> None:
if os.path.exists(destination_path):
if not quiet:
print("{} already exists, skipping ...".format(destination_path))
else:
for mirror in MIRRORS:
url = mirror + resource
print("Downloading {} ...".format(url))
try:
hook = None if quiet else report_download_progress
urlretrieve(url, destination_path, reporthook=hook)
except (URLError, ConnectionError) as e:
print("Failed to download (trying next):\n{}".format(e))
continue
finally:
if not quiet:
# Just a newline.
print()
break
else:
raise RuntimeError("Error downloading resource!")
def unzip(zipped_path: str, quiet: bool) -> None:
unzipped_path = os.path.splitext(zipped_path)[0]
if os.path.exists(unzipped_path):
if not quiet:
print("{} already exists, skipping ... ".format(unzipped_path))
return
with gzip.open(zipped_path, "rb") as zipped_file:
with open(unzipped_path, "wb") as unzipped_file:
unzipped_file.write(zipped_file.read())
if not quiet:
print("Unzipped {} ...".format(zipped_path))
def main() -> None:
parser = argparse.ArgumentParser(
description="Download the MNIST dataset from the internet"
)
parser.add_argument(
"-d", "--destination", default=".", help="Destination directory"
)
parser.add_argument(
"-q", "--quiet", action="store_true", help="Don't report about progress"
)
options = parser.parse_args()
if not os.path.exists(options.destination):
os.makedirs(options.destination)
try:
for resource in RESOURCES:
path = os.path.join(options.destination, resource)
download(path, resource, options.quiet)
unzip(path, options.quiet)
except KeyboardInterrupt:
print("Interrupted")
if __name__ == "__main__":
main()