mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support CUDA nightly package in tools/nightly.py (#131133)
Add a new option `--cuda` to `tools/nightly.py` to pull the nightly packages with CUDA support. ```bash # installs pytorch-nightly with cpuonly tools/nightly.py pull # The following only available on Linux and Windows # installs pytorch-nightly with latest CUDA we support tools/nightly.py pull --cuda # installs pytorch-nightly with CUDA 12.1 tools/nightly.py pull --cuda 12.1 ``` Also add targets in `Makefile` and instructions in constribution guidelines. ```bash # setup conda environment with pytorch-nightly make setup-env # setup conda environment with pytorch-nightly with CUDA support make setup-env-cuda ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/131133 Approved by: https://github.com/ezyang
This commit is contained in:
parent
ceab3121de
commit
42a4df9447
|
|
@ -5,7 +5,7 @@ git submodule sync
|
||||||
git submodule update --init --recursive
|
git submodule update --init --recursive
|
||||||
|
|
||||||
# This takes some time
|
# This takes some time
|
||||||
make setup_lint
|
make setup-lint
|
||||||
|
|
||||||
# Add CMAKE_PREFIX_PATH to bashrc
|
# Add CMAKE_PREFIX_PATH to bashrc
|
||||||
echo 'export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}' >> ~/.bashrc
|
echo 'export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}' >> ~/.bashrc
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ aspects of contributing to PyTorch.
|
||||||
<!-- toc -->
|
<!-- toc -->
|
||||||
|
|
||||||
- [Developing PyTorch](#developing-pytorch)
|
- [Developing PyTorch](#developing-pytorch)
|
||||||
|
- [Setup the development environment](#setup-the-development-environment)
|
||||||
- [Tips and Debugging](#tips-and-debugging)
|
- [Tips and Debugging](#tips-and-debugging)
|
||||||
- [Nightly Checkout & Pull](#nightly-checkout--pull)
|
- [Nightly Checkout & Pull](#nightly-checkout--pull)
|
||||||
- [Codebase structure](#codebase-structure)
|
- [Codebase structure](#codebase-structure)
|
||||||
|
|
@ -64,8 +65,24 @@ aspects of contributing to PyTorch.
|
||||||
<!-- tocstop -->
|
<!-- tocstop -->
|
||||||
|
|
||||||
## Developing PyTorch
|
## Developing PyTorch
|
||||||
|
|
||||||
Follow the instructions for [installing PyTorch from source](https://github.com/pytorch/pytorch#from-source). If you get stuck when developing PyTorch on your machine, check out the [tips and debugging](#tips-and-debugging) section below for common solutions.
|
Follow the instructions for [installing PyTorch from source](https://github.com/pytorch/pytorch#from-source). If you get stuck when developing PyTorch on your machine, check out the [tips and debugging](#tips-and-debugging) section below for common solutions.
|
||||||
|
|
||||||
|
### Setup the development environment
|
||||||
|
|
||||||
|
First, you need to [fork the PyTorch project on GitHub](https://github.com/pytorch/pytorch/fork) and follow the instructions at [Connecting to GitHub with SSH](https://docs.github.com/en/authentication/connecting-to-github-with-ssh) to setup your SSH authentication credentials.
|
||||||
|
|
||||||
|
Then clone the PyTorch project and setup the development environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:<USERNAME>/pytorch.git
|
||||||
|
cd pytorch
|
||||||
|
git remote add origin git@github.com:pytorch/pytorch.git
|
||||||
|
|
||||||
|
make setup-env # or make setup-env-cuda for pre-built CUDA binaries
|
||||||
|
conda activate pytorch-deps
|
||||||
|
```
|
||||||
|
|
||||||
### Tips and Debugging
|
### Tips and Debugging
|
||||||
|
|
||||||
* If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below.
|
* If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below.
|
||||||
|
|
@ -175,6 +192,13 @@ the regular environment parameters (`--name` or `--prefix`):
|
||||||
conda activate my-env
|
conda activate my-env
|
||||||
```
|
```
|
||||||
|
|
||||||
|
To install the nightly binaries built with CUDA, you can pass in the flag `--cuda`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./tools/nightly.py checkout -b my-nightly-branch --cuda
|
||||||
|
conda activate pytorch-deps
|
||||||
|
```
|
||||||
|
|
||||||
You can also use this tool to pull the nightly commits into the current branch:
|
You can also use this tool to pull the nightly commits into the current branch:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
@ -325,7 +349,7 @@ command runs tests such as `TestNN.test_BCELoss` and
|
||||||
Install all prerequisites by running
|
Install all prerequisites by running
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
make setup_lint
|
make setup-lint
|
||||||
```
|
```
|
||||||
|
|
||||||
You can now run the same linting steps that are used in CI locally via `make`:
|
You can now run the same linting steps that are used in CI locally via `make`:
|
||||||
|
|
|
||||||
22
Makefile
22
Makefile
|
|
@ -1,6 +1,7 @@
|
||||||
# This makefile does nothing but delegating the actual building to cmake.
|
# This makefile does nothing but delegating the actual building to cmake.
|
||||||
PYTHON = python3
|
PYTHON = python3
|
||||||
PIP = pip3
|
PIP = $(PYTHON) -m pip
|
||||||
|
NIGHTLY_TOOL_OPTS := pull
|
||||||
|
|
||||||
all:
|
all:
|
||||||
@mkdir -p build && cd build && cmake .. $(shell $(PYTHON) ./scripts/get_python_cmake_flags.py) && $(MAKE)
|
@mkdir -p build && cd build && cmake .. $(shell $(PYTHON) ./scripts/get_python_cmake_flags.py) && $(MAKE)
|
||||||
|
|
@ -22,10 +23,27 @@ linecount:
|
||||||
echo "Cloc is not available on the machine. You can install cloc with " && \
|
echo "Cloc is not available on the machine. You can install cloc with " && \
|
||||||
echo " sudo apt-get install cloc"
|
echo " sudo apt-get install cloc"
|
||||||
|
|
||||||
setup_lint:
|
ensure-branch-clean:
|
||||||
|
@if [ -n "$(shell git status --porcelain)" ]; then \
|
||||||
|
echo "Please commit or stash all changes before running this script"; \
|
||||||
|
exit 1; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
setup-env: ensure-branch-clean
|
||||||
|
$(PYTHON) tools/nightly.py $(NIGHTLY_TOOL_OPTS)
|
||||||
|
|
||||||
|
setup-env-cuda:
|
||||||
|
$(MAKE) setup-env PYTHON="$(PYTHON)" NIGHTLY_TOOL_OPTS="$(NIGHTLY_TOOL_OPTS) --cuda"
|
||||||
|
|
||||||
|
setup_env: setup-env
|
||||||
|
setup_env_cuda: setup-env-cuda
|
||||||
|
|
||||||
|
setup-lint:
|
||||||
$(PIP) install lintrunner
|
$(PIP) install lintrunner
|
||||||
lintrunner init
|
lintrunner init
|
||||||
|
|
||||||
|
setup_lint: setup-lint
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
lintrunner
|
lintrunner
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ if __name__ == "__main__":
|
||||||
name="command-failed",
|
name="command-failed",
|
||||||
original=None,
|
original=None,
|
||||||
replacement=None,
|
replacement=None,
|
||||||
description="Lintrunner is not installed, did you forget to run `make setup_lint && make lint`?",
|
description="Lintrunner is not installed, did you forget to run `make setup-lint && make lint`?",
|
||||||
)
|
)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
|
||||||
120
tools/nightly.py
120
tools/nightly.py
|
|
@ -15,22 +15,29 @@ the regular environment parameters (--name or --prefix)::
|
||||||
$ ./tools/nightly.py checkout -b my-nightly-branch -n my-env
|
$ ./tools/nightly.py checkout -b my-nightly-branch -n my-env
|
||||||
$ conda activate my-env
|
$ conda activate my-env
|
||||||
|
|
||||||
|
To install the nightly binaries built with CUDA, you can pass in the flag --cuda::
|
||||||
|
|
||||||
|
$ ./tools/nightly.py checkout -b my-nightly-branch --cuda
|
||||||
|
$ conda activate pytorch-deps
|
||||||
|
|
||||||
You can also use this tool to pull the nightly commits into the current branch as
|
You can also use this tool to pull the nightly commits into the current branch as
|
||||||
well. This can be done with
|
well. This can be done with::
|
||||||
|
|
||||||
$ ./tools/nightly.py pull -n my-env
|
$ ./tools/nightly.py pull -n my-env
|
||||||
$ conda activate my-env
|
$ conda activate my-env
|
||||||
|
|
||||||
Pulling will reinstalle the conda dependencies as well as the nightly binaries into
|
Pulling will reinstall the conda dependencies as well as the nightly binaries into
|
||||||
the repo directory.
|
the repo directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
import contextlib
|
import contextlib
|
||||||
import datetime
|
import datetime
|
||||||
import functools
|
import functools
|
||||||
import glob
|
import glob
|
||||||
|
import itertools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
@ -41,8 +48,8 @@ import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from argparse import ArgumentParser
|
|
||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
|
from platform import system as platform_system
|
||||||
from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar
|
from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -52,7 +59,7 @@ DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
|
||||||
SHA1_RE = re.compile("([0-9a-fA-F]{40})")
|
SHA1_RE = re.compile("([0-9a-fA-F]{40})")
|
||||||
USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@")
|
USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@")
|
||||||
LOG_DIRNAME_RE = re.compile(
|
LOG_DIRNAME_RE = re.compile(
|
||||||
r"(\d{4}-\d\d-\d\d_\d\dh\d\dm\d\ds)_" r"[0-9a-f]{8}-(?:[0-9a-f]{4}-){3}[0-9a-f]{12}"
|
r"(\d{4}-\d\d-\d\d_\d\dh\d\dm\d\ds)_[0-9a-f]{8}-(?:[0-9a-f]{4}-){3}[0-9a-f]{12}"
|
||||||
)
|
)
|
||||||
SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx")
|
SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx")
|
||||||
|
|
||||||
|
|
@ -261,6 +268,8 @@ def _make_channel_args(
|
||||||
|
|
||||||
@timed("Solving conda environment")
|
@timed("Solving conda environment")
|
||||||
def conda_solve(
|
def conda_solve(
|
||||||
|
specs: Iterable[str],
|
||||||
|
*,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
prefix: str | None = None,
|
prefix: str | None = None,
|
||||||
channels: Iterable[str] = ("pytorch-nightly",),
|
channels: Iterable[str] = ("pytorch-nightly",),
|
||||||
|
|
@ -302,12 +311,13 @@ def conda_solve(
|
||||||
channels=channels, override_channels=override_channels
|
channels=channels, override_channels=override_channels
|
||||||
)
|
)
|
||||||
cmd.extend(channel_args)
|
cmd.extend(channel_args)
|
||||||
cmd.extend(SPECS_TO_INSTALL)
|
cmd.extend(specs)
|
||||||
p = subprocess.run(cmd, capture_output=True, check=True)
|
p = subprocess.run(cmd, capture_output=True, check=True)
|
||||||
# parse solution
|
# parse solution
|
||||||
solve = json.loads(p.stdout)
|
solve = json.loads(p.stdout)
|
||||||
link = solve["actions"]["LINK"]
|
link = solve["actions"]["LINK"]
|
||||||
deps = []
|
deps = []
|
||||||
|
pytorch, platform = "", ""
|
||||||
for pkg in link:
|
for pkg in link:
|
||||||
url = URL_FORMAT.format(**pkg)
|
url = URL_FORMAT.format(**pkg)
|
||||||
if pkg["name"] == "pytorch":
|
if pkg["name"] == "pytorch":
|
||||||
|
|
@ -315,6 +325,8 @@ def conda_solve(
|
||||||
platform = pkg["platform"]
|
platform = pkg["platform"]
|
||||||
else:
|
else:
|
||||||
deps.append(url)
|
deps.append(url)
|
||||||
|
assert pytorch, "PyTorch package not found in solve"
|
||||||
|
assert platform, "Platform not found in solve"
|
||||||
return deps, pytorch, platform, existing_env, env_opts
|
return deps, pytorch, platform, existing_env, env_opts
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -412,23 +424,33 @@ def pull_nightly_version(spdir: str) -> None:
|
||||||
|
|
||||||
|
|
||||||
def _get_listing_linux(source_dir: str) -> list[str]:
|
def _get_listing_linux(source_dir: str) -> list[str]:
|
||||||
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
return list(
|
||||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so")))
|
itertools.chain(
|
||||||
return listing
|
glob.iglob(os.path.join(source_dir, "*.so")),
|
||||||
|
glob.iglob(os.path.join(source_dir, "lib", "*.so")),
|
||||||
|
glob.iglob(os.path.join(source_dir, "lib", "*.so.*")),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_listing_osx(source_dir: str) -> list[str]:
|
def _get_listing_osx(source_dir: str) -> list[str]:
|
||||||
# oddly, these are .so files even on Mac
|
# oddly, these are .so files even on Mac
|
||||||
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
return list(
|
||||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib")))
|
itertools.chain(
|
||||||
return listing
|
glob.iglob(os.path.join(source_dir, "*.so")),
|
||||||
|
glob.iglob(os.path.join(source_dir, "lib", "*.dylib")),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_listing_win(source_dir: str) -> list[str]:
|
def _get_listing_win(source_dir: str) -> list[str]:
|
||||||
listing = glob.glob(os.path.join(source_dir, "*.pyd"))
|
return list(
|
||||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib")))
|
itertools.chain(
|
||||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll")))
|
glob.iglob(os.path.join(source_dir, "*.pyd")),
|
||||||
return listing
|
glob.iglob(os.path.join(source_dir, "lib", "*.lib")),
|
||||||
|
glob.iglob(os.path.join(source_dir, "lib", "*.dll")),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _glob_pyis(d: str) -> set[str]:
|
def _glob_pyis(d: str) -> set[str]:
|
||||||
|
|
@ -480,6 +502,8 @@ def _move_single(
|
||||||
is_dir = os.path.isdir(src)
|
is_dir = os.path.isdir(src)
|
||||||
relpath = os.path.relpath(src, source_dir)
|
relpath = os.path.relpath(src, source_dir)
|
||||||
trg = os.path.join(target_dir, relpath)
|
trg = os.path.join(target_dir, relpath)
|
||||||
|
src = os.path.normpath(src)
|
||||||
|
trg = os.path.normpath(trg)
|
||||||
_remove_existing(trg, is_dir)
|
_remove_existing(trg, is_dir)
|
||||||
# move over new files
|
# move over new files
|
||||||
if is_dir:
|
if is_dir:
|
||||||
|
|
@ -488,8 +512,8 @@ def _move_single(
|
||||||
relroot = os.path.relpath(root, src)
|
relroot = os.path.relpath(root, src)
|
||||||
for name in files:
|
for name in files:
|
||||||
relname = os.path.join(relroot, name)
|
relname = os.path.join(relroot, name)
|
||||||
s = os.path.join(src, relname)
|
s = os.path.normpath(os.path.join(src, relname))
|
||||||
t = os.path.join(trg, relname)
|
t = os.path.normpath(os.path.join(trg, relname))
|
||||||
print(f"{verb} {s} -> {t}")
|
print(f"{verb} {s} -> {t}")
|
||||||
mover(s, t)
|
mover(s, t)
|
||||||
for name in dirs:
|
for name in dirs:
|
||||||
|
|
@ -515,7 +539,9 @@ def move_nightly_files(spdir: str, platform: str) -> None:
|
||||||
"""Moves PyTorch files from temporary installed location to repo."""
|
"""Moves PyTorch files from temporary installed location to repo."""
|
||||||
# get file listing
|
# get file listing
|
||||||
source_dir = os.path.join(spdir, "torch")
|
source_dir = os.path.join(spdir, "torch")
|
||||||
target_dir = os.path.abspath("torch")
|
target_dir = os.path.abspath(
|
||||||
|
os.path.join(os.path.dirname(os.path.dirname(__file__)), "torch")
|
||||||
|
)
|
||||||
listing = _get_listing(source_dir, target_dir, platform)
|
listing = _get_listing(source_dir, target_dir, platform)
|
||||||
# copy / link files
|
# copy / link files
|
||||||
if platform.startswith("win"):
|
if platform.startswith("win"):
|
||||||
|
|
@ -569,6 +595,7 @@ def write_pth(env_opts: list[str], platform: str) -> None:
|
||||||
|
|
||||||
|
|
||||||
def install(
|
def install(
|
||||||
|
specs: Iterable[str],
|
||||||
*,
|
*,
|
||||||
logger: logging.Logger,
|
logger: logging.Logger,
|
||||||
subcommand: str = "checkout",
|
subcommand: str = "checkout",
|
||||||
|
|
@ -579,8 +606,13 @@ def install(
|
||||||
override_channels: bool = False,
|
override_channels: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Development install of PyTorch"""
|
"""Development install of PyTorch"""
|
||||||
|
specs = list(specs)
|
||||||
deps, pytorch, platform, existing_env, env_opts = conda_solve(
|
deps, pytorch, platform, existing_env, env_opts = conda_solve(
|
||||||
name=name, prefix=prefix, channels=channels, override_channels=override_channels
|
specs=specs,
|
||||||
|
name=name,
|
||||||
|
prefix=prefix,
|
||||||
|
channels=channels,
|
||||||
|
override_channels=override_channels,
|
||||||
)
|
)
|
||||||
if deps:
|
if deps:
|
||||||
deps_install(deps, existing_env, env_opts)
|
deps_install(deps, existing_env, env_opts)
|
||||||
|
|
@ -602,12 +634,12 @@ def install(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_parser() -> ArgumentParser:
|
def make_parser() -> argparse.ArgumentParser:
|
||||||
p = ArgumentParser("nightly")
|
p = argparse.ArgumentParser()
|
||||||
# subcommands
|
# subcommands
|
||||||
subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
|
subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
|
||||||
co = subcmd.add_parser("checkout", help="checkout a new branch")
|
checkout = subcmd.add_parser("checkout", help="checkout a new branch")
|
||||||
co.add_argument(
|
checkout.add_argument(
|
||||||
"-b",
|
"-b",
|
||||||
"--branch",
|
"--branch",
|
||||||
help="Branch name to checkout",
|
help="Branch name to checkout",
|
||||||
|
|
@ -619,9 +651,9 @@ def make_parser() -> ArgumentParser:
|
||||||
"pull", help="pulls the nightly commits into the current branch"
|
"pull", help="pulls the nightly commits into the current branch"
|
||||||
)
|
)
|
||||||
# general arguments
|
# general arguments
|
||||||
subps = [co, pull]
|
subparsers = [checkout, pull]
|
||||||
for subp in subps:
|
for subparser in subparsers:
|
||||||
subp.add_argument(
|
subparser.add_argument(
|
||||||
"-n",
|
"-n",
|
||||||
"--name",
|
"--name",
|
||||||
help="Name of environment",
|
help="Name of environment",
|
||||||
|
|
@ -629,7 +661,7 @@ def make_parser() -> ArgumentParser:
|
||||||
default=None,
|
default=None,
|
||||||
metavar="ENVIRONMENT",
|
metavar="ENVIRONMENT",
|
||||||
)
|
)
|
||||||
subp.add_argument(
|
subparser.add_argument(
|
||||||
"-p",
|
"-p",
|
||||||
"--prefix",
|
"--prefix",
|
||||||
help="Full path to environment location (i.e. prefix)",
|
help="Full path to environment location (i.e. prefix)",
|
||||||
|
|
@ -637,7 +669,7 @@ def make_parser() -> ArgumentParser:
|
||||||
default=None,
|
default=None,
|
||||||
metavar="PATH",
|
metavar="PATH",
|
||||||
)
|
)
|
||||||
subp.add_argument(
|
subparser.add_argument(
|
||||||
"-v",
|
"-v",
|
||||||
"--verbose",
|
"--verbose",
|
||||||
help="Provide debugging info",
|
help="Provide debugging info",
|
||||||
|
|
@ -645,21 +677,36 @@ def make_parser() -> ArgumentParser:
|
||||||
default=False,
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
)
|
)
|
||||||
subp.add_argument(
|
subparser.add_argument(
|
||||||
"--override-channels",
|
"--override-channels",
|
||||||
help="Do not search default or .condarc channels.",
|
help="Do not search default or .condarc channels.",
|
||||||
dest="override_channels",
|
dest="override_channels",
|
||||||
default=False,
|
default=False,
|
||||||
action="store_true",
|
action="store_true",
|
||||||
)
|
)
|
||||||
subp.add_argument(
|
subparser.add_argument(
|
||||||
"-c",
|
"-c",
|
||||||
"--channel",
|
"--channel",
|
||||||
help="Additional channel to search for packages. 'pytorch-nightly' will always be prepended to this list.",
|
help=(
|
||||||
|
"Additional channel to search for packages. "
|
||||||
|
"'pytorch-nightly' will always be prepended to this list."
|
||||||
|
),
|
||||||
dest="channels",
|
dest="channels",
|
||||||
action="append",
|
action="append",
|
||||||
metavar="CHANNEL",
|
metavar="CHANNEL",
|
||||||
)
|
)
|
||||||
|
if platform_system() in {"Linux", "Windows"}:
|
||||||
|
subparser.add_argument(
|
||||||
|
"--cuda",
|
||||||
|
help=(
|
||||||
|
"CUDA version to install "
|
||||||
|
"(defaults to the latest version available on the platform)"
|
||||||
|
),
|
||||||
|
dest="cuda",
|
||||||
|
nargs="?",
|
||||||
|
default=argparse.SUPPRESS,
|
||||||
|
metavar="VERSION",
|
||||||
|
)
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -673,12 +720,23 @@ def main(args: Sequence[str] | None = None) -> None:
|
||||||
status = status or check_branch(ns.subcmd, ns.branch)
|
status = status or check_branch(ns.subcmd, ns.branch)
|
||||||
if status:
|
if status:
|
||||||
sys.exit(status)
|
sys.exit(status)
|
||||||
|
specs = list(SPECS_TO_INSTALL)
|
||||||
channels = ["pytorch-nightly"]
|
channels = ["pytorch-nightly"]
|
||||||
|
if hasattr(ns, "cuda"):
|
||||||
|
if ns.cuda is not None:
|
||||||
|
specs.append(f"pytorch-cuda={ns.cuda}")
|
||||||
|
else:
|
||||||
|
specs.append("pytorch-cuda")
|
||||||
|
specs.append("pytorch-mutex=*=*cuda*")
|
||||||
|
channels.append("nvidia")
|
||||||
|
else:
|
||||||
|
specs.append("pytorch-mutex=*=*cpu*")
|
||||||
if ns.channels:
|
if ns.channels:
|
||||||
channels.extend(ns.channels)
|
channels.extend(ns.channels)
|
||||||
with logging_manager(debug=ns.verbose) as logger:
|
with logging_manager(debug=ns.verbose) as logger:
|
||||||
LOGGER = logger
|
LOGGER = logger
|
||||||
install(
|
install(
|
||||||
|
specs=specs,
|
||||||
subcommand=ns.subcmd,
|
subcommand=ns.subcmd,
|
||||||
branch=ns.branch,
|
branch=ns.branch,
|
||||||
name=ns.name,
|
name=ns.name,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user