mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34055 Enable custom mobile build with dynamic dispatch for OSS build. It calls a python util script to calculate transitive dependencies from the op dependency graph and the list of used root ops, then pass the result as the op registration whitelist to aten codegen, so that only these used ops are registered and kept at link time. For custom build with dynamic dispatch to work correctly, it's critical to have the accurate list of used ops. Current assumption is that only those ops referenced by TorchScript model are used. It works well if client code doesn't call libtorch API (e.g. tensor methods) directly; otherwise the extra used ops need to be added to the whitelist manually, as shown by the HACK in prepare_model.py. Also, if JIT starts calling extra ops independent of specific model, then the extra ops need to be added to the whitelist as well. Verified the correctness of the whole process with MobileNetV2: ``` TEST_CUSTOM_BUILD_DYNAMIC=1 test/mobile/custom_build/build.sh ``` Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D20193327 Pulled By: ljk53 fbshipit-source-id: 9d369b8864856b098342aea79e0ac8eec04149aa
65 lines
1.8 KiB
Python
65 lines
1.8 KiB
Python
"""
|
|
This util takes the op dependency graph of ATen and the list of root ops, and
|
|
outputs all transitive dependencies of the root ops. It is invoked from cmake
|
|
for custom mobile build.
|
|
"""
|
|
|
|
import argparse
|
|
import yaml
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
def canonical_name(opname):
|
|
# Skip the overload name part as it's not supported by code analyzer yet.
|
|
return opname.split('.', 1)[0]
|
|
|
|
|
|
def load_op_dep_graph(fname):
|
|
with open(fname, 'r') as stream:
|
|
result = defaultdict(set)
|
|
for op in yaml.safe_load(stream):
|
|
op_name = canonical_name(op['name'])
|
|
for dep in op.get('depends', []):
|
|
dep_name = canonical_name(dep['name'])
|
|
result[op_name].add(dep_name)
|
|
return result
|
|
|
|
|
|
def load_root_ops(fname):
|
|
result = []
|
|
with open(fname, 'r') as stream:
|
|
for op in yaml.safe_load(stream):
|
|
result.append(canonical_name(op))
|
|
return result
|
|
|
|
|
|
def gen_transitive_closure(dep_graph, root_ops):
|
|
result = set(root_ops)
|
|
queue = root_ops[:]
|
|
|
|
while queue:
|
|
cur = queue.pop()
|
|
for dep in dep_graph.get(cur, []):
|
|
if dep not in result:
|
|
result.add(dep)
|
|
queue.append(dep)
|
|
|
|
return ' '.join(sorted(result))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description='Util to produce transitive dependencies for custom build')
|
|
parser.add_argument(
|
|
'--op-dependency',
|
|
help='input yaml file of op dependency graph')
|
|
parser.add_argument(
|
|
'--root-ops',
|
|
help='input yaml file of root (directly used) operators')
|
|
args = parser.parse_args()
|
|
|
|
deps = load_op_dep_graph(args.op_dependency)
|
|
root_ops = load_root_ops(args.root_ops)
|
|
print(gen_transitive_closure(deps, root_ops))
|