From 940b60db974f08a31c746eec2f9c399fc8a861ee Mon Sep 17 00:00:00 2001 From: George White Date: Tue, 11 Mar 2025 03:56:07 +0000 Subject: [PATCH] Use the device interface for detecting Triton availability (#139171) This allows for each device type to check current devices for Triton compatibility and ensure their Triton backend is present. This PR replaces the `has_triton()` global method which was previously used for this task, and moves the initial check for each Inductor backend on to their associated `BaseScheduler` subclass. This means that other backends, such as Halide, can also implement their own availability checks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139171 Approved by: https://github.com/jansel --- tools/dynamo/verify_dynamo.py | 28 ++--- torch/_dynamo/device_interface.py | 113 ++++++++++++++---- torch/_dynamo/eval_frame.py | 37 ++++-- torch/_dynamo/logging.py | 1 - torch/_dynamo/utils.py | 17 +-- torch/_dynamo/variables/builder.py | 4 +- torch/_higher_order_ops/triton_kernel_wrap.py | 4 +- .../codegen/cuda_combined_scheduling.py | 11 +- torch/_inductor/codegen/mps.py | 2 +- torch/_inductor/codegen/triton.py | 16 +++ torch/_inductor/compile_fx.py | 4 +- torch/_inductor/fx_passes/pad_mm.py | 10 +- torch/_inductor/runtime/autotune_cache.py | 4 +- torch/_inductor/scheduler.py | 30 ++--- torch/_inductor/utils.py | 4 +- torch/sparse/_triton_ops.py | 4 +- torch/utils/_content_store.py | 3 +- torch/utils/_triton.py | 79 ++++++------ 18 files changed, 243 insertions(+), 128 deletions(-) diff --git a/tools/dynamo/verify_dynamo.py b/tools/dynamo/verify_dynamo.py index ec09fbd2b48..9d1b4cc7f74 100644 --- a/tools/dynamo/verify_dynamo.py +++ b/tools/dynamo/verify_dynamo.py @@ -142,7 +142,7 @@ def check_rocm(): return rocm_ver if torch.version.hip else "None" -def check_dynamo(backend, device, err_msg) -> None: +def check_dynamo(backend: str, device: str, err_msg: str) -> None: import torch if device == "cuda" and not torch.cuda.is_available(): @@ -151,17 +151,15 @@ def check_dynamo(backend, device, err_msg) -> None: try: import torch._dynamo as dynamo + from torch._dynamo.eval_frame import raise_if_inductor_unavailable - if device == "cuda": - from torch.utils._triton import has_triton - - if not has_triton(): - print( - f"WARNING: CUDA available but triton cannot be used. " - f"Your GPU may not be supported. " - f"Skipping CUDA check on {backend} backend\n" - ) - return + try: + raise_if_inductor_unavailable(device) + except RuntimeError as e: + print( + f"WARNING: Inductor not available for {device} ({e}). Skipping check." + ) + return dynamo.reset() @@ -205,6 +203,8 @@ _SANITY_CHECK_ARGS = ( def main() -> None: + from torch._dynamo.eval_frame import is_dynamo_supported + python_ver = check_python() torch_ver = check_torch() cuda_ver = check_cuda() @@ -215,10 +215,10 @@ def main() -> None: f"CUDA version: {cuda_ver}\n" f"ROCM version: {rocm_ver}\n" ) + if not is_dynamo_supported(): + warnings.warn("Dynamo is not supported on this platform. Skipping check.") + return for args in _SANITY_CHECK_ARGS: - if sys.version_info >= (3, 13): - warnings.warn("Dynamo not yet supported in Python 3.13. Skipping check.") - continue check_dynamo(*args) print("All required checks passed") diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index d8610915ec3..ac51e35a2b9 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -17,6 +17,7 @@ The abstraction layer enables device-agnostic code in TorchDynamo while allowing specialized implementations for each hardware backend's unique features. """ +import inspect import time from collections.abc import Iterable from dataclasses import dataclass @@ -31,8 +32,6 @@ if torch.cuda._is_compiled(): else: get_cuda_stream = None -_device_t = Union[torch.device, str, int, None] - # Recording the device properties in the main process but used in worker process. caching_worker_device_properties: dict[str, Any] = {} caching_worker_current_devices: dict[str, int] = {} @@ -45,7 +44,7 @@ class DeviceInterface: """ class device: - def __new__(cls, device: _device_t): + def __new__(cls, device: torch.types.Device): raise NotImplementedError class Event: @@ -77,7 +76,7 @@ class DeviceInterface: raise NotImplementedError @staticmethod - def get_device_properties(device: _device_t = None): + def get_device_properties(device: torch.types.Device = None): raise NotImplementedError @staticmethod @@ -85,7 +84,7 @@ class DeviceInterface: raise NotImplementedError @staticmethod - def set_device(device: _device_t): + def set_device(device: torch.types.Device): raise NotImplementedError @staticmethod @@ -125,15 +124,15 @@ class DeviceInterface: raise NotImplementedError @staticmethod - def synchronize(device: _device_t = None): + def synchronize(device: torch.types.Device = None): raise NotImplementedError @classmethod - def get_device_properties(cls, device: _device_t = None): + def get_device_properties(cls, device: torch.types.Device = None): return cls.Worker.get_device_properties(device) @staticmethod - def get_compute_capability(device: _device_t = None): + def get_compute_capability(device: torch.types.Device = None): raise NotImplementedError @staticmethod @@ -147,9 +146,30 @@ class DeviceInterface: return dtype != torch.bfloat16 or cls.is_bf16_supported(including_emulation) @staticmethod - def memory_allocated(device: _device_t = None) -> int: + def memory_allocated(device: torch.types.Device = None) -> int: raise NotImplementedError + @staticmethod + def is_triton_capable(device: torch.types.Device = None) -> bool: + """ + Returns True if the device has Triton support, False otherwise, even if + the appropriate Triton backend is not available. + """ + return False + + @classmethod + def raise_if_triton_unavailable(cls, device: torch.types.Device = None) -> None: + """ + Raises a `RuntimeError` with the appropriate human-readable instructions + to resolve the issue if Triton is not available for the given device, or + the default device if `device` is `None`. + + The caller should ensure the presence of the 'triton' package before + calling this method. + """ + if not cls.is_triton_capable(): + raise RuntimeError("This device is not capable of supporting Triton") + class DeviceGuard: """ @@ -198,7 +218,7 @@ class CudaInterface(DeviceInterface): return torch.cuda.current_device() @staticmethod - def get_device_properties(device: _device_t = None): + def get_device_properties(device: torch.types.Device = None): if device is not None: if isinstance(device, str): device = torch.device(device) @@ -238,13 +258,36 @@ class CudaInterface(DeviceInterface): return torch.cuda.is_available() @staticmethod - def get_compute_capability(device: _device_t = None): + def get_compute_capability(device: torch.types.Device = None): if torch.version.hip is None: major, min = torch.cuda.get_device_capability(device) return major * 10 + min else: return torch.cuda.get_device_properties(device).gcnArchName.split(":", 1)[0] + @staticmethod + def is_triton_capable(device: torch.types.Device = None) -> bool: + return ( + torch.version.hip is not None + or torch.cuda.get_device_properties(device).major >= 7 + ) + + @staticmethod + def raise_if_triton_unavailable(device: torch.types.Device = None) -> None: + from torch._inductor.exc import GPUTooOldForTriton + + if not CudaInterface.is_triton_capable(device): + device_props = torch.cuda.get_device_properties(device) + raise GPUTooOldForTriton(device_props, inspect.currentframe()) + + import triton.backends + + if torch.version.hip is not None: + if "amd" not in triton.backends.backends: + raise RuntimeError("triton not built with the 'amd' backend") + elif "nvidia" not in triton.backends.backends: + raise RuntimeError("triton not built with the 'nvidia' backend") + get_xpu_stream: Optional[Callable[[int], int]] if torch.xpu._is_compiled(): @@ -270,7 +313,7 @@ class XpuInterface(DeviceInterface): return torch.xpu.current_device() @staticmethod - def get_device_properties(device: _device_t = None): + def get_device_properties(device: torch.types.Device = None): if device is not None: if isinstance(device, str): device = torch.device(device) @@ -309,7 +352,7 @@ class XpuInterface(DeviceInterface): return torch.xpu.is_available() @staticmethod - def get_compute_capability(device: _device_t = None): + def get_compute_capability(device: torch.types.Device = None): cc = torch.xpu.get_device_capability(device) return cc @@ -317,6 +360,17 @@ class XpuInterface(DeviceInterface): def is_bf16_supported(including_emulation: bool = False) -> bool: return torch.xpu.is_bf16_supported() + @staticmethod + def is_triton_capable(device: torch.types.Device = None) -> bool: + return True + + @staticmethod + def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None: + import triton.backends + + if "intel" not in triton.backends.backends: + raise RuntimeError("triton not built with the 'intel' backend") + @dataclass class CpuDeviceProperties: @@ -334,6 +388,14 @@ class CpuInterface(DeviceInterface): def record(self, stream=None): self.time = time.perf_counter() + class Worker: + @staticmethod + def get_device_properties(device: torch.types.Device = None): + import multiprocessing + + cpu_count = multiprocessing.cpu_count() + return CpuDeviceProperties(cpu_count) + @staticmethod def is_available() -> bool: return True @@ -343,7 +405,7 @@ class CpuInterface(DeviceInterface): return True @staticmethod - def get_compute_capability(device: _device_t = None) -> str: + def get_compute_capability(device: torch.types.Device = None) -> str: return "" @staticmethod @@ -355,16 +417,19 @@ class CpuInterface(DeviceInterface): return 0 @staticmethod - def synchronize(device: _device_t = None): + def synchronize(device: torch.types.Device = None): pass - class Worker: - @staticmethod - def get_device_properties(device: _device_t = None): - import multiprocessing + @staticmethod + def is_triton_capable(device: torch.types.Device = None) -> bool: + return True - cpu_count = multiprocessing.cpu_count() - return CpuDeviceProperties(cpu_count) + @staticmethod + def raise_if_triton_unavailable(device: torch.types.Device = None) -> None: + import triton.backends + + if "cpu" not in triton.backends.backends: + raise RuntimeError("triton not built with the 'cpu' backend") class MpsInterface(DeviceInterface): @@ -389,16 +454,16 @@ class MpsInterface(DeviceInterface): return 0 @staticmethod - def get_compute_capability(device: _device_t = None) -> str: + def get_compute_capability(device: torch.types.Device = None) -> str: return "" @staticmethod - def synchronize(device: _device_t = None): + def synchronize(device: torch.types.Device = None): torch.mps.synchronize() class Worker: @staticmethod - def get_device_properties(device: _device_t = None): + def get_device_properties(device: torch.types.Device = None): return {} @staticmethod diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index dcdc5b3874d..bd9fc5e2828 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -893,7 +893,7 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg] return fn -def check_if_dynamo_supported(): +def raise_if_dynamo_unavailable() -> None: if sys.version_info >= (3, 14): raise RuntimeError("Python 3.14+ not yet supported for torch.compile") elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1: @@ -902,21 +902,40 @@ def check_if_dynamo_supported(): ) -def is_dynamo_supported(): +def is_dynamo_supported() -> bool: try: - check_if_dynamo_supported() + raise_if_dynamo_unavailable() return True except Exception: return False -def check_if_inductor_supported(): - check_if_dynamo_supported() +def raise_if_inductor_unavailable(device: torch.device | str | None = None) -> None: + from torch._inductor.codegen.common import ( + get_scheduling_for_device, + init_backend_registration, + ) + + raise_if_dynamo_unavailable() + + init_backend_registration() + + if device is None: + device = torch.get_default_device() + elif isinstance(device, str): + device = torch.device(device) + + scheduling_factory = get_scheduling_for_device(device.type) + if scheduling_factory is None: + raise RuntimeError( + f"No Inductor scheduling factory registered for {device.type}" + ) + scheduling_factory(None).raise_if_unavailable(device) -def is_inductor_supported(): +def is_inductor_supported(device: torch.device | str | None = None) -> bool: try: - check_if_inductor_supported() + raise_if_inductor_unavailable(device) return True except Exception: return False @@ -979,7 +998,7 @@ def _optimize( @torch._dynamo.optimize() def toy_example(a, b): ... """ - check_if_dynamo_supported() + raise_if_dynamo_unavailable() check_for_incompatible_configs() # Note: The hooks object could be global instead of passed around, *however* that would make # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls. @@ -1547,7 +1566,7 @@ def export( f = _f specialize_float = _specialize_float assume_static_by_default = _assume_static_by_default - check_if_dynamo_supported() + raise_if_dynamo_unavailable() torch._C._log_api_usage_once("torch._dynamo.export") if decomposition_table is not None: assert aten_graph, ( diff --git a/torch/_dynamo/logging.py b/torch/_dynamo/logging.py index 2d67665f5e9..56cf894f338 100644 --- a/torch/_dynamo/logging.py +++ b/torch/_dynamo/logging.py @@ -44,7 +44,6 @@ _step_counter = itertools.count(1) # Update num_steps if more phases are added: Dynamo, AOT, Backend # This is very inductor centric -# _inductor.utils.has_triton() gives a circular import error here if not disable_progress: try: diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 7ebaa8dcaac..63e68a1624e 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -89,7 +89,7 @@ from torch._utils_internal import ( from torch.fx._utils import _format_graph_code, lazy_format_graph_code from torch.monitor import _WaitCounter from torch.nn.modules.lazy import LazyModuleMixin -from torch.utils._triton import has_triton, has_triton_package +from torch.utils._triton import has_triton_package from torch.utils.hooks import RemovableHandle @@ -1489,7 +1489,7 @@ def record_compilation_metrics( "dynamo_config": _get_dynamo_config_for_logging(), "inductor_config": _scrubbed_inductor_config_for_logging(), "cuda_version": torch.version.cuda, - "triton_version": triton.__version__ if has_triton() else "", + "triton_version": triton.__version__ if has_triton_package() else "", "remote_cache_version": remote_cache_version, "inductor_fx_remote_cache_backend_type": inductor_fx_remote_cache_backend_type, } @@ -3744,17 +3744,10 @@ def build_checkpoint_variable(**options): ) -def is_compile_supported(device_type): - from .eval_frame import is_dynamo_supported +def is_compile_supported(device_type: str) -> bool: + from .eval_frame import is_inductor_supported - compile_supported = is_dynamo_supported() - if device_type == "cpu": - pass - elif device_type in ["cuda", "xpu"] and compile_supported: - compile_supported = has_triton() - else: - compile_supported = False - return compile_supported + return is_inductor_supported(device_type) # The following 3.11 source code functions are adapted from diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 33093332f9a..2642f3af5a4 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -567,9 +567,9 @@ class VariableBuilder: def _wrap(self, value): # import here to avoid circular dependencies - from torch.utils._triton import has_triton, has_triton_tma + from torch.utils._triton import has_triton_package, has_triton_tma - if has_triton(): + if has_triton_package(): from triton.runtime.autotuner import Autotuner from triton.runtime.jit import JITFunction else: diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 4e1f76279bd..120a3cec08f 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -37,14 +37,14 @@ if TYPE_CHECKING: from torch._dynamo.variables.functions import TritonKernelVariable from torch._subclasses.functional_tensor import BaseFunctionalizeAPI from torch.fx.proxy import Proxy - from torch.utils._triton import has_triton + from torch.utils._triton import has_triton_package TritonMetaParamsType = dict[str, int] TritonGridTupleType = tuple[Union[int, sympy.Expr, SymInt], ...] TritonGridCallableType = Callable[[TritonMetaParamsType], tuple[int, ...]] TritonGridType = Union[TritonGridTupleType, TritonGridCallableType] - if has_triton(): + if has_triton_package(): from triton.runtime.autotuner import Autotuner, Config as TritonConfig from triton.runtime.jit import JITFunction else: diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index 3af7d72f710..67ce4f2c303 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -45,8 +45,15 @@ class CUDACombinedScheduling(BaseScheduling): self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) - def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]: - return self._triton_scheduling.get_backend_features(device) + @classmethod + def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: + return TritonScheduling.get_backend_features(device) + + @classmethod + def raise_if_unavailable( + cls, device: Union[str, torch.device, None] = None + ) -> None: + TritonScheduling.raise_if_unavailable(device) def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: if self._cuda_cpp_scheduling.is_cuda_cpp_template(node): diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index ba1ad65e8dd..8bb6e07d25b 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -595,7 +595,7 @@ class MetalScheduling(SIMDScheduling): def __init__(self, scheduler: Optional[Scheduler]) -> None: super().__init__(scheduler) - wrapper = V.graph.wrapper_code + wrapper = getattr(V.graph, "wrapper_code", None) if wrapper is not None: wrapper.header.splice( "from torch._inductor.runtime.runtime_utils import compile_mps_shader" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index d6a05425204..498ced9e19e 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -5,6 +5,7 @@ import collections import contextlib import dataclasses import functools +import inspect import itertools import logging import math @@ -32,6 +33,7 @@ from ...utils._sympy.value_ranges import ValueRanges from .. import config, ir, metrics from ..async_compile import AsyncCompile from ..codecache import code_hash, get_path, PyCodeCache +from ..exc import TritonMissing from ..ops_handler import DefaultHandler from ..runtime.benchmarking import benchmarker from ..runtime.hints import ( @@ -4039,6 +4041,20 @@ class TritonScheduling(SIMDScheduling): ) return cls.backend_features + @classmethod + def raise_if_unavailable( + cls, device: Union[str, torch.device, None] = None + ) -> None: + if not has_triton_package(): + raise TritonMissing(inspect.currentframe()) + + from torch._dynamo.device_interface import get_interface_for_device + + if device is None: + device = torch.get_default_device() + + get_interface_for_device(device).raise_if_triton_unavailable(device) + def codegen_comment(self, node_schedule): wrapper = V.graph.wrapper_code origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 9f6adf19015..4ce39080b47 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -93,7 +93,7 @@ from .._dynamo.backends.common import aot_autograd from .._dynamo.exc import ShortenTraceback, SkipFrame from ..fx._lazy_graph_module import _use_lazy_graph_module from ..fx.graph import _PyTreeCodeGen -from ..utils._triton import has_triton +from ..utils._triton import has_triton_package from . import config, metrics from .debug import DebugContext from .decomposition import select_decomp_table @@ -1664,7 +1664,7 @@ def get_cpp_wrapper_config() -> dict[str, object]: "triton.autotune_at_compile_time": ( config.triton.autotune_at_compile_time if config.triton.autotune_at_compile_time is not None - else has_triton() + else has_triton_package() ), "triton.autotune_cublasLt": False, "triton.cudagraphs": False, # TODO: to be removed diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index a42296fe68a..4e88f109622 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -24,7 +24,9 @@ from torch._inductor.autoheuristic.autoheuristic_utils import ( from torch._subclasses.fake_tensor import FakeTensor from torch.utils._mode_utils import no_dispatch -from ...utils._triton import has_triton +from ..codegen.common import get_scheduling_for_device +from ..codegen.cuda_combined_scheduling import CUDACombinedScheduling +from ..codegen.triton import TritonScheduling from ..pattern_matcher import ( fwd_only, gen_register_replacement, @@ -458,7 +460,11 @@ def _should_pad_bench( ): return True - if not has_triton(): + scheduling_factory = get_scheduling_for_device(mat1.device.type) + if scheduling_factory is None or not isinstance( + scheduling_factory(None), + (TritonScheduling, CUDACombinedScheduling), + ): return False if not is_mm_compute_bound(m, k, n, mat1.dtype): diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index 0c098f6afa4..0687d9e9da5 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -11,7 +11,7 @@ from typing_extensions import override import torch from torch.compiler._cache import CacheArtifactManager, CacheArtifactType -from torch.utils._triton import has_triton +from torch.utils._triton import has_triton_package from ..remote_cache import ( create_cache, @@ -36,7 +36,7 @@ def inductor_meta_from_config() -> _InductorMetaTy: from torch._inductor import config backend_hash = None - if has_triton(): + if has_triton_package(): try: backend_hash = torch.utils._triton.triton_hash_with_backend() except RuntimeError: diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 0b5b2101004..ee321dc2d19 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3,7 +3,6 @@ from __future__ import annotations import collections import dataclasses import functools -import inspect import itertools import logging import math @@ -31,14 +30,12 @@ from torch._inductor.metrics import get_metric_table, is_metric_table_enabled from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.symbol import free_symbol_is_type, SymT -from torch.utils._triton import has_triton from . import comms, config, dependencies, ir, metrics from .analyze_preserves_zero_mask import can_codegen_without_upcasts from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel from .comm_analysis import estimate_nccl_collective_runtime from .dependencies import Dep, MemoryDep, StarDep, WeakDep -from .exc import GPUTooOldForTriton, TritonMissing from .ir import ( ComputedBuffer, get_device_type, @@ -3910,20 +3907,14 @@ class Scheduler: ) V.graph.add_device_info(device) - device_scheduling = get_scheduling_for_device(device.type) - if device_scheduling is None: + device_scheduling_type = get_scheduling_for_device(device.type) + if device_scheduling_type is None: raise RuntimeError(f"Unsupported device type: {device.type}") - if not has_triton(): - if ( - device.type == "cuda" - and (device_props := torch.cuda.get_device_properties(device)).major < 7 - ): - raise GPUTooOldForTriton(device_props, inspect.currentframe()) - elif is_gpu(device.type) and not device.type == "mps": - raise TritonMissing(inspect.currentframe()) + scheduling = device_scheduling_type(self) + scheduling.raise_if_unavailable(device) - return device_scheduling(self) + return scheduling def get_backend(self, device: Optional[torch.device]) -> BaseScheduling: assert device is not None @@ -4372,6 +4363,17 @@ class BaseScheduling: """Return a set of .codegen.common.BackendFeature()""" return OrderedSet() + @classmethod + def raise_if_unavailable( + cls, device: Union[str, torch.device, None] = None + ) -> None: + """ + Raises a RuntimeError if the given device does not support this codegen or required + prerequisites are not available with a useful description for the user. If None is given, + the default device is checked. + """ + return None + def can_fuse_vertical( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> bool: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 8e92150d5b0..4bda98cd35d 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1337,7 +1337,7 @@ def use_triton_template( def use_triton_tma_template(*matrices: IRNode) -> bool: - from torch.utils._triton import has_triton_tma_device + from torch.utils._triton import has_triton_tma from .virtualized import V @@ -1362,7 +1362,7 @@ def use_triton_tma_template(*matrices: IRNode) -> bool: return ( config.triton.enable_persistent_tma_matmul - and has_triton_tma_device() + and has_triton_tma() and all(_is_tma_compatible(m) for m in matrices) ) diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index ce0e8446cba..2628422b430 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -8,7 +8,7 @@ from typing import Optional import torch from torch._dynamo.utils import warn_once -from torch.utils._triton import has_triton +from torch.utils._triton import has_triton_package from ._triton_ops_meta import get_meta @@ -1323,7 +1323,7 @@ def bsr_dense_addmm( return out_backup -if has_triton(): +if has_triton_package(): import triton import triton.language as tl diff --git a/torch/utils/_content_store.py b/torch/utils/_content_store.py index fab3730a43c..953befe6cbd 100644 --- a/torch/utils/_content_store.py +++ b/torch/utils/_content_store.py @@ -97,7 +97,8 @@ def hash_storage(storage: torch.UntypedStorage, *, stable_hash: bool = False) -> from torch._dynamo.utils import is_compile_supported device_type = storage.device.type - if stable_hash or not is_compile_supported(device_type): + # FIXME: MPS does not yet support some of the ops required for hashing + if stable_hash or not is_compile_supported(device_type) or device_type == "mps": cpu_storage = storage.cpu() # TODO: make storage support buffer protocol so this isn't # necessary diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index 1609a3fe77c..f5329fa0007 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -1,6 +1,11 @@ # mypy: allow-untyped-defs import functools import hashlib +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from torch.types import Device @functools.lru_cache(None) @@ -9,11 +14,47 @@ def has_triton_package() -> bool: from triton.compiler.compiler import triton_key return triton_key is not None - except ImportError: + except (ImportError, RuntimeError): return False - except RuntimeError: + + +@functools.lru_cache(None) +def has_triton(device: "Device" = None) -> bool: + """ + Determine if Triton is available for use on this system for a given device + (if device is not None) or any available device type if no device is given. + """ + import torch + from torch._dynamo.device_interface import ( + DeviceInterface, + get_interface_for_device, + get_registered_device_interfaces, + ) + + if not has_triton_package(): return False + def device_has_triton(di: type[DeviceInterface]) -> bool: + if not di.is_available(): + return False + + try: + di.raise_if_triton_unavailable(device) + except RuntimeError: + return False + + return True + + if device is None: + return any( + device_has_triton(di) for _, di in get_registered_device_interfaces() + ) + + if not isinstance(device, (str, torch.device)): + device = torch.device(device) + + return device_has_triton(get_interface_for_device(device)) + @functools.lru_cache(None) def has_triton_tma(): @@ -61,40 +102,6 @@ def has_triton_tma_device(): return False -@functools.lru_cache(None) -def has_triton() -> bool: - if not has_triton_package(): - return False - - from torch._dynamo.device_interface import get_interface_for_device - - def cuda_extra_check(device_interface): - return device_interface.Worker.get_device_properties().major >= 7 - - def cpu_extra_check(device_interface): - import triton.backends - - return "cpu" in triton.backends.backends - - def _return_true(device_interface): - return True - - triton_supported_devices = { - "cuda": cuda_extra_check, - "xpu": _return_true, - "cpu": cpu_extra_check, - } - - def is_device_compatible_with_triton(): - for device, extra_check in triton_supported_devices.items(): - device_interface = get_interface_for_device(device) - if device_interface.is_available() and extra_check(device_interface): - return True - return False - - return is_device_compatible_with_triton() - - @functools.lru_cache(None) def triton_backend(): from triton.compiler.compiler import make_backend