Revert "has_triton: Use the device interface for detecting Triton availability (#139171)"

This reverts commit 48bfe9afc7.

Reverted https://github.com/pytorch/pytorch/pull/139171 on behalf of https://github.com/masnesral due to Performance regression for huggingface ([comment](https://github.com/pytorch/pytorch/pull/139171#issuecomment-2868939790))
This commit is contained in:
PyTorch MergeBot 2025-05-10 14:46:23 +00:00
parent 70c8047c2d
commit 01bb249978
9 changed files with 56 additions and 68 deletions

View File

@ -1548,7 +1548,7 @@ def record_compilation_metrics(
"dynamo_config": _get_dynamo_config_for_logging(),
"inductor_config": _scrubbed_inductor_config_for_logging(),
"cuda_version": torch.version.cuda,
"triton_version": triton.__version__ if has_triton_package() else "",
"triton_version": triton.__version__ if has_triton() else "",
"remote_cache_version": remote_cache_version,
"inductor_fx_remote_cache_backend_type": inductor_fx_remote_cache_backend_type,
"python_version": sys.version,
@ -3830,14 +3830,15 @@ def build_checkpoint_variable(**options):
)
def is_compile_supported(device_type: str) -> bool:
def is_compile_supported(device_type):
from .eval_frame import is_dynamo_supported
type = torch.device(device_type).type
compile_supported = is_dynamo_supported()
if device_type == "cpu":
if type == "cpu":
pass
elif device_type in ["cuda", "xpu"] and compile_supported:
compile_supported = has_triton(device_type)
elif type in ["cuda", "xpu"] and compile_supported:
compile_supported = has_triton()
else:
compile_supported = False
return compile_supported

View File

@ -596,11 +596,11 @@ class VariableBuilder:
def _wrap(self, value):
# import here to avoid circular dependencies
from torch.utils._triton import has_triton_package, has_triton_tma
from torch.utils._triton import has_triton, has_triton_tma
from ..decorators import DynamoConfigPatchProxy
if has_triton_package():
if has_triton():
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
else:

View File

@ -39,14 +39,14 @@ if TYPE_CHECKING:
from torch._dynamo.variables.functions import TritonKernelVariable
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
from torch.fx.proxy import Proxy
from torch.utils._triton import has_triton_package
from torch.utils._triton import has_triton
TritonMetaParamsType = dict[str, int]
TritonGridTupleType = tuple[Union[int, sympy.Expr, SymInt], ...]
TritonGridCallableType = Callable[[TritonMetaParamsType], tuple[int, ...]]
TritonGridType = Union[TritonGridTupleType, TritonGridCallableType]
if has_triton_package():
if has_triton():
from triton.runtime.autotuner import Autotuner, Config as TritonConfig
from triton.runtime.jit import JITFunction
else:

View File

@ -95,7 +95,7 @@ from .._dynamo.backends.common import aot_autograd
from .._dynamo.exc import ShortenTraceback, SkipFrame
from ..fx._lazy_graph_module import _use_lazy_graph_module
from ..fx.graph import _PyTreeCodeGen
from ..utils._triton import has_triton_package
from ..utils._triton import has_triton
from . import config, metrics
from .codegen.common import get_wrapper_codegen_for_device, init_backend_registration
from .debug import DebugContext
@ -1857,7 +1857,7 @@ def get_cpp_wrapper_config() -> dict[str, object]:
"triton.autotune_at_compile_time": (
config.triton.autotune_at_compile_time
if config.triton.autotune_at_compile_time is not None
else has_triton_package()
else has_triton()
),
"triton.autotune_cublasLt": False,
"triton.cudagraphs": False, # TODO: to be removed

View File

@ -24,9 +24,7 @@ from torch._inductor.autoheuristic.autoheuristic_utils import (
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._mode_utils import no_dispatch
from ..codegen.common import get_scheduling_for_device
from ..codegen.cuda_combined_scheduling import CUDACombinedScheduling
from ..codegen.triton import TritonScheduling
from ...utils._triton import has_triton
from ..pattern_matcher import (
fwd_only,
gen_register_replacement,
@ -460,11 +458,7 @@ def _should_pad_bench(
):
return True
scheduling_factory = get_scheduling_for_device(mat1.device.type)
if scheduling_factory is None or not isinstance(
scheduling_factory(None),
(TritonScheduling, CUDACombinedScheduling),
):
if not has_triton():
return False
if not is_mm_compute_bound(m, k, n, mat1.dtype):

View File

@ -11,7 +11,7 @@ from typing_extensions import override
import torch
from torch.compiler._cache import CacheArtifactManager, CacheArtifactType
from torch.utils._triton import has_triton_package
from torch.utils._triton import has_triton
from ..remote_cache import (
create_cache,
@ -36,7 +36,7 @@ def inductor_meta_from_config() -> _InductorMetaTy:
from torch._inductor import config
backend_hash = None
if has_triton_package():
if has_triton():
try:
backend_hash = torch.utils._triton.triton_hash_with_backend()
except RuntimeError:

View File

@ -1504,7 +1504,7 @@ def use_triton_template(
def use_triton_tma_template(*matrices: IRNode) -> bool:
from torch.utils._triton import has_triton_tma
from torch.utils._triton import has_triton_tma_device
from .virtualized import V
@ -1535,7 +1535,7 @@ def use_triton_tma_template(*matrices: IRNode) -> bool:
return (
config.triton.enable_persistent_tma_matmul
and has_triton_tma()
and has_triton_tma_device()
and all(_is_tma_compatible(m) for m in matrices)
)

View File

@ -8,7 +8,7 @@ from typing import Optional
import torch
from torch._dynamo.utils import warn_once
from torch.utils._triton import has_triton_package
from torch.utils._triton import has_triton
from ._triton_ops_meta import get_meta
@ -1323,7 +1323,7 @@ def bsr_dense_addmm(
return out_backup
if has_triton_package():
if has_triton():
import triton
import triton.language as tl

View File

@ -1,11 +1,6 @@
# mypy: allow-untyped-defs
import functools
import hashlib
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch.types import Device
@functools.lru_cache(None)
@ -14,47 +9,11 @@ def has_triton_package() -> bool:
from triton.compiler.compiler import triton_key
return triton_key is not None
except (ImportError, RuntimeError):
except ImportError:
return False
@functools.lru_cache(None)
def has_triton(device: "Device" = None) -> bool:
"""
Determine if Triton is available for use on this system for a given device
(if device is not None) or any available device type if no device is given.
"""
import torch
from torch._dynamo.device_interface import (
DeviceInterface,
get_interface_for_device,
get_registered_device_interfaces,
)
if not has_triton_package():
except RuntimeError:
return False
def device_has_triton(di: type[DeviceInterface]) -> bool:
if not di.is_available():
return False
try:
di.raise_if_triton_unavailable(device)
except RuntimeError:
return False
return True
if device is None:
return any(
device_has_triton(di) for _, di in get_registered_device_interfaces()
)
if not isinstance(device, (str, torch.device)):
device = torch.device(device)
return device_has_triton(get_interface_for_device(device))
@functools.lru_cache(None)
def has_triton_tma():
@ -102,6 +61,40 @@ def has_triton_tma_device():
return False
@functools.lru_cache(None)
def has_triton() -> bool:
if not has_triton_package():
return False
from torch._dynamo.device_interface import get_interface_for_device
def cuda_extra_check(device_interface):
return device_interface.Worker.get_device_properties().major >= 7
def cpu_extra_check(device_interface):
import triton.backends
return "cpu" in triton.backends.backends
def _return_true(device_interface):
return True
triton_supported_devices = {
"cuda": cuda_extra_check,
"xpu": _return_true,
"cpu": cpu_extra_check,
}
def is_device_compatible_with_triton():
for device, extra_check in triton_supported_devices.items():
device_interface = get_interface_for_device(device)
if device_interface.is_available() and extra_check(device_interface):
return True
return False
return is_device_compatible_with_triton()
@functools.lru_cache(None)
def triton_backend():
from triton.compiler.compiler import make_backend