[CI][docker] Use multistage build for triton (#149413)

Sees to reduce docker pull times by ~3 min if triton is requested, some compressed docker sizes seems to have decreased by 1/3 ish

Also add check that triton is installed/not installed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149413
Approved by: https://github.com/malfet
This commit is contained in:
Catherine Lee 2025-03-28 16:07:19 +00:00 committed by PyTorch MergeBot
parent 0ece461cca
commit d5a8bd0688
4 changed files with 48 additions and 18 deletions

View File

@ -515,7 +515,7 @@ docker build \
UBUNTU_VERSION=$(echo ${UBUNTU_VERSION} | sed 's/-rc$//')
function drun() {
docker run --rm "$tmp_tag" $*
docker run --rm "$tmp_tag" "$@"
}
if [[ "$OS" == "ubuntu" ]]; then
@ -563,3 +563,14 @@ if [ -n "$KATEX" ]; then
exit 1
fi
fi
HAS_TRITON=$(drun python -c "import triton" > /dev/null 2>&1 && echo "yes" || echo "no")
if [[ -n "$TRITON" || -n "$TRITON_CPU" ]]; then
if [ "$HAS_TRITON" = "no" ]; then
echo "expecting triton to be installed, but it is not"
exit 1
fi
elif [ "$HAS_TRITON" = "yes" ]; then
echo "expecting triton to not be installed, but it is"
exit 1
fi

View File

@ -2,6 +2,12 @@
set -ex
mkdir -p /opt/triton
if [ -z "${TRITON}" ] && [ -z "${TRITON_CPU}" ]; then
echo "TRITON and TRITON_CPU are not set. Exiting..."
exit 0
fi
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
get_conda_version() {
@ -52,6 +58,7 @@ cd triton
as_jenkins git checkout ${TRITON_PINNED_COMMIT}
as_jenkins git submodule update --init --recursive
cd python
pip_install pybind11==2.13.6
# TODO: remove patch setup.py once we have a proper fix for https://github.com/triton-lang/triton/issues/4527
as_jenkins sed -i -e 's/https:\/\/tritonlang.blob.core.windows.net\/llvm-builds/https:\/\/oaitriton.blob.core.windows.net\/public\/llvm-builds/g' setup.py
@ -60,17 +67,22 @@ if [ -n "${UBUNTU_VERSION}" ] && [ -n "${GCC_VERSION}" ] && [[ "${GCC_VERSION}"
# Triton needs at least gcc-9 to build
apt-get install -y g++-9
CXX=g++-9 pip_install .
CXX=g++-9 conda_run python setup.py bdist_wheel
elif [ -n "${UBUNTU_VERSION}" ] && [ -n "${CLANG_VERSION}" ]; then
# Triton needs <filesystem> which surprisingly is not available with clang-9 toolchain
add-apt-repository -y ppa:ubuntu-toolchain-r/test
apt-get install -y g++-9
CXX=g++-9 pip_install .
CXX=g++-9 conda_run python setup.py bdist_wheel
else
pip_install .
conda_run python setup.py bdist_wheel
fi
# Copy the wheel to /opt for multi stage docker builds
cp dist/*.whl /opt/triton
# Install the wheel for docker builds that don't use multi stage
pip_install dist/*.whl
if [ -n "${CONDA_CMAKE}" ]; then
# TODO: This is to make sure that the same cmake and numpy version from install conda
# script is used. Without this step, the newer cmake version (3.25.2) downloaded by

View File

@ -2,7 +2,7 @@ ARG UBUNTU_VERSION
ARG CUDA_VERSION
ARG IMAGE_NAME
FROM ${IMAGE_NAME}
FROM ${IMAGE_NAME} as base
ARG UBUNTU_VERSION
ARG CUDA_VERSION
@ -90,14 +90,20 @@ RUN if [ -n "${CMAKE_VERSION}" ]; then bash ./install_cmake.sh; fi
RUN rm install_cmake.sh
ARG TRITON
FROM base as triton-builder
# Install triton, this needs to be done before sccache because the latter will
# try to reach out to S3, which docker build runners don't have access
COPY ./common/install_triton.sh install_triton.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/triton.txt triton.txt
COPY triton_version.txt triton_version.txt
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt
RUN bash ./install_triton.sh
FROM base as final
COPY --from=triton-builder /opt/triton /opt/triton
RUN if [ -n "${TRITON}" ]; then pip install /opt/triton/*.whl; chown -R jenkins:jenkins /opt/conda; fi
RUN rm -rf /opt/triton
ARG HALIDE
# Build and install halide

View File

@ -1,6 +1,6 @@
ARG UBUNTU_VERSION
FROM ubuntu:${UBUNTU_VERSION}
FROM ubuntu:${UBUNTU_VERSION} as base
ARG UBUNTU_VERSION
@ -108,20 +108,21 @@ RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_d
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt
ARG TRITON
# Install triton, this needs to be done before sccache because the latter will
# try to reach out to S3, which docker build runners don't have access
ARG TRITON_CPU
# Create a separate stage for building Triton and Triton-CPU. install_triton
# will check for the presence of env vars
FROM base as triton-builder
COPY ./common/install_triton.sh install_triton.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/triton.txt triton.txt
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
RUN rm install_triton.sh common_utils.sh triton.txt
ARG TRITON_CPU
COPY ./common/install_triton.sh install_triton.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/triton-cpu.txt triton-cpu.txt
RUN if [ -n "${TRITON_CPU}" ]; then bash ./install_triton.sh; fi
RUN rm install_triton.sh common_utils.sh triton-cpu.txt
RUN bash ./install_triton.sh
FROM base as final
COPY --from=triton-builder /opt/triton /opt/triton
RUN if [ -n "${TRITON}" ] || [ -n "${TRITON_CPU}" ]; then pip install /opt/triton/*.whl; chown -R jenkins:jenkins /opt/conda; fi
RUN rm -rf /opt/triton
ARG EXECUTORCH
# Build and install executorch