mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68246 Currently the codegen produces a list of output files at CMake configuration time and the build system has no way of knowing if the outputs change. So if that happens, you basically need to delete the build folder and re-run from scratch. Instead, this generates the output list every time the code generation is run and changes the output to be a `.cmake` file that gets included in the main cmake configuration step. That means the build system knows to re-run cmake automatically if a new output is added. So, for example you could change the number of shards that `Operators.cpp` is split into and it all just works transparently to the user. Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D32596268 Pulled By: albanD fbshipit-source-id: 15e0896aeaead90aed64b9c8fda70cf28fef13a2
233 lines
8.6 KiB
Python
233 lines
8.6 KiB
Python
import re
|
|
import os
|
|
from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional, Dict, Any, Union, Set, NoReturn
|
|
from enum import Enum
|
|
import contextlib
|
|
import textwrap
|
|
import hashlib
|
|
import functools
|
|
|
|
from tools.codegen.code_template import CodeTemplate
|
|
|
|
# Safely load fast C Yaml loader/dumper if they are available
|
|
try:
|
|
from yaml import CSafeLoader as Loader
|
|
except ImportError:
|
|
from yaml import SafeLoader as Loader # type: ignore[misc]
|
|
|
|
try:
|
|
from yaml import CSafeDumper as Dumper
|
|
except ImportError:
|
|
from yaml import SafeDumper as Dumper # type: ignore[misc]
|
|
YamlDumper = Dumper
|
|
|
|
# A custom loader for YAML that errors on duplicate keys.
|
|
# This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165
|
|
class YamlLoader(Loader):
|
|
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
|
|
mapping = []
|
|
for key_node, value_node in node.value:
|
|
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
|
|
assert key not in mapping, f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
|
|
mapping.append(key)
|
|
mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
|
|
return mapping
|
|
|
|
# Many of these functions share logic for defining both the definition
|
|
# and declaration (for example, the function signature is the same), so
|
|
# we organize them into one function that takes a Target to say which
|
|
# code we want.
|
|
#
|
|
# This is an OPEN enum (we may add more cases to it in the future), so be sure
|
|
# to explicitly specify with Union[Literal[Target.XXX]] what targets are valid
|
|
# for your use.
|
|
Target = Enum('Target', (
|
|
# top level namespace (not including at)
|
|
'DEFINITION',
|
|
'DECLARATION',
|
|
# TORCH_LIBRARY(...) { ... }
|
|
'REGISTRATION',
|
|
# namespace { ... }
|
|
'ANONYMOUS_DEFINITION',
|
|
# namespace cpu { ... }
|
|
'NAMESPACED_DEFINITION',
|
|
'NAMESPACED_DECLARATION',
|
|
))
|
|
|
|
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
|
|
# occurrence of a parameter in the derivative formula
|
|
IDENT_REGEX = r'(^|\W){}($|\W)'
|
|
|
|
# TODO: Use a real parser here; this will get bamboozled
|
|
def split_name_params(schema: str) -> Tuple[str, List[str]]:
|
|
m = re.match(r'(\w+)(\.\w+)?\((.*)\)', schema)
|
|
if m is None:
|
|
raise RuntimeError(f'Unsupported function schema: {schema}')
|
|
name, _, params = m.groups()
|
|
return name, params.split(', ')
|
|
|
|
T = TypeVar('T')
|
|
S = TypeVar('S')
|
|
|
|
# These two functions purposely return generators in analogy to map()
|
|
# so that you don't mix up when you need to list() them
|
|
|
|
# Map over function that may return None; omit Nones from output sequence
|
|
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
for x in xs:
|
|
r = func(x)
|
|
if r is not None:
|
|
yield r
|
|
|
|
# Map over function that returns sequences and cat them all together
|
|
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
for x in xs:
|
|
for r in func(x):
|
|
yield r
|
|
|
|
# Conveniently add error context to exceptions raised. Lets us
|
|
# easily say that an error occurred while processing a specific
|
|
# context.
|
|
@contextlib.contextmanager
|
|
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
|
|
try:
|
|
yield
|
|
except Exception as e:
|
|
# TODO: this does the wrong thing with KeyError
|
|
msg = msg_fn()
|
|
msg = textwrap.indent(msg, ' ')
|
|
msg = f'{e.args[0]}\n{msg}' if e.args else msg
|
|
e.args = (msg,) + e.args[1:]
|
|
raise
|
|
|
|
# A little trick from https://github.com/python/mypy/issues/6366
|
|
# for getting mypy to do exhaustiveness checking
|
|
# TODO: put this somewhere else, maybe
|
|
def assert_never(x: NoReturn) -> NoReturn:
|
|
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def _read_template(template_fn: str) -> CodeTemplate:
|
|
return CodeTemplate.from_file(template_fn)
|
|
|
|
|
|
# String hash that's stable across different executions, unlike builtin hash
|
|
def string_stable_hash(s: str) -> int:
|
|
sha1 = hashlib.sha1(s.encode('latin1')).digest()
|
|
return int.from_bytes(sha1, byteorder='little')
|
|
|
|
# A small abstraction for writing out generated files and keeping track
|
|
# of what files have been written (so you can write out a list of output
|
|
# files)
|
|
class FileManager:
|
|
install_dir: str
|
|
template_dir: str
|
|
dry_run: bool
|
|
filenames: Set[str]
|
|
|
|
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
|
|
self.install_dir = install_dir
|
|
self.template_dir = template_dir
|
|
self.filenames = set()
|
|
self.dry_run = dry_run
|
|
|
|
def _write_if_changed(self, filename: str, contents: str) -> None:
|
|
old_contents: Optional[str]
|
|
try:
|
|
with open(filename, 'r') as f:
|
|
old_contents = f.read()
|
|
except IOError:
|
|
old_contents = None
|
|
if contents != old_contents:
|
|
with open(filename, 'w') as f:
|
|
f.write(contents)
|
|
|
|
def write_with_template(self, filename: str, template_fn: str,
|
|
env_callable: Callable[[], Union[str, Dict[str, Any]]]) -> None:
|
|
filename = '{}/{}'.format(self.install_dir, filename)
|
|
assert filename not in self.filenames, "duplicate file write {filename}"
|
|
self.filenames.add(filename)
|
|
if not self.dry_run:
|
|
env = env_callable()
|
|
if isinstance(env, dict):
|
|
# TODO: Update the comment reference to the correct location
|
|
if 'generated_comment' not in env:
|
|
comment = "@" + "generated by tools/codegen/gen.py"
|
|
comment += " from {}".format(os.path.basename(template_fn))
|
|
env['generated_comment'] = comment
|
|
template = _read_template(os.path.join(self.template_dir, template_fn))
|
|
self._write_if_changed(filename, template.substitute(env))
|
|
elif isinstance(env, str):
|
|
self._write_if_changed(filename, env)
|
|
else:
|
|
assert_never(env)
|
|
|
|
|
|
def write(self, filename: str, env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]]) -> None:
|
|
self.write_with_template(filename, filename, env_callable)
|
|
|
|
def write_sharded(
|
|
self,
|
|
filename: str,
|
|
items: Iterable[T],
|
|
*,
|
|
key_fn: Callable[[T], str],
|
|
env_callable: Callable[[T], Dict[str, List[str]]],
|
|
num_shards: int,
|
|
base_env: Optional[Dict[str, Any]] = None,
|
|
sharded_keys: Set[str]
|
|
) -> None:
|
|
|
|
everything: Dict[str, Any] = {'shard_id': 'Everything'}
|
|
shards: List[Dict[str, Any]] = [{'shard_id': f'_{i}'} for i in range(num_shards)]
|
|
all_shards = [everything] + shards
|
|
|
|
if base_env is not None:
|
|
for shard in all_shards:
|
|
shard.update(base_env)
|
|
|
|
for key in sharded_keys:
|
|
for shard in all_shards:
|
|
if key in shard:
|
|
assert isinstance(shard[key], list), "sharded keys in base_env must be a list"
|
|
shard[key] = shard[key].copy()
|
|
else:
|
|
shard[key] = []
|
|
|
|
|
|
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
|
|
for k, v in from_.items():
|
|
assert k in sharded_keys, f"undeclared sharded key {k}"
|
|
into[k] += v
|
|
|
|
for item in items:
|
|
key = key_fn(item)
|
|
sid = string_stable_hash(key) % num_shards
|
|
env = env_callable(item)
|
|
|
|
merge_env(shards[sid], env)
|
|
merge_env(everything, env)
|
|
|
|
dot_pos = filename.rfind('.')
|
|
if dot_pos == -1:
|
|
dot_pos = len(filename)
|
|
base_filename = filename[:dot_pos]
|
|
extension = filename[dot_pos:]
|
|
|
|
for shard in all_shards:
|
|
shard_id = shard['shard_id']
|
|
self.write_with_template(f"{base_filename}{shard_id}{extension}",
|
|
filename,
|
|
lambda: shard)
|
|
|
|
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
|
|
self.filenames.discard(
|
|
f"{self.install_dir}/{base_filename}Everything{extension}")
|
|
|
|
def write_outputs(self, variable_name: str, filename: str) -> None:
|
|
"""Write a file containing the list of all outputs which are
|
|
generated by this script."""
|
|
content = 'set({}\n {})'.format(
|
|
variable_name, '\n '.join('"' + name + '"' for name in sorted(self.filenames)))
|
|
self._write_if_changed(filename, content)
|