mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[RFC] Switch PyTorch Selective Build (Custom Build) to use the SelectiveBuilder abstraction (#45722)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45722 This diff does a bunch of things: 1. Introduces some abstractions as detailed in https://fb.quip.com/2oEzAR5MKqbD to help with selective build related codegen in multiple files. 2. Adds helper methods to combine operators, debug info, operator lists, etc... 3. Currently, the selective build machinery querying `op_registration_whitelist` directly at various places in the code. `op_registration_whitelist` is a list of allowed operator names (without overload name). We want to move to a world where the overload names are also included so that we can be more selective about which operators we include. To that effect, it makes sense to hide the checking logic in a separate abstraction and have the build use that abstraction instead of putting all this selective build specific logic in the code-generator itself. This change is attempting to do just that. 4. Updates generate_code, unboxing-wrapper codegen, and autograd codegen to accept the operator selector paradigm as opposed to a selected operator list. 5. Update `tools/code_analyzer/gen_op_registration_allowlist.py` to expose providing an actual structured operator dependency graph in addition to a serialized string. There are a bunch of structural changes as well: 1. `root_op_list.yaml` and `combined_op_list.yaml` are now actual YAML files (not a space separated list of operator names) 2. `generate_code.py` accepts only paths to operator list YAML files (both old style as well as new style) and not list of operator names on the command line as arguments 3. `gen.py` optionally also accepts a custom build related operators YAML path (this file has information about which operators to register in the generated library). ghstack-source-id: 114578753 (Note: this ignores all push blocking failures!) Test Plan: `buck test caffe2/test:selective_build` Generated YAML files after the change: {P143981979} {P143982025} {P143982056} Ensure that the generated files are same before and after the change: ``` [dhruvbird@devvm2490 /tmp/TypeDefault.cpp] find -name "*.cpp" | xargs md5sum d72c3d125baa7b77e4c5581bbc7110d2 ./after_change/gen_aten/TypeDefault.cpp 42353036c83ebc7620a7159235b9647f ./after_change/lite_predictor_lib_aten/TypeDefault.cpp d72c3d125baa7b77e4c5581bbc7110d2 ./before_change/gen_aten/TypeDefault.cpp 42353036c83ebc7620a7159235b9647f ./before_change/lite_predictor_lib_aten/TypeDefault.cpp ``` `VariableTypes_N.cpp` are generated the same both before and after the change: ``` [dhruvbird@devvm2490 /tmp/VariableType] find -name "*.cpp" | xargs -n 1 md5sum | sort 3be89f63fd098291f01935077a60b677 ./after/VariableType_2.cpp 3be89f63fd098291f01935077a60b677 ./before/VariableType_2.cpp 40a3e59d64e9dbe86024cf314f127fd6 ./after/VariableType_4.cpp 40a3e59d64e9dbe86024cf314f127fd6 ./before/VariableType_4.cpp a4911699ceda3c3a430f08c64e8243fd ./after/VariableType_1.cpp a4911699ceda3c3a430f08c64e8243fd ./before/VariableType_1.cpp ca9aa611fcb2a573a8cba4e269468c99 ./after/VariableType_0.cpp ca9aa611fcb2a573a8cba4e269468c99 ./before/VariableType_0.cpp e18f639ed23d802dc4a31cdba40df570 ./after/VariableType_3.cpp e18f639ed23d802dc4a31cdba40df570 ./before/VariableType_3.cpp ``` Reviewed By: ljk53 Differential Revision: D23837010 fbshipit-source-id: ad06b1756af5be25baa39fd801dfdf09bc565442
This commit is contained in:
parent
bcd68dfa5f
commit
0c5cd8c2b9
|
|
@ -27,7 +27,8 @@ import os
|
|||
import yaml
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from .utils import YamlLoader, split_name_params, op_name_without_overload
|
||||
from .utils import YamlLoader, split_name_params, op_name_with_overload
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
# See NOTE [ Autograd View Variables ] in variable.h for details.
|
||||
# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT,
|
||||
|
|
@ -218,15 +219,17 @@ def load_deprecated_signatures(aten_decls, deprecated_path):
|
|||
return declarations
|
||||
|
||||
|
||||
def gen_autograd(aten_path, out, autograd_dir, disable_autograd=False, selected_op_list=None):
|
||||
def gen_autograd(aten_path, out, autograd_dir, operator_selector: SelectiveBuilder, disable_autograd=False):
|
||||
full_aten_decls = load_aten_declarations(aten_path)
|
||||
|
||||
def filter_decls(aten_decls, selected_op_list):
|
||||
if selected_op_list is None:
|
||||
return aten_decls
|
||||
return [decl for decl in aten_decls if op_name_without_overload(decl) in selected_op_list]
|
||||
def filter_decls(aten_decls, operator_selector):
|
||||
def is_operator_selected_for_training(decl):
|
||||
op_name = op_name_with_overload(decl)
|
||||
return operator_selector.is_operator_selected_for_training(op_name)
|
||||
|
||||
aten_decls = filter_decls(full_aten_decls, selected_op_list)
|
||||
return [decl for decl in aten_decls if is_operator_selected_for_training(decl)]
|
||||
|
||||
aten_decls = filter_decls(full_aten_decls, operator_selector)
|
||||
|
||||
# Parse and load derivatives.yaml
|
||||
from .load_derivatives import load_derivatives
|
||||
|
|
|
|||
|
|
@ -74,9 +74,8 @@ def is_tensor_method(declaration):
|
|||
def is_out_variant(decl):
|
||||
return decl['name'].endswith('_out')
|
||||
|
||||
def op_name_without_overload(decl):
|
||||
name = decl['name'] if not is_out_variant(decl) else decl['name'][:-4]
|
||||
return 'aten::{}'.format(name)
|
||||
def op_name_with_overload(decl):
|
||||
return decl['operator_name_with_overload']
|
||||
|
||||
def load_op_list_and_strip_overload(op_list, op_list_path):
|
||||
if op_list is None and op_list_path is None:
|
||||
|
|
|
|||
|
|
@ -59,7 +59,10 @@ def gen_transitive_closure(dep_graph, root_ops):
|
|||
result.add(dep)
|
||||
queue.append(dep)
|
||||
|
||||
return ' '.join(sorted(result))
|
||||
return sorted(result)
|
||||
|
||||
def gen_transitive_closure_str(dep_graph, root_ops):
|
||||
return ' '.join(gen_transitive_closure(dep_graph, root_ops))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -77,4 +80,4 @@ if __name__ == "__main__":
|
|||
|
||||
deps = load_op_dep_graph(args.op_dependency) if args.op_dependency else {}
|
||||
root_ops = load_root_ops(args.root_ops)
|
||||
print(gen_transitive_closure(deps, root_ops))
|
||||
print(gen_transitive_closure_str(deps, root_ops))
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import tools.codegen.api.cpp as cpp
|
|||
import tools.codegen.api.dispatcher as dispatcher
|
||||
import tools.codegen.api.native as native
|
||||
import tools.codegen.local as local
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
try:
|
||||
# use faster C loader if available
|
||||
|
|
@ -189,9 +190,9 @@ KEYWORD_ALL_BACKENDS = ('DefaultBackend', 'Math')
|
|||
def compute_type_method(
|
||||
dispatch: Optional[str], *,
|
||||
target: Target,
|
||||
# Which operators to actually generate code for. If None, generate
|
||||
# code for all operators
|
||||
op_registration_whitelist: Optional[Set[str]],
|
||||
# Selector object to determine which operators to generate
|
||||
# registration code for.
|
||||
selector: SelectiveBuilder,
|
||||
# Only valid for generating registrations. If True, only generate
|
||||
# def() invocations (for schema registration); do not generate
|
||||
# any impl() invocations for, e.g., catch-all kernels
|
||||
|
|
@ -210,8 +211,8 @@ def compute_type_method(
|
|||
if f.dispatch is not None and target is not Target.REGISTRATION:
|
||||
return None
|
||||
|
||||
if op_registration_whitelist is not None and \
|
||||
f"aten::{f.func.name.name}" not in op_registration_whitelist and target is Target.REGISTRATION:
|
||||
op_name = f"aten::{f.func.name}"
|
||||
if target is Target.REGISTRATION and not selector.is_operator_selected(op_name):
|
||||
return None
|
||||
|
||||
name = native.name(f.func)
|
||||
|
|
@ -882,6 +883,33 @@ class FileManager:
|
|||
filename,
|
||||
''.join(name + ";" for name in sorted(self.filenames)))
|
||||
|
||||
def get_custom_build_selector(
|
||||
provided_op_registration_allowlist: Optional[List[str]],
|
||||
op_selection_yaml_path: Optional[str]) -> SelectiveBuilder:
|
||||
assert not (
|
||||
provided_op_registration_allowlist is not None and
|
||||
op_selection_yaml_path is not None), (
|
||||
"Both provided_op_registration_allowlist and " +
|
||||
"op_selection_yaml_path can NOT be provided at the " +
|
||||
"same time.")
|
||||
|
||||
op_registration_allowlist: Optional[Set[str]] = None
|
||||
if provided_op_registration_allowlist is not None:
|
||||
op_registration_allowlist = set(provided_op_registration_allowlist)
|
||||
|
||||
if op_registration_allowlist is not None:
|
||||
selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||
op_registration_allowlist,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
elif op_selection_yaml_path is not None:
|
||||
selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
|
||||
else:
|
||||
selector = SelectiveBuilder.get_nop_selector()
|
||||
|
||||
return selector
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description='Generate ATen source files')
|
||||
parser.add_argument(
|
||||
|
|
@ -905,12 +933,22 @@ def main() -> None:
|
|||
'--vulkan',
|
||||
action='store_true',
|
||||
help='Generate Vulkan backend functions')
|
||||
# TODO: --op_registration_whitelist will be removed when all call-sites
|
||||
# for gen.py are moved over to using the operator YAML file for mobile
|
||||
# custom build.
|
||||
parser.add_argument(
|
||||
'--op_registration_whitelist',
|
||||
nargs='*',
|
||||
help='filter op registrations by the whitelist (if set); '
|
||||
'each item is `namespace`::`operator name` without overload name; '
|
||||
'e.g.: aten::empty aten::conv2d ...')
|
||||
parser.add_argument(
|
||||
'--op_selection_yaml_path',
|
||||
help='Provide a path to the operator selection (for custom build) YAML '
|
||||
'that contains the information about the set of selected operators '
|
||||
'and their categories (training, ...). Each operator is either a '
|
||||
'full operator name with overload or just a bare operator name. '
|
||||
'The operator names also contain the namespace prefix (e.g. aten::)')
|
||||
parser.add_argument(
|
||||
'--backend_whitelist',
|
||||
nargs='*',
|
||||
|
|
@ -923,11 +961,10 @@ def main() -> None:
|
|||
'those that are not listed on --op_registration_whitelist')
|
||||
options = parser.parse_args()
|
||||
|
||||
op_registration_whitelist: Optional[Set[str]]
|
||||
if options.op_registration_whitelist is not None:
|
||||
op_registration_whitelist = set(options.op_registration_whitelist)
|
||||
else:
|
||||
op_registration_whitelist = None
|
||||
selector = get_custom_build_selector(
|
||||
options.op_registration_whitelist,
|
||||
options.op_selection_yaml_path,
|
||||
)
|
||||
|
||||
native_functions = parse_native_yaml(os.path.join(options.source_path, 'native/native_functions.yaml'))
|
||||
|
||||
|
|
@ -995,7 +1032,7 @@ def main() -> None:
|
|||
'Type': f'{dispatch}Type',
|
||||
'extra_cuda_headers': extra_cuda_headers if 'CUDA' in dispatch else '', # TODO: remove this
|
||||
'type_derived_method_declarations': list(mapMaybe(
|
||||
compute_type_method(dispatch, target=Target.DECLARATION, op_registration_whitelist=op_registration_whitelist),
|
||||
compute_type_method(dispatch, target=Target.DECLARATION, selector=selector),
|
||||
native_functions
|
||||
)),
|
||||
})
|
||||
|
|
@ -1013,48 +1050,49 @@ def main() -> None:
|
|||
'',
|
||||
'Backend': dispatch,
|
||||
'type_derived_method_definitions': list(mapMaybe(
|
||||
compute_type_method(dispatch, target=Target.DEFINITION, op_registration_whitelist=op_registration_whitelist),
|
||||
compute_type_method(dispatch, target=Target.DEFINITION, selector=selector),
|
||||
native_functions
|
||||
)),
|
||||
'function_registrations': list(mapMaybe(
|
||||
compute_type_method(
|
||||
dispatch, target=Target.REGISTRATION, op_registration_whitelist=op_registration_whitelist),
|
||||
native_functions)),
|
||||
dispatch, target=Target.REGISTRATION, selector=selector),
|
||||
native_functions
|
||||
)),
|
||||
})
|
||||
del fm
|
||||
|
||||
cpu_fm.write('TypeDefault.h', lambda: {
|
||||
'type_method_declarations':
|
||||
list(mapMaybe(
|
||||
compute_type_method(None, target=Target.DECLARATION, op_registration_whitelist=op_registration_whitelist),
|
||||
compute_type_method(None, target=Target.DECLARATION, selector=selector),
|
||||
native_functions)) +
|
||||
list(mapMaybe(
|
||||
compute_type_method('Math', target=Target.DECLARATION, op_registration_whitelist=op_registration_whitelist),
|
||||
compute_type_method('Math', target=Target.DECLARATION, selector=selector),
|
||||
native_functions)) +
|
||||
list(mapMaybe(
|
||||
compute_type_method('DefaultBackend', target=Target.DECLARATION, op_registration_whitelist=op_registration_whitelist),
|
||||
compute_type_method('DefaultBackend', target=Target.DECLARATION, selector=selector),
|
||||
native_functions)),
|
||||
})
|
||||
cpu_fm.write('TypeDefault.cpp', lambda: {
|
||||
'type_method_definitions':
|
||||
list(mapMaybe(
|
||||
compute_type_method(None, target=Target.DEFINITION, op_registration_whitelist=op_registration_whitelist),
|
||||
compute_type_method(None, target=Target.DEFINITION, selector=selector),
|
||||
native_functions)) +
|
||||
list(mapMaybe(
|
||||
compute_type_method('Math', target=Target.DEFINITION, op_registration_whitelist=op_registration_whitelist),
|
||||
compute_type_method('Math', target=Target.DEFINITION, selector=selector),
|
||||
native_functions)) +
|
||||
list(mapMaybe(
|
||||
compute_type_method('DefaultBackend', target=Target.DEFINITION, op_registration_whitelist=op_registration_whitelist),
|
||||
compute_type_method('DefaultBackend', target=Target.DEFINITION, selector=selector),
|
||||
native_functions)),
|
||||
|
||||
'function_registrations': list(mapMaybe(
|
||||
compute_type_method(None, target=Target.REGISTRATION, op_registration_whitelist=op_registration_whitelist),
|
||||
compute_type_method(None, target=Target.REGISTRATION, selector=selector),
|
||||
native_functions)) + list(mapMaybe(
|
||||
compute_type_method('Math', target=Target.REGISTRATION, op_registration_whitelist=op_registration_whitelist),
|
||||
compute_type_method('Math', target=Target.REGISTRATION, selector=selector),
|
||||
native_functions)),
|
||||
|
||||
'default_backend_function_registrations': list(mapMaybe(
|
||||
compute_type_method('DefaultBackend', target=Target.REGISTRATION, op_registration_whitelist=op_registration_whitelist),
|
||||
compute_type_method('DefaultBackend', target=Target.REGISTRATION, selector=selector),
|
||||
native_functions)),
|
||||
})
|
||||
cpu_fm.write('Functions.h', lambda: {
|
||||
|
|
@ -1085,7 +1123,7 @@ def main() -> None:
|
|||
if options.force_schema_registration:
|
||||
def computeSchemaRegister() -> Dict[str, object]:
|
||||
schema_registrations = list(mapMaybe(
|
||||
compute_type_method(None, target=Target.REGISTRATION, op_registration_whitelist=None, def_only=True),
|
||||
compute_type_method(None, target=Target.REGISTRATION, selector=SelectiveBuilder.get_nop_selector(), def_only=True),
|
||||
native_functions))
|
||||
return {
|
||||
'schema_registrations': schema_registrations,
|
||||
|
|
|
|||
0
tools/codegen/selective_build/__init__.py
Normal file
0
tools/codegen/selective_build/__init__.py
Normal file
159
tools/codegen/selective_build/operator.py
Normal file
159
tools/codegen/selective_build/operator.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
from typing import Dict, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
# This class holds information about a single operator used to determine
|
||||
# the outcome of a selective/custom PyTorch build that doesn't include
|
||||
# registration code for all the supported operators. This is done to
|
||||
# reduce the size of the generated binary so that it can be deployed in
|
||||
# situations where binary size comes at a premium.
|
||||
#
|
||||
@dataclass(frozen=True)
|
||||
class SelectiveBuildOperator():
|
||||
# The name of the operator. This includes the aten::, etc... prefix
|
||||
# The operator name may or may not have the overload name. If this
|
||||
# operator name does not specify an overload name, the way to determine
|
||||
# if this entry refers to the family of operators with this base name
|
||||
# or just the operator with this name is to look at the value of the
|
||||
# 'include_all_overloads' flag in this class.
|
||||
name: str
|
||||
|
||||
# True if this is a root operator (i.e. called directly from a
|
||||
# TorchScript model, etc...). An operator is considered to be a
|
||||
# root operator if it is called directly from any one of the models
|
||||
# that this instance of the pytorch library was built for. Hence, it
|
||||
# may not be a root operator in all of the models that are used in
|
||||
# this instance of the pytorch library.
|
||||
is_root_operator: bool
|
||||
|
||||
# Is this operator used for on-device training? If True, then we need to
|
||||
# use the information to generate code in VariableType_N.cpp for registration
|
||||
# of training related operators. Again, this is True if this operator
|
||||
# is used for training in one or more models used by this instance of the
|
||||
# pytorch library.
|
||||
is_used_for_training: bool
|
||||
|
||||
# If True, it indicates that this operator instance (object) refers to an
|
||||
# operator without the overload name and should apply to all overloads
|
||||
# which have this operator name as the base name. This flag is applicable
|
||||
# only for objects that have operator names without a DOT (period) character
|
||||
# in them.
|
||||
#
|
||||
# Note: This flag is a temporary workaround to grandfather in the current
|
||||
# static selective (custom) build mechanism, which largely ignores overload
|
||||
# names when determining whether to select operators for registration
|
||||
# purposes.
|
||||
include_all_overloads: bool
|
||||
|
||||
# Debug Information at the operator level
|
||||
_debug_info: Optional[Tuple[str, ...]]
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_dict(op_name: str, op_info: Dict[str, object]) -> 'SelectiveBuildOperator':
|
||||
allowed_keys = {'name', 'is_root_operator', 'is_used_for_training', 'include_all_overloads', 'debug_info'}
|
||||
|
||||
if len(set(op_info.keys()) - allowed_keys) > 0:
|
||||
raise Exception("Got unexpected top level keys: {}".format(
|
||||
",".join(set(op_info.keys()) - allowed_keys),
|
||||
))
|
||||
|
||||
if 'name' in op_info:
|
||||
assert op_name == op_info['name']
|
||||
|
||||
is_root_operator = op_info.get('is_root_operator', True)
|
||||
assert isinstance(is_root_operator, bool)
|
||||
|
||||
is_used_for_training = op_info.get('is_used_for_training', True)
|
||||
assert isinstance(is_used_for_training, bool)
|
||||
|
||||
include_all_overloads = op_info.get('include_all_overloads', True)
|
||||
assert isinstance(include_all_overloads, bool)
|
||||
|
||||
debug_info: Optional[Tuple[str, ...]] = None
|
||||
if 'debug_info' in op_info:
|
||||
di_list = op_info['debug_info']
|
||||
assert isinstance(di_list, list)
|
||||
debug_info = tuple(map(lambda x: str(x), di_list))
|
||||
|
||||
return SelectiveBuildOperator(
|
||||
name=op_name,
|
||||
is_root_operator=is_root_operator,
|
||||
is_used_for_training=is_used_for_training,
|
||||
include_all_overloads=include_all_overloads,
|
||||
_debug_info=debug_info,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_legacy_operator_name_without_overload(name: str) -> 'SelectiveBuildOperator':
|
||||
return SelectiveBuildOperator(
|
||||
name=name,
|
||||
is_root_operator=True,
|
||||
is_used_for_training=True,
|
||||
include_all_overloads=True,
|
||||
_debug_info=None,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
ret: Dict[str, object] = {
|
||||
'is_root_operator': self.is_root_operator,
|
||||
'is_used_for_training': self.is_used_for_training,
|
||||
'include_all_overloads': self.include_all_overloads,
|
||||
}
|
||||
if self._debug_info is not None:
|
||||
ret['debug_info'] = self._debug_info
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def merge_debug_info(
|
||||
lhs: Optional[Tuple[str, ...]],
|
||||
rhs: Optional[Tuple[str, ...]],
|
||||
) -> Optional[Tuple[str, ...]]:
|
||||
# Ensure that when merging, each entry shows up just once.
|
||||
if lhs is None and rhs is None:
|
||||
return None
|
||||
|
||||
return tuple(set((lhs or ()) + (rhs or ())))
|
||||
|
||||
|
||||
def combine_operators(
|
||||
lhs: 'SelectiveBuildOperator',
|
||||
rhs: 'SelectiveBuildOperator') -> 'SelectiveBuildOperator':
|
||||
if str(lhs.name) != str(rhs.name):
|
||||
raise Exception(
|
||||
"Expected both arguments to have the same name, but got '{}' and '{}' instead".format(
|
||||
str(lhs.name),
|
||||
str(rhs.name),
|
||||
)
|
||||
)
|
||||
|
||||
return SelectiveBuildOperator(
|
||||
name=lhs.name,
|
||||
# Consider this operator to be a root operator if it is a
|
||||
# root operator in any of the models used in this instance of
|
||||
# the pytorch library.
|
||||
is_root_operator=lhs.is_root_operator or rhs.is_root_operator,
|
||||
# Consider this operator to be a training operator if it is
|
||||
# an operator used for training in any of the models used
|
||||
# in this instance of the pytorch library.
|
||||
is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training,
|
||||
include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads,
|
||||
_debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info),
|
||||
)
|
||||
|
||||
def merge_operator_dicts(
|
||||
lhs: Dict[str, SelectiveBuildOperator],
|
||||
rhs: Dict[str, SelectiveBuildOperator],
|
||||
) -> Dict[str, SelectiveBuildOperator]:
|
||||
operators: Dict[str, SelectiveBuildOperator] = {}
|
||||
for (op_name, op) in list(lhs.items()) + list(rhs.items()):
|
||||
new_op = op
|
||||
if op_name in operators:
|
||||
new_op = combine_operators(operators[op_name], op)
|
||||
|
||||
operators[op_name] = new_op
|
||||
|
||||
return operators
|
||||
|
||||
|
||||
def strip_operator_overload_name(op_name: str) -> str:
|
||||
return op_name.split(".")[0]
|
||||
160
tools/codegen/selective_build/selector.py
Normal file
160
tools/codegen/selective_build/selector.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
from typing import Dict, Set, Optional, Tuple
|
||||
import yaml
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tools.codegen.selective_build.operator import *
|
||||
|
||||
# A SelectiveBuilder holds information extracted from the selective build
|
||||
# YAML specification.
|
||||
#
|
||||
# It includes information about the build's selectivity, the debug_info
|
||||
# associated with this selective build (opaque string), and the set of
|
||||
# operators that should be included in the build.
|
||||
#
|
||||
@dataclass(frozen=True)
|
||||
class SelectiveBuilder:
|
||||
|
||||
# If true, then the build is not selective, and includes all
|
||||
# operators.
|
||||
include_all_operators: bool
|
||||
|
||||
# Debug Information at the selective/custom build level.
|
||||
_debug_info: Optional[Tuple[str, ...]]
|
||||
|
||||
# A dictionary of operator -> operator metadata.
|
||||
operators: Dict[str, SelectiveBuildOperator]
|
||||
|
||||
@staticmethod
|
||||
def get_nop_selector() -> 'SelectiveBuilder':
|
||||
return SelectiveBuilder.from_yaml_dict({'include_all_operators': True})
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_dict(data: Dict[str, object]) -> 'SelectiveBuilder':
|
||||
valid_top_level_keys = {
|
||||
'include_all_operators',
|
||||
'debug_info',
|
||||
'operators',
|
||||
}
|
||||
top_level_keys = set(data.keys())
|
||||
if len(top_level_keys - valid_top_level_keys) > 0:
|
||||
raise Exception("Got unexpected top level keys: {}".format(
|
||||
",".join(top_level_keys - valid_top_level_keys),
|
||||
))
|
||||
include_all_operators = data.get('include_all_operators', False)
|
||||
assert isinstance(include_all_operators, bool)
|
||||
|
||||
debug_info = None
|
||||
if 'debug_info' in data:
|
||||
di_list = data['debug_info']
|
||||
assert isinstance(di_list, list)
|
||||
|
||||
debug_info = tuple(map(lambda x: str(x), di_list))
|
||||
|
||||
operators = {}
|
||||
operators_dict = data.get('operators', {})
|
||||
assert isinstance(operators_dict, dict)
|
||||
|
||||
for (k, v) in operators_dict.items():
|
||||
operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v)
|
||||
return SelectiveBuilder(include_all_operators, debug_info, operators)
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_str(config_contents: str) -> 'SelectiveBuilder':
|
||||
contents = yaml.load(config_contents)
|
||||
return SelectiveBuilder.from_yaml_dict(contents)
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_path(config_path: str) -> 'SelectiveBuilder':
|
||||
with open(config_path, 'r') as f:
|
||||
contents = yaml.load(f)
|
||||
return SelectiveBuilder.from_yaml_dict(contents)
|
||||
|
||||
@staticmethod
|
||||
def from_legacy_op_registration_allow_list(
|
||||
allow_list: Set[str],
|
||||
is_root_operator: bool,
|
||||
is_used_for_training: bool) -> 'SelectiveBuilder':
|
||||
operators = {}
|
||||
for op in allow_list:
|
||||
operators[op] = {
|
||||
'name': op,
|
||||
'is_root_operator': is_root_operator,
|
||||
'is_used_for_training': is_used_for_training,
|
||||
'include_all_overloads': True,
|
||||
}
|
||||
return SelectiveBuilder.from_yaml_dict({
|
||||
'operators': operators,
|
||||
})
|
||||
|
||||
def is_operator_selected(self, name: str) -> bool:
|
||||
if self.include_all_operators:
|
||||
return True
|
||||
|
||||
if name in self.operators:
|
||||
return True
|
||||
name = strip_operator_overload_name(name)
|
||||
return name in self.operators and self.operators[name].include_all_overloads
|
||||
|
||||
def is_operator_selected_for_training(self, name: str) -> bool:
|
||||
if not self.is_operator_selected(name):
|
||||
return False
|
||||
if self.include_all_operators:
|
||||
return True
|
||||
|
||||
not_training_op = SelectiveBuildOperator(
|
||||
name='',
|
||||
is_root_operator=False,
|
||||
is_used_for_training=False,
|
||||
include_all_overloads=False,
|
||||
_debug_info=None,
|
||||
)
|
||||
op = not_training_op
|
||||
if name in self.operators:
|
||||
op = self.operators[name]
|
||||
|
||||
name = strip_operator_overload_name(name)
|
||||
base_op = not_training_op
|
||||
if name in self.operators:
|
||||
base_op = self.operators[name]
|
||||
|
||||
return (
|
||||
op.is_used_for_training or
|
||||
(base_op.include_all_overloads and base_op.is_used_for_training)
|
||||
)
|
||||
|
||||
def is_root_operator(self, name: str) -> bool:
|
||||
if not self.is_operator_selected(name):
|
||||
return False
|
||||
if self.include_all_operators:
|
||||
return True
|
||||
|
||||
if name in self.operators:
|
||||
op: SelectiveBuildOperator = self.operators[name]
|
||||
return op.is_root_operator
|
||||
name = strip_operator_overload_name(name)
|
||||
if name not in self.operators:
|
||||
return False
|
||||
base_op: SelectiveBuildOperator = self.operators[name]
|
||||
return base_op.include_all_overloads and base_op.is_root_operator
|
||||
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
ret: Dict[str, object] = {
|
||||
'include_all_operators': self.include_all_operators,
|
||||
}
|
||||
operators = {}
|
||||
for (op_name, op) in self.operators.items():
|
||||
operators[op_name] = op.to_dict()
|
||||
ret['operators'] = operators
|
||||
|
||||
if self._debug_info is not None:
|
||||
ret['debug_info'] = self._debug_info
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def combine_selective_builders(lhs: SelectiveBuilder, rhs: SelectiveBuilder) -> SelectiveBuilder:
|
||||
include_all_operators = lhs.include_all_operators or rhs.include_all_operators
|
||||
debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info)
|
||||
operators = merge_operator_dicts(lhs.operators, rhs.operators)
|
||||
return SelectiveBuilder(include_all_operators, debug_info, operators)
|
||||
|
|
@ -24,7 +24,8 @@ from itertools import groupby
|
|||
from functools import reduce
|
||||
from ..autograd.gen_autograd import load_aten_declarations
|
||||
from ..autograd.gen_autograd import RETURNS_VIEWS_OF_INPUT
|
||||
from ..autograd.utils import CodeTemplate, write, is_out_variant, op_name_without_overload
|
||||
from ..autograd.utils import CodeTemplate, write, is_out_variant, op_name_with_overload
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
# JIT has a type system of
|
||||
# Scalar = int | float | bool # int is the largest int (int64_t),
|
||||
|
|
@ -280,8 +281,8 @@ def gen_unboxing_wrappers(
|
|||
declarations,
|
||||
out,
|
||||
template_path,
|
||||
operator_selector: SelectiveBuilder,
|
||||
disable_autograd=False,
|
||||
selected_op_list=None,
|
||||
force_schema_registration=False,
|
||||
):
|
||||
GENERATED_UNBOXING_WRAPPERS_CPP = CodeTemplate.from_file(template_path + '/generated_unboxing_wrappers.cpp')
|
||||
|
|
@ -386,18 +387,19 @@ def gen_unboxing_wrappers(
|
|||
|
||||
return constructor
|
||||
|
||||
def filter_decls(jit_decls, disable_autograd, selected_op_list, force_schema_registration):
|
||||
def filter_decls(jit_decls, disable_autograd, operator_selector: SelectiveBuilder, force_schema_registration):
|
||||
result = []
|
||||
for decl in jit_decls:
|
||||
if disable_autograd and is_backward_op(decl):
|
||||
continue
|
||||
op_name = op_name_without_overload(decl)
|
||||
if selected_op_list is not None and op_name not in selected_op_list:
|
||||
op_name = op_name_with_overload(decl)
|
||||
if operator_selector.is_root_operator(op_name):
|
||||
result.append(decl)
|
||||
else:
|
||||
if force_schema_registration:
|
||||
decl['emit_dummy_placeholder'] = True
|
||||
else:
|
||||
continue
|
||||
result.append(decl)
|
||||
result.append(decl)
|
||||
|
||||
return result
|
||||
|
||||
# This function declares an order on declarations. This is necessary because
|
||||
|
|
@ -467,7 +469,7 @@ def gen_unboxing_wrappers(
|
|||
reorder_out_args(decl)
|
||||
|
||||
jit_decls.extend(additional_jit_decls)
|
||||
jit_decls = filter_decls(jit_decls, disable_autograd, selected_op_list, force_schema_registration)
|
||||
jit_decls = filter_decls(jit_decls, disable_autograd, operator_selector, force_schema_registration)
|
||||
|
||||
# generation is deterministic
|
||||
jit_decl_groups = sort_decls(jit_decls)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ source_files = {'.py', '.cpp', '.h'}
|
|||
|
||||
DECLARATIONS_PATH = 'torch/share/ATen/Declarations.yaml'
|
||||
|
||||
|
||||
# TODO: This is a little inaccurate, because it will also pick
|
||||
# up setup_helper scripts which don't affect code generation
|
||||
def all_generator_source():
|
||||
|
|
@ -25,16 +24,13 @@ def generate_code(ninja_global=None,
|
|||
install_dir=None,
|
||||
subset=None,
|
||||
disable_autograd=False,
|
||||
selected_op_list_path=None,
|
||||
selected_op_list=None,
|
||||
force_schema_registration=False):
|
||||
# cwrap depends on pyyaml, so we can't import it earlier
|
||||
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.insert(0, root)
|
||||
force_schema_registration=False,
|
||||
operator_selector=None):
|
||||
from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python
|
||||
from tools.autograd.gen_annotated_fn_args import gen_annotated
|
||||
from tools.autograd.utils import load_op_list_and_strip_overload
|
||||
from tools.jit.gen_unboxing_wrappers import gen_unboxing_wrappers
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
||||
# Build ATen based Variable classes
|
||||
if install_dir is None:
|
||||
|
|
@ -55,22 +51,24 @@ def generate_code(ninja_global=None,
|
|||
if subset == "pybindings" or not subset:
|
||||
gen_autograd_python(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, autograd_dir)
|
||||
|
||||
if operator_selector is None:
|
||||
operator_selector = SelectiveBuilder.get_nop_selector()
|
||||
|
||||
if subset == "libtorch" or not subset:
|
||||
selected_op_list = load_op_list_and_strip_overload(selected_op_list, selected_op_list_path)
|
||||
|
||||
gen_autograd(
|
||||
declarations_path or DECLARATIONS_PATH,
|
||||
autograd_gen_dir,
|
||||
autograd_dir,
|
||||
disable_autograd=disable_autograd,
|
||||
selected_op_list=selected_op_list,
|
||||
operator_selector=operator_selector,
|
||||
)
|
||||
gen_unboxing_wrappers(
|
||||
declarations_path or DECLARATIONS_PATH,
|
||||
jit_gen_dir,
|
||||
tools_jit_templates,
|
||||
disable_autograd=disable_autograd,
|
||||
selected_op_list=selected_op_list,
|
||||
operator_selector=operator_selector,
|
||||
force_schema_registration=force_schema_registration)
|
||||
|
||||
if subset == "python" or not subset:
|
||||
|
|
@ -79,6 +77,56 @@ def generate_code(ninja_global=None,
|
|||
python_install_dir,
|
||||
autograd_dir)
|
||||
|
||||
def get_selector_from_legacy_operator_selection_list(
|
||||
selected_op_list_path: str,
|
||||
):
|
||||
from tools.autograd.utils import load_op_list_and_strip_overload
|
||||
|
||||
selected_op_list = load_op_list_and_strip_overload(
|
||||
None,
|
||||
selected_op_list_path,
|
||||
)
|
||||
|
||||
# Internal build doesn't use this flag any more. Only used by OSS
|
||||
# build now. Every operator should be considered a root operator
|
||||
# (hence generating unboxing code for it, which is consistent with
|
||||
# the current behaviour), and also be considered as used for
|
||||
# training, since OSS doesn't support training on mobile for now.
|
||||
#
|
||||
is_root_operator = True
|
||||
is_used_for_training = True
|
||||
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
selector: SelectiveBuilder = SelectiveBuilder.get_nop_selector()
|
||||
if selected_op_list is not None:
|
||||
selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||
selected_op_list,
|
||||
is_root_operator,
|
||||
is_used_for_training,
|
||||
)
|
||||
|
||||
return selector
|
||||
|
||||
|
||||
def get_selector(selected_op_list_path, operators_yaml_path):
|
||||
# cwrap depends on pyyaml, so we can't import it earlier
|
||||
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.insert(0, root)
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
assert not (selected_op_list_path is not None and
|
||||
operators_yaml_path is not None), \
|
||||
("Expected at most one of selected_op_list_path and " +
|
||||
"operators_yaml_path to be set.")
|
||||
|
||||
if selected_op_list_path is None and operators_yaml_path is None:
|
||||
return SelectiveBuilder.get_nop_selector()
|
||||
elif selected_op_list_path is not None:
|
||||
return get_selector_from_legacy_operator_selection_list(selected_op_list_path)
|
||||
else:
|
||||
return SelectiveBuilder.from_yaml_path(operators_yaml_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Autogenerate code')
|
||||
|
|
@ -101,11 +149,8 @@ def main():
|
|||
help='Path to the yaml file that contains the list of operators to include for custom build.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--selected-op-list',
|
||||
nargs="*",
|
||||
type=str,
|
||||
help="""List of operator names to include for custom build, in addition to those in selected-op-list-path.
|
||||
For example, --selected-op-list aten::add.Tensor aten::_convolution.""",
|
||||
'--operators_yaml_path',
|
||||
help='Path to the model YAML file that contains the list of operators to include for custom build.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--force_schema_registration',
|
||||
|
|
@ -114,6 +159,7 @@ def main():
|
|||
'listed on --selected-op-list'
|
||||
)
|
||||
options = parser.parse_args()
|
||||
|
||||
generate_code(
|
||||
options.ninja_global,
|
||||
options.declarations_path,
|
||||
|
|
@ -121,9 +167,9 @@ def main():
|
|||
options.install_dir,
|
||||
options.subset,
|
||||
options.disable_autograd,
|
||||
options.selected_op_list_path,
|
||||
options.selected_op_list,
|
||||
options.force_schema_registration,
|
||||
# options.selected_op_list
|
||||
operator_selector=get_selector(options.selected_op_list_path, options.operators_yaml_path),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user