pytorch/torch/utils/_triton.py
Sam Larsen 358da54be5 [inductor] Better messaging when triton version is too old (#130403)
Summary:
If triton is available, but we can't import triton.compiler.compiler.triton_key, then we see some annoying behavior:
1) If we don't actually need to compile triton, the subprocess pool will still spew error messages about the import failure; it's unclear to users if this is an actual problem.
2) If we do need to compile triton, we a) see the error messages from above and b) get a vanilla import exception without the helpful "RuntimeError: Cannot find a working triton installation ..."

Test Plan: Ran with and without torch.compile for a) recent version of triton, b) triton 2.2, and c) no triton. In all cases, verified expected output (success or meaningful error message)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130403
Approved by: https://github.com/eellison
2024-07-10 23:45:50 +00:00

88 lines
2.4 KiB
Python

# mypy: allow-untyped-defs
import functools
import hashlib
@functools.lru_cache(None)
def has_triton_package() -> bool:
try:
from triton.compiler.compiler import triton_key
return triton_key is not None
except ImportError:
return False
@functools.lru_cache(None)
def has_triton() -> bool:
from torch._dynamo.device_interface import get_interface_for_device
def cuda_extra_check(device_interface):
return device_interface.Worker.get_device_properties().major >= 7
def _return_true(device_interface):
return True
triton_supported_devices = {"cuda": cuda_extra_check, "xpu": _return_true}
def is_device_compatible_with_triton():
for device, extra_check in triton_supported_devices.items():
device_interface = get_interface_for_device(device)
if device_interface.is_available() and extra_check(device_interface):
return True
return False
return is_device_compatible_with_triton() and has_triton_package()
@functools.lru_cache(None)
def triton_backend():
import torch
if torch.version.hip:
# Does not work with ROCm
return None
from triton.compiler.compiler import make_backend
from triton.runtime.driver import driver
target = driver.active.get_current_target()
return make_backend(target)
@functools.lru_cache(None)
def triton_hash_with_backend():
import torch
if torch.version.hip:
# Does not work with ROCm
return None
from triton.compiler.compiler import triton_key
backend = triton_backend()
key = f"{triton_key()}-{backend.hash()}"
# Hash is upper case so that it can't contain any Python keywords.
return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
def dtype_to_string(dtype):
if dtype.name.startswith("fp"):
suffix = "float" + dtype.name[2:]
elif dtype.name.startswith("bf"):
suffix = "bfloat" + dtype.name[2:]
else:
suffix = dtype.name
return "triton.language." + suffix
def patch_triton_dtype_repr():
import triton
# Hack to get triton dtype repr to produce an evaluatable expression
# triton.language.float32 emits triton.language.fp32 which does not
# exist
# REMOVE when https://github.com/openai/triton/pull/3342 lands
triton.language.dtype.__repr__ = lambda self: dtype_to_string(self)