mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Refactor conditional triton imports into triton_compat.py (#143814)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143814 Approved by: https://github.com/Skylion007 ghstack dependencies: #143813
This commit is contained in:
parent
efac5ed81b
commit
cf76c05b4d
|
|
@ -10,7 +10,7 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
|||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from torch.utils._triton import has_triton, has_triton_package
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
from ..remote_cache import (
|
||||
create_cache,
|
||||
|
|
@ -19,14 +19,12 @@ from ..remote_cache import (
|
|||
RemoteCacheBackend,
|
||||
RemoteCacheJsonSerde,
|
||||
)
|
||||
from .triton_compat import Config
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..remote_cache import Sample
|
||||
|
||||
if has_triton_package():
|
||||
from triton import Config
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,16 +2,15 @@
|
|||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, TYPE_CHECKING
|
||||
|
||||
from .hints import TRITON_MAX_BLOCK
|
||||
from .runtime_utils import red_text, triton_config_to_hashable
|
||||
|
||||
|
||||
try:
|
||||
import triton
|
||||
except ImportError:
|
||||
triton = None
|
||||
if TYPE_CHECKING:
|
||||
from .triton_compat import triton
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
118
torch/_inductor/runtime/triton_compat.py
Normal file
118
torch/_inductor/runtime/triton_compat.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
try:
|
||||
import triton
|
||||
except ImportError:
|
||||
triton = None
|
||||
|
||||
|
||||
if triton is not None:
|
||||
import triton.language as tl
|
||||
from triton import Config
|
||||
from triton.compiler import CompiledKernel
|
||||
from triton.runtime.autotuner import OutOfResources
|
||||
from triton.runtime.jit import KernelInterface
|
||||
|
||||
try:
|
||||
from triton.runtime.autotuner import PTXASError
|
||||
except ImportError:
|
||||
|
||||
class PTXASError(Exception): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
try:
|
||||
from triton.compiler.compiler import ASTSource
|
||||
except ImportError:
|
||||
ASTSource = None
|
||||
|
||||
try:
|
||||
from triton.backends.compiler import GPUTarget
|
||||
except ImportError:
|
||||
GPUTarget = None
|
||||
|
||||
# In the latest triton, math functions were shuffled around into different modules:
|
||||
# https://github.com/openai/triton/pull/3172
|
||||
try:
|
||||
from triton.language.extra import libdevice
|
||||
|
||||
libdevice = tl.extra.libdevice # noqa: F811
|
||||
math = tl.math
|
||||
except ImportError:
|
||||
if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"):
|
||||
libdevice = tl.extra.cuda.libdevice
|
||||
math = tl.math
|
||||
elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"):
|
||||
libdevice = tl.extra.intel.libdevice
|
||||
math = tl.math
|
||||
else:
|
||||
libdevice = tl.math
|
||||
math = tl
|
||||
|
||||
try:
|
||||
from triton.language.standard import _log2
|
||||
except ImportError:
|
||||
|
||||
def _log2(x: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
|
||||
def _raise_error(*args: Any, **kwargs: Any) -> Any:
|
||||
raise RuntimeError("triton package is not installed")
|
||||
|
||||
class OutOfResources(Exception): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class PTXASError(Exception): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
Config = object
|
||||
CompiledKernel = object
|
||||
KernelInterface = object
|
||||
ASTSource = None
|
||||
GPUTarget = None
|
||||
_log2 = _raise_error
|
||||
libdevice = None
|
||||
math = None
|
||||
|
||||
class triton: # type: ignore[no-redef]
|
||||
@staticmethod
|
||||
def jit(*args: Any, **kwargs: Any) -> Any:
|
||||
return _raise_error
|
||||
|
||||
class tl: # type: ignore[no-redef]
|
||||
@staticmethod
|
||||
def constexpr(val: Any) -> Any:
|
||||
return val
|
||||
|
||||
tensor = Any
|
||||
dtype = Any
|
||||
|
||||
|
||||
try:
|
||||
autograd_profiler = torch.autograd.profiler
|
||||
except AttributeError: # Compile workers only have a mock version of torch
|
||||
|
||||
class autograd_profiler: # type: ignore[no-redef]
|
||||
_is_profiler_enabled = False
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
"CompiledKernel",
|
||||
"OutOfResources",
|
||||
"KernelInterface",
|
||||
"PTXASError",
|
||||
"ASTSource",
|
||||
"GPUTarget",
|
||||
"tl",
|
||||
"_log2",
|
||||
"libdevice",
|
||||
"math",
|
||||
"triton",
|
||||
]
|
||||
|
|
@ -2,35 +2,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import warnings
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# In the latest triton, math functions were shuffled around into different modules:
|
||||
# https://github.com/openai/triton/pull/3172
|
||||
try:
|
||||
from triton.language.extra import libdevice
|
||||
|
||||
libdevice = tl.extra.libdevice # noqa: F811
|
||||
math = tl.math
|
||||
except ImportError:
|
||||
if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"):
|
||||
libdevice = tl.extra.cuda.libdevice
|
||||
math = tl.math
|
||||
elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"):
|
||||
libdevice = tl.extra.intel.libdevice
|
||||
math = tl.math
|
||||
else:
|
||||
libdevice = tl.math
|
||||
math = tl
|
||||
|
||||
|
||||
try:
|
||||
from triton.language.standard import _log2
|
||||
except ImportError:
|
||||
|
||||
def _log2(x):
|
||||
raise NotImplementedError
|
||||
from .triton_compat import _log2, libdevice, math, tl, triton # noqa: F401
|
||||
|
||||
|
||||
def set_driver_to_cpu():
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from torch.utils._ordered_set import OrderedSet
|
|||
|
||||
from ..triton_bundler import TritonBundler
|
||||
from ..utils import prefix_is_reduction
|
||||
from . import triton_helpers
|
||||
from .autotune_cache import AutotuneCache
|
||||
from .benchmarking import benchmarker
|
||||
from .coordinate_descent_tuner import CoordescTuner
|
||||
|
|
@ -50,58 +51,17 @@ from .runtime_utils import (
|
|||
triton_hash_to_path_key,
|
||||
validate_triton_config,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import triton
|
||||
except ImportError:
|
||||
triton = None
|
||||
|
||||
if triton is not None:
|
||||
from triton import Config
|
||||
from triton.compiler import CompiledKernel
|
||||
from triton.runtime.autotuner import OutOfResources
|
||||
from triton.runtime.jit import KernelInterface
|
||||
|
||||
from . import triton_helpers
|
||||
|
||||
try:
|
||||
from triton.runtime.autotuner import PTXASError
|
||||
except ImportError:
|
||||
|
||||
class PTXASError(Exception): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
try:
|
||||
from triton.compiler.compiler import ASTSource
|
||||
except ImportError:
|
||||
ASTSource = None
|
||||
|
||||
try:
|
||||
from triton.backends.compiler import GPUTarget
|
||||
except ImportError:
|
||||
GPUTarget = None
|
||||
else:
|
||||
from types import ModuleType
|
||||
|
||||
class OutOfResources(Exception): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class PTXASError(Exception): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
Config = object
|
||||
KernelInterface = object
|
||||
ASTSource = None
|
||||
GPUTarget = None
|
||||
triton_helpers = ModuleType("triton_helpers")
|
||||
|
||||
try:
|
||||
autograd_profiler = torch.autograd.profiler
|
||||
except AttributeError: # Compile workers only have a mock version of torch
|
||||
|
||||
class autograd_profiler: # type: ignore[no-redef]
|
||||
_is_profiler_enabled = False
|
||||
from .triton_compat import (
|
||||
ASTSource,
|
||||
autograd_profiler,
|
||||
CompiledKernel,
|
||||
Config,
|
||||
GPUTarget,
|
||||
KernelInterface,
|
||||
OutOfResources,
|
||||
PTXASError,
|
||||
triton,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user