mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66238 The codegen should error if it sees two yaml entries with the same key. The default behavior of python's yaml loader is to overwrite duplicate keys with the new value. This would have caught a nasty bug that showed up in https://github.com/pytorch/pytorch/pull/66225/files#r723796194. I tested it on that linked PR, to confirm that it errors correctly (and gives the line number containing the duplicate). Test Plan: Imported from OSS Reviewed By: dagitses, albanD, sean-ngo Differential Revision: D31464585 Pulled By: bdhirsh fbshipit-source-id: 5b35157ffa9a933bf4b344c4b9fe2878698370a3
97 lines
3.4 KiB
Python
97 lines
3.4 KiB
Python
import re
|
|
from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional
|
|
from enum import Enum
|
|
import contextlib
|
|
import textwrap
|
|
|
|
# Safely load fast C Yaml loader/dumper if they are available
|
|
try:
|
|
from yaml import CSafeLoader as Loader
|
|
except ImportError:
|
|
from yaml import SafeLoader as Loader # type: ignore[misc]
|
|
|
|
try:
|
|
from yaml import CSafeDumper as Dumper
|
|
except ImportError:
|
|
from yaml import SafeDumper as Dumper # type: ignore[misc]
|
|
YamlDumper = Dumper
|
|
|
|
# A custom loader for YAML that errors on duplicate keys.
|
|
# This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165
|
|
class YamlLoader(Loader):
|
|
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
|
|
mapping = []
|
|
for key_node, value_node in node.value:
|
|
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
|
|
assert key not in mapping, f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
|
|
mapping.append(key)
|
|
mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
|
|
return mapping
|
|
|
|
# Many of these functions share logic for defining both the definition
|
|
# and declaration (for example, the function signature is the same), so
|
|
# we organize them into one function that takes a Target to say which
|
|
# code we want.
|
|
#
|
|
# This is an OPEN enum (we may add more cases to it in the future), so be sure
|
|
# to explicitly specify with Union[Literal[Target.XXX]] what targets are valid
|
|
# for your use.
|
|
Target = Enum('Target', (
|
|
# top level namespace (not including at)
|
|
'DEFINITION',
|
|
'DECLARATION',
|
|
# TORCH_LIBRARY(...) { ... }
|
|
'REGISTRATION',
|
|
# namespace { ... }
|
|
'ANONYMOUS_DEFINITION',
|
|
# namespace cpu { ... }
|
|
'NAMESPACED_DEFINITION',
|
|
'NAMESPACED_DECLARATION',
|
|
))
|
|
|
|
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
|
|
# occurrence of a parameter in the derivative formula
|
|
IDENT_REGEX = r'(^|\W){}($|\W)'
|
|
|
|
# TODO: Use a real parser here; this will get bamboozled
|
|
def split_name_params(schema: str) -> Tuple[str, List[str]]:
|
|
m = re.match(r'(\w+)(\.\w+)?\((.*)\)', schema)
|
|
if m is None:
|
|
raise RuntimeError(f'Unsupported function schema: {schema}')
|
|
name, _, params = m.groups()
|
|
return name, params.split(', ')
|
|
|
|
T = TypeVar('T')
|
|
S = TypeVar('S')
|
|
|
|
# These two functions purposely return generators in analogy to map()
|
|
# so that you don't mix up when you need to list() them
|
|
|
|
# Map over function that may return None; omit Nones from output sequence
|
|
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
for x in xs:
|
|
r = func(x)
|
|
if r is not None:
|
|
yield r
|
|
|
|
# Map over function that returns sequences and cat them all together
|
|
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
for x in xs:
|
|
for r in func(x):
|
|
yield r
|
|
|
|
# Conveniently add error context to exceptions raised. Lets us
|
|
# easily say that an error occurred while processing a specific
|
|
# context.
|
|
@contextlib.contextmanager
|
|
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
|
|
try:
|
|
yield
|
|
except Exception as e:
|
|
# TODO: this does the wrong thing with KeyError
|
|
msg = msg_fn()
|
|
msg = textwrap.indent(msg, ' ')
|
|
msg = f'{e.args[0]}\n{msg}' if e.args else msg
|
|
e.args = (msg,) + e.args[1:]
|
|
raise
|