mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
This improves a dry-run of `gen.py` from 0.80s to 0.45s.
`FileManager` in `dry_run` mode doesn't actually need to compute the
environment; it just records the filenames that would have been
written.
cc ezyang bhosmer bdhirsh
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69805
Reviewed By: ngimel
Differential Revision: D33944912
Pulled By: albanD
fbshipit-source-id: 74f22af3f2bd5afdef7105961270198566fa91e5
(cherry picked from commit 6fcdc15954)
238 lines
8.9 KiB
Python
238 lines
8.9 KiB
Python
import contextlib
|
|
import functools
|
|
import hashlib
|
|
import os
|
|
import re
|
|
import textwrap
|
|
from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional, Dict, Any, Union, Set, NoReturn
|
|
from enum import Enum
|
|
|
|
from tools.codegen.code_template import CodeTemplate
|
|
|
|
# Safely load fast C Yaml loader/dumper if they are available
|
|
try:
|
|
from yaml import CSafeLoader as Loader
|
|
except ImportError:
|
|
from yaml import SafeLoader as Loader # type: ignore[misc]
|
|
|
|
try:
|
|
from yaml import CSafeDumper as Dumper
|
|
except ImportError:
|
|
from yaml import SafeDumper as Dumper # type: ignore[misc]
|
|
YamlDumper = Dumper
|
|
|
|
# A custom loader for YAML that errors on duplicate keys.
|
|
# This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165
|
|
class YamlLoader(Loader):
|
|
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
|
|
mapping = []
|
|
for key_node, value_node in node.value:
|
|
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
|
|
assert key not in mapping, f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
|
|
mapping.append(key)
|
|
mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
|
|
return mapping
|
|
|
|
# Many of these functions share logic for defining both the definition
|
|
# and declaration (for example, the function signature is the same), so
|
|
# we organize them into one function that takes a Target to say which
|
|
# code we want.
|
|
#
|
|
# This is an OPEN enum (we may add more cases to it in the future), so be sure
|
|
# to explicitly specify with Union[Literal[Target.XXX]] what targets are valid
|
|
# for your use.
|
|
Target = Enum('Target', (
|
|
# top level namespace (not including at)
|
|
'DEFINITION',
|
|
'DECLARATION',
|
|
# TORCH_LIBRARY(...) { ... }
|
|
'REGISTRATION',
|
|
# namespace { ... }
|
|
'ANONYMOUS_DEFINITION',
|
|
# namespace cpu { ... }
|
|
'NAMESPACED_DEFINITION',
|
|
'NAMESPACED_DECLARATION',
|
|
))
|
|
|
|
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
|
|
# occurrence of a parameter in the derivative formula
|
|
IDENT_REGEX = r'(^|\W){}($|\W)'
|
|
|
|
# TODO: Use a real parser here; this will get bamboozled
|
|
def split_name_params(schema: str) -> Tuple[str, List[str]]:
|
|
m = re.match(r'(\w+)(\.\w+)?\((.*)\)', schema)
|
|
if m is None:
|
|
raise RuntimeError(f'Unsupported function schema: {schema}')
|
|
name, _, params = m.groups()
|
|
return name, params.split(', ')
|
|
|
|
T = TypeVar('T')
|
|
S = TypeVar('S')
|
|
|
|
# These two functions purposely return generators in analogy to map()
|
|
# so that you don't mix up when you need to list() them
|
|
|
|
# Map over function that may return None; omit Nones from output sequence
|
|
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
for x in xs:
|
|
r = func(x)
|
|
if r is not None:
|
|
yield r
|
|
|
|
# Map over function that returns sequences and cat them all together
|
|
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
for x in xs:
|
|
for r in func(x):
|
|
yield r
|
|
|
|
# Conveniently add error context to exceptions raised. Lets us
|
|
# easily say that an error occurred while processing a specific
|
|
# context.
|
|
@contextlib.contextmanager
|
|
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
|
|
try:
|
|
yield
|
|
except Exception as e:
|
|
# TODO: this does the wrong thing with KeyError
|
|
msg = msg_fn()
|
|
msg = textwrap.indent(msg, ' ')
|
|
msg = f'{e.args[0]}\n{msg}' if e.args else msg
|
|
e.args = (msg,) + e.args[1:]
|
|
raise
|
|
|
|
# A little trick from https://github.com/python/mypy/issues/6366
|
|
# for getting mypy to do exhaustiveness checking
|
|
# TODO: put this somewhere else, maybe
|
|
def assert_never(x: NoReturn) -> NoReturn:
|
|
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def _read_template(template_fn: str) -> CodeTemplate:
|
|
return CodeTemplate.from_file(template_fn)
|
|
|
|
|
|
# String hash that's stable across different executions, unlike builtin hash
|
|
def string_stable_hash(s: str) -> int:
|
|
sha1 = hashlib.sha1(s.encode('latin1')).digest()
|
|
return int.from_bytes(sha1, byteorder='little')
|
|
|
|
# A small abstraction for writing out generated files and keeping track
|
|
# of what files have been written (so you can write out a list of output
|
|
# files)
|
|
class FileManager:
|
|
install_dir: str
|
|
template_dir: str
|
|
dry_run: bool
|
|
filenames: Set[str]
|
|
|
|
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
|
|
self.install_dir = install_dir
|
|
self.template_dir = template_dir
|
|
self.filenames = set()
|
|
self.dry_run = dry_run
|
|
|
|
def _write_if_changed(self, filename: str, contents: str) -> None:
|
|
old_contents: Optional[str]
|
|
try:
|
|
with open(filename, 'r') as f:
|
|
old_contents = f.read()
|
|
except IOError:
|
|
old_contents = None
|
|
if contents != old_contents:
|
|
# Create output directory if it doesn't exist
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
with open(filename, 'w') as f:
|
|
f.write(contents)
|
|
|
|
def write_with_template(self, filename: str, template_fn: str,
|
|
env_callable: Callable[[], Union[str, Dict[str, Any]]]) -> None:
|
|
filename = '{}/{}'.format(self.install_dir, filename)
|
|
assert filename not in self.filenames, "duplicate file write {filename}"
|
|
self.filenames.add(filename)
|
|
if not self.dry_run:
|
|
env = env_callable()
|
|
if isinstance(env, dict):
|
|
# TODO: Update the comment reference to the correct location
|
|
if 'generated_comment' not in env:
|
|
comment = "@" + "generated by tools/codegen/gen.py"
|
|
comment += " from {}".format(os.path.basename(template_fn))
|
|
env['generated_comment'] = comment
|
|
template = _read_template(os.path.join(self.template_dir, template_fn))
|
|
self._write_if_changed(filename, template.substitute(env))
|
|
elif isinstance(env, str):
|
|
self._write_if_changed(filename, env)
|
|
else:
|
|
assert_never(env)
|
|
|
|
|
|
def write(self, filename: str, env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]]) -> None:
|
|
self.write_with_template(filename, filename, env_callable)
|
|
|
|
def write_sharded(
|
|
self,
|
|
filename: str,
|
|
items: Iterable[T],
|
|
*,
|
|
key_fn: Callable[[T], str],
|
|
env_callable: Callable[[T], Dict[str, List[str]]],
|
|
num_shards: int,
|
|
base_env: Optional[Dict[str, Any]] = None,
|
|
sharded_keys: Set[str]
|
|
) -> None:
|
|
|
|
everything: Dict[str, Any] = {'shard_id': 'Everything'}
|
|
shards: List[Dict[str, Any]] = [{'shard_id': f'_{i}'} for i in range(num_shards)]
|
|
all_shards = [everything] + shards
|
|
|
|
if base_env is not None:
|
|
for shard in all_shards:
|
|
shard.update(base_env)
|
|
|
|
for key in sharded_keys:
|
|
for shard in all_shards:
|
|
if key in shard:
|
|
assert isinstance(shard[key], list), "sharded keys in base_env must be a list"
|
|
shard[key] = shard[key].copy()
|
|
else:
|
|
shard[key] = []
|
|
|
|
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
|
|
for k, v in from_.items():
|
|
assert k in sharded_keys, f"undeclared sharded key {k}"
|
|
into[k] += v
|
|
|
|
if self.dry_run:
|
|
# Dry runs don't write any templates, so incomplete environments are fine
|
|
items = ()
|
|
|
|
for item in items:
|
|
key = key_fn(item)
|
|
sid = string_stable_hash(key) % num_shards
|
|
env = env_callable(item)
|
|
|
|
merge_env(shards[sid], env)
|
|
merge_env(everything, env)
|
|
|
|
dot_pos = filename.rfind('.')
|
|
if dot_pos == -1:
|
|
dot_pos = len(filename)
|
|
base_filename = filename[:dot_pos]
|
|
extension = filename[dot_pos:]
|
|
|
|
for shard in all_shards:
|
|
shard_id = shard['shard_id']
|
|
self.write_with_template(f"{base_filename}{shard_id}{extension}",
|
|
filename,
|
|
lambda: shard)
|
|
|
|
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
|
|
self.filenames.discard(
|
|
f"{self.install_dir}/{base_filename}Everything{extension}")
|
|
|
|
def write_outputs(self, variable_name: str, filename: str) -> None:
|
|
"""Write a file containing the list of all outputs which are
|
|
generated by this script."""
|
|
content = 'set({}\n {})'.format(
|
|
variable_name, '\n '.join('"' + name + '"' for name in sorted(self.filenames)))
|
|
self._write_if_changed(filename, content)
|