mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Consolidate all python targets in the tools folder (#80408)
Summary: All buck targets that points to caffe2/tools folder are now moved to tools/BUCK. This also eliminates all python library/binary import in pt_defs.bzl, which caused T124308913. Test Plan: CI Differential Revision: D37468313 Pull Request resolved: https://github.com/pytorch/pytorch/pull/80408 Approved by: https://github.com/seemethere, https://github.com/malfet
This commit is contained in:
parent
70e86b4562
commit
b62d39eda0
10
.github/workflows/_buck-build-test.yml
vendored
10
.github/workflows/_buck-build-test.yml
vendored
|
|
@ -62,7 +62,15 @@ jobs:
|
||||||
command: |
|
command: |
|
||||||
sh scripts/buck_setup.sh
|
sh scripts/buck_setup.sh
|
||||||
|
|
||||||
- name: Build C10
|
- name: Build tools
|
||||||
|
run: |
|
||||||
|
buck build tools: --keep-going
|
||||||
|
|
||||||
|
- name: Run tools tests
|
||||||
|
run: |
|
||||||
|
buck test tools:selective_build_test tools:gen_oplist_test tools:gen_operators_yaml_test
|
||||||
|
|
||||||
|
- name: Build c10
|
||||||
run: |
|
run: |
|
||||||
buck build c10:c10
|
buck build c10:c10
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ exclude_patterns = [
|
||||||
'torch/lib/**',
|
'torch/lib/**',
|
||||||
'venv/**',
|
'venv/**',
|
||||||
'**/*.pyi',
|
'**/*.pyi',
|
||||||
|
'tools/test/test_selective_build.py',
|
||||||
]
|
]
|
||||||
command = [
|
command = [
|
||||||
'python3',
|
'python3',
|
||||||
|
|
@ -145,6 +146,10 @@ include_patterns = [
|
||||||
exclude_patterns = [
|
exclude_patterns = [
|
||||||
# (linbinyu) copied from internal repo
|
# (linbinyu) copied from internal repo
|
||||||
'tools/code_analyzer/gen_operators_yaml.py',
|
'tools/code_analyzer/gen_operators_yaml.py',
|
||||||
|
'tools/gen_vulkan_spv.py',
|
||||||
|
'tools/test/gen_operators_yaml_test.py',
|
||||||
|
'tools/test/gen_oplist_test.py',
|
||||||
|
'tools/test/test_selective_build.py',
|
||||||
]
|
]
|
||||||
command = [
|
command = [
|
||||||
'python3',
|
'python3',
|
||||||
|
|
@ -334,6 +339,7 @@ exclude_patterns = [
|
||||||
command = [
|
command = [
|
||||||
'python3',
|
'python3',
|
||||||
'tools/linter/adapters/grep_linter.py',
|
'tools/linter/adapters/grep_linter.py',
|
||||||
|
# @lint-ignore TXT2
|
||||||
'--pattern= ',
|
'--pattern= ',
|
||||||
'--linter-name=TABS',
|
'--linter-name=TABS',
|
||||||
'--error-name=saw some tabs',
|
'--error-name=saw some tabs',
|
||||||
|
|
@ -565,6 +571,9 @@ include_patterns = [
|
||||||
'torch/_decomp/**/*.py',
|
'torch/_decomp/**/*.py',
|
||||||
'test/onnx/**/*.py',
|
'test/onnx/**/*.py',
|
||||||
]
|
]
|
||||||
|
exclude_patterns = [
|
||||||
|
'tools/gen_vulkan_spv.py',
|
||||||
|
]
|
||||||
command = [
|
command = [
|
||||||
'python3',
|
'python3',
|
||||||
'tools/linter/adapters/black_linter.py',
|
'tools/linter/adapters/black_linter.py',
|
||||||
|
|
|
||||||
104
buckbuild.bzl
104
buckbuild.bzl
|
|
@ -3,8 +3,6 @@
|
||||||
|
|
||||||
load("@bazel_skylib//lib:paths.bzl", "paths")
|
load("@bazel_skylib//lib:paths.bzl", "paths")
|
||||||
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
|
load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
|
||||||
load("//tools/build_defs:fb_python_binary.bzl", "fb_python_binary")
|
|
||||||
load("//tools/build_defs:fb_python_library.bzl", "fb_python_library")
|
|
||||||
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
|
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
|
||||||
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||||
load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
|
load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
|
||||||
|
|
@ -416,7 +414,7 @@ def gen_aten_files(
|
||||||
name = name,
|
name = name,
|
||||||
default_outs = ["."],
|
default_outs = ["."],
|
||||||
outs = get_aten_generated_files(backends),
|
outs = get_aten_generated_files(backends),
|
||||||
cmd = "$(exe {}:gen_aten_bin) ".format(ROOT) + " ".join([
|
cmd = "$(exe {}torchgen:gen) ".format(ROOT_PATH) + " ".join([
|
||||||
"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
|
"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
|
||||||
"--install_dir $OUT",
|
"--install_dir $OUT",
|
||||||
] + extra_params),
|
] + extra_params),
|
||||||
|
|
@ -442,7 +440,7 @@ def gen_aten_unboxing_files(
|
||||||
name = genrule_name,
|
name = genrule_name,
|
||||||
default_outs = ["."],
|
default_outs = ["."],
|
||||||
outs = get_unboxing_generated_files(),
|
outs = get_unboxing_generated_files(),
|
||||||
cmd = "$(exe {}:gen_unboxing_bin) ".format(ROOT) + " ".join([
|
cmd = "$(exe {}tools:gen_unboxing_bin) ".format(ROOT_PATH) + " ".join([
|
||||||
"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
|
"--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
|
||||||
"--install_dir $OUT",
|
"--install_dir $OUT",
|
||||||
] + extra_params),
|
] + extra_params),
|
||||||
|
|
@ -515,7 +513,7 @@ def pt_operator_query_codegen(
|
||||||
# @lint-ignore BUCKLINT
|
# @lint-ignore BUCKLINT
|
||||||
fb_native.genrule(
|
fb_native.genrule(
|
||||||
name = oplist_dir_name,
|
name = oplist_dir_name,
|
||||||
cmd = ("$(exe {}:gen_oplist) ".format(ROOT) +
|
cmd = ("$(exe {}tools:gen_oplist) ".format(ROOT_PATH) +
|
||||||
"--model_file_list_path $(@query_outputs 'attrfilter(labels, pt_operator_library, deps(set({deps})))') " +
|
"--model_file_list_path $(@query_outputs 'attrfilter(labels, pt_operator_library, deps(set({deps})))') " +
|
||||||
("" if enforce_traced_op_list else "--allow_include_all_overloads ") +
|
("" if enforce_traced_op_list else "--allow_include_all_overloads ") +
|
||||||
"--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])),
|
"--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])),
|
||||||
|
|
@ -620,7 +618,7 @@ def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple
|
||||||
outs = get_generate_code_bin_outs(),
|
outs = get_generate_code_bin_outs(),
|
||||||
default_outs = ["."],
|
default_outs = ["."],
|
||||||
bash = "mkdir -p tools && " +
|
bash = "mkdir -p tools && " +
|
||||||
"$(exe {}tools/setup_helpers:generate_code_bin) ".format(ROOT_PATH) + " ".join(
|
"$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
|
||||||
# Mobile build only needs libtorch - skip python bindings for now, except
|
# Mobile build only needs libtorch - skip python bindings for now, except
|
||||||
# for ovrsource, which needs Python bindings.
|
# for ovrsource, which needs Python bindings.
|
||||||
(["--subset libtorch"] if not is_arvr_mode() else []) + [
|
(["--subset libtorch"] if not is_arvr_mode() else []) + [
|
||||||
|
|
@ -630,7 +628,7 @@ def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple
|
||||||
] + extra_params,
|
] + extra_params,
|
||||||
),
|
),
|
||||||
cmd_exe = "@powershell -Command New-Item -Path tools -ItemType Directory -Force; " +
|
cmd_exe = "@powershell -Command New-Item -Path tools -ItemType Directory -Force; " +
|
||||||
"$(exe {}tools/setup_helpers:generate_code_bin) ".format(ROOT_PATH) + " ".join(
|
"$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
|
||||||
# Mobile build only needs libtorch - skip python bindings for now, except
|
# Mobile build only needs libtorch - skip python bindings for now, except
|
||||||
# for ovrsource, which needs Python bindings.
|
# for ovrsource, which needs Python bindings.
|
||||||
(["--subset libtorch"] if not is_arvr_mode() else []) + [
|
(["--subset libtorch"] if not is_arvr_mode() else []) + [
|
||||||
|
|
@ -950,7 +948,7 @@ def define_buck_targets(
|
||||||
"torch/csrc/api/include/torch/version.h.in",
|
"torch/csrc/api/include/torch/version.h.in",
|
||||||
"version.txt",
|
"version.txt",
|
||||||
],
|
],
|
||||||
cmd = "$(exe {}tools/setup_helpers:gen-version-header) ".format(ROOT_PATH) + " ".join([
|
cmd = "$(exe {}tools:gen-version-header) ".format(ROOT_PATH) + " ".join([
|
||||||
"--template-path",
|
"--template-path",
|
||||||
"torch/csrc/api/include/torch/version.h.in",
|
"torch/csrc/api/include/torch/version.h.in",
|
||||||
"--version-path",
|
"--version-path",
|
||||||
|
|
@ -995,28 +993,13 @@ def define_buck_targets(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
fb_python_library(
|
|
||||||
name = "substitutelib",
|
|
||||||
srcs = ["tools/substitute.py"],
|
|
||||||
base_module = "",
|
|
||||||
)
|
|
||||||
|
|
||||||
fb_python_binary(
|
|
||||||
name = "substitute",
|
|
||||||
main_module = "tools.substitute",
|
|
||||||
visibility = ["PUBLIC"],
|
|
||||||
deps = [
|
|
||||||
":substitutelib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# @lint-ignore BUCKLINT
|
# @lint-ignore BUCKLINT
|
||||||
fb_native.genrule(
|
fb_native.genrule(
|
||||||
name = "generate_aten_config",
|
name = "generate_aten_config",
|
||||||
srcs = [
|
srcs = [
|
||||||
"aten/src/ATen/Config.h.in",
|
"aten/src/ATen/Config.h.in",
|
||||||
],
|
],
|
||||||
cmd = "$(exe :substitute) " + " ".join([
|
cmd = "$(exe {}tools:substitute) ".format(ROOT_PATH) + " ".join([
|
||||||
"--install_dir",
|
"--install_dir",
|
||||||
"$OUT",
|
"$OUT",
|
||||||
"--input-file",
|
"--input-file",
|
||||||
|
|
@ -1072,79 +1055,6 @@ def define_buck_targets(
|
||||||
default_outs = ["."],
|
default_outs = ["."],
|
||||||
)
|
)
|
||||||
|
|
||||||
fb_python_binary(
|
|
||||||
name = "gen_aten_bin",
|
|
||||||
main_module = "torchgen.gen",
|
|
||||||
visibility = [
|
|
||||||
"PUBLIC",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
ROOT_PATH + "torchgen:torchgen",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
fb_python_binary(
|
|
||||||
name = "gen_unboxing_bin",
|
|
||||||
main_module = "tools.jit.gen_unboxing",
|
|
||||||
visibility = [
|
|
||||||
"PUBLIC",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
ROOT_PATH + "tools/jit:jit",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
fb_python_library(
|
|
||||||
name = "gen_oplist_lib",
|
|
||||||
srcs = subdir_glob([
|
|
||||||
("tools/code_analyzer", "gen_oplist.py"),
|
|
||||||
("tools/code_analyzer", "gen_op_registration_allowlist.py"),
|
|
||||||
]),
|
|
||||||
base_module = "",
|
|
||||||
tests = [
|
|
||||||
":gen_oplist_test",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
third_party("pyyaml"),
|
|
||||||
ROOT_PATH + "tools/lite_interpreter:gen_selected_mobile_ops_header",
|
|
||||||
ROOT_PATH + "torchgen:torchgen",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
fb_python_library(
|
|
||||||
name = "gen_operators_yaml_lib",
|
|
||||||
srcs = subdir_glob([
|
|
||||||
("tools/code_analyzer", "gen_operators_yaml.py"),
|
|
||||||
("tools/code_analyzer", "gen_op_registration_allowlist.py"),
|
|
||||||
]),
|
|
||||||
base_module = "",
|
|
||||||
tests = [
|
|
||||||
":gen_operators_yaml_test",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
third_party("pyyaml"),
|
|
||||||
ROOT_PATH + "torchgen:torchgen",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
fb_python_binary(
|
|
||||||
name = "gen_oplist",
|
|
||||||
main_module = "gen_oplist",
|
|
||||||
visibility = ["PUBLIC"],
|
|
||||||
deps = [
|
|
||||||
":gen_oplist_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
fb_python_binary(
|
|
||||||
name = "gen_operators_yaml",
|
|
||||||
main_module = "gen_operators_yaml",
|
|
||||||
visibility = ["PUBLIC"],
|
|
||||||
deps = [
|
|
||||||
":gen_operators_yaml_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
gen_aten_files(
|
gen_aten_files(
|
||||||
name = "gen_aten",
|
name = "gen_aten",
|
||||||
extra_flags = get_aten_codegen_extra_params(USED_PT_BACKENDS),
|
extra_flags = get_aten_codegen_extra_params(USED_PT_BACKENDS),
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ if(NOT USE_VULKAN_SHADERC_RUNTIME)
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND
|
COMMAND
|
||||||
"${PYTHON_EXECUTABLE}"
|
"${PYTHON_EXECUTABLE}"
|
||||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen_vulkan_spv.py
|
${CMAKE_CURRENT_LIST_DIR}/../tools/gen_vulkan_spv.py
|
||||||
--glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl
|
--glsl-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/vulkan/glsl
|
||||||
--output-path ${VULKAN_GEN_OUTPUT_PATH}
|
--output-path ${VULKAN_GEN_OUTPUT_PATH}
|
||||||
--glslc-path=${GLSLC_PATH}
|
--glslc-path=${GLSLC_PATH}
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ def pt_operator_library(
|
||||||
name = name,
|
name = name,
|
||||||
out = "model_operators.yaml",
|
out = "model_operators.yaml",
|
||||||
cmd = (
|
cmd = (
|
||||||
"$(exe {root}:gen_operators_yaml) " +
|
"$(exe {exe}) " +
|
||||||
"{optionally_root_ops} " +
|
"{optionally_root_ops} " +
|
||||||
"{optionally_training_root_ops} " +
|
"{optionally_training_root_ops} " +
|
||||||
"--rule_name {rule_name} " +
|
"--rule_name {rule_name} " +
|
||||||
|
|
@ -52,7 +52,7 @@ def pt_operator_library(
|
||||||
"{optionally_model_traced_backends} " +
|
"{optionally_model_traced_backends} " +
|
||||||
"{optionally_include_all_operators}"
|
"{optionally_include_all_operators}"
|
||||||
).format(
|
).format(
|
||||||
root = "//" if IS_OSS else "//xplat/caffe2",
|
exe = "//tools:gen_operators_yaml" if IS_OSS else "//xplat/caffe2/tools:gen_operators_yaml",
|
||||||
rule_name = name,
|
rule_name = name,
|
||||||
model_name = model_name,
|
model_name = model_name,
|
||||||
dep_graph_yaml = "none" if IS_OSS else "$(location //xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
|
dep_graph_yaml = "none" if IS_OSS else "$(location //xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
|
||||||
|
|
|
||||||
263
tools/BUCK.bzl
Normal file
263
tools/BUCK.bzl
Normal file
|
|
@ -0,0 +1,263 @@
|
||||||
|
# @lint-ignore-every FBCODEBZLADDLOADS
|
||||||
|
load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
|
||||||
|
|
||||||
|
# shared by internal and OSS BUCK
|
||||||
|
def define_tools_targets(
|
||||||
|
python_binary,
|
||||||
|
python_library,
|
||||||
|
python_test,
|
||||||
|
third_party,
|
||||||
|
torchgen_deps,
|
||||||
|
contacts = []):
|
||||||
|
python_library(
|
||||||
|
name = "substitutelib",
|
||||||
|
srcs = ["substitute.py"],
|
||||||
|
base_module = "",
|
||||||
|
)
|
||||||
|
|
||||||
|
python_binary(
|
||||||
|
name = "substitute",
|
||||||
|
main_module = "substitute",
|
||||||
|
visibility = ["PUBLIC"],
|
||||||
|
deps = [
|
||||||
|
":substitutelib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_library(
|
||||||
|
name = "jit",
|
||||||
|
# @lint-ignore BUCKRESTRICTEDSYNTAX
|
||||||
|
srcs = glob([
|
||||||
|
"jit/*.py",
|
||||||
|
"jit/templates/*",
|
||||||
|
]),
|
||||||
|
base_module = "tools",
|
||||||
|
visibility = ["PUBLIC"],
|
||||||
|
deps = [
|
||||||
|
torchgen_deps,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_binary(
|
||||||
|
name = "gen_unboxing_bin",
|
||||||
|
main_module = "tools.jit.gen_unboxing",
|
||||||
|
visibility = [
|
||||||
|
"PUBLIC",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":jit",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_library(
|
||||||
|
name = "gen_selected_mobile_ops_header",
|
||||||
|
srcs = ["lite_interpreter/gen_selected_mobile_ops_header.py"],
|
||||||
|
base_module = "tools",
|
||||||
|
visibility = ["PUBLIC"],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_library(
|
||||||
|
name = "gen_oplist_lib",
|
||||||
|
srcs = subdir_glob([
|
||||||
|
("code_analyzer", "gen_oplist.py"),
|
||||||
|
("code_analyzer", "gen_op_registration_allowlist.py"),
|
||||||
|
]),
|
||||||
|
base_module = "",
|
||||||
|
tests = [
|
||||||
|
":gen_oplist_test",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":gen_selected_mobile_ops_header",
|
||||||
|
torchgen_deps,
|
||||||
|
third_party("pyyaml"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_binary(
|
||||||
|
name = "gen_oplist",
|
||||||
|
main_module = "gen_oplist",
|
||||||
|
visibility = ["PUBLIC"],
|
||||||
|
deps = [
|
||||||
|
":gen_oplist_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_library(
|
||||||
|
name = "gen_operators_yaml_lib",
|
||||||
|
srcs = subdir_glob([
|
||||||
|
("code_analyzer", "gen_operators_yaml.py"),
|
||||||
|
("code_analyzer", "gen_op_registration_allowlist.py"),
|
||||||
|
]),
|
||||||
|
base_module = "",
|
||||||
|
tests = [
|
||||||
|
":gen_operators_yaml_test",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
third_party("pyyaml"),
|
||||||
|
torchgen_deps,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_binary(
|
||||||
|
name = "gen_operators_yaml",
|
||||||
|
main_module = "gen_operators_yaml",
|
||||||
|
visibility = ["PUBLIC"],
|
||||||
|
deps = [
|
||||||
|
":gen_operators_yaml_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_library(
|
||||||
|
name = "autograd",
|
||||||
|
# @lint-ignore BUCKRESTRICTEDSYNTAX
|
||||||
|
srcs = glob(
|
||||||
|
["autograd/*.py"],
|
||||||
|
),
|
||||||
|
base_module = "tools",
|
||||||
|
resources = [
|
||||||
|
"autograd/deprecated.yaml",
|
||||||
|
"autograd/derivatives.yaml",
|
||||||
|
"autograd/templates/ADInplaceOrViewType.cpp",
|
||||||
|
"autograd/templates/Functions.cpp",
|
||||||
|
"autograd/templates/Functions.h",
|
||||||
|
"autograd/templates/TraceType.cpp",
|
||||||
|
"autograd/templates/VariableType.cpp",
|
||||||
|
"autograd/templates/VariableType.h",
|
||||||
|
"autograd/templates/annotated_fn_args.py.in",
|
||||||
|
"autograd/templates/python_enum_tag.cpp",
|
||||||
|
"autograd/templates/python_fft_functions.cpp",
|
||||||
|
"autograd/templates/python_functions.cpp",
|
||||||
|
"autograd/templates/python_functions.h",
|
||||||
|
"autograd/templates/python_linalg_functions.cpp",
|
||||||
|
"autograd/templates/python_nn_functions.cpp",
|
||||||
|
"autograd/templates/python_return_types.cpp",
|
||||||
|
"autograd/templates/python_sparse_functions.cpp",
|
||||||
|
"autograd/templates/python_special_functions.cpp",
|
||||||
|
"autograd/templates/python_torch_functions.cpp",
|
||||||
|
"autograd/templates/python_variable_methods.cpp",
|
||||||
|
"autograd/templates/variable_factories.h",
|
||||||
|
],
|
||||||
|
visibility = ["PUBLIC"],
|
||||||
|
deps = [
|
||||||
|
third_party("pyyaml"),
|
||||||
|
torchgen_deps,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_library(
|
||||||
|
name = "generate_code",
|
||||||
|
srcs = [
|
||||||
|
"setup_helpers/generate_code.py",
|
||||||
|
],
|
||||||
|
base_module = "tools",
|
||||||
|
deps = [
|
||||||
|
":autograd",
|
||||||
|
":jit",
|
||||||
|
torchgen_deps,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_binary(
|
||||||
|
name = "generate_code_bin",
|
||||||
|
main_module = "tools.setup_helpers.generate_code",
|
||||||
|
# Windows does not support inplace:
|
||||||
|
# https://github.com/facebook/buck/issues/2161.
|
||||||
|
#
|
||||||
|
# Note that //arvr/mode/embedded/win/clang-aarch64-release sets
|
||||||
|
# its target platform to
|
||||||
|
# ovr_config//platform/embedded:clang-aarch64-linux-release, hence
|
||||||
|
# that is why we are selecting that OS to trigger this behavior.
|
||||||
|
package_style = select({
|
||||||
|
"DEFAULT": "inplace",
|
||||||
|
"ovr_config//os:linux-arm64": "standalone",
|
||||||
|
}),
|
||||||
|
visibility = ["PUBLIC"],
|
||||||
|
# Because Windows does not support inplace packaging, we need to
|
||||||
|
# ensure it is unzipped before executing it, otherwise it will not
|
||||||
|
# be able to find any resources using path manipulation.
|
||||||
|
#
|
||||||
|
# See note above about why the OS is Linux here and not Windows.
|
||||||
|
zip_safe = select({
|
||||||
|
"DEFAULT": True,
|
||||||
|
"ovr_config//os:linux-arm64": False,
|
||||||
|
}),
|
||||||
|
deps = [
|
||||||
|
":generate_code",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_library(
|
||||||
|
name = "gen-version-header-lib",
|
||||||
|
srcs = [
|
||||||
|
"setup_helpers/gen_version_header.py",
|
||||||
|
],
|
||||||
|
base_module = "",
|
||||||
|
deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_binary(
|
||||||
|
name = "gen-version-header",
|
||||||
|
main_module = "setup_helpers.gen_version_header",
|
||||||
|
visibility = ["PUBLIC"],
|
||||||
|
deps = [
|
||||||
|
":gen-version-header-lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_library(
|
||||||
|
name = "gen_aten_vulkan_spv_lib",
|
||||||
|
srcs = [
|
||||||
|
"gen_vulkan_spv.py",
|
||||||
|
],
|
||||||
|
base_module = "",
|
||||||
|
deps = [
|
||||||
|
torchgen_deps,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_binary(
|
||||||
|
name = "gen_aten_vulkan_spv_bin",
|
||||||
|
main_module = "gen_vulkan_spv",
|
||||||
|
visibility = [
|
||||||
|
"PUBLIC",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":gen_aten_vulkan_spv_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_test(
|
||||||
|
name = "selective_build_test",
|
||||||
|
srcs = [
|
||||||
|
"test/test_selective_build.py",
|
||||||
|
],
|
||||||
|
contacts = contacts,
|
||||||
|
visibility = ["PUBLIC"],
|
||||||
|
deps = [
|
||||||
|
torchgen_deps,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_test(
|
||||||
|
name = "gen_oplist_test",
|
||||||
|
srcs = [
|
||||||
|
"test/gen_oplist_test.py",
|
||||||
|
],
|
||||||
|
contacts = contacts,
|
||||||
|
visibility = ["PUBLIC"],
|
||||||
|
deps = [
|
||||||
|
":gen_oplist_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
python_test(
|
||||||
|
name = "gen_operators_yaml_test",
|
||||||
|
srcs = [
|
||||||
|
"test/gen_operators_yaml_test.py",
|
||||||
|
],
|
||||||
|
visibility = ["PUBLIC"],
|
||||||
|
contacts = contacts,
|
||||||
|
deps = [
|
||||||
|
":gen_operators_yaml_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
10
tools/BUCK.oss
Normal file
10
tools/BUCK.oss
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
load("//:buckbuild.bzl", "third_party")
|
||||||
|
load(":BUCK.bzl", "define_tools_targets")
|
||||||
|
|
||||||
|
define_tools_targets(
|
||||||
|
python_binary = python_binary,
|
||||||
|
python_library = python_library,
|
||||||
|
python_test = python_test,
|
||||||
|
third_party = third_party,
|
||||||
|
torchgen_deps = "//torchgen:torchgen",
|
||||||
|
)
|
||||||
|
|
@ -1,35 +0,0 @@
|
||||||
python_library(
|
|
||||||
name = "autograd",
|
|
||||||
srcs = glob(
|
|
||||||
["*.py"],
|
|
||||||
),
|
|
||||||
base_module = "tools.autograd",
|
|
||||||
resources = [
|
|
||||||
"deprecated.yaml",
|
|
||||||
"derivatives.yaml",
|
|
||||||
"templates/ADInplaceOrViewType.cpp",
|
|
||||||
"templates/Functions.cpp",
|
|
||||||
"templates/Functions.h",
|
|
||||||
"templates/TraceType.cpp",
|
|
||||||
"templates/VariableType.cpp",
|
|
||||||
"templates/VariableType.h",
|
|
||||||
"templates/annotated_fn_args.py.in",
|
|
||||||
"templates/python_fft_functions.cpp",
|
|
||||||
"templates/python_functions.cpp",
|
|
||||||
"templates/python_functions.h",
|
|
||||||
"templates/python_linalg_functions.cpp",
|
|
||||||
"templates/python_nn_functions.cpp",
|
|
||||||
"templates/python_return_types.cpp",
|
|
||||||
"templates/python_sparse_functions.cpp",
|
|
||||||
"templates/python_special_functions.cpp",
|
|
||||||
"templates/python_torch_functions.cpp",
|
|
||||||
"templates/python_variable_methods.cpp",
|
|
||||||
"templates/variable_factories.h",
|
|
||||||
"templates/python_enum_tag.cpp",
|
|
||||||
],
|
|
||||||
visibility = ["PUBLIC"],
|
|
||||||
deps = [
|
|
||||||
"//third_party:pyyaml",
|
|
||||||
"//torchgen:torchgen",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
@ -1,9 +0,0 @@
|
||||||
# Only used for PyTorch open source BUCK build
|
|
||||||
# @lint-ignore-every BUCKRESTRICTEDSYNTAX
|
|
||||||
# @lint-ignore-every FBCODEBZLADDLOADS
|
|
||||||
|
|
||||||
def fb_python_binary(**kwgs):
|
|
||||||
if read_config("pt", "is_oss", "0") == "0":
|
|
||||||
fail("This file is for open source pytorch build. Do not use it in fbsource!")
|
|
||||||
|
|
||||||
python_binary(**kwgs)
|
|
||||||
|
|
@ -1,9 +0,0 @@
|
||||||
# Only used for PyTorch open source BUCK build
|
|
||||||
# @lint-ignore-every BUCKRESTRICTEDSYNTAX
|
|
||||||
# @lint-ignore-every FBCODEBZLADDLOADS
|
|
||||||
|
|
||||||
def fb_python_library(**kwgs):
|
|
||||||
if read_config("pt", "is_oss", "0") == "0":
|
|
||||||
fail("This file is for open source pytorch build. Do not use it in fbsource!")
|
|
||||||
|
|
||||||
python_library(**kwgs)
|
|
||||||
|
|
@ -1,12 +0,0 @@
|
||||||
python_library(
|
|
||||||
name = "jit",
|
|
||||||
srcs = glob([
|
|
||||||
"*.py",
|
|
||||||
"templates/*",
|
|
||||||
]),
|
|
||||||
base_module = "tools.jit",
|
|
||||||
visibility = ["PUBLIC"],
|
|
||||||
deps = [
|
|
||||||
"//torchgen:torchgen",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
python_library(
|
|
||||||
name = "gen_selected_mobile_ops_header",
|
|
||||||
srcs = ["gen_selected_mobile_ops_header.py"],
|
|
||||||
base_module = "tools.lite_interpreter",
|
|
||||||
visibility = ["PUBLIC"],
|
|
||||||
)
|
|
||||||
|
|
@ -1,41 +0,0 @@
|
||||||
python_library(
|
|
||||||
name = "generate_code",
|
|
||||||
srcs = [
|
|
||||||
"generate_code.py",
|
|
||||||
],
|
|
||||||
base_module = "tools.setup_helpers",
|
|
||||||
deps = [
|
|
||||||
"//tools/autograd:autograd",
|
|
||||||
"//tools/jit:jit",
|
|
||||||
"//torchgen:torchgen",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python_binary(
|
|
||||||
name = "generate_code_bin",
|
|
||||||
main_module = "tools.setup_helpers.generate_code",
|
|
||||||
visibility = ["PUBLIC"],
|
|
||||||
# package_style = "inplace",
|
|
||||||
zip_safe = False,
|
|
||||||
deps = [
|
|
||||||
":generate_code",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
python_library(
|
|
||||||
name = "gen-version-header-lib",
|
|
||||||
srcs = [
|
|
||||||
"gen_version_header.py",
|
|
||||||
],
|
|
||||||
base_module = "tools.setup_helpers",
|
|
||||||
deps = [],
|
|
||||||
)
|
|
||||||
|
|
||||||
python_binary(
|
|
||||||
name = "gen-version-header",
|
|
||||||
main_module = "tools.setup_helpers.gen_version_header",
|
|
||||||
visibility = ["PUBLIC"],
|
|
||||||
deps = [
|
|
||||||
":gen-version-header-lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
190
tools/test/gen_operators_yaml_test.py
Normal file
190
tools/test/gen_operators_yaml_test.py
Normal file
|
|
@ -0,0 +1,190 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2004-present Facebook. All Rights Reserved.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from gen_operators_yaml import make_filter_from_options, verify_all_specified_present
|
||||||
|
|
||||||
|
|
||||||
|
class GenOperatorsYAMLTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_filter_creation(self):
|
||||||
|
filter_func = make_filter_from_options(
|
||||||
|
model_name="abc",
|
||||||
|
model_versions=["100", "101"],
|
||||||
|
model_assets=None,
|
||||||
|
model_backends=None,
|
||||||
|
)
|
||||||
|
config = [
|
||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "abc",
|
||||||
|
"version": 100,
|
||||||
|
"asset": "asset-1",
|
||||||
|
"backend": "CPU",
|
||||||
|
},
|
||||||
|
"root_operators": [],
|
||||||
|
"traced_operators": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "abc",
|
||||||
|
"version": 102,
|
||||||
|
"asset": "asset-1",
|
||||||
|
"backend": "CPU",
|
||||||
|
},
|
||||||
|
"root_operators": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "abcd",
|
||||||
|
"version": 100,
|
||||||
|
"asset": "asset-1",
|
||||||
|
"backend": "CPU",
|
||||||
|
},
|
||||||
|
"root_operators": [],
|
||||||
|
"traced_operators": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "abc",
|
||||||
|
"version": 101,
|
||||||
|
"asset": "asset-2",
|
||||||
|
"backend": "CPU",
|
||||||
|
},
|
||||||
|
"root_operators": [],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered_configs = list(filter(filter_func, config))
|
||||||
|
assert (
|
||||||
|
len(filtered_configs) == 2
|
||||||
|
), "Expected 2 elements in filtered_configs, but got {}".format(
|
||||||
|
len(filtered_configs)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_verification_success(self):
|
||||||
|
filter_func = make_filter_from_options(
|
||||||
|
model_name="abc",
|
||||||
|
model_versions=["100", "101"],
|
||||||
|
model_assets=["asset-1", "asset-2"],
|
||||||
|
model_backends=None,
|
||||||
|
)
|
||||||
|
config = [
|
||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "abc",
|
||||||
|
"version": 100,
|
||||||
|
"asset": "asset-1",
|
||||||
|
"backend": "CPU",
|
||||||
|
},
|
||||||
|
"root_operators": [],
|
||||||
|
"traced_operators": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "abc",
|
||||||
|
"version": 101,
|
||||||
|
"asset": "asset-2",
|
||||||
|
"backend": "CPU",
|
||||||
|
},
|
||||||
|
"root_operators": [],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
filtered_configs = list(filter(filter_func, config))
|
||||||
|
try:
|
||||||
|
verify_all_specified_present(
|
||||||
|
model_assets=["asset-1", "asset-2"],
|
||||||
|
model_versions=["100", "101"],
|
||||||
|
selected_models_yaml=filtered_configs,
|
||||||
|
rule_name="test",
|
||||||
|
model_name="abc",
|
||||||
|
new_style_rule=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self.fail(
|
||||||
|
"expected verify_all_specified_present to succeed instead it raised an exception"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_verification_fail(self):
|
||||||
|
config = [
|
||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "abc",
|
||||||
|
"version": 100,
|
||||||
|
"asset": "asset-1",
|
||||||
|
"backend": "CPU",
|
||||||
|
},
|
||||||
|
"root_operators": [],
|
||||||
|
"traced_operators": [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "abc",
|
||||||
|
"version": 101,
|
||||||
|
"asset": "asset-2",
|
||||||
|
"backend": "CPU",
|
||||||
|
},
|
||||||
|
"root_operators": [],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
good_assets = ["asset-1", "asset-2"]
|
||||||
|
good_versions = ["100", "101"]
|
||||||
|
good_name = "abc"
|
||||||
|
|
||||||
|
# Test bad asset
|
||||||
|
filter_func_bad_asset = make_filter_from_options(
|
||||||
|
model_name=good_name,
|
||||||
|
model_versions=good_versions,
|
||||||
|
model_assets=["asset-1", "asset-2", "asset-3"],
|
||||||
|
model_backends=None,
|
||||||
|
)
|
||||||
|
filtered_configs_asset = list(filter(filter_func_bad_asset, config))
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
verify_all_specified_present(
|
||||||
|
model_assets=["asset-1", "asset-2", "asset-3"],
|
||||||
|
model_versions=good_versions,
|
||||||
|
selected_models_yaml=filtered_configs_asset,
|
||||||
|
rule_name="test",
|
||||||
|
model_name=good_name,
|
||||||
|
new_style_rule=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test bad version
|
||||||
|
filter_func_bad_version = make_filter_from_options(
|
||||||
|
model_name=good_name,
|
||||||
|
model_versions=["100", "101", "102"],
|
||||||
|
model_assets=good_assets,
|
||||||
|
model_backends=None,
|
||||||
|
)
|
||||||
|
filtered_configs_version = list(filter(filter_func_bad_version, config))
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
verify_all_specified_present(
|
||||||
|
model_assets=good_assets,
|
||||||
|
model_versions=["100", "101", "102"],
|
||||||
|
selected_models_yaml=filtered_configs_version,
|
||||||
|
rule_name="test",
|
||||||
|
model_name=good_name,
|
||||||
|
new_style_rule=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test bad name
|
||||||
|
filter_func_bad_name = make_filter_from_options(
|
||||||
|
model_name="abcd",
|
||||||
|
model_versions=good_versions,
|
||||||
|
model_assets=good_assets,
|
||||||
|
model_backends=None,
|
||||||
|
)
|
||||||
|
filtered_configs_name = list(filter(filter_func_bad_name, config))
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
verify_all_specified_present(
|
||||||
|
model_assets=good_assets,
|
||||||
|
model_versions=good_versions,
|
||||||
|
selected_models_yaml=filtered_configs_name,
|
||||||
|
rule_name="test",
|
||||||
|
model_name="abcd",
|
||||||
|
new_style_rule=True,
|
||||||
|
)
|
||||||
35
tools/test/gen_oplist_test.py
Normal file
35
tools/test/gen_oplist_test.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2004-present Facebook. All Rights Reserved.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from gen_oplist import throw_if_any_op_includes_overloads
|
||||||
|
|
||||||
|
|
||||||
|
class GenOplistTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_throw_if_any_op_includes_overloads(self):
|
||||||
|
selective_builder = MagicMock()
|
||||||
|
selective_builder.operators = MagicMock()
|
||||||
|
selective_builder.operators.items.return_value = [
|
||||||
|
("op1", MagicMock(include_all_overloads=True)),
|
||||||
|
("op2", MagicMock(include_all_overloads=False)),
|
||||||
|
("op3", MagicMock(include_all_overloads=True)),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertRaises(
|
||||||
|
Exception, throw_if_any_op_includes_overloads, selective_builder
|
||||||
|
)
|
||||||
|
|
||||||
|
selective_builder.operators.items.return_value = [
|
||||||
|
("op1", MagicMock(include_all_overloads=False)),
|
||||||
|
("op2", MagicMock(include_all_overloads=False)),
|
||||||
|
("op3", MagicMock(include_all_overloads=False)),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Here we do not expect it to throw an exception since none of the ops
|
||||||
|
# include all overloads.
|
||||||
|
throw_if_any_op_includes_overloads(selective_builder)
|
||||||
281
tools/test/test_selective_build.py
Normal file
281
tools/test/test_selective_build.py
Normal file
|
|
@ -0,0 +1,281 @@
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from torchgen.selective_build.operator import *
|
||||||
|
from torchgen.selective_build.selector import (
|
||||||
|
combine_selective_builders,
|
||||||
|
SelectiveBuilder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSelectiveBuild(unittest.TestCase):
|
||||||
|
def test_selective_build_operator(self):
|
||||||
|
op = SelectiveBuildOperator(
|
||||||
|
"aten::add.int",
|
||||||
|
is_root_operator=True,
|
||||||
|
is_used_for_training=False,
|
||||||
|
include_all_overloads=False,
|
||||||
|
_debug_info=None,
|
||||||
|
)
|
||||||
|
self.assertTrue(op.is_root_operator)
|
||||||
|
self.assertFalse(op.is_used_for_training)
|
||||||
|
self.assertFalse(op.include_all_overloads)
|
||||||
|
|
||||||
|
def test_selector_factory(self):
|
||||||
|
yaml_config_v1 = """
|
||||||
|
debug_info:
|
||||||
|
- model1@v100
|
||||||
|
- model2@v51
|
||||||
|
operators:
|
||||||
|
aten::add:
|
||||||
|
is_used_for_training: No
|
||||||
|
is_root_operator: Yes
|
||||||
|
include_all_overloads: Yes
|
||||||
|
aten::add.int:
|
||||||
|
is_used_for_training: Yes
|
||||||
|
is_root_operator: No
|
||||||
|
include_all_overloads: No
|
||||||
|
aten::mul.int:
|
||||||
|
is_used_for_training: Yes
|
||||||
|
is_root_operator: No
|
||||||
|
include_all_overloads: No
|
||||||
|
"""
|
||||||
|
|
||||||
|
yaml_config_v2 = """
|
||||||
|
debug_info:
|
||||||
|
- model1@v100
|
||||||
|
- model2@v51
|
||||||
|
operators:
|
||||||
|
aten::sub:
|
||||||
|
is_used_for_training: No
|
||||||
|
is_root_operator: Yes
|
||||||
|
include_all_overloads: No
|
||||||
|
debug_info:
|
||||||
|
- model1@v100
|
||||||
|
aten::sub.int:
|
||||||
|
is_used_for_training: Yes
|
||||||
|
is_root_operator: No
|
||||||
|
include_all_overloads: No
|
||||||
|
"""
|
||||||
|
|
||||||
|
yaml_config_all = "include_all_operators: Yes"
|
||||||
|
|
||||||
|
yaml_config_invalid = "invalid:"
|
||||||
|
|
||||||
|
selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1)
|
||||||
|
|
||||||
|
self.assertTrue(selector1.is_operator_selected("aten::add"))
|
||||||
|
self.assertTrue(selector1.is_operator_selected("aten::add.int"))
|
||||||
|
# Overload name is not used for checking in v1.
|
||||||
|
self.assertTrue(selector1.is_operator_selected("aten::add.float"))
|
||||||
|
|
||||||
|
def gen():
|
||||||
|
return SelectiveBuilder.from_yaml_str(yaml_config_invalid)
|
||||||
|
|
||||||
|
self.assertRaises(Exception, gen)
|
||||||
|
|
||||||
|
selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all)
|
||||||
|
|
||||||
|
self.assertTrue(selector_all.is_operator_selected("aten::add"))
|
||||||
|
self.assertTrue(selector_all.is_operator_selected("aten::sub"))
|
||||||
|
self.assertTrue(selector_all.is_operator_selected("aten::sub.int"))
|
||||||
|
self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32"))
|
||||||
|
|
||||||
|
selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2)
|
||||||
|
|
||||||
|
self.assertFalse(selector2.is_operator_selected("aten::add"))
|
||||||
|
self.assertTrue(selector2.is_operator_selected("aten::sub"))
|
||||||
|
self.assertTrue(selector2.is_operator_selected("aten::sub.int"))
|
||||||
|
|
||||||
|
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||||
|
["aten::add", "aten::add.int", "aten::mul.int"],
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float"))
|
||||||
|
self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add"))
|
||||||
|
self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int"))
|
||||||
|
self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub"))
|
||||||
|
|
||||||
|
self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
|
||||||
|
self.assertFalse(
|
||||||
|
selector_legacy_v1.is_operator_selected_for_training("aten::add")
|
||||||
|
)
|
||||||
|
|
||||||
|
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||||
|
["aten::add", "aten::add.int", "aten::mul.int"],
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(selector_legacy_v1.is_root_operator("aten::add"))
|
||||||
|
self.assertFalse(
|
||||||
|
selector_legacy_v1.is_operator_selected_for_training("aten::add")
|
||||||
|
)
|
||||||
|
self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float"))
|
||||||
|
self.assertFalse(
|
||||||
|
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
|
||||||
|
)
|
||||||
|
|
||||||
|
selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
|
||||||
|
["aten::add", "aten::add.int", "aten::mul.int"],
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
|
||||||
|
self.assertTrue(
|
||||||
|
selector_legacy_v1.is_operator_selected_for_training("aten::add")
|
||||||
|
)
|
||||||
|
self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float"))
|
||||||
|
self.assertTrue(
|
||||||
|
selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_operator_combine(self):
|
||||||
|
op1 = SelectiveBuildOperator(
|
||||||
|
"aten::add.int",
|
||||||
|
is_root_operator=True,
|
||||||
|
is_used_for_training=False,
|
||||||
|
include_all_overloads=False,
|
||||||
|
_debug_info=None,
|
||||||
|
)
|
||||||
|
op2 = SelectiveBuildOperator(
|
||||||
|
"aten::add.int",
|
||||||
|
is_root_operator=False,
|
||||||
|
is_used_for_training=False,
|
||||||
|
include_all_overloads=False,
|
||||||
|
_debug_info=None,
|
||||||
|
)
|
||||||
|
op3 = SelectiveBuildOperator(
|
||||||
|
"aten::add",
|
||||||
|
is_root_operator=True,
|
||||||
|
is_used_for_training=False,
|
||||||
|
include_all_overloads=False,
|
||||||
|
_debug_info=None,
|
||||||
|
)
|
||||||
|
op4 = SelectiveBuildOperator(
|
||||||
|
"aten::add.int",
|
||||||
|
is_root_operator=True,
|
||||||
|
is_used_for_training=True,
|
||||||
|
include_all_overloads=False,
|
||||||
|
_debug_info=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
op5 = combine_operators(op1, op2)
|
||||||
|
|
||||||
|
self.assertTrue(op5.is_root_operator)
|
||||||
|
self.assertFalse(op5.is_used_for_training)
|
||||||
|
|
||||||
|
op6 = combine_operators(op1, op4)
|
||||||
|
|
||||||
|
self.assertTrue(op6.is_root_operator)
|
||||||
|
self.assertTrue(op6.is_used_for_training)
|
||||||
|
|
||||||
|
def gen_new_op():
|
||||||
|
return combine_operators(op1, op3)
|
||||||
|
|
||||||
|
self.assertRaises(Exception, gen_new_op)
|
||||||
|
|
||||||
|
def test_training_op_fetch(self):
|
||||||
|
yaml_config = """
|
||||||
|
operators:
|
||||||
|
aten::add.int:
|
||||||
|
is_used_for_training: No
|
||||||
|
is_root_operator: Yes
|
||||||
|
include_all_overloads: No
|
||||||
|
aten::add:
|
||||||
|
is_used_for_training: Yes
|
||||||
|
is_root_operator: No
|
||||||
|
include_all_overloads: Yes
|
||||||
|
"""
|
||||||
|
|
||||||
|
selector = SelectiveBuilder.from_yaml_str(yaml_config)
|
||||||
|
self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
|
||||||
|
self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
|
||||||
|
|
||||||
|
def test_kernel_dtypes(self):
|
||||||
|
yaml_config = """
|
||||||
|
kernel_metadata:
|
||||||
|
add_kernel:
|
||||||
|
- int8
|
||||||
|
- int32
|
||||||
|
sub_kernel:
|
||||||
|
- int16
|
||||||
|
- int32
|
||||||
|
add/sub_kernel:
|
||||||
|
- float
|
||||||
|
- complex
|
||||||
|
"""
|
||||||
|
|
||||||
|
selector = SelectiveBuilder.from_yaml_str(yaml_config)
|
||||||
|
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
|
||||||
|
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
|
||||||
|
self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
|
||||||
|
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))
|
||||||
|
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
|
||||||
|
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
|
||||||
|
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
|
||||||
|
|
||||||
|
def test_merge_kernel_dtypes(self):
|
||||||
|
yaml_config1 = """
|
||||||
|
kernel_metadata:
|
||||||
|
add_kernel:
|
||||||
|
- int8
|
||||||
|
add/sub_kernel:
|
||||||
|
- float
|
||||||
|
- complex
|
||||||
|
- none
|
||||||
|
mul_kernel:
|
||||||
|
- int8
|
||||||
|
"""
|
||||||
|
|
||||||
|
yaml_config2 = """
|
||||||
|
kernel_metadata:
|
||||||
|
add_kernel:
|
||||||
|
- int32
|
||||||
|
sub_kernel:
|
||||||
|
- int16
|
||||||
|
- int32
|
||||||
|
add/sub_kernel:
|
||||||
|
- float
|
||||||
|
- complex
|
||||||
|
"""
|
||||||
|
|
||||||
|
selector1 = SelectiveBuilder.from_yaml_str(yaml_config1)
|
||||||
|
selector2 = SelectiveBuilder.from_yaml_str(yaml_config2)
|
||||||
|
|
||||||
|
selector = combine_selective_builders(selector1, selector2)
|
||||||
|
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
|
||||||
|
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
|
||||||
|
self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
|
||||||
|
self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))
|
||||||
|
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "none"))
|
||||||
|
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
|
||||||
|
self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
|
||||||
|
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
|
||||||
|
self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
|
||||||
|
|
||||||
|
def test_all_kernel_dtypes_selected(self):
|
||||||
|
yaml_config = """
|
||||||
|
include_all_non_op_selectives: True
|
||||||
|
"""
|
||||||
|
|
||||||
|
selector = SelectiveBuilder.from_yaml_str(yaml_config)
|
||||||
|
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int16"))
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
|
||||||
|
self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))
|
||||||
Loading…
Reference in New Issue
Block a user