pytorch/tools/code_analyzer/gen_transitive_deps.py
Jiakai Liu 3c042a6ab9 [pytorch][mobile] support for custom mobile build with dynamic dispatch (#34055)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34055

Enable custom mobile build with dynamic dispatch for OSS build.

It calls a python util script to calculate transitive dependencies from
the op dependency graph and the list of used root ops, then pass the
result as the op registration whitelist to aten codegen, so that only
these used ops are registered and kept at link time.

For custom build with dynamic dispatch to work correctly, it's critical
to have the accurate list of used ops. Current assumption is that only
those ops referenced by TorchScript model are used. It works well if
client code doesn't call libtorch API (e.g.  tensor methods) directly;
otherwise the extra used ops need to be added to the whitelist manually,
as shown by the HACK in prepare_model.py.

Also, if JIT starts calling extra ops independent of specific model,
then the extra ops need to be added to the whitelist as well.

Verified the correctness of the whole process with MobileNetV2:
```
TEST_CUSTOM_BUILD_DYNAMIC=1 test/mobile/custom_build/build.sh
```

Test Plan: Imported from OSS

Reviewed By: bhosmer

Differential Revision: D20193327

Pulled By: ljk53

fbshipit-source-id: 9d369b8864856b098342aea79e0ac8eec04149aa
2020-03-03 19:25:16 -08:00

65 lines
1.8 KiB
Python

"""
This util takes the op dependency graph of ATen and the list of root ops, and
outputs all transitive dependencies of the root ops. It is invoked from cmake
for custom mobile build.
"""
import argparse
import yaml
from collections import defaultdict
def canonical_name(opname):
# Skip the overload name part as it's not supported by code analyzer yet.
return opname.split('.', 1)[0]
def load_op_dep_graph(fname):
with open(fname, 'r') as stream:
result = defaultdict(set)
for op in yaml.safe_load(stream):
op_name = canonical_name(op['name'])
for dep in op.get('depends', []):
dep_name = canonical_name(dep['name'])
result[op_name].add(dep_name)
return result
def load_root_ops(fname):
result = []
with open(fname, 'r') as stream:
for op in yaml.safe_load(stream):
result.append(canonical_name(op))
return result
def gen_transitive_closure(dep_graph, root_ops):
result = set(root_ops)
queue = root_ops[:]
while queue:
cur = queue.pop()
for dep in dep_graph.get(cur, []):
if dep not in result:
result.add(dep)
queue.append(dep)
return ' '.join(sorted(result))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Util to produce transitive dependencies for custom build')
parser.add_argument(
'--op-dependency',
help='input yaml file of op dependency graph')
parser.add_argument(
'--root-ops',
help='input yaml file of root (directly used) operators')
args = parser.parse_args()
deps = load_op_dep_graph(args.op_dependency)
root_ops = load_root_ops(args.root_ops)
print(gen_transitive_closure(deps, root_ops))