mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/76275 In preparation for addressing https://github.com/pytorch/pytorch/issues/73212 Diff was generated with: ``` git mv tools/codegen torchgen git grep -l 'tools.codegen' | xargs sed -i 's/tools.codegen/torchgen/g' sed -i "s/\${TOOLS_PATH}\/codegen/\${TORCH_ROOT}\/torchgen/g" caffe2/CMakeLists.txt ``` and a manual edits to: * tools/test/test_gen_backend_stubs.py * torchgen/build.bzl * torchgen/gen_backend_stubs.py aka this diff: ``` diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py index 3dc26c6d2d..104054575e 100644 --- a/tools/test/test_gen_backend_stubs.py +++ b/tools/test/test_gen_backend_stubs.py @@ -9,7 +9,7 @@ from torchgen.gen_backend_stubs import run from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401 path = os.path.dirname(os.path.realpath(__file__)) -gen_backend_stubs_path = os.path.join(path, '../torchgen/gen_backend_stubs.py') +gen_backend_stubs_path = os.path.join(path, '../../torchgen/gen_backend_stubs.py') # gen_backend_stubs.py is an integration point that is called directly by external backends. # The tests here are to confirm that badly formed inputs result in reasonable error messages. diff --git a/torchgen/build.bzl b/torchgen/build.bzl index ed04e35a43..d00078a3cf 100644 --- a/torchgen/build.bzl +++ b/torchgen/build.bzl @@ -1,6 +1,6 @@ def define_targets(rules): rules.py_library( - name = "codegen", + name = "torchgen", srcs = rules.glob(["**/*.py"]), deps = [ rules.requirement("PyYAML"), @@ -11,6 +11,6 @@ def define_targets(rules): rules.py_binary( name = "gen", - srcs = [":codegen"], + srcs = [":torchgen"], visibility = ["//visibility:public"], ) diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index c1a672a655..beee7a15e0 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -474,7 +474,7 @@ def run( ) -> None: # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py - pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute() + pytorch_root = pathlib.Path(__file__).parent.parent.absolute() template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates") def make_file_manager(install_dir: str) -> FileManager: ``` run_all_fbandroid_tests Test Plan: sandcastle Reviewed By: albanD, ngimel Differential Revision: D35770317 fbshipit-source-id: 153ac4a7fef15b1e750812a90bfafdbc8f1ebcdf (cherry picked from commit c6d485d1d4648fa1c8a4c14c5bf3d8e899b9b4dd)
173 lines
6.4 KiB
Python
173 lines
6.4 KiB
Python
from typing import Dict, Optional, Tuple
|
|
from dataclasses import dataclass
|
|
|
|
# This class holds information about a single operator used to determine
|
|
# the outcome of a selective/custom PyTorch build that doesn't include
|
|
# registration code for all the supported operators. This is done to
|
|
# reduce the size of the generated binary so that it can be deployed in
|
|
# situations where binary size comes at a premium.
|
|
#
|
|
@dataclass(frozen=True)
|
|
class SelectiveBuildOperator:
|
|
# The name of the operator. This includes the aten::, etc... prefix
|
|
# The operator name may or may not have the overload name. If this
|
|
# operator name does not specify an overload name, the way to determine
|
|
# if this entry refers to the family of operators with this base name
|
|
# or just the operator with this name is to look at the value of the
|
|
# 'include_all_overloads' flag in this class.
|
|
name: str
|
|
|
|
# True if this is a root operator (i.e. called directly from a
|
|
# TorchScript model, etc...). An operator is considered to be a
|
|
# root operator if it is called directly from any one of the models
|
|
# that this instance of the pytorch library was built for. Hence, it
|
|
# may not be a root operator in all of the models that are used in
|
|
# this instance of the pytorch library.
|
|
is_root_operator: bool
|
|
|
|
# Is this operator used for on-device training? If True, then we need to
|
|
# use the information to generate code in VariableType_N.cpp for registration
|
|
# of training related operators. Again, this is True if this operator
|
|
# is used for training in one or more models used by this instance of the
|
|
# pytorch library.
|
|
is_used_for_training: bool
|
|
|
|
# If True, it indicates that this operator instance (object) refers to an
|
|
# operator without the overload name and should apply to all overloads
|
|
# which have this operator name as the base name. This flag is applicable
|
|
# only for objects that have operator names without a DOT (period) character
|
|
# in them.
|
|
#
|
|
# Note: This flag is a temporary workaround to grandfather in the current
|
|
# static selective (custom) build mechanism, which largely ignores overload
|
|
# names when determining whether to select operators for registration
|
|
# purposes.
|
|
include_all_overloads: bool
|
|
|
|
# Debug Information at the operator level
|
|
_debug_info: Optional[Tuple[str, ...]]
|
|
|
|
@staticmethod
|
|
def from_yaml_dict(
|
|
op_name: str, op_info: Dict[str, object]
|
|
) -> "SelectiveBuildOperator":
|
|
allowed_keys = {
|
|
"name",
|
|
"is_root_operator",
|
|
"is_used_for_training",
|
|
"include_all_overloads",
|
|
"debug_info",
|
|
}
|
|
|
|
if len(set(op_info.keys()) - allowed_keys) > 0:
|
|
raise Exception(
|
|
"Got unexpected top level keys: {}".format(
|
|
",".join(set(op_info.keys()) - allowed_keys),
|
|
)
|
|
)
|
|
|
|
if "name" in op_info:
|
|
assert op_name == op_info["name"]
|
|
|
|
is_root_operator = op_info.get("is_root_operator", True)
|
|
assert isinstance(is_root_operator, bool)
|
|
|
|
is_used_for_training = op_info.get("is_used_for_training", True)
|
|
assert isinstance(is_used_for_training, bool)
|
|
|
|
include_all_overloads = op_info.get("include_all_overloads", True)
|
|
assert isinstance(include_all_overloads, bool)
|
|
|
|
debug_info: Optional[Tuple[str, ...]] = None
|
|
if "debug_info" in op_info:
|
|
di_list = op_info["debug_info"]
|
|
assert isinstance(di_list, list)
|
|
debug_info = tuple(map(lambda x: str(x), di_list))
|
|
|
|
return SelectiveBuildOperator(
|
|
name=op_name,
|
|
is_root_operator=is_root_operator,
|
|
is_used_for_training=is_used_for_training,
|
|
include_all_overloads=include_all_overloads,
|
|
_debug_info=debug_info,
|
|
)
|
|
|
|
@staticmethod
|
|
def from_legacy_operator_name_without_overload(
|
|
name: str,
|
|
) -> "SelectiveBuildOperator":
|
|
return SelectiveBuildOperator(
|
|
name=name,
|
|
is_root_operator=True,
|
|
is_used_for_training=True,
|
|
include_all_overloads=True,
|
|
_debug_info=None,
|
|
)
|
|
|
|
def to_dict(self) -> Dict[str, object]:
|
|
ret: Dict[str, object] = {
|
|
"is_root_operator": self.is_root_operator,
|
|
"is_used_for_training": self.is_used_for_training,
|
|
"include_all_overloads": self.include_all_overloads,
|
|
}
|
|
if self._debug_info is not None:
|
|
ret["debug_info"] = self._debug_info
|
|
|
|
return ret
|
|
|
|
|
|
def merge_debug_info(
|
|
lhs: Optional[Tuple[str, ...]],
|
|
rhs: Optional[Tuple[str, ...]],
|
|
) -> Optional[Tuple[str, ...]]:
|
|
# Ensure that when merging, each entry shows up just once.
|
|
if lhs is None and rhs is None:
|
|
return None
|
|
|
|
return tuple(set((lhs or ()) + (rhs or ())))
|
|
|
|
|
|
def combine_operators(
|
|
lhs: "SelectiveBuildOperator", rhs: "SelectiveBuildOperator"
|
|
) -> "SelectiveBuildOperator":
|
|
if str(lhs.name) != str(rhs.name):
|
|
raise Exception(
|
|
"Expected both arguments to have the same name, but got '{}' and '{}' instead".format(
|
|
str(lhs.name),
|
|
str(rhs.name),
|
|
)
|
|
)
|
|
|
|
return SelectiveBuildOperator(
|
|
name=lhs.name,
|
|
# Consider this operator to be a root operator if it is a
|
|
# root operator in any of the models used in this instance of
|
|
# the pytorch library.
|
|
is_root_operator=lhs.is_root_operator or rhs.is_root_operator,
|
|
# Consider this operator to be a training operator if it is
|
|
# an operator used for training in any of the models used
|
|
# in this instance of the pytorch library.
|
|
is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training,
|
|
include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads,
|
|
_debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info),
|
|
)
|
|
|
|
|
|
def merge_operator_dicts(
|
|
lhs: Dict[str, SelectiveBuildOperator],
|
|
rhs: Dict[str, SelectiveBuildOperator],
|
|
) -> Dict[str, SelectiveBuildOperator]:
|
|
operators: Dict[str, SelectiveBuildOperator] = {}
|
|
for (op_name, op) in list(lhs.items()) + list(rhs.items()):
|
|
new_op = op
|
|
if op_name in operators:
|
|
new_op = combine_operators(operators[op_name], op)
|
|
|
|
operators[op_name] = new_op
|
|
|
|
return operators
|
|
|
|
|
|
def strip_operator_overload_name(op_name: str) -> str:
|
|
return op_name.split(".")[0]
|