mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "has_triton: Use the device interface for detecting Triton availability (#139171)"
This reverts commit 48bfe9afc7.
Reverted https://github.com/pytorch/pytorch/pull/139171 on behalf of https://github.com/masnesral due to Performance regression for huggingface ([comment](https://github.com/pytorch/pytorch/pull/139171#issuecomment-2868939790))
This commit is contained in:
parent
70c8047c2d
commit
01bb249978
|
|
@ -1548,7 +1548,7 @@ def record_compilation_metrics(
|
|||
"dynamo_config": _get_dynamo_config_for_logging(),
|
||||
"inductor_config": _scrubbed_inductor_config_for_logging(),
|
||||
"cuda_version": torch.version.cuda,
|
||||
"triton_version": triton.__version__ if has_triton_package() else "",
|
||||
"triton_version": triton.__version__ if has_triton() else "",
|
||||
"remote_cache_version": remote_cache_version,
|
||||
"inductor_fx_remote_cache_backend_type": inductor_fx_remote_cache_backend_type,
|
||||
"python_version": sys.version,
|
||||
|
|
@ -3830,14 +3830,15 @@ def build_checkpoint_variable(**options):
|
|||
)
|
||||
|
||||
|
||||
def is_compile_supported(device_type: str) -> bool:
|
||||
def is_compile_supported(device_type):
|
||||
from .eval_frame import is_dynamo_supported
|
||||
|
||||
type = torch.device(device_type).type
|
||||
compile_supported = is_dynamo_supported()
|
||||
if device_type == "cpu":
|
||||
if type == "cpu":
|
||||
pass
|
||||
elif device_type in ["cuda", "xpu"] and compile_supported:
|
||||
compile_supported = has_triton(device_type)
|
||||
elif type in ["cuda", "xpu"] and compile_supported:
|
||||
compile_supported = has_triton()
|
||||
else:
|
||||
compile_supported = False
|
||||
return compile_supported
|
||||
|
|
|
|||
|
|
@ -596,11 +596,11 @@ class VariableBuilder:
|
|||
|
||||
def _wrap(self, value):
|
||||
# import here to avoid circular dependencies
|
||||
from torch.utils._triton import has_triton_package, has_triton_tma
|
||||
from torch.utils._triton import has_triton, has_triton_tma
|
||||
|
||||
from ..decorators import DynamoConfigPatchProxy
|
||||
|
||||
if has_triton_package():
|
||||
if has_triton():
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
from triton.runtime.jit import JITFunction
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -39,14 +39,14 @@ if TYPE_CHECKING:
|
|||
from torch._dynamo.variables.functions import TritonKernelVariable
|
||||
from torch._subclasses.functional_tensor import BaseFunctionalizeAPI
|
||||
from torch.fx.proxy import Proxy
|
||||
from torch.utils._triton import has_triton_package
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
TritonMetaParamsType = dict[str, int]
|
||||
TritonGridTupleType = tuple[Union[int, sympy.Expr, SymInt], ...]
|
||||
TritonGridCallableType = Callable[[TritonMetaParamsType], tuple[int, ...]]
|
||||
TritonGridType = Union[TritonGridTupleType, TritonGridCallableType]
|
||||
|
||||
if has_triton_package():
|
||||
if has_triton():
|
||||
from triton.runtime.autotuner import Autotuner, Config as TritonConfig
|
||||
from triton.runtime.jit import JITFunction
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ from .._dynamo.backends.common import aot_autograd
|
|||
from .._dynamo.exc import ShortenTraceback, SkipFrame
|
||||
from ..fx._lazy_graph_module import _use_lazy_graph_module
|
||||
from ..fx.graph import _PyTreeCodeGen
|
||||
from ..utils._triton import has_triton_package
|
||||
from ..utils._triton import has_triton
|
||||
from . import config, metrics
|
||||
from .codegen.common import get_wrapper_codegen_for_device, init_backend_registration
|
||||
from .debug import DebugContext
|
||||
|
|
@ -1857,7 +1857,7 @@ def get_cpp_wrapper_config() -> dict[str, object]:
|
|||
"triton.autotune_at_compile_time": (
|
||||
config.triton.autotune_at_compile_time
|
||||
if config.triton.autotune_at_compile_time is not None
|
||||
else has_triton_package()
|
||||
else has_triton()
|
||||
),
|
||||
"triton.autotune_cublasLt": False,
|
||||
"triton.cudagraphs": False, # TODO: to be removed
|
||||
|
|
|
|||
|
|
@ -24,9 +24,7 @@ from torch._inductor.autoheuristic.autoheuristic_utils import (
|
|||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
|
||||
from ..codegen.common import get_scheduling_for_device
|
||||
from ..codegen.cuda_combined_scheduling import CUDACombinedScheduling
|
||||
from ..codegen.triton import TritonScheduling
|
||||
from ...utils._triton import has_triton
|
||||
from ..pattern_matcher import (
|
||||
fwd_only,
|
||||
gen_register_replacement,
|
||||
|
|
@ -460,11 +458,7 @@ def _should_pad_bench(
|
|||
):
|
||||
return True
|
||||
|
||||
scheduling_factory = get_scheduling_for_device(mat1.device.type)
|
||||
if scheduling_factory is None or not isinstance(
|
||||
scheduling_factory(None),
|
||||
(TritonScheduling, CUDACombinedScheduling),
|
||||
):
|
||||
if not has_triton():
|
||||
return False
|
||||
|
||||
if not is_mm_compute_bound(m, k, n, mat1.dtype):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing_extensions import override
|
|||
|
||||
import torch
|
||||
from torch.compiler._cache import CacheArtifactManager, CacheArtifactType
|
||||
from torch.utils._triton import has_triton_package
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
from ..remote_cache import (
|
||||
create_cache,
|
||||
|
|
@ -36,7 +36,7 @@ def inductor_meta_from_config() -> _InductorMetaTy:
|
|||
from torch._inductor import config
|
||||
|
||||
backend_hash = None
|
||||
if has_triton_package():
|
||||
if has_triton():
|
||||
try:
|
||||
backend_hash = torch.utils._triton.triton_hash_with_backend()
|
||||
except RuntimeError:
|
||||
|
|
|
|||
|
|
@ -1504,7 +1504,7 @@ def use_triton_template(
|
|||
|
||||
|
||||
def use_triton_tma_template(*matrices: IRNode) -> bool:
|
||||
from torch.utils._triton import has_triton_tma
|
||||
from torch.utils._triton import has_triton_tma_device
|
||||
|
||||
from .virtualized import V
|
||||
|
||||
|
|
@ -1535,7 +1535,7 @@ def use_triton_tma_template(*matrices: IRNode) -> bool:
|
|||
|
||||
return (
|
||||
config.triton.enable_persistent_tma_matmul
|
||||
and has_triton_tma()
|
||||
and has_triton_tma_device()
|
||||
and all(_is_tma_compatible(m) for m in matrices)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Optional
|
|||
|
||||
import torch
|
||||
from torch._dynamo.utils import warn_once
|
||||
from torch.utils._triton import has_triton_package
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
from ._triton_ops_meta import get_meta
|
||||
|
||||
|
|
@ -1323,7 +1323,7 @@ def bsr_dense_addmm(
|
|||
return out_backup
|
||||
|
||||
|
||||
if has_triton_package():
|
||||
if has_triton():
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.types import Device
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
|
|
@ -14,47 +9,11 @@ def has_triton_package() -> bool:
|
|||
from triton.compiler.compiler import triton_key
|
||||
|
||||
return triton_key is not None
|
||||
except (ImportError, RuntimeError):
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton(device: "Device" = None) -> bool:
|
||||
"""
|
||||
Determine if Triton is available for use on this system for a given device
|
||||
(if device is not None) or any available device type if no device is given.
|
||||
"""
|
||||
import torch
|
||||
from torch._dynamo.device_interface import (
|
||||
DeviceInterface,
|
||||
get_interface_for_device,
|
||||
get_registered_device_interfaces,
|
||||
)
|
||||
|
||||
if not has_triton_package():
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
def device_has_triton(di: type[DeviceInterface]) -> bool:
|
||||
if not di.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
di.raise_if_triton_unavailable(device)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if device is None:
|
||||
return any(
|
||||
device_has_triton(di) for _, di in get_registered_device_interfaces()
|
||||
)
|
||||
|
||||
if not isinstance(device, (str, torch.device)):
|
||||
device = torch.device(device)
|
||||
|
||||
return device_has_triton(get_interface_for_device(device))
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton_tma():
|
||||
|
|
@ -102,6 +61,40 @@ def has_triton_tma_device():
|
|||
return False
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def has_triton() -> bool:
|
||||
if not has_triton_package():
|
||||
return False
|
||||
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
|
||||
def cuda_extra_check(device_interface):
|
||||
return device_interface.Worker.get_device_properties().major >= 7
|
||||
|
||||
def cpu_extra_check(device_interface):
|
||||
import triton.backends
|
||||
|
||||
return "cpu" in triton.backends.backends
|
||||
|
||||
def _return_true(device_interface):
|
||||
return True
|
||||
|
||||
triton_supported_devices = {
|
||||
"cuda": cuda_extra_check,
|
||||
"xpu": _return_true,
|
||||
"cpu": cpu_extra_check,
|
||||
}
|
||||
|
||||
def is_device_compatible_with_triton():
|
||||
for device, extra_check in triton_supported_devices.items():
|
||||
device_interface = get_interface_for_device(device)
|
||||
if device_interface.is_available() and extra_check(device_interface):
|
||||
return True
|
||||
return False
|
||||
|
||||
return is_device_compatible_with_triton()
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def triton_backend():
|
||||
from triton.compiler.compiler import make_backend
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user