From 01b662bafe54dfda561d442015dc512edf8b1564 Mon Sep 17 00:00:00 2001 From: Zhicheng Yan Date: Sat, 2 Sep 2023 17:37:36 +0000 Subject: [PATCH] [gen_operators_yaml] add arguments to control include_all_overloads (#108396) Summary: In SelectiveBuildOperator, we can specify argument `include_all_overloads`. If True, all overloaded operators (for example, `aten::to.dtype_layout`, `aten::to.prim_Device"` are considered as overloaded operators of `aten::to`), will be built and linked to the final binary. This can significantly increases the final binary size, which could be a deal breaker for on-device deployment. In this diff, we make back-compatible changes to add new arguments `--not-include-all-overloads-static-root-ops` and `--not-include-all-overloads-closure-ops`. When they are set, we set `include_all_overloads` flag to False for static root ops and closure ops, and rely on code analyzer to decide the actual used overloaded operator. Test Plan: - unit test ``` buck test //xplat/caffe2/tools:gen_operators_yaml_test ``` - See test plan in D48771544 where we reduce the shared lib file `libmrengine.lib` from 16653072 bytes to 13686032 bytes. - See detailed document: https://fburl.com/gdoc/mc93h6kb Reviewed By: larryliu0820 Differential Revision: D48772302 Pull Request resolved: https://github.com/pytorch/pytorch/pull/108396 Approved by: https://github.com/larryliu0820 --- pt_ops.bzl | 15 +++++- tools/code_analyzer/gen_operators_yaml.py | 34 +++++++++++-- tools/test/gen_operators_yaml_test.py | 60 ++++++++++++++++++++++- 3 files changed, 101 insertions(+), 8 deletions(-) diff --git a/pt_ops.bzl b/pt_ops.bzl index c680140366c..be77313cc91 100644 --- a/pt_ops.bzl +++ b/pt_ops.bzl @@ -1,5 +1,5 @@ -load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native") load("//tools/build_defs:expect.bzl", "expect") +load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native") load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") @@ -72,6 +72,13 @@ def pt_operator_library( # was hand-crafted which is not a support workflow for traced ops. yaml_option = "--models_yaml_path $(location fbsource//xplat/pytorch_models/build/{}/v{}:{})/{}.yaml".format(model_name, model_version, yaml_dep, model_asset) + not_include_all_overloads_static_root_ops = kwargs.pop( + "not_include_all_overloads_static_root_ops", + False, + ) + + not_include_all_overloads_closure_ops = kwargs.pop("not_include_all_overloads_closure_ops", False) + fb_xplat_genrule( name = name, out = "model_operators.yaml", @@ -87,7 +94,9 @@ def pt_operator_library( "{optionally_model_versions} " + "{optionally_model_assets} " + "{optionally_model_traced_backends} " + - "{optionally_include_all_operators}" + "{optionally_include_all_operators}" + + "{not_include_all_overloads_static_root_ops}" + + "{not_include_all_overloads_closure_ops}" ).format( exe = "//tools:gen_operators_yaml" if IS_OSS else "fbsource//xplat/caffe2/tools:gen_operators_yaml", rule_name = name, @@ -100,6 +109,8 @@ def pt_operator_library( optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "", optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "", optionally_include_all_operators = "--include_all_operators " if include_all_operators else "", + not_include_all_overloads_static_root_ops = "--not_include_all_overloads_static_root_ops " if not_include_all_overloads_static_root_ops else "", + not_include_all_overloads_closure_ops = "--not_include_all_overloads_closure_ops " if not_include_all_overloads_closure_ops else "", ), labels = labels + [ "pt_operator_library", diff --git a/tools/code_analyzer/gen_operators_yaml.py b/tools/code_analyzer/gen_operators_yaml.py index c5ca98cf364..aedb2acccbd 100644 --- a/tools/code_analyzer/gen_operators_yaml.py +++ b/tools/code_analyzer/gen_operators_yaml.py @@ -348,7 +348,7 @@ def fill_output(output: Dict[str, object], options: object): { "is_root_operator": True, "is_used_for_training": False, - "include_all_overloads": True, + "include_all_overloads": not options.not_include_all_overloads_static_root_ops, "debug_info": [options.model_name], }, ) @@ -362,7 +362,7 @@ def fill_output(output: Dict[str, object], options: object): { "is_root_operator": False, "is_used_for_training": False, - "include_all_overloads": True, + "include_all_overloads": not options.not_include_all_overloads_closure_ops, "debug_info": [options.model_name], }, ) @@ -489,7 +489,7 @@ def fill_output(output: Dict[str, object], options: object): output["kernel_metadata"] = kernel_metadata -def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace: +def add_arguments_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument( "--root-ops", "--root_ops", @@ -567,8 +567,32 @@ def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace: help="The name of pt_operator_library rule resulting in this generation", required=True, ) - options = parser.parse_args() - return options + parser.add_argument( + "--not-include-all-overloads-static-root-ops", + "--not_include_all_overloads_static_root_ops", + action="store_true", + default=False, + help="Set this flag to not include all overloaded operators for static root ops bucket in fill_output() subroutine", + required=False, + ) + parser.add_argument( + "--not-include-all-overloads-closure-ops", + "--not_include_all_overloads_closure_ops", + action="store_true", + default=False, + help="Set this flag to not include all overloaded operators for closure ops bucket in fill_output() subroutine", + required=False, + ) + return parser + + +def parse_options(parser: argparse.ArgumentParser) -> argparse.Namespace: + return parser.parse_args() + + +def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace: + parser = add_arguments_parser(parser) + return parse_options(parser) def main(argv) -> None: diff --git a/tools/test/gen_operators_yaml_test.py b/tools/test/gen_operators_yaml_test.py index 3c57b2a4748..956ec0f3412 100644 --- a/tools/test/gen_operators_yaml_test.py +++ b/tools/test/gen_operators_yaml_test.py @@ -1,9 +1,45 @@ #!/usr/bin/env python3 # Copyright 2004-present Facebook. All Rights Reserved. +import argparse +import json import unittest +from collections import defaultdict -from gen_operators_yaml import make_filter_from_options, verify_all_specified_present +from unittest.mock import Mock, patch + +from gen_operators_yaml import ( + fill_output, + get_parser_options, + make_filter_from_options, + verify_all_specified_present, +) + + +def _mock_options(): + options = argparse.Namespace() + options.root_ops = "aten::add,aten::cat" + options.training_root_ops = [] + options.output_path = "/tmp" + options.dep_graph_yaml_path = "dummy_pytorch_op_deps.yaml" + options.model_name = "test_model" + options.model_versions = None + options.model_assets = None + options.model_backends = None + options.models_yaml_path = None + options.include_all_operators = False + options.rule_name = "test_rule" + options.not_include_all_overloads_static_root_ops = True + options.not_include_all_overloads_closure_ops = True + + return options + + +def _mock_load_op_dep_graph(): + result = defaultdict(set) + result["aten::add"] = {"aten::add", "aten::as_strided_"} + result["aten::cat"] = {"aten::cat", "aten::as_strided_"} + return dict(result) class GenOperatorsYAMLTest(unittest.TestCase): @@ -186,3 +222,25 @@ class GenOperatorsYAMLTest(unittest.TestCase): model_name="abcd", new_style_rule=True, ) + + @patch("gen_operators_yaml.parse_options", return_value=_mock_options()) + @patch( + "gen_operators_yaml.load_op_dep_graph", return_value=_mock_load_op_dep_graph() + ) + def test_fill_output_with_arguments_not_include_all_overloads( + self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock + ): + parser = argparse.ArgumentParser(description="Generate used operators YAML") + options = get_parser_options(parser) + + model_dict = { + "model_name": options.model_name, + "asset_info": {}, + "is_new_style_rule": False, + } + output = {"debug_info": [json.dumps(model_dict)]} + + fill_output(output, options) + + for op_val in output["operators"].values(): + self.assertFalse(op_val["include_all_overloads"])