mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Update triton pin (#102736)
There is some bug in triton's handling of `tl.reduce` that breaks the variance PR, but is fixed on the latest triton master. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102736 Approved by: https://github.com/huydhn, https://github.com/desertfire
This commit is contained in:
parent
455f542ed9
commit
31ee1512d3
|
|
@ -1 +1 @@
|
||||||
9820899b3845e461d9031dba66062efade65d420
|
440fd1bf20697b0961ddb0822de86d151c58dd36
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,10 @@ if [ -n "${CONDA_CMAKE}" ]; then
|
||||||
NUMPY_VERSION=$(get_conda_version numpy)
|
NUMPY_VERSION=$(get_conda_version numpy)
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ -z "${MAX_JOBS}" ]; then
|
||||||
|
export MAX_JOBS=$(nproc)
|
||||||
|
fi
|
||||||
|
|
||||||
if [ -n "${GCC_VERSION}" ] && [[ "${GCC_VERSION}" == "7" ]]; then
|
if [ -n "${GCC_VERSION}" ] && [[ "${GCC_VERSION}" == "7" ]]; then
|
||||||
# Triton needs at least gcc-9 to build
|
# Triton needs at least gcc-9 to build
|
||||||
apt-get install -y g++-9
|
apt-get install -y g++-9
|
||||||
|
|
|
||||||
12
.github/scripts/build_triton_wheel.py
vendored
12
.github/scripts/build_triton_wheel.py
vendored
|
|
@ -1,4 +1,5 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -60,6 +61,11 @@ def build_triton(
|
||||||
build_rocm: bool = False,
|
build_rocm: bool = False,
|
||||||
py_version: Optional[str] = None,
|
py_version: Optional[str] = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
|
env = os.environ.copy()
|
||||||
|
if "MAX_JOBS" not in env:
|
||||||
|
max_jobs = os.cpu_count() or 1
|
||||||
|
env["MAX_JOBS"] = str(max_jobs)
|
||||||
|
|
||||||
with TemporaryDirectory() as tmpdir:
|
with TemporaryDirectory() as tmpdir:
|
||||||
triton_basedir = Path(tmpdir) / "triton"
|
triton_basedir = Path(tmpdir) / "triton"
|
||||||
triton_pythondir = triton_basedir / "python"
|
triton_pythondir = triton_basedir / "python"
|
||||||
|
|
@ -81,6 +87,7 @@ def build_triton(
|
||||||
print(
|
print(
|
||||||
"build:\n string: py{{py}}\n number: 1\n script: cd python; "
|
"build:\n string: py{{py}}\n number: 1\n script: cd python; "
|
||||||
"python setup.py install --single-version-externally-managed --record=record.txt\n",
|
"python setup.py install --single-version-externally-managed --record=record.txt\n",
|
||||||
|
" script_env:\n - MAX_JOBS\n",
|
||||||
file=meta,
|
file=meta,
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
|
|
@ -113,6 +120,7 @@ def build_triton(
|
||||||
".",
|
".",
|
||||||
],
|
],
|
||||||
cwd=triton_basedir,
|
cwd=triton_basedir,
|
||||||
|
env=env,
|
||||||
)
|
)
|
||||||
conda_path = list(Path(tmpdir).glob("linux-64/torchtriton*.bz2"))[0]
|
conda_path = list(Path(tmpdir).glob("linux-64/torchtriton*.bz2"))[0]
|
||||||
shutil.copy(conda_path, Path.cwd())
|
shutil.copy(conda_path, Path.cwd())
|
||||||
|
|
@ -131,7 +139,9 @@ def build_triton(
|
||||||
if build_rocm:
|
if build_rocm:
|
||||||
check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True)
|
check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True)
|
||||||
|
|
||||||
check_call([sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir)
|
check_call(
|
||||||
|
[sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir, env=env
|
||||||
|
)
|
||||||
|
|
||||||
whl_path = list((triton_pythondir / "dist").glob("*.whl"))[0]
|
whl_path = list((triton_pythondir / "dist").glob("*.whl"))[0]
|
||||||
shutil.copy(whl_path, Path.cwd())
|
shutil.copy(whl_path, Path.cwd())
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,6 @@ inner(torch.randn(2, 2).to("{device}"))
|
||||||
|
|
||||||
@requires_cuda()
|
@requires_cuda()
|
||||||
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "runtime_error")
|
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "runtime_error")
|
||||||
@unittest.expectedFailure # Skipping this test due to Triton bug, see https://github.com/openai/triton/issues/1704 for details
|
|
||||||
def test_after_aot_cuda_runtime_error(self):
|
def test_after_aot_cuda_runtime_error(self):
|
||||||
self._test_after_aot_runtime_error("cuda", "device-side assert")
|
self._test_after_aot_runtime_error("cuda", "device-side assert")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user