[ROCm] Use rocm manylinux builder image for triton wheels (#107600)

Update to ROCm triton pinned commit for the 2.1 branch cut off.

As part of this we are updating `build_triton_wheel.py` and `build-triton-wheel.yml` to support building ROCm triton wheels through pytorch/manylinux-rocm to avoid the need of slowly downloading rpm libraries for ROCm in the cpu manylinux builder image and avoiding the need to maintain a conditional file with hard coded repositories from radeon.org for every ROCm release.

This new approach will allow us to build wheels faster in a more easily maintainable way.

This PR also brings in a required change as Triton on ROCm requires device_type to be set to hip so we can pass down the correct device type to triton (https://github.com/ROCmSoftwarePlatform/triton/pull/284).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107600
Approved by: https://github.com/jansel, https://github.com/jithunnair-amd
This commit is contained in:
Jack Taylor 2023-08-25 10:25:29 +00:00 committed by PyTorch MergeBot
parent 39854df1d3
commit 196ef78b90
4 changed files with 12 additions and 3 deletions

View File

@ -1 +1 @@
34887ff8ca7a264c2c75972f5421a1ed3b7d8f6c
05d67b9418cacda0d356c2102d7c1a887948b013

View File

@ -138,6 +138,7 @@ def build_triton(
if build_rocm:
check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True)
print("ROCm libraries setup for triton installation...")
check_call(
[sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir, env=env
@ -147,7 +148,7 @@ def build_triton(
shutil.copy(whl_path, Path.cwd())
if build_rocm:
check_call(".github/scripts/fix_so.sh", cwd=triton_basedir, shell=True)
check_call("scripts/amd/fix_so.sh", cwd=triton_basedir, shell=True)
return Path.cwd() / whl_path.name

View File

@ -31,9 +31,14 @@ jobs:
matrix:
py_vers: [ "3.8", "3.9", "3.10", "3.11" ]
device: ["cuda", "rocm"]
include:
- device: "rocm"
rocm_version: "5.6"
- device: "cuda"
rocm_version: ""
timeout-minutes: 40
env:
DOCKER_IMAGE: pytorch/manylinux-builder:cpu
DOCKER_IMAGE: ${{ matrix.device == 'rocm' && format('pytorch/manylinux-rocm:{0}', matrix.rocm_version) || 'pytorch/manylinux-builder:cpu' }}
PY_VERS: ${{ matrix.py_vers }}
BUILD_DEVICE: ${{ matrix.device }}
steps:

View File

@ -182,6 +182,9 @@ class CachingAutotuner(KernelInterface):
config.triton.assert_indirect_indexing and torch.version.hip is None
)
# Setting device_type="hip" required on ROCm to pass down to triton
compile_meta["device_type"] = "cuda" if torch.version.hip is None else "hip"
if warm_cache_only_with_cc:
triton.compile(
self.fn,