mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51585 Some payoff from the stack of refactors. When I initially landed at::cpu, Brian asked me why I couldn't just separate the anonymous and namespaced definitions. Well, it used to be annoying. Now it's not annoying anymore, so go ahead and split them up. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D26209873 Pulled By: ezyang fbshipit-source-id: 63057d22acfaa0c17229947d9e65ec1193e360ec
72 lines
2.4 KiB
Python
72 lines
2.4 KiB
Python
import re
|
|
from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional
|
|
from enum import Enum
|
|
import contextlib
|
|
import textwrap
|
|
|
|
# Many of these functions share logic for defining both the definition
|
|
# and declaration (for example, the function signature is the same), so
|
|
# we organize them into one function that takes a Target to say which
|
|
# code we want.
|
|
#
|
|
# This is an OPEN enum (we may add more cases to it in the future), so be sure
|
|
# to explicitly specify with Union[Literal[Target.XXX]] what targets are valid
|
|
# for your use.
|
|
Target = Enum('Target', (
|
|
# top level namespace (not including at)
|
|
'DEFINITION',
|
|
'DECLARATION',
|
|
# TORCH_LIBRARY(...) { ... }
|
|
'REGISTRATION',
|
|
# namespace { ... }
|
|
'ANONYMOUS_DEFINITION',
|
|
# namespace cpu { ... }
|
|
'NAMESPACED_DEFINITION',
|
|
'NAMESPACED_DECLARATION',
|
|
))
|
|
|
|
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
|
|
# occurrence of a parameter in the derivative formula
|
|
IDENT_REGEX = r'(^|\W){}($|\W)'
|
|
|
|
# TODO: Use a real parser here; this will get bamboozled
|
|
def split_name_params(schema: str) -> Tuple[str, List[str]]:
|
|
m = re.match(r'(\w+)(\.\w+)?\((.*)\)', schema)
|
|
if m is None:
|
|
raise RuntimeError(f'Unsupported function schema: {schema}')
|
|
name, _, params = m.groups()
|
|
return name, params.split(', ')
|
|
|
|
T = TypeVar('T')
|
|
S = TypeVar('S')
|
|
|
|
# These two functions purposely return generators in analogy to map()
|
|
# so that you don't mix up when you need to list() them
|
|
|
|
# Map over function that may return None; omit Nones from output sequence
|
|
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
for x in xs:
|
|
r = func(x)
|
|
if r is not None:
|
|
yield r
|
|
|
|
# Map over function that returns sequences and cat them all together
|
|
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
for x in xs:
|
|
for r in func(x):
|
|
yield r
|
|
|
|
# Conveniently add error context to exceptions raised. Lets us
|
|
# easily say that an error occurred while processing a specific
|
|
# context.
|
|
@contextlib.contextmanager
|
|
def context(msg: str) -> Iterator[None]:
|
|
try:
|
|
yield
|
|
except Exception as e:
|
|
# TODO: this does the wrong thing with KeyError
|
|
msg = textwrap.indent(msg, ' ')
|
|
msg = f'{e.args[0]}\n{msg}' if e.args else msg
|
|
e.args = (msg,) + e.args[1:]
|
|
raise
|