mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Docker] Install Trition deps (#90841)
Triton needs a working gcc, so install one from apt Also, copy `ptxas` and `cuda.h` from conda to `/usr/local/cuda` Add `torchaudio` to the matrix Fix typo in workflow file Fixes https://github.com/pytorch/pytorch/issues/90377 Pull Request resolved: https://github.com/pytorch/pytorch/pull/90841 Approved by: https://github.com/ngimel
This commit is contained in:
parent
7dd5e55497
commit
c6cba1865f
2
.github/workflows/docker-release.yml
vendored
2
.github/workflows/docker-release.yml
vendored
|
|
@ -75,7 +75,7 @@ jobs:
|
|||
- name: Setup job specific variables
|
||||
run: |
|
||||
set -eou pipefail
|
||||
# To get QEMU binaries in our PATh
|
||||
# To get QEMU binaries in our PATH
|
||||
echo "${RUNNER_TEMP}/bin" >> "${GITHUB_PATH}"
|
||||
# Generate PyTorch version to use
|
||||
echo "PYTORCH_VERSION=$(python3 .github/scripts/generate_pytorch_version.py)" >> "${GITHUB_ENV}"
|
||||
|
|
|
|||
21
Dockerfile
21
Dockerfile
|
|
@ -66,26 +66,35 @@ ARG INSTALL_CHANNEL=pytorch-nightly
|
|||
RUN /opt/conda/bin/conda update -y conda
|
||||
RUN /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -y python=${PYTHON_VERSION}
|
||||
ARG TARGETPLATFORM
|
||||
ARG TRITON_VERSION
|
||||
|
||||
# On arm64 we can only install wheel packages
|
||||
RUN case ${TARGETPLATFORM} in \
|
||||
"linux/arm64") pip install --extra-index-url https://download.pytorch.org/whl/cpu/ torch torchvision torchtext ;; \
|
||||
*) /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch torchvision torchtext "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
|
||||
"linux/arm64") pip install --extra-index-url https://download.pytorch.org/whl/cpu/ torch torchvision torchaudio torchtext ;; \
|
||||
*) /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch torchvision torchaudio torchtext "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
|
||||
esac && \
|
||||
/opt/conda/bin/conda clean -ya
|
||||
RUN /opt/conda/bin/pip install torchelastic
|
||||
RUN if test -n "${TRITON_VERSION}" -a "${TARGETPLATFORM}" != "linux/arm64"; then /opt/conda/bin/pip install "torchtriton==${TRITON_VERSION}" --extra-index-url https://download.pytorch.org/whl/nightly/cpu ; fi
|
||||
|
||||
FROM ${BASE_IMAGE} as official
|
||||
ARG PYTORCH_VERSION
|
||||
ARG TRITON_VERSION
|
||||
ARG TARGETPLATFORM
|
||||
ARG CUDA_VERSION
|
||||
LABEL com.nvidia.volumes.needed="nvidia_driver"
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
libjpeg-dev \
|
||||
libpng-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
libpng-dev
|
||||
COPY --from=conda-installs /opt/conda /opt/conda
|
||||
RUN if test -n "${TRITON_VERSION}" -a "${TARGETPLATFORM}" != "linux/arm64"; then \
|
||||
apt install -y --no-install-recommends gcc; \
|
||||
CU_VER=$(echo $CUDA_VERSION | cut -d'.' -f 1-2) && \
|
||||
mkdir -p /usr/local/triton-min-cuda-${CU_VER} && \
|
||||
ln -s /usr/local/triton-min-cuda-${CU_VER} /usr/local/cuda; \
|
||||
mkdir -p /usr/local/cuda/bin; cp /opt/conda/bin/ptxas /usr/local/cuda/bin; \
|
||||
mkdir -p /usr/local/cuda/include; cp /opt/conda/include/cuda.h /usr/local/cuda/include; \
|
||||
fi
|
||||
RUN rm -rf /var/lib/apt/lists/*
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
ENV NVIDIA_VISIBLE_DEVICES all
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user