import re from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional from enum import Enum import contextlib import textwrap # Many of these functions share logic for defining both the definition # and declaration (for example, the function signature is the same), so # we organize them into one function that takes a Target to say which # code we want. # # This is an OPEN enum (we may add more cases to it in the future), so be sure # to explicitly specify with Union[Literal[Target.XXX]] what targets are valid # for your use. Target = Enum('Target', ( # top level namespace (not including at) 'DEFINITION', 'DECLARATION', # TORCH_LIBRARY(...) { ... } 'REGISTRATION', # namespace { ... } 'ANONYMOUS_DEFINITION', # namespace cpu { ... } 'NAMESPACED_DEFINITION', 'NAMESPACED_DECLARATION', )) # Matches "foo" in "foo, bar" but not "foobar". Used to search for the # occurrence of a parameter in the derivative formula IDENT_REGEX = r'(^|\W){}($|\W)' # TODO: Use a real parser here; this will get bamboozled def split_name_params(schema: str) -> Tuple[str, List[str]]: m = re.match(r'(\w+)(\.\w+)?\((.*)\)', schema) if m is None: raise RuntimeError(f'Unsupported function schema: {schema}') name, _, params = m.groups() return name, params.split(', ') T = TypeVar('T') S = TypeVar('S') # These two functions purposely return generators in analogy to map() # so that you don't mix up when you need to list() them # Map over function that may return None; omit Nones from output sequence def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]: for x in xs: r = func(x) if r is not None: yield r # Map over function that returns sequences and cat them all together def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]: for x in xs: for r in func(x): yield r # Conveniently add error context to exceptions raised. Lets us # easily say that an error occurred while processing a specific # context. @contextlib.contextmanager def context(msg: str) -> Iterator[None]: try: yield except Exception as e: # TODO: this does the wrong thing with KeyError msg = textwrap.indent(msg, ' ') msg = f'{e.args[0]}\n{msg}' if e.args else msg e.args = (msg,) + e.args[1:] raise