[inductor] Refactor conditional triton imports into triton_compat.py (#143814)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143814
Approved by: https://github.com/Skylion007
ghstack dependencies: #143813
This commit is contained in:
Jason Ansel 2024-12-25 07:21:57 -08:00 committed by PyTorch MergeBot
parent efac5ed81b
commit cf76c05b4d
5 changed files with 137 additions and 90 deletions

View File

@ -10,7 +10,7 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from typing_extensions import override
import torch
from torch.utils._triton import has_triton, has_triton_package
from torch.utils._triton import has_triton
from ..remote_cache import (
create_cache,
@ -19,14 +19,12 @@ from ..remote_cache import (
RemoteCacheBackend,
RemoteCacheJsonSerde,
)
from .triton_compat import Config
if TYPE_CHECKING:
from ..remote_cache import Sample
if has_triton_package():
from triton import Config
log = logging.getLogger(__name__)

View File

@ -2,16 +2,15 @@
import copy
import itertools
import logging
from typing import Callable, Optional
from typing import Callable, Optional, TYPE_CHECKING
from .hints import TRITON_MAX_BLOCK
from .runtime_utils import red_text, triton_config_to_hashable
try:
import triton
except ImportError:
triton = None
if TYPE_CHECKING:
from .triton_compat import triton
log = logging.getLogger(__name__)

View File

@ -0,0 +1,118 @@
from __future__ import annotations
from typing import Any
import torch
try:
import triton
except ImportError:
triton = None
if triton is not None:
import triton.language as tl
from triton import Config
from triton.compiler import CompiledKernel
from triton.runtime.autotuner import OutOfResources
from triton.runtime.jit import KernelInterface
try:
from triton.runtime.autotuner import PTXASError
except ImportError:
class PTXASError(Exception): # type: ignore[no-redef]
pass
try:
from triton.compiler.compiler import ASTSource
except ImportError:
ASTSource = None
try:
from triton.backends.compiler import GPUTarget
except ImportError:
GPUTarget = None
# In the latest triton, math functions were shuffled around into different modules:
# https://github.com/openai/triton/pull/3172
try:
from triton.language.extra import libdevice
libdevice = tl.extra.libdevice # noqa: F811
math = tl.math
except ImportError:
if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"):
libdevice = tl.extra.cuda.libdevice
math = tl.math
elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"):
libdevice = tl.extra.intel.libdevice
math = tl.math
else:
libdevice = tl.math
math = tl
try:
from triton.language.standard import _log2
except ImportError:
def _log2(x: Any) -> Any:
raise NotImplementedError
else:
def _raise_error(*args: Any, **kwargs: Any) -> Any:
raise RuntimeError("triton package is not installed")
class OutOfResources(Exception): # type: ignore[no-redef]
pass
class PTXASError(Exception): # type: ignore[no-redef]
pass
Config = object
CompiledKernel = object
KernelInterface = object
ASTSource = None
GPUTarget = None
_log2 = _raise_error
libdevice = None
math = None
class triton: # type: ignore[no-redef]
@staticmethod
def jit(*args: Any, **kwargs: Any) -> Any:
return _raise_error
class tl: # type: ignore[no-redef]
@staticmethod
def constexpr(val: Any) -> Any:
return val
tensor = Any
dtype = Any
try:
autograd_profiler = torch.autograd.profiler
except AttributeError: # Compile workers only have a mock version of torch
class autograd_profiler: # type: ignore[no-redef]
_is_profiler_enabled = False
__all__ = [
"Config",
"CompiledKernel",
"OutOfResources",
"KernelInterface",
"PTXASError",
"ASTSource",
"GPUTarget",
"tl",
"_log2",
"libdevice",
"math",
"triton",
]

View File

@ -2,35 +2,7 @@
# mypy: allow-untyped-defs
import warnings
import triton
import triton.language as tl
# In the latest triton, math functions were shuffled around into different modules:
# https://github.com/openai/triton/pull/3172
try:
from triton.language.extra import libdevice
libdevice = tl.extra.libdevice # noqa: F811
math = tl.math
except ImportError:
if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"):
libdevice = tl.extra.cuda.libdevice
math = tl.math
elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"):
libdevice = tl.extra.intel.libdevice
math = tl.math
else:
libdevice = tl.math
math = tl
try:
from triton.language.standard import _log2
except ImportError:
def _log2(x):
raise NotImplementedError
from .triton_compat import _log2, libdevice, math, tl, triton # noqa: F401
def set_driver_to_cpu():

View File

@ -23,6 +23,7 @@ from torch.utils._ordered_set import OrderedSet
from ..triton_bundler import TritonBundler
from ..utils import prefix_is_reduction
from . import triton_helpers
from .autotune_cache import AutotuneCache
from .benchmarking import benchmarker
from .coordinate_descent_tuner import CoordescTuner
@ -50,58 +51,17 @@ from .runtime_utils import (
triton_hash_to_path_key,
validate_triton_config,
)
try:
import triton
except ImportError:
triton = None
if triton is not None:
from triton import Config
from triton.compiler import CompiledKernel
from triton.runtime.autotuner import OutOfResources
from triton.runtime.jit import KernelInterface
from . import triton_helpers
try:
from triton.runtime.autotuner import PTXASError
except ImportError:
class PTXASError(Exception): # type: ignore[no-redef]
pass
try:
from triton.compiler.compiler import ASTSource
except ImportError:
ASTSource = None
try:
from triton.backends.compiler import GPUTarget
except ImportError:
GPUTarget = None
else:
from types import ModuleType
class OutOfResources(Exception): # type: ignore[no-redef]
pass
class PTXASError(Exception): # type: ignore[no-redef]
pass
Config = object
KernelInterface = object
ASTSource = None
GPUTarget = None
triton_helpers = ModuleType("triton_helpers")
try:
autograd_profiler = torch.autograd.profiler
except AttributeError: # Compile workers only have a mock version of torch
class autograd_profiler: # type: ignore[no-redef]
_is_profiler_enabled = False
from .triton_compat import (
ASTSource,
autograd_profiler,
CompiledKernel,
Config,
GPUTarget,
KernelInterface,
OutOfResources,
PTXASError,
triton,
)
log = logging.getLogger(__name__)