mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: RFC: https://github.com/pytorch/rfcs/pull/40 This PR (re)introduces python codegen for unboxing wrappers. Given an entry of `native_functions.yaml` the codegen should be able to generate the corresponding C++ code to convert ivalues from the stack to their proper types. To trigger the codegen, run ``` tools/jit/gen_unboxing.py -d cg/torch/share/ATen ``` Merged changes on CI test. In https://github.com/pytorch/pytorch/issues/71782 I added an e2e test for static dispatch + codegen unboxing. The test exports a mobile model of mobilenetv2, load and run it on a new binary for lite interpreter: `test/mobile/custom_build/lite_predictor.cpp`. ## Lite predictor build specifics 1. Codegen: `gen.py` generates `RegisterCPU.cpp` and `RegisterSchema.cpp`. Now with this PR, once `static_dispatch` mode is enabled, `gen.py` will not generate `TORCH_LIBRARY` API calls in those cpp files, hence avoids interaction with the dispatcher. Once `USE_LIGHTWEIGHT_DISPATCH` is turned on, `cmake/Codegen.cmake` calls `gen_unboxing.py` which generates `UnboxingFunctions.h`, `UnboxingFunctions_[0-4].cpp` and `RegisterCodegenUnboxedKernels_[0-4].cpp`. 2. Build: `USE_LIGHTWEIGHT_DISPATCH` adds generated sources into `all_cpu_cpp` in `aten/src/ATen/CMakeLists.txt`. All other files remain unchanged. In reality all the `Operators_[0-4].cpp` are not necessary but we can rely on linker to strip them off. ## Current CI job test coverage update Created a new CI job `linux-xenial-py3-clang5-mobile-lightweight-dispatch-build` that enables the following build options: * `USE_LIGHTWEIGHT_DISPATCH=1` * `BUILD_LITE_INTERPRETER=1` * `STATIC_DISPATCH_BACKEND=CPU` This job triggers `test/mobile/lightweight_dispatch/build.sh` and builds `libtorch`. Then the script runs C++ tests written in `test_lightweight_dispatch.cpp` and `test_codegen_unboxing.cpp`. Recent commits added tests to cover as many C++ argument type as possible: in `build.sh` we installed PyTorch Python API so that we can export test models in `tests_setup.py`. Then we run C++ test binary to run these models on lightweight dispatch enabled runtime. Pull Request resolved: https://github.com/pytorch/pytorch/pull/69881 Reviewed By: iseeyuan Differential Revision: D33692299 Pulled By: larryliu0820 fbshipit-source-id: 211e59f2364100703359b4a3d2ab48ca5155a023 (cherry picked from commit 58e1c9a25e3d1b5b656282cf3ac2f548d98d530b)
246 lines
9.3 KiB
Python
246 lines
9.3 KiB
Python
import contextlib
|
|
import functools
|
|
import hashlib
|
|
import os
|
|
import re
|
|
import textwrap
|
|
from argparse import Namespace
|
|
from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional, Dict, Any, Union, Set, NoReturn
|
|
from enum import Enum
|
|
|
|
from tools.codegen.code_template import CodeTemplate
|
|
|
|
# Safely load fast C Yaml loader/dumper if they are available
|
|
try:
|
|
from yaml import CSafeLoader as Loader
|
|
except ImportError:
|
|
from yaml import SafeLoader as Loader # type: ignore[misc]
|
|
|
|
try:
|
|
from yaml import CSafeDumper as Dumper
|
|
except ImportError:
|
|
from yaml import SafeDumper as Dumper # type: ignore[misc]
|
|
YamlDumper = Dumper
|
|
|
|
# A custom loader for YAML that errors on duplicate keys.
|
|
# This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165
|
|
class YamlLoader(Loader):
|
|
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
|
|
mapping = []
|
|
for key_node, value_node in node.value:
|
|
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
|
|
assert key not in mapping, f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
|
|
mapping.append(key)
|
|
mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
|
|
return mapping
|
|
|
|
# Many of these functions share logic for defining both the definition
|
|
# and declaration (for example, the function signature is the same), so
|
|
# we organize them into one function that takes a Target to say which
|
|
# code we want.
|
|
#
|
|
# This is an OPEN enum (we may add more cases to it in the future), so be sure
|
|
# to explicitly specify with Union[Literal[Target.XXX]] what targets are valid
|
|
# for your use.
|
|
Target = Enum('Target', (
|
|
# top level namespace (not including at)
|
|
'DEFINITION',
|
|
'DECLARATION',
|
|
# TORCH_LIBRARY(...) { ... }
|
|
'REGISTRATION',
|
|
# namespace { ... }
|
|
'ANONYMOUS_DEFINITION',
|
|
# namespace cpu { ... }
|
|
'NAMESPACED_DEFINITION',
|
|
'NAMESPACED_DECLARATION',
|
|
))
|
|
|
|
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
|
|
# occurrence of a parameter in the derivative formula
|
|
IDENT_REGEX = r'(^|\W){}($|\W)'
|
|
|
|
# TODO: Use a real parser here; this will get bamboozled
|
|
def split_name_params(schema: str) -> Tuple[str, List[str]]:
|
|
m = re.match(r'(\w+)(\.\w+)?\((.*)\)', schema)
|
|
if m is None:
|
|
raise RuntimeError(f'Unsupported function schema: {schema}')
|
|
name, _, params = m.groups()
|
|
return name, params.split(', ')
|
|
|
|
T = TypeVar('T')
|
|
S = TypeVar('S')
|
|
|
|
# These two functions purposely return generators in analogy to map()
|
|
# so that you don't mix up when you need to list() them
|
|
|
|
# Map over function that may return None; omit Nones from output sequence
|
|
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
for x in xs:
|
|
r = func(x)
|
|
if r is not None:
|
|
yield r
|
|
|
|
# Map over function that returns sequences and cat them all together
|
|
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
|
|
for x in xs:
|
|
for r in func(x):
|
|
yield r
|
|
|
|
# Conveniently add error context to exceptions raised. Lets us
|
|
# easily say that an error occurred while processing a specific
|
|
# context.
|
|
@contextlib.contextmanager
|
|
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
|
|
try:
|
|
yield
|
|
except Exception as e:
|
|
# TODO: this does the wrong thing with KeyError
|
|
msg = msg_fn()
|
|
msg = textwrap.indent(msg, ' ')
|
|
msg = f'{e.args[0]}\n{msg}' if e.args else msg
|
|
e.args = (msg,) + e.args[1:]
|
|
raise
|
|
|
|
# A little trick from https://github.com/python/mypy/issues/6366
|
|
# for getting mypy to do exhaustiveness checking
|
|
# TODO: put this somewhere else, maybe
|
|
def assert_never(x: NoReturn) -> NoReturn:
|
|
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def _read_template(template_fn: str) -> CodeTemplate:
|
|
return CodeTemplate.from_file(template_fn)
|
|
|
|
|
|
# String hash that's stable across different executions, unlike builtin hash
|
|
def string_stable_hash(s: str) -> int:
|
|
sha1 = hashlib.sha1(s.encode('latin1')).digest()
|
|
return int.from_bytes(sha1, byteorder='little')
|
|
|
|
# A small abstraction for writing out generated files and keeping track
|
|
# of what files have been written (so you can write out a list of output
|
|
# files)
|
|
class FileManager:
|
|
install_dir: str
|
|
template_dir: str
|
|
dry_run: bool
|
|
filenames: Set[str]
|
|
|
|
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
|
|
self.install_dir = install_dir
|
|
self.template_dir = template_dir
|
|
self.filenames = set()
|
|
self.dry_run = dry_run
|
|
|
|
def _write_if_changed(self, filename: str, contents: str) -> None:
|
|
old_contents: Optional[str]
|
|
try:
|
|
with open(filename, 'r') as f:
|
|
old_contents = f.read()
|
|
except IOError:
|
|
old_contents = None
|
|
if contents != old_contents:
|
|
# Create output directory if it doesn't exist
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
with open(filename, 'w') as f:
|
|
f.write(contents)
|
|
|
|
def write_with_template(self, filename: str, template_fn: str,
|
|
env_callable: Callable[[], Union[str, Dict[str, Any]]]) -> None:
|
|
filename = '{}/{}'.format(self.install_dir, filename)
|
|
assert filename not in self.filenames, "duplicate file write {filename}"
|
|
self.filenames.add(filename)
|
|
if not self.dry_run:
|
|
env = env_callable()
|
|
if isinstance(env, dict):
|
|
# TODO: Update the comment reference to the correct location
|
|
if 'generated_comment' not in env:
|
|
comment = "@" + "generated by tools/codegen/gen.py"
|
|
comment += " from {}".format(os.path.basename(template_fn))
|
|
env['generated_comment'] = comment
|
|
template = _read_template(os.path.join(self.template_dir, template_fn))
|
|
self._write_if_changed(filename, template.substitute(env))
|
|
elif isinstance(env, str):
|
|
self._write_if_changed(filename, env)
|
|
else:
|
|
assert_never(env)
|
|
|
|
|
|
def write(self, filename: str, env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]]) -> None:
|
|
self.write_with_template(filename, filename, env_callable)
|
|
|
|
def write_sharded(
|
|
self,
|
|
filename: str,
|
|
items: Iterable[T],
|
|
*,
|
|
key_fn: Callable[[T], str],
|
|
env_callable: Callable[[T], Dict[str, List[str]]],
|
|
num_shards: int,
|
|
base_env: Optional[Dict[str, Any]] = None,
|
|
sharded_keys: Set[str]
|
|
) -> None:
|
|
|
|
everything: Dict[str, Any] = {'shard_id': 'Everything'}
|
|
shards: List[Dict[str, Any]] = [{'shard_id': f'_{i}'} for i in range(num_shards)]
|
|
all_shards = [everything] + shards
|
|
|
|
if base_env is not None:
|
|
for shard in all_shards:
|
|
shard.update(base_env)
|
|
|
|
for key in sharded_keys:
|
|
for shard in all_shards:
|
|
if key in shard:
|
|
assert isinstance(shard[key], list), "sharded keys in base_env must be a list"
|
|
shard[key] = shard[key].copy()
|
|
else:
|
|
shard[key] = []
|
|
|
|
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
|
|
for k, v in from_.items():
|
|
assert k in sharded_keys, f"undeclared sharded key {k}"
|
|
into[k] += v
|
|
|
|
if self.dry_run:
|
|
# Dry runs don't write any templates, so incomplete environments are fine
|
|
items = ()
|
|
|
|
for item in items:
|
|
key = key_fn(item)
|
|
sid = string_stable_hash(key) % num_shards
|
|
env = env_callable(item)
|
|
|
|
merge_env(shards[sid], env)
|
|
merge_env(everything, env)
|
|
|
|
dot_pos = filename.rfind('.')
|
|
if dot_pos == -1:
|
|
dot_pos = len(filename)
|
|
base_filename = filename[:dot_pos]
|
|
extension = filename[dot_pos:]
|
|
|
|
for shard in all_shards:
|
|
shard_id = shard['shard_id']
|
|
self.write_with_template(f"{base_filename}{shard_id}{extension}",
|
|
filename,
|
|
lambda: shard)
|
|
|
|
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
|
|
self.filenames.discard(
|
|
f"{self.install_dir}/{base_filename}Everything{extension}")
|
|
|
|
def write_outputs(self, variable_name: str, filename: str) -> None:
|
|
"""Write a file containing the list of all outputs which are
|
|
generated by this script."""
|
|
content = 'set({}\n {})'.format(
|
|
variable_name, '\n '.join('"' + name + '"' for name in sorted(self.filenames)))
|
|
self._write_if_changed(filename, content)
|
|
|
|
|
|
# Helper function to generate file manager
|
|
def make_file_manager(options: Namespace, install_dir: Optional[str] = None) -> FileManager:
|
|
template_dir = os.path.join(options.source_path, "templates")
|
|
install_dir = install_dir if install_dir else options.install_dir
|
|
return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run)
|