[Docker] Install Trition deps (#90841)

Triton needs a working gcc, so install one from apt
Also, copy `ptxas` and `cuda.h` from conda to `/usr/local/cuda`
Add `torchaudio` to the matrix
Fix typo in workflow file

Fixes https://github.com/pytorch/pytorch/issues/90377

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90841
Approved by: https://github.com/ngimel
This commit is contained in:
Nikita Shulga 2022-12-16 06:35:40 +00:00 committed by PyTorch MergeBot
parent 7dd5e55497
commit c6cba1865f
2 changed files with 16 additions and 7 deletions

View File

@ -75,7 +75,7 @@ jobs:
- name: Setup job specific variables
run: |
set -eou pipefail
# To get QEMU binaries in our PATh
# To get QEMU binaries in our PATH
echo "${RUNNER_TEMP}/bin" >> "${GITHUB_PATH}"
# Generate PyTorch version to use
echo "PYTORCH_VERSION=$(python3 .github/scripts/generate_pytorch_version.py)" >> "${GITHUB_ENV}"

View File

@ -66,26 +66,35 @@ ARG INSTALL_CHANNEL=pytorch-nightly
RUN /opt/conda/bin/conda update -y conda
RUN /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -y python=${PYTHON_VERSION}
ARG TARGETPLATFORM
ARG TRITON_VERSION
# On arm64 we can only install wheel packages
RUN case ${TARGETPLATFORM} in \
"linux/arm64") pip install --extra-index-url https://download.pytorch.org/whl/cpu/ torch torchvision torchtext ;; \
*) /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch torchvision torchtext "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
"linux/arm64") pip install --extra-index-url https://download.pytorch.org/whl/cpu/ torch torchvision torchaudio torchtext ;; \
*) /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch torchvision torchaudio torchtext "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
esac && \
/opt/conda/bin/conda clean -ya
RUN /opt/conda/bin/pip install torchelastic
RUN if test -n "${TRITON_VERSION}" -a "${TARGETPLATFORM}" != "linux/arm64"; then /opt/conda/bin/pip install "torchtriton==${TRITON_VERSION}" --extra-index-url https://download.pytorch.org/whl/nightly/cpu ; fi
FROM ${BASE_IMAGE} as official
ARG PYTORCH_VERSION
ARG TRITON_VERSION
ARG TARGETPLATFORM
ARG CUDA_VERSION
LABEL com.nvidia.volumes.needed="nvidia_driver"
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
libjpeg-dev \
libpng-dev && \
rm -rf /var/lib/apt/lists/*
libpng-dev
COPY --from=conda-installs /opt/conda /opt/conda
RUN if test -n "${TRITON_VERSION}" -a "${TARGETPLATFORM}" != "linux/arm64"; then \
apt install -y --no-install-recommends gcc; \
CU_VER=$(echo $CUDA_VERSION | cut -d'.' -f 1-2) && \
mkdir -p /usr/local/triton-min-cuda-${CU_VER} && \
ln -s /usr/local/triton-min-cuda-${CU_VER} /usr/local/cuda; \
mkdir -p /usr/local/cuda/bin; cp /opt/conda/bin/ptxas /usr/local/cuda/bin; \
mkdir -p /usr/local/cuda/include; cp /opt/conda/include/cuda.h /usr/local/cuda/include; \
fi
RUN rm -rf /var/lib/apt/lists/*
ENV PATH /opt/conda/bin:$PATH
ENV NVIDIA_VISIBLE_DEVICES all
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility