mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
With ufmt in place https://github.com/pytorch/pytorch/pull/81157, we can now use it to gradually format all files. I'm breaking this down into multiple smaller batches to avoid too many merge conflicts later on. This batch (as copied from the current BLACK linter config): * `tools/**/*.py` Upcoming batchs: * `torchgen/**/*.py` * `torch/package/**/*.py` * `torch/onnx/**/*.py` * `torch/_refs/**/*.py` * `torch/_prims/**/*.py` * `torch/_meta_registrations.py` * `torch/_decomp/**/*.py` * `test/onnx/**/*.py` Once they are all formatted, BLACK linter will be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81285 Approved by: https://github.com/suo
94 lines
2.8 KiB
Python
94 lines
2.8 KiB
Python
import argparse
|
|
import gzip
|
|
import os
|
|
import sys
|
|
from urllib.error import URLError
|
|
from urllib.request import urlretrieve
|
|
|
|
MIRRORS = [
|
|
"http://yann.lecun.com/exdb/mnist/",
|
|
"https://ossci-datasets.s3.amazonaws.com/mnist/",
|
|
]
|
|
|
|
RESOURCES = [
|
|
"train-images-idx3-ubyte.gz",
|
|
"train-labels-idx1-ubyte.gz",
|
|
"t10k-images-idx3-ubyte.gz",
|
|
"t10k-labels-idx1-ubyte.gz",
|
|
]
|
|
|
|
|
|
def report_download_progress(
|
|
chunk_number: int,
|
|
chunk_size: int,
|
|
file_size: int,
|
|
) -> None:
|
|
if file_size != -1:
|
|
percent = min(1, (chunk_number * chunk_size) / file_size)
|
|
bar = "#" * int(64 * percent)
|
|
sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100)))
|
|
|
|
|
|
def download(destination_path: str, resource: str, quiet: bool) -> None:
|
|
if os.path.exists(destination_path):
|
|
if not quiet:
|
|
print("{} already exists, skipping ...".format(destination_path))
|
|
else:
|
|
for mirror in MIRRORS:
|
|
url = mirror + resource
|
|
print("Downloading {} ...".format(url))
|
|
try:
|
|
hook = None if quiet else report_download_progress
|
|
urlretrieve(url, destination_path, reporthook=hook)
|
|
except (URLError, ConnectionError) as e:
|
|
print("Failed to download (trying next):\n{}".format(e))
|
|
continue
|
|
finally:
|
|
if not quiet:
|
|
# Just a newline.
|
|
print()
|
|
break
|
|
else:
|
|
raise RuntimeError("Error downloading resource!")
|
|
|
|
|
|
def unzip(zipped_path: str, quiet: bool) -> None:
|
|
unzipped_path = os.path.splitext(zipped_path)[0]
|
|
if os.path.exists(unzipped_path):
|
|
if not quiet:
|
|
print("{} already exists, skipping ... ".format(unzipped_path))
|
|
return
|
|
with gzip.open(zipped_path, "rb") as zipped_file:
|
|
with open(unzipped_path, "wb") as unzipped_file:
|
|
unzipped_file.write(zipped_file.read())
|
|
if not quiet:
|
|
print("Unzipped {} ...".format(zipped_path))
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Download the MNIST dataset from the internet"
|
|
)
|
|
parser.add_argument(
|
|
"-d", "--destination", default=".", help="Destination directory"
|
|
)
|
|
parser.add_argument(
|
|
"-q", "--quiet", action="store_true", help="Don't report about progress"
|
|
)
|
|
options = parser.parse_args()
|
|
|
|
if not os.path.exists(options.destination):
|
|
os.makedirs(options.destination)
|
|
|
|
try:
|
|
for resource in RESOURCES:
|
|
path = os.path.join(options.destination, resource)
|
|
download(path, resource, options.quiet)
|
|
unzip(path, options.quiet)
|
|
except KeyboardInterrupt:
|
|
print("Interrupted")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|