mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] Use rocm manylinux builder image for triton wheels (#107600)
Update to ROCm triton pinned commit for the 2.1 branch cut off. As part of this we are updating `build_triton_wheel.py` and `build-triton-wheel.yml` to support building ROCm triton wheels through pytorch/manylinux-rocm to avoid the need of slowly downloading rpm libraries for ROCm in the cpu manylinux builder image and avoiding the need to maintain a conditional file with hard coded repositories from radeon.org for every ROCm release. This new approach will allow us to build wheels faster in a more easily maintainable way. This PR also brings in a required change as Triton on ROCm requires device_type to be set to hip so we can pass down the correct device type to triton (https://github.com/ROCmSoftwarePlatform/triton/pull/284). Pull Request resolved: https://github.com/pytorch/pytorch/pull/107600 Approved by: https://github.com/jansel, https://github.com/jithunnair-amd
This commit is contained in:
parent
39854df1d3
commit
196ef78b90
|
|
@ -1 +1 @@
|
|||
34887ff8ca7a264c2c75972f5421a1ed3b7d8f6c
|
||||
05d67b9418cacda0d356c2102d7c1a887948b013
|
||||
|
|
|
|||
3
.github/scripts/build_triton_wheel.py
vendored
3
.github/scripts/build_triton_wheel.py
vendored
|
|
@ -138,6 +138,7 @@ def build_triton(
|
|||
|
||||
if build_rocm:
|
||||
check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True)
|
||||
print("ROCm libraries setup for triton installation...")
|
||||
|
||||
check_call(
|
||||
[sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir, env=env
|
||||
|
|
@ -147,7 +148,7 @@ def build_triton(
|
|||
shutil.copy(whl_path, Path.cwd())
|
||||
|
||||
if build_rocm:
|
||||
check_call(".github/scripts/fix_so.sh", cwd=triton_basedir, shell=True)
|
||||
check_call("scripts/amd/fix_so.sh", cwd=triton_basedir, shell=True)
|
||||
|
||||
return Path.cwd() / whl_path.name
|
||||
|
||||
|
|
|
|||
7
.github/workflows/build-triton-wheel.yml
vendored
7
.github/workflows/build-triton-wheel.yml
vendored
|
|
@ -31,9 +31,14 @@ jobs:
|
|||
matrix:
|
||||
py_vers: [ "3.8", "3.9", "3.10", "3.11" ]
|
||||
device: ["cuda", "rocm"]
|
||||
include:
|
||||
- device: "rocm"
|
||||
rocm_version: "5.6"
|
||||
- device: "cuda"
|
||||
rocm_version: ""
|
||||
timeout-minutes: 40
|
||||
env:
|
||||
DOCKER_IMAGE: pytorch/manylinux-builder:cpu
|
||||
DOCKER_IMAGE: ${{ matrix.device == 'rocm' && format('pytorch/manylinux-rocm:{0}', matrix.rocm_version) || 'pytorch/manylinux-builder:cpu' }}
|
||||
PY_VERS: ${{ matrix.py_vers }}
|
||||
BUILD_DEVICE: ${{ matrix.device }}
|
||||
steps:
|
||||
|
|
|
|||
|
|
@ -182,6 +182,9 @@ class CachingAutotuner(KernelInterface):
|
|||
config.triton.assert_indirect_indexing and torch.version.hip is None
|
||||
)
|
||||
|
||||
# Setting device_type="hip" required on ROCm to pass down to triton
|
||||
compile_meta["device_type"] = "cuda" if torch.version.hip is None else "hip"
|
||||
|
||||
if warm_cache_only_with_cc:
|
||||
triton.compile(
|
||||
self.fn,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user