import contextlib import functools import hashlib import os import re import sys import textwrap from argparse import Namespace from dataclasses import fields, is_dataclass from enum import Enum from typing import ( Any, Callable, Dict, Generic, Iterable, Iterator, List, NoReturn, Optional, Sequence, Set, Tuple, TypeVar, Union, ) from typing_extensions import Literal from torchgen.code_template import CodeTemplate # Safely load fast C Yaml loader/dumper if they are available try: from yaml import CSafeLoader as Loader except ImportError: from yaml import SafeLoader as Loader # type: ignore[misc] try: from yaml import CSafeDumper as Dumper except ImportError: from yaml import SafeDumper as Dumper # type: ignore[misc] YamlDumper = Dumper # A custom loader for YAML that errors on duplicate keys. # This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165 class YamlLoader(Loader): def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] mapping = [] for key_node, value_node in node.value: key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call] assert ( key not in mapping ), f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}" mapping.append(key) mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call] return mapping # Many of these functions share logic for defining both the definition # and declaration (for example, the function signature is the same), so # we organize them into one function that takes a Target to say which # code we want. # # This is an OPEN enum (we may add more cases to it in the future), so be sure # to explicitly specify with Union[Literal[Target.XXX]] what targets are valid # for your use. Target = Enum( "Target", ( # top level namespace (not including at) "DEFINITION", "DECLARATION", # TORCH_LIBRARY(...) { ... } "REGISTRATION", # namespace { ... } "ANONYMOUS_DEFINITION", # namespace cpu { ... } "NAMESPACED_DEFINITION", "NAMESPACED_DECLARATION", ), ) # Matches "foo" in "foo, bar" but not "foobar". Used to search for the # occurrence of a parameter in the derivative formula IDENT_REGEX = r"(^|\W){}($|\W)" # TODO: Use a real parser here; this will get bamboozled def split_name_params(schema: str) -> Tuple[str, List[str]]: m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema) if m is None: raise RuntimeError(f"Unsupported function schema: {schema}") name, _, params = m.groups() return name, params.split(", ") T = TypeVar("T") S = TypeVar("S") # These two functions purposely return generators in analogy to map() # so that you don't mix up when you need to list() them # Map over function that may return None; omit Nones from output sequence def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]: for x in xs: r = func(x) if r is not None: yield r # Map over function that returns sequences and cat them all together def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]: for x in xs: for r in func(x): yield r # Conveniently add error context to exceptions raised. Lets us # easily say that an error occurred while processing a specific # context. @contextlib.contextmanager def context(msg_fn: Callable[[], str]) -> Iterator[None]: try: yield except Exception as e: # TODO: this does the wrong thing with KeyError msg = msg_fn() msg = textwrap.indent(msg, " ") msg = f"{e.args[0]}\n{msg}" if e.args else msg e.args = (msg,) + e.args[1:] raise # A little trick from https://github.com/python/mypy/issues/6366 # for getting mypy to do exhaustiveness checking # TODO: put this somewhere else, maybe def assert_never(x: NoReturn) -> NoReturn: raise AssertionError("Unhandled type: {}".format(type(x).__name__)) @functools.lru_cache(maxsize=None) def _read_template(template_fn: str) -> CodeTemplate: return CodeTemplate.from_file(template_fn) # String hash that's stable across different executions, unlike builtin hash def string_stable_hash(s: str) -> int: sha1 = hashlib.sha1(s.encode("latin1")).digest() return int.from_bytes(sha1, byteorder="little") # A small abstraction for writing out generated files and keeping track # of what files have been written (so you can write out a list of output # files) class FileManager: install_dir: str template_dir: str dry_run: bool filenames: Set[str] def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None: self.install_dir = install_dir self.template_dir = template_dir self.filenames = set() self.dry_run = dry_run def _write_if_changed(self, filename: str, contents: str) -> None: old_contents: Optional[str] try: with open(filename, "r") as f: old_contents = f.read() except IOError: old_contents = None if contents != old_contents: # Create output directory if it doesn't exist os.makedirs(os.path.dirname(filename), exist_ok=True) with open(filename, "w") as f: f.write(contents) # Read from template file and replace pattern with callable (type could be dict or str). def substitute_with_template( self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]] ) -> str: template_path = os.path.join(self.template_dir, template_fn) env = env_callable() if isinstance(env, dict): # TODO: Update the comment reference to the correct location if "generated_comment" not in env: comment = "@" + "generated by torchgen/gen.py" comment += " from {}".format(os.path.basename(template_path)) env["generated_comment"] = comment template = _read_template(template_path) return template.substitute(env) elif isinstance(env, str): return env else: assert_never(env) def write_with_template( self, filename: str, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]], ) -> None: filename = "{}/{}".format(self.install_dir, filename) assert filename not in self.filenames, "duplicate file write {filename}" self.filenames.add(filename) if not self.dry_run: substitute_out = self.substitute_with_template( template_fn=template_fn, env_callable=env_callable, ) self._write_if_changed(filename=filename, contents=substitute_out) def write( self, filename: str, env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]], ) -> None: self.write_with_template(filename, filename, env_callable) def write_sharded( self, filename: str, items: Iterable[T], *, key_fn: Callable[[T], str], env_callable: Callable[[T], Dict[str, List[str]]], num_shards: int, base_env: Optional[Dict[str, Any]] = None, sharded_keys: Set[str], ) -> None: everything: Dict[str, Any] = {"shard_id": "Everything"} shards: List[Dict[str, Any]] = [ {"shard_id": f"_{i}"} for i in range(num_shards) ] all_shards = [everything] + shards if base_env is not None: for shard in all_shards: shard.update(base_env) for key in sharded_keys: for shard in all_shards: if key in shard: assert isinstance( shard[key], list ), "sharded keys in base_env must be a list" shard[key] = shard[key].copy() else: shard[key] = [] def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None: for k, v in from_.items(): assert k in sharded_keys, f"undeclared sharded key {k}" into[k] += v if self.dry_run: # Dry runs don't write any templates, so incomplete environments are fine items = () for item in items: key = key_fn(item) sid = string_stable_hash(key) % num_shards env = env_callable(item) merge_env(shards[sid], env) merge_env(everything, env) dot_pos = filename.rfind(".") if dot_pos == -1: dot_pos = len(filename) base_filename = filename[:dot_pos] extension = filename[dot_pos:] for shard in all_shards: shard_id = shard["shard_id"] self.write_with_template( f"{base_filename}{shard_id}{extension}", filename, lambda: shard ) # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled self.filenames.discard( f"{self.install_dir}/{base_filename}Everything{extension}" ) def write_outputs(self, variable_name: str, filename: str) -> None: """Write a file containing the list of all outputs which are generated by this script.""" content = "set({}\n {})".format( variable_name, "\n ".join('"' + name + '"' for name in sorted(self.filenames)), ) self._write_if_changed(filename, content) # Helper function to generate file manager def make_file_manager( options: Namespace, install_dir: Optional[str] = None ) -> FileManager: template_dir = os.path.join(options.source_path, "templates") install_dir = install_dir if install_dir else options.install_dir return FileManager( install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run ) # Helper function to create a pretty representation for dataclasses def dataclass_repr( obj: Any, indent: int = 0, width: int = 80, ) -> str: # built-in pprint module support dataclasses from python 3.10 if sys.version_info >= (3, 10): from pprint import pformat return pformat(obj, indent, width) return _pformat(obj, indent=indent, width=width) def _pformat( obj: Any, indent: int, width: int, curr_indent: int = 0, ) -> str: assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}" class_name = obj.__class__.__name__ # update current indentation level with class name curr_indent += len(class_name) + 1 fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr] fields_str = [] for name, attr in fields_list: # update the current indent level with the field name # dict, list, set and tuple also add indent as done in pprint _curr_indent = curr_indent + len(name) + 1 if is_dataclass(attr): str_repr = _pformat(attr, indent, width, _curr_indent) elif isinstance(attr, dict): str_repr = _format_dict(attr, indent, width, _curr_indent) elif isinstance(attr, (list, set, tuple)): str_repr = _format_list(attr, indent, width, _curr_indent) else: str_repr = repr(attr) fields_str.append(f"{name}={str_repr}") indent_str = curr_indent * " " body = f",\n{indent_str}".join(fields_str) return f"{class_name}({body})" def _format_dict( attr: Dict[Any, Any], indent: int, width: int, curr_indent: int, ) -> str: curr_indent += indent + 3 dict_repr = [] for k, v in attr.items(): k_repr = repr(k) v_str = ( _pformat(v, indent, width, curr_indent + len(k_repr)) if is_dataclass(v) else repr(v) ) dict_repr.append(f"{k_repr}: {v_str}") return _format(dict_repr, indent, width, curr_indent, "{", "}") def _format_list( attr: Union[List[Any], Set[Any], Tuple[Any, ...]], indent: int, width: int, curr_indent: int, ) -> str: curr_indent += indent + 1 list_repr = [ _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l) for l in attr ] start, end = ("[", "]") if isinstance(attr, list) else ("(", ")") return _format(list_repr, indent, width, curr_indent, start, end) def _format( fields_str: List[str], indent: int, width: int, curr_indent: int, start: str, end: str, ) -> str: delimiter, curr_indent_str = "", "" # if it exceed the max width then we place one element per line if len(repr(fields_str)) >= width: delimiter = "\n" curr_indent_str = " " * curr_indent indent_str = " " * indent body = f", {delimiter}{curr_indent_str}".join(fields_str) return f"{start}{indent_str}{body}{end}" class NamespaceHelper: """A helper for constructing the namespace open and close strings for a nested set of namespaces. e.g. for namespace_str torch::lazy, prologue: namespace torch { namespace lazy { epilogue: } // namespace lazy } // namespace torch """ def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2): # cpp_namespace can be a colon joined string such as torch::lazy cpp_namespaces = namespace_str.split("::") assert ( len(cpp_namespaces) <= max_level ), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}." self.cpp_namespace_ = namespace_str self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces]) self.epilogue_ = "\n".join( [f"}} // namespace {n}" for n in reversed(cpp_namespaces)] ) self.namespaces_ = cpp_namespaces self.entity_name_ = entity_name @staticmethod def from_namespaced_entity( namespaced_entity: str, max_level: int = 2 ) -> "NamespaceHelper": """ Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add" """ names = namespaced_entity.split("::") entity_name = names[-1] namespace_str = "::".join(names[:-1]) return NamespaceHelper( namespace_str=namespace_str, entity_name=entity_name, max_level=max_level ) @property def prologue(self) -> str: return self.prologue_ @property def epilogue(self) -> str: return self.epilogue_ @property def entity_name(self) -> str: return self.entity_name_ # Only allow certain level of namespaces def get_cpp_namespace(self, default: str = "") -> str: """ Return the namespace string from joining all the namespaces by "::" (hence no leading "::"). Return default if namespace string is empty. """ return self.cpp_namespace_ if self.cpp_namespace_ else default class OrderedSet(Generic[T]): storage: Dict[T, Literal[None]] def __init__(self, iterable: Optional[Iterable[T]] = None): if iterable is None: self.storage = {} else: self.storage = {k: None for k in iterable} def __contains__(self, item: T) -> bool: return item in self.storage def __iter__(self) -> Iterator[T]: return iter(self.storage.keys()) def update(self, items: "OrderedSet[T]") -> None: self.storage.update(items.storage) def add(self, item: T) -> None: self.storage[item] = None def copy(self) -> "OrderedSet[T]": ret: OrderedSet[T] = OrderedSet() ret.storage = self.storage.copy() return ret @staticmethod def union(*args: "OrderedSet[T]") -> "OrderedSet[T]": ret = args[0].copy() for s in args[1:]: ret.update(s) return ret def __or__(self, other: "OrderedSet[T]") -> "OrderedSet[T]": return OrderedSet.union(self, other) def __ior__(self, other: "OrderedSet[T]") -> "OrderedSet[T]": self.update(other) return self def __eq__(self, other: object) -> bool: if isinstance(other, OrderedSet): return self.storage == other.storage else: return set(self.storage.keys()) == other