mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: If triton is available, but we can't import triton.compiler.compiler.triton_key, then we see some annoying behavior: 1) If we don't actually need to compile triton, the subprocess pool will still spew error messages about the import failure; it's unclear to users if this is an actual problem. 2) If we do need to compile triton, we a) see the error messages from above and b) get a vanilla import exception without the helpful "RuntimeError: Cannot find a working triton installation ..." Test Plan: Ran with and without torch.compile for a) recent version of triton, b) triton 2.2, and c) no triton. In all cases, verified expected output (success or meaningful error message) Pull Request resolved: https://github.com/pytorch/pytorch/pull/130403 Approved by: https://github.com/eellison
88 lines
2.4 KiB
Python
88 lines
2.4 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
import hashlib
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton_package() -> bool:
|
|
try:
|
|
from triton.compiler.compiler import triton_key
|
|
|
|
return triton_key is not None
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton() -> bool:
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
|
|
def cuda_extra_check(device_interface):
|
|
return device_interface.Worker.get_device_properties().major >= 7
|
|
|
|
def _return_true(device_interface):
|
|
return True
|
|
|
|
triton_supported_devices = {"cuda": cuda_extra_check, "xpu": _return_true}
|
|
|
|
def is_device_compatible_with_triton():
|
|
for device, extra_check in triton_supported_devices.items():
|
|
device_interface = get_interface_for_device(device)
|
|
if device_interface.is_available() and extra_check(device_interface):
|
|
return True
|
|
return False
|
|
|
|
return is_device_compatible_with_triton() and has_triton_package()
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def triton_backend():
|
|
import torch
|
|
|
|
if torch.version.hip:
|
|
# Does not work with ROCm
|
|
return None
|
|
|
|
from triton.compiler.compiler import make_backend
|
|
from triton.runtime.driver import driver
|
|
|
|
target = driver.active.get_current_target()
|
|
return make_backend(target)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def triton_hash_with_backend():
|
|
import torch
|
|
|
|
if torch.version.hip:
|
|
# Does not work with ROCm
|
|
return None
|
|
|
|
from triton.compiler.compiler import triton_key
|
|
|
|
backend = triton_backend()
|
|
key = f"{triton_key()}-{backend.hash()}"
|
|
|
|
# Hash is upper case so that it can't contain any Python keywords.
|
|
return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
|
|
|
|
|
|
def dtype_to_string(dtype):
|
|
if dtype.name.startswith("fp"):
|
|
suffix = "float" + dtype.name[2:]
|
|
elif dtype.name.startswith("bf"):
|
|
suffix = "bfloat" + dtype.name[2:]
|
|
else:
|
|
suffix = dtype.name
|
|
return "triton.language." + suffix
|
|
|
|
|
|
def patch_triton_dtype_repr():
|
|
import triton
|
|
|
|
# Hack to get triton dtype repr to produce an evaluatable expression
|
|
# triton.language.float32 emits triton.language.fp32 which does not
|
|
# exist
|
|
# REMOVE when https://github.com/openai/triton/pull/3342 lands
|
|
triton.language.dtype.__repr__ = lambda self: dtype_to_string(self)
|