mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Update triton pin (#102736)
There is some bug in triton's handling of `tl.reduce` that breaks the variance PR, but is fixed on the latest triton master. Pull Request resolved: https://github.com/pytorch/pytorch/pull/102736 Approved by: https://github.com/huydhn, https://github.com/desertfire
This commit is contained in:
parent
455f542ed9
commit
31ee1512d3
|
|
@ -1 +1 @@
|
|||
9820899b3845e461d9031dba66062efade65d420
|
||||
440fd1bf20697b0961ddb0822de86d151c58dd36
|
||||
|
|
|
|||
|
|
@ -32,6 +32,10 @@ if [ -n "${CONDA_CMAKE}" ]; then
|
|||
NUMPY_VERSION=$(get_conda_version numpy)
|
||||
fi
|
||||
|
||||
if [ -z "${MAX_JOBS}" ]; then
|
||||
export MAX_JOBS=$(nproc)
|
||||
fi
|
||||
|
||||
if [ -n "${GCC_VERSION}" ] && [[ "${GCC_VERSION}" == "7" ]]; then
|
||||
# Triton needs at least gcc-9 to build
|
||||
apt-get install -y g++-9
|
||||
|
|
|
|||
12
.github/scripts/build_triton_wheel.py
vendored
12
.github/scripts/build_triton_wheel.py
vendored
|
|
@ -1,4 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
|
@ -60,6 +61,11 @@ def build_triton(
|
|||
build_rocm: bool = False,
|
||||
py_version: Optional[str] = None,
|
||||
) -> Path:
|
||||
env = os.environ.copy()
|
||||
if "MAX_JOBS" not in env:
|
||||
max_jobs = os.cpu_count() or 1
|
||||
env["MAX_JOBS"] = str(max_jobs)
|
||||
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
triton_basedir = Path(tmpdir) / "triton"
|
||||
triton_pythondir = triton_basedir / "python"
|
||||
|
|
@ -81,6 +87,7 @@ def build_triton(
|
|||
print(
|
||||
"build:\n string: py{{py}}\n number: 1\n script: cd python; "
|
||||
"python setup.py install --single-version-externally-managed --record=record.txt\n",
|
||||
" script_env:\n - MAX_JOBS\n",
|
||||
file=meta,
|
||||
)
|
||||
print(
|
||||
|
|
@ -113,6 +120,7 @@ def build_triton(
|
|||
".",
|
||||
],
|
||||
cwd=triton_basedir,
|
||||
env=env,
|
||||
)
|
||||
conda_path = list(Path(tmpdir).glob("linux-64/torchtriton*.bz2"))[0]
|
||||
shutil.copy(conda_path, Path.cwd())
|
||||
|
|
@ -131,7 +139,9 @@ def build_triton(
|
|||
if build_rocm:
|
||||
check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True)
|
||||
|
||||
check_call([sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir)
|
||||
check_call(
|
||||
[sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir, env=env
|
||||
)
|
||||
|
||||
whl_path = list((triton_pythondir / "dist").glob("*.whl"))[0]
|
||||
shutil.copy(whl_path, Path.cwd())
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ inner(torch.randn(2, 2).to("{device}"))
|
|||
|
||||
@requires_cuda()
|
||||
@inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "runtime_error")
|
||||
@unittest.expectedFailure # Skipping this test due to Triton bug, see https://github.com/openai/triton/issues/1704 for details
|
||||
def test_after_aot_cuda_runtime_error(self):
|
||||
self._test_after_aot_runtime_error("cuda", "device-side assert")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user