mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37404 Many aten operators are really like util functions, e.g.: aten::is_nonzero, aten::is_floating_point, etc. These ops can be called via overloaded c++ operator, so seemingly trivial and innocent code changes can affect how these ops are used by other ops (thus changes the output of static analyzer). Most of these util ops are rather small in terms of build size cost, so for the purpose of optimizing binary size with custom build, whether to include these ops or not does not make significant difference. In fact for non-trivial models a set of these ops are almost always used. This PR introduced the (optional) '__BASE__' ops section to the dependency graph. We can maintain the list of frequently used small util ops for internal BUCK build. This way, the output dependency graph will only contain meaningful edges with significant binary size impact, and it will be more stable from trivial code changes (which is checked in FB codebase). Having a stable and sparse deps graph by factoring out frequently used based ops is also a nice property to allow us to explore alternative custom build solutions in case we find it hard to maintain the static code analyzer. Test Plan: Imported from OSS Differential Revision: D21280835 Pulled By: ljk53 fbshipit-source-id: c4d0d1f07ca868c60f23118d877fc1eeead4c875
134 lines
3.5 KiB
Python
134 lines
3.5 KiB
Python
"""
|
|
This util is used to parse op_deps_pass output (in yaml) and convert it into
|
|
other formats for downstream use cases. It is not used by OSS cmake build.
|
|
|
|
To run this file by hand from the root of the PyTorch repository, run:
|
|
|
|
python -m tools.code_analyzer.op_deps_processor \
|
|
--op-dependency build_code_analyzer/work/torch_result.yaml \
|
|
--output pt_deps.bzl
|
|
"""
|
|
|
|
import argparse
|
|
import yaml
|
|
|
|
from ..autograd.utils import CodeTemplate
|
|
|
|
BAZEL_OUTPUT = CodeTemplate("""\
|
|
# Generated for selective build without using static dispatch.
|
|
# Manually run the script to update:
|
|
# ANALYZE_TORCH=1 DEPLOY=1 tools/code_analyzer/build.sh
|
|
TORCH_DEPS = {
|
|
${ops}
|
|
}
|
|
""")
|
|
|
|
BAZEL_OP = CodeTemplate("""\
|
|
"${op_name}": [
|
|
${op_deps}
|
|
],
|
|
""")
|
|
|
|
BAZEL_OP_DEP = CodeTemplate("""\
|
|
"${dep_name}",
|
|
""")
|
|
|
|
DOT_OUTPUT = CodeTemplate("""\
|
|
digraph {
|
|
layout="circo";
|
|
${ops}
|
|
}
|
|
""")
|
|
|
|
DOT_OP = CodeTemplate("""\
|
|
${op_deps}
|
|
""")
|
|
|
|
DOT_OP_DEP = CodeTemplate("""\
|
|
"${op_name}" -> "${dep_name}";
|
|
""")
|
|
|
|
|
|
def load_op_deps(fname):
|
|
with open(fname, 'r') as stream:
|
|
return yaml.safe_load(stream)
|
|
|
|
|
|
def process_base_ops(graph, base_ops):
|
|
# remove base ops from all `depends` lists to compress the output graph
|
|
for op in graph:
|
|
op['depends'] = [
|
|
dep for dep in op.get('depends', []) if dep['name'] not in base_ops
|
|
]
|
|
|
|
# add base ops section at the beginning
|
|
graph.insert(0, {
|
|
'name': '__BASE__',
|
|
'depends': [{'name': name} for name in base_ops]})
|
|
|
|
|
|
def convert(fname, graph, output_template, op_template, op_dep_template):
|
|
ops = []
|
|
for op in graph:
|
|
op_name = op['name']
|
|
op_deps = []
|
|
|
|
for dep in op.get('depends', []):
|
|
dep_name = dep['name']
|
|
if dep_name == op_name:
|
|
# skip itself reference
|
|
continue
|
|
op_deps.append(
|
|
op_dep_template.substitute(
|
|
op_name=op_name,
|
|
dep_name=dep_name))
|
|
|
|
if not op_deps:
|
|
# skip ops without any fanout
|
|
continue
|
|
|
|
ops.append(
|
|
op_template.substitute(
|
|
op_name=op_name,
|
|
op_deps=op_deps))
|
|
|
|
with open(fname, 'w') as out:
|
|
out.write(output_template.substitute(ops=ops))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description='Util to parse & convert op_deps_pass output')
|
|
parser.add_argument(
|
|
'--op_dependency',
|
|
required=True,
|
|
help='input yaml file of op dependency graph produced by op_deps_pass')
|
|
parser.add_argument(
|
|
'--format',
|
|
default='bazel',
|
|
help='output file format [bazel, dot]')
|
|
parser.add_argument(
|
|
'--base_ops',
|
|
nargs='*',
|
|
help='optional list of `base` ops that should always be kept in '
|
|
'custom build, to make the output stable from trivial changes; '
|
|
'each item is `namespace`::`operator name` without overload; '
|
|
'e.g.: aten::empty aten::size ...')
|
|
parser.add_argument(
|
|
'--output',
|
|
required=True,
|
|
help='output file')
|
|
args = parser.parse_args()
|
|
|
|
deps = load_op_deps(args.op_dependency)
|
|
|
|
if args.base_ops:
|
|
process_base_ops(deps, args.base_ops)
|
|
|
|
if args.format == 'bazel':
|
|
convert(args.output, deps, BAZEL_OUTPUT, BAZEL_OP, BAZEL_OP_DEP)
|
|
elif args.format == 'dot':
|
|
convert(args.output, deps, DOT_OUTPUT, DOT_OP, DOT_OP_DEP)
|
|
else:
|
|
raise Exception("Unknown output format: " + args.format)
|