mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93389 Approved by: https://github.com/voznesenskym
82 lines
2.2 KiB
Python
82 lines
2.2 KiB
Python
import functools
|
|
from typing import Callable, Dict, List, Optional, Sequence, Tuple
|
|
|
|
from typing_extensions import Protocol
|
|
|
|
import torch
|
|
from torch import fx
|
|
|
|
|
|
class CompiledFn(Protocol):
|
|
def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
|
...
|
|
|
|
|
|
CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]
|
|
|
|
_BACKENDS: Dict[str, CompilerFn] = dict()
|
|
|
|
|
|
def register_backend(
|
|
compiler_fn: Optional[CompilerFn] = None,
|
|
name: Optional[str] = None,
|
|
tags: Sequence[str] = (),
|
|
):
|
|
"""
|
|
Decorator to add a given compiler to the registry to allow calling
|
|
`torch.compile` with string shorthand. Note: for projects not
|
|
imported by default, it might be easier to pass a function directly
|
|
as a backend and not use a string.
|
|
|
|
Args:
|
|
compiler_fn: Callable taking a FX graph and fake tensor inputs
|
|
name: Optional name, defaults to `compiler_fn.__name__`
|
|
tags: Optional set of string tags to categorize backend with
|
|
"""
|
|
if compiler_fn is None:
|
|
# @register_backend(name="") syntax
|
|
return functools.partial(register_backend, name=name, tags=tags)
|
|
assert callable(compiler_fn)
|
|
name = name or compiler_fn.__name__
|
|
assert name not in _BACKENDS, f"duplicate name: {name}"
|
|
_BACKENDS[name] = compiler_fn
|
|
compiler_fn._tags = tuple(tags)
|
|
return compiler_fn
|
|
|
|
|
|
def lookup_backend(compiler_fn):
|
|
"""Expand backend strings to functions"""
|
|
if isinstance(compiler_fn, str):
|
|
if compiler_fn not in _BACKENDS:
|
|
_lazy_import()
|
|
compiler_fn = _BACKENDS[compiler_fn]
|
|
return compiler_fn
|
|
|
|
|
|
def list_backends():
|
|
"""
|
|
Return valid strings that can be passed to:
|
|
|
|
torch.compile(..., backend="name")
|
|
"""
|
|
_lazy_import()
|
|
return sorted(_BACKENDS.keys())
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def _lazy_import():
|
|
from .. import backends
|
|
from ..utils import import_submodule
|
|
|
|
import_submodule(backends)
|
|
|
|
# TODO(jansel): refactor backends defined in other places
|
|
from .. import debug_utils
|
|
from ..optimizations import backends, distributed, training
|
|
|
|
training.create_aot_backends()
|
|
# avoid unused import lint
|
|
assert backends is not None
|
|
assert distributed is not None
|
|
assert debug_utils is not None
|