From 24f882369a6172e13dfa7ffe569ddf52e17abe04 Mon Sep 17 00:00:00 2001 From: Ian Graves Date: Tue, 18 Apr 2023 17:19:55 +0000 Subject: [PATCH] [EdgeML] Remove dependency on all_mobile_model_configs.yaml from pt_operator_library BUCK rule (#99122) Summary: Removes the dependency on the unified YAML file Test Plan: Smoke test via some caffe2 tests. ``` buck2 run xplat/caffe2:supported_mobile_models_test ``` Build a major FoA app that uses model tracing and confirm it still works. ``` buck2 build fb4a ``` CI/CD for the rest. If operator tracing / bundling was broken, I'd hope in the 1000+ tests spawned by this change should catch it. Differential Revision: D44946368 Pull Request resolved: https://github.com/pytorch/pytorch/pull/99122 Approved by: https://github.com/dhruvbird --- pt_ops.bzl | 39 +++++++++++++++++++++-- tools/code_analyzer/gen_operators_yaml.py | 21 ++++++------ 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/pt_ops.bzl b/pt_ops.bzl index 30328d55873..26a859b1e27 100644 --- a/pt_ops.bzl +++ b/pt_ops.bzl @@ -1,3 +1,4 @@ +load("@fbsource//tools/build_defs:fb_native_wrapper.bzl", "fb_native") load("//tools/build_defs:expect.bzl", "expect") load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") @@ -37,6 +38,40 @@ def pt_operator_library( labels = kwargs.pop("labels", []) visibility = kwargs.pop("visibility", ["PUBLIC"]) + # Sanity check the model name and versions. While the input to both is an array, the + # codegen script only ever outputs a single item in the array so we can just assume that + # here. If you ever need to depends on more than one assets, just break it up into a separate + # BUCK targets. + if model_assets or model_versions: + if len(model_assets) != 1: + fail("Model assets must be of size 1") + if len(model_versions) != 1: + fail("Model versions must be of size 1") + + # Is this a traced operator therefore has a YAML file with ops? + yaml_option = "" + if model_assets and len(model_assets) > 0: + # We know these lists are only of length 1 via earlier assert. + model_asset = model_assets[0] + model_version = model_versions[0] + + # Pass the YAML file from this asset to the genrule below. + yaml_dep = "{}_v{}_yaml".format(model_asset, model_version) + fb_native.filegroup( + name = yaml_dep, + srcs = [ + model_asset + ".yaml", + ], + # The visibility is not set to PUBLIC as this an internal detail. If you see this error + # in your buck build flow, you are trying to use a hand-crafted "pt_operator_library" that + # with parameters not supported outside of codegen targets! + ) + + # Since all selective traced ops are created by automation, we can assume they + # have a YAML file at this very location. If it doesn't exist, it means the targets + # was hand-crafted which is not a support workflow for traced ops. + yaml_option = "--models_yaml_path $(location fbsource//xplat/pytorch_models/build/{}/v{}:{})/{}.yaml".format(model_name, model_version, yaml_dep, model_asset) + fb_xplat_genrule( name = name, out = "model_operators.yaml", @@ -48,7 +83,7 @@ def pt_operator_library( "--output_path \"${{OUT}}\" " + "--model_name {model_name} " + "--dep_graph_yaml_path {dep_graph_yaml} " + - "--models_yaml_path {models_yaml} " + + "{optionally_model_yamls} " + "{optionally_model_versions} " + "{optionally_model_assets} " + "{optionally_model_traced_backends} " + @@ -58,7 +93,7 @@ def pt_operator_library( rule_name = name, model_name = model_name, dep_graph_yaml = "none" if IS_OSS else "$(location fbsource//xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ", - models_yaml = "none" if IS_OSS else "$(location fbsource//xplat/pytorch_models/build:all_mobile_model_configs)/all_mobile_model_configs.yaml ", + optionally_model_yamls = "" if (IS_OSS or yaml_option == None) else yaml_option, optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "", optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "", optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "", diff --git a/tools/code_analyzer/gen_operators_yaml.py b/tools/code_analyzer/gen_operators_yaml.py index b0be5efb52a..c5ca98cf364 100644 --- a/tools/code_analyzer/gen_operators_yaml.py +++ b/tools/code_analyzer/gen_operators_yaml.py @@ -73,12 +73,11 @@ from torchgen.selective_build.selector import merge_kernel_metadata # # 4. Model Metadata (--model-name, --model-versions, --model-assets, # --model-backends): Self-descriptive. These are used to tell this -# script which model operator lists to fetch from the Unified Model -# Build Metadata YAML file. +# script which model operator lists to fetch from the Model +# Build Metadata YAML files. # -# 5. Unified Model YAML file (--models-yaml-path): A path to the Unified -# model YAML operator list file. This yaml file contains (for each -# model/version/asset/backend) the set of used root and traced +# 5. Model YAML files (--models-yaml-path): These yaml files contains +# (for each model/version/asset/backend) the set of used root and traced # operators. This is used to extract the actual set of operators # needed to be included in the build. # @@ -214,8 +213,11 @@ def fill_output(output: Dict[str, object], options: object): options.model_assets.split(",") if options.model_assets is not None else None ) - with open(options.models_yaml_path, "rb") as models_yaml_file: - all_models_yaml = yaml.safe_load(models_yaml_file) or [] + all_models_yaml = [] + if options.models_yaml_path: + for yaml_path in options.models_yaml_path: + with open(yaml_path, "rb") as f: + all_models_yaml.append(yaml.safe_load(f)) model_filter_func = make_filter_from_options( options.model_name, model_versions, model_assets, options.model_backends @@ -546,8 +548,9 @@ def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace: "--models-yaml-path", "--models_yaml_path", type=str, - help="The path to where the unified Mobile Model Config YAML resides.", - required=True, + help="The paths to the mobile model config YAML files.", + required=False, + nargs="+", ) parser.add_argument( "--include-all-operators",