mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Revert "Allow specifying tags for aten operators in native_functions.yaml"
This reverts commit 1dab71ab25.
Reverted https://github.com/pytorch/pytorch/pull/72549 on behalf of https://github.com/malfet
This commit is contained in:
parent
cce831c805
commit
ea44645c9a
|
|
@ -66,7 +66,6 @@ cp torch/_utils_internal.py tools/shared
|
|||
# Generate PyTorch files
|
||||
time python tools/setup_helpers/generate_code.py \
|
||||
--native-functions-path aten/src/ATen/native/native_functions.yaml \
|
||||
--tags-path aten/src/ATen/native/tags.yaml \
|
||||
--nn-path aten/src/
|
||||
|
||||
# Build the docs
|
||||
|
|
|
|||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
|
|
@ -493,7 +493,7 @@ jobs:
|
|||
set -eux
|
||||
time python3 -mtools.generate_torch_version --is_debug=false
|
||||
time python3 -mtools.codegen.gen -s aten/src/ATen -d build/aten/src/ATen
|
||||
time python3 -mtools.pyi.gen_pyi --native-functions-path aten/src/ATen/native/native_functions.yaml --tags-path aten/src/ATen/native/tags.yaml --deprecated-functions-path "tools/autograd/deprecated.yaml"
|
||||
time python3 -mtools.pyi.gen_pyi --native-functions-path aten/src/ATen/native/native_functions.yaml --deprecated-functions-path "tools/autograd/deprecated.yaml"
|
||||
- name: Run mypy
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@ mkdir -p "$OUT"/pyi/torch/_C
|
|||
mkdir -p "$OUT"/pyi/torch/nn
|
||||
python -m tools.pyi.gen_pyi \
|
||||
--native-functions-path aten/src/ATen/native/native_functions.yaml \
|
||||
--tags-path aten/src/ATen/native/tags.yaml \
|
||||
--deprecated-functions-path tools/autograd/deprecated.yaml \
|
||||
--out "$OUT"/pyi
|
||||
|
||||
|
|
@ -46,7 +45,6 @@ python -m tools.pyi.gen_pyi \
|
|||
python -m tools.autograd.gen_autograd \
|
||||
"$OUT"/torch/share/ATen/Declarations.yaml \
|
||||
aten/src/ATen/native/native_functions.yaml \
|
||||
aten/src/ATen/native/tags.yaml \
|
||||
"$OUT"/autograd \
|
||||
tools/autograd
|
||||
|
||||
|
|
@ -54,6 +52,5 @@ python -m tools.autograd.gen_autograd \
|
|||
mkdir -p "$OUT"/annotated_fn_args
|
||||
python -m tools.autograd.gen_annotated_fn_args \
|
||||
aten/src/ATen/native/native_functions.yaml \
|
||||
aten/src/ATen/native/tags.yaml \
|
||||
"$OUT"/annotated_fn_args \
|
||||
tools/autograd
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ py_binary(
|
|||
],
|
||||
)
|
||||
|
||||
aten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + ["aten/src/ATen/native/tags.yaml"] + glob(["aten/src/ATen/templates/**"])
|
||||
aten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + glob(["aten/src/ATen/templates/**"])
|
||||
|
||||
generated_cpu_cpp = [
|
||||
"aten/src/ATen/RegisterBackendSelect.cpp",
|
||||
|
|
@ -185,7 +185,6 @@ genrule(
|
|||
name = "all_generated_code",
|
||||
srcs = [
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"aten/src/ATen/native/ts_native_functions.yaml",
|
||||
"torch/csrc/lazy/core/shape_inference.h",
|
||||
"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
|
||||
|
|
@ -195,7 +194,7 @@ genrule(
|
|||
"aten/src/ATen/templates/LazyIr.h",
|
||||
],
|
||||
outs = libtorch_cpp_generated_sources + libtorch_python_generated_sources,
|
||||
cmd = "$(location :generate_code) --install_dir `dirname $(location torch/csrc/autograd/generated/variable_factories.h)`/../.. --native-functions-path $(location aten/src/ATen/native/native_functions.yaml) --tags-path $(location aten/src/ATen/native/tags.yaml) --nn-path aten/src --gen_lazy_ts_backend",
|
||||
cmd = "$(location :generate_code) --install_dir `dirname $(location torch/csrc/autograd/generated/variable_factories.h)`/../.. --native-functions-path $(location aten/src/ATen/native/native_functions.yaml) --nn-path aten/src --gen_lazy_ts_backend",
|
||||
tools = [":generate_code"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
# This yaml file contains all the possible tags that can be defined in `tags` in `native_functions.yaml`
|
||||
|
||||
- tag: inplace_view
|
||||
desc: |
|
||||
This tag indicates if an operator *only* modifies the tensor metadata
|
||||
|
|
@ -417,7 +417,6 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||
COMMAND
|
||||
"${PYTHON_EXECUTABLE}" tools/setup_helpers/generate_code.py
|
||||
--native-functions-path "aten/src/ATen/native/native_functions.yaml"
|
||||
--tags-path "aten/src/ATen/native/tags.yaml"
|
||||
--nn-path "aten/src"
|
||||
$<$<BOOL:${INTERN_DISABLE_AUTOGRAD}>:--disable-autograd>
|
||||
$<$<BOOL:${SELECTED_OP_LIST}>:--selected-op-list-path="${SELECTED_OP_LIST}">
|
||||
|
|
@ -426,7 +425,6 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||
${GEN_PER_OPERATOR_FLAG}
|
||||
DEPENDS
|
||||
"${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml"
|
||||
"${TORCH_ROOT}/aten/src/ATen/native/tags.yaml"
|
||||
"${TORCH_ROOT}/aten/src/ATen/native/ts_native_functions.yaml"
|
||||
"${TORCH_ROOT}/torch/csrc/lazy/core/shape_inference.h"
|
||||
"${TORCH_ROOT}/torch/csrc/lazy/ts_backend/ts_native_functions.cpp"
|
||||
|
|
|
|||
|
|
@ -135,7 +135,6 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||
COMMAND ${GEN_UNBOXING_COMMAND_sources}
|
||||
DEPENDS ${all_unboxing_script} ${sources_templates}
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/tags.yaml
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/..
|
||||
)
|
||||
else() # Otherwise do not generate or include sources into build.
|
||||
|
|
@ -211,7 +210,6 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||
COMMAND ${GEN_COMMAND_${gen_type}}
|
||||
DEPENDS ${all_python} ${${gen_type}_templates}
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/tags.yaml
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/..
|
||||
)
|
||||
endforeach()
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ python -m tools.codegen.gen
|
|||
|
||||
python tools/setup_helpers/generate_code.py \
|
||||
--native-functions-path aten/src/ATen/native/native_functions.yaml \
|
||||
--tags-path aten/src/ATen/native/tags.yaml \
|
||||
--nn-path aten/src
|
||||
|
||||
popd
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ test/test_overrides.py.
|
|||
|
||||
python -m tools.autograd.gen_annotated_fn_args \
|
||||
aten/src/ATen/native/native_functions.yaml \
|
||||
aten/src/ATen/native/tags.yaml \
|
||||
$OUTPUT_DIR \
|
||||
tools/autograd
|
||||
|
||||
|
|
@ -30,8 +29,8 @@ from .gen_python_functions import should_generate_py_binding, is_py_torch_functi
|
|||
is_py_nn_function, is_py_linalg_function, is_py_variable_method, is_py_special_function, \
|
||||
is_py_fft_function
|
||||
|
||||
def gen_annotated(native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str) -> None:
|
||||
native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
|
||||
def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None:
|
||||
native_functions = parse_native_yaml(native_yaml_path).native_functions
|
||||
mappings = (
|
||||
(is_py_torch_function, 'torch._C._VariableFunctions'),
|
||||
(is_py_nn_function, 'torch._C._nn'),
|
||||
|
|
@ -78,14 +77,12 @@ def main() -> None:
|
|||
description='Generate annotated_fn_args script')
|
||||
parser.add_argument('native_functions', metavar='NATIVE',
|
||||
help='path to native_functions.yaml')
|
||||
parser.add_argument('tags', metavar='TAGS',
|
||||
help='path to tags.yaml')
|
||||
parser.add_argument('out', metavar='OUT',
|
||||
help='path to output directory')
|
||||
parser.add_argument('autograd', metavar='AUTOGRAD',
|
||||
help='path to template directory')
|
||||
args = parser.parse_args()
|
||||
gen_annotated(args.native_functions, args.tags, args.out, args.autograd)
|
||||
gen_annotated(args.native_functions, args.out, args.autograd)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ repository, run:
|
|||
python -m tools.autograd.gen_autograd \
|
||||
build/aten/src/ATen/Declarations.yaml \
|
||||
aten/src/ATen/native/native_functions.yaml \
|
||||
aten/src/ATen/native/tags.yaml \
|
||||
$OUTPUT_DIR \
|
||||
tools/autograd
|
||||
|
||||
|
|
@ -42,7 +41,6 @@ from .load_derivatives import load_derivatives
|
|||
|
||||
def gen_autograd(
|
||||
native_functions_path: str,
|
||||
tags_path: str,
|
||||
out: str,
|
||||
autograd_dir: str,
|
||||
operator_selector: SelectiveBuilder,
|
||||
|
|
@ -50,12 +48,11 @@ def gen_autograd(
|
|||
) -> None:
|
||||
# Parse and load derivatives.yaml
|
||||
differentiability_infos = load_derivatives(
|
||||
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path,
|
||||
tags_path)
|
||||
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
|
||||
|
||||
template_path = os.path.join(autograd_dir, 'templates')
|
||||
|
||||
native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions
|
||||
native_funcs = parse_native_yaml(native_functions_path).native_functions
|
||||
fns = list(sorted(filter(
|
||||
operator_selector.is_native_function_selected_for_training,
|
||||
native_funcs), key=lambda f: cpp.name(f.func)))
|
||||
|
|
@ -63,9 +60,9 @@ def gen_autograd(
|
|||
|
||||
# Generate VariableType.h/cpp
|
||||
if not disable_autograd:
|
||||
gen_variable_type(out, native_functions_path, tags_path, fns_with_diff_infos, template_path)
|
||||
gen_variable_type(out, native_functions_path, fns_with_diff_infos, template_path)
|
||||
|
||||
gen_inplace_or_view_type(out, native_functions_path, tags_path, fns_with_diff_infos, template_path)
|
||||
gen_inplace_or_view_type(out, native_functions_path, fns_with_diff_infos, template_path)
|
||||
|
||||
# operator filter not applied as tracing sources are excluded in selective build
|
||||
gen_trace_type(out, native_funcs, template_path)
|
||||
|
|
@ -74,18 +71,16 @@ def gen_autograd(
|
|||
out, differentiability_infos, template_path)
|
||||
|
||||
# Generate variable_factories.h
|
||||
gen_variable_factories(out, native_functions_path, tags_path, template_path)
|
||||
gen_variable_factories(out, native_functions_path, template_path)
|
||||
|
||||
|
||||
def gen_autograd_python(
|
||||
native_functions_path: str,
|
||||
tags_path: str,
|
||||
out: str,
|
||||
autograd_dir: str,
|
||||
) -> None:
|
||||
differentiability_infos = load_derivatives(
|
||||
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path,
|
||||
tags_path)
|
||||
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
|
||||
|
||||
template_path = os.path.join(autograd_dir, 'templates')
|
||||
|
||||
|
|
@ -96,7 +91,7 @@ def gen_autograd_python(
|
|||
# Generate Python bindings
|
||||
deprecated_path = os.path.join(autograd_dir, 'deprecated.yaml')
|
||||
gen_python_functions.gen(
|
||||
out, native_functions_path, tags_path, deprecated_path, template_path)
|
||||
out, native_functions_path, deprecated_path, template_path)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
@ -104,14 +99,12 @@ def main() -> None:
|
|||
description='Generate autograd C++ files script')
|
||||
parser.add_argument('native_functions', metavar='NATIVE',
|
||||
help='path to native_functions.yaml')
|
||||
parser.add_argument('tags', metavar='TAGS',
|
||||
help='path to tags.yaml')
|
||||
parser.add_argument('out', metavar='OUT',
|
||||
help='path to output directory')
|
||||
parser.add_argument('autograd', metavar='AUTOGRAD',
|
||||
help='path to autograd directory')
|
||||
args = parser.parse_args()
|
||||
gen_autograd(args.native_functions, args.tags,
|
||||
gen_autograd(args.native_functions,
|
||||
args.out, args.autograd,
|
||||
SelectiveBuilder.get_nop_selector())
|
||||
|
||||
|
|
|
|||
|
|
@ -421,7 +421,6 @@ def gen_inplace_or_view_type_env(fn: NativeFunctionWithDifferentiabilityInfo) ->
|
|||
def gen_inplace_or_view_type(
|
||||
out: str,
|
||||
native_yaml_path: str,
|
||||
tags_yaml_path: str,
|
||||
fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
|
||||
template_path: str
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -160,9 +160,9 @@ def is_py_special_function(f: NativeFunction) -> bool:
|
|||
#
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
||||
def gen(out: str, native_yaml_path: str, tags_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
|
||||
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
|
||||
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
||||
native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
|
||||
native_functions = parse_native_yaml(native_yaml_path).native_functions
|
||||
native_functions = list(filter(should_generate_py_binding, native_functions))
|
||||
|
||||
methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ def fully_qualified_type(argument_type: str) -> str:
|
|||
qualified_type = f'{argument_type[:index]}at::{argument_type[index:]}'
|
||||
return maybe_optional_type(qualified_type, is_opt)
|
||||
|
||||
def gen_variable_factories(out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str) -> None:
|
||||
native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
|
||||
def gen_variable_factories(out: str, native_yaml_path: str, template_path: str) -> None:
|
||||
native_functions = parse_native_yaml(native_yaml_path).native_functions
|
||||
factory_functions = [fn for fn in native_functions if is_factory_function(fn)]
|
||||
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
||||
fm.write_with_template('variable_factories.h', 'variable_factories.h', lambda: {
|
||||
|
|
|
|||
|
|
@ -371,7 +371,6 @@ for (const auto& _t: ${arg}) {
|
|||
def gen_variable_type(
|
||||
out: str,
|
||||
native_yaml_path: str,
|
||||
tags_yaml_path: str,
|
||||
fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo],
|
||||
template_path: str,
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from tools.codegen.utils import IDENT_REGEX, split_name_params, YamlLoader
|
|||
|
||||
_GLOBAL_LOAD_DERIVATIVE_CACHE = {}
|
||||
|
||||
def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str) -> Sequence[DifferentiabilityInfo]:
|
||||
def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Sequence[DifferentiabilityInfo]:
|
||||
# Do some caching as this is a deterministic function
|
||||
global _GLOBAL_LOAD_DERIVATIVE_CACHE
|
||||
key = (derivatives_yaml_path, native_yaml_path)
|
||||
|
|
@ -30,7 +30,7 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str, tags_yam
|
|||
with open(derivatives_yaml_path, 'r') as f:
|
||||
definitions = yaml.load(f, Loader=YamlLoader)
|
||||
|
||||
functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
|
||||
functions = parse_native_yaml(native_yaml_path).native_functions
|
||||
|
||||
# What's the difference between function schema v.s. signature?
|
||||
# function schema is the complete declaration including mutability annotation / default value and etc.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from tools.codegen.model import (
|
||||
FunctionSchema, BaseTy, BaseType, NativeFunction, Argument,
|
||||
FunctionSchema, BaseTy, BaseType, NativeFunction, Argument, Tag,
|
||||
)
|
||||
from tools.codegen.api.types import (
|
||||
Binding, NamedCType, ConstRefCType, BaseCType, CType, tensorT, longT
|
||||
|
|
@ -44,7 +44,7 @@ mutated_view_idx_binding = Binding(
|
|||
# The name returned here corresponds to the name of the inner function called by the lambda.
|
||||
def name(f: NativeFunction, *, functional_op: NativeFunction, is_reverse: bool, include_namespace: bool) -> str:
|
||||
# For inplace_view ops, the lambda calls out to the corresponding functional view op
|
||||
fn = functional_op if 'inplace_view' in f.tags else f
|
||||
fn = functional_op if f.tag is Tag.inplace_view else f
|
||||
name = fn.func.name.unambiguous_name()
|
||||
if is_reverse:
|
||||
# in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import argparse
|
|||
import pathlib
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
|
||||
from tools.codegen.model import (Argument, DispatchKey, FunctionSchema,
|
||||
Location, NativeFunction,
|
||||
|
|
@ -18,7 +17,7 @@ from tools.codegen.model import (Argument, DispatchKey, FunctionSchema,
|
|||
is_cuda_dispatch_key,
|
||||
is_generic_dispatch_key,
|
||||
is_ufunc_dispatch_key,
|
||||
BaseOperatorName)
|
||||
Tag, BaseOperatorName)
|
||||
from tools.codegen.api.types import (Binding, CppSignature, CppSignatureGroup,
|
||||
DispatcherSignature, NativeSignature)
|
||||
from tools.codegen.api import cpp
|
||||
|
|
@ -114,7 +113,7 @@ _GLOBAL_PARSE_NATIVE_YAML_CACHE = {}
|
|||
# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
|
||||
ParsedYaml = namedtuple('ParsedYaml', ['native_functions', 'backend_indices'])
|
||||
|
||||
def parse_native_yaml_struct(es: object, valid_tags: Set[str], path: str = "<stdin>") -> ParsedYaml:
|
||||
def parse_native_yaml_struct(es: object, path: str = "<stdin>") -> ParsedYaml:
|
||||
assert isinstance(es, list)
|
||||
rs: List[NativeFunction] = []
|
||||
bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
|
||||
|
|
@ -123,7 +122,7 @@ def parse_native_yaml_struct(es: object, valid_tags: Set[str], path: str = "<std
|
|||
loc = Location(path, e['__line__'])
|
||||
funcs = e.get('func')
|
||||
with context(lambda: f'in {loc}:\n {funcs}'):
|
||||
func, m = NativeFunction.from_yaml(e, loc, valid_tags)
|
||||
func, m = NativeFunction.from_yaml(e, loc)
|
||||
rs.append(func)
|
||||
BackendIndex.grow_index(bs, m)
|
||||
error_check_native_functions(rs)
|
||||
|
|
@ -145,38 +144,12 @@ def parse_native_yaml_struct(es: object, valid_tags: Set[str], path: str = "<std
|
|||
index=v)
|
||||
return ParsedYaml(rs, indices)
|
||||
|
||||
def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
|
||||
assert isinstance(es, list)
|
||||
rs: Set[str] = set()
|
||||
for e in es:
|
||||
assert isinstance(e.get('__line__'), int), e
|
||||
loc = Location(path, e['__line__'])
|
||||
tags = e.get('tag')
|
||||
with context(lambda: f'in {loc}:\n {tags}'):
|
||||
e_i = e.copy()
|
||||
name = e_i.pop('tag')
|
||||
desc = e_i.pop('desc', '')
|
||||
# ensure that each tag has a non-empty description
|
||||
assert desc != ''
|
||||
rs.add(name)
|
||||
return rs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def parse_tags_yaml(path: str) -> Set[str]:
|
||||
# TODO: parse tags.yaml and create a tags database (a dict of tag name mapping to a Tag object)
|
||||
with open(path, 'r') as f:
|
||||
es = yaml.load(f, Loader=LineLoader)
|
||||
valid_tags = parse_tags_yaml_struct(es, path=path)
|
||||
return valid_tags
|
||||
|
||||
def parse_native_yaml(path: str, tags_yaml_path: str) -> ParsedYaml:
|
||||
# TODO: parse tags.yaml and create a tags database (a dict of tag name mapping to a Tag object)
|
||||
def parse_native_yaml(path: str) -> ParsedYaml:
|
||||
global _GLOBAL_PARSE_NATIVE_YAML_CACHE
|
||||
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
|
||||
valid_tags = parse_tags_yaml(tags_yaml_path)
|
||||
with open(path, 'r') as f:
|
||||
es = yaml.load(f, Loader=LineLoader)
|
||||
_GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(es, valid_tags, path=path)
|
||||
_GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(es, path=path)
|
||||
|
||||
return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
|
||||
|
||||
|
|
@ -195,7 +168,7 @@ def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
|
|||
f"{f.func.name} is marked as a structured_delegate pointing to " \
|
||||
f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. " \
|
||||
f"Consider adding 'structured=True' to the delegated operator"
|
||||
if 'inplace_view' in f.tags:
|
||||
if f.tag is not None and f.tag is Tag.inplace_view:
|
||||
base_name = f.func.name.name
|
||||
overload_name = f.func.name.overload_name
|
||||
assert base_name.inplace, \
|
||||
|
|
@ -1688,8 +1661,7 @@ def main() -> None:
|
|||
)
|
||||
|
||||
native_yaml_path = os.path.join(options.source_path, 'native/native_functions.yaml')
|
||||
tags_yaml_path = os.path.join(options.source_path, 'native/tags.yaml')
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path)
|
||||
native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
|
||||
grouped_native_functions = get_grouped_native_functions(native_functions)
|
||||
structured_native_functions = [g for g in grouped_native_functions
|
||||
|
|
|
|||
|
|
@ -301,8 +301,7 @@ def run(source_yaml: str, output_dir: str, dry_run: bool, impl_path: Optional[st
|
|||
fm = make_file_manager(output_dir)
|
||||
|
||||
native_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/native_functions.yaml')
|
||||
tags_yaml_path = os.path.join(pytorch_root, 'aten/src/ATen/native/tags.yaml')
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path)
|
||||
native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
|
||||
grouped_native_functions = get_grouped_native_functions(native_functions)
|
||||
parsed_backend_yaml = parse_backend_yaml(source_yaml, grouped_native_functions, backend_indices)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from tools.codegen.api.translate import translate
|
|||
from tools.codegen.context import with_native_function
|
||||
from tools.codegen.model import (
|
||||
Argument, NativeFunction, SchemaKind, BackendIndex,
|
||||
FunctionSchema, SelfArgument, TensorOptionsArguments, BaseType, BaseTy
|
||||
Tag, FunctionSchema, SelfArgument, TensorOptionsArguments, BaseType, BaseTy
|
||||
)
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
from typing import List, Optional, Union, Tuple
|
||||
|
|
@ -107,7 +107,7 @@ def emit_view_functionalization_body(
|
|||
# view op case
|
||||
assert f.is_view_op
|
||||
|
||||
if 'inplace_view' in f.tags:
|
||||
if f.tag is Tag.inplace_view:
|
||||
# This op is both an inplace op AND a view op.
|
||||
# See Note [Functionalization Pass - Inplace View Ops] for details.
|
||||
# I currently have the view meta call into the out-of-place variant of the view, to avoid
|
||||
|
|
@ -142,7 +142,7 @@ def emit_view_functionalization_body(
|
|||
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
|
||||
meta_call_args = [e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False)]
|
||||
|
||||
if 'inplace_view' in f.tags:
|
||||
if f.tag is Tag.inplace_view:
|
||||
# See Note [Functionalization Pass - Inplace View Ops] for more details
|
||||
return f"""
|
||||
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
|
||||
|
|
|
|||
|
|
@ -135,8 +135,7 @@ def run_gen_lazy_tensor(aten_path: str, source_yaml: str, output_dir: str,
|
|||
fm = make_file_manager(output_dir)
|
||||
|
||||
native_yaml_path = os.path.join(aten_path, 'native/native_functions.yaml')
|
||||
tags_yaml_path = os.path.join(aten_path, 'native/tags.yaml')
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path)
|
||||
native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
|
||||
grouped_native_functions = get_grouped_native_functions(native_functions)
|
||||
|
||||
|
|
|
|||
|
|
@ -270,6 +270,19 @@ class DeviceCheckType(Enum):
|
|||
NoCheck = 0
|
||||
ExactSame = 1
|
||||
|
||||
class Tag(Enum):
|
||||
inplace_view = 0
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
@staticmethod
|
||||
def parse(value: str) -> 'Tag':
|
||||
for k, v in Tag.__members__.items():
|
||||
if k == value:
|
||||
return v
|
||||
raise AssertionError(f'unknown tag {value}')
|
||||
|
||||
# The basic input to the code generation is native_functions.yaml.
|
||||
# The name "native", BTW, comes from the distinction between native
|
||||
# functions and legacy TH functions. The legacy TH functions are gone,
|
||||
|
|
@ -377,7 +390,8 @@ class NativeFunction:
|
|||
|
||||
# Tags are used to describe semantic information about (groups of) operators,
|
||||
# That aren't easily inferrable directly from the operator's schema.
|
||||
tags: Set[str]
|
||||
# For now operators have at most one tag.
|
||||
tag: Optional['Tag']
|
||||
|
||||
# NB: The benefit of defining a dataclass is that we automatically get
|
||||
# a constructor defined for all the fields we specify. No need
|
||||
|
|
@ -387,8 +401,7 @@ class NativeFunction:
|
|||
@staticmethod
|
||||
def from_yaml(
|
||||
ei: Dict[str, object],
|
||||
loc: 'Location',
|
||||
valid_tags: Set[str]
|
||||
loc: 'Location'
|
||||
) -> Tuple['NativeFunction', Dict[DispatchKey, Dict['OperatorName', 'BackendMetadata']]]:
|
||||
"""
|
||||
Parse a NativeFunction from a dictionary as directly parsed
|
||||
|
|
@ -457,18 +470,9 @@ class NativeFunction:
|
|||
assert precomputed_dict is None or structured is True
|
||||
precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None
|
||||
|
||||
tags_s = e.pop('tags', '')
|
||||
assert isinstance(tags_s, str)
|
||||
tags: Set[str] = set()
|
||||
if len(tags_s) > 0:
|
||||
assert len(valid_tags) > 0
|
||||
for t in tags_s.split(', '):
|
||||
# TODO: verify that the tag is valid and has an entry in tags.yaml
|
||||
if t in valid_tags:
|
||||
tags.add(t)
|
||||
else:
|
||||
raise AssertionError(f'illegal tag {t}')
|
||||
assert isinstance(tags, set)
|
||||
tag_str = e.pop('tags', None)
|
||||
assert tag_str is None or isinstance(tag_str, str), f'not a str: {tag_str}'
|
||||
tag = Tag.parse(tag_str) if tag_str else None
|
||||
|
||||
from tools.codegen.api import cpp
|
||||
|
||||
|
|
@ -585,7 +589,7 @@ class NativeFunction:
|
|||
is_abstract=is_abstract,
|
||||
has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
|
||||
has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
|
||||
tags=tags,
|
||||
tag=tag,
|
||||
), backend_metadata
|
||||
|
||||
|
||||
|
|
@ -641,7 +645,7 @@ class NativeFunction:
|
|||
def is_view_op(self) -> bool:
|
||||
rets = self.func.returns
|
||||
is_non_mutating_view = len(rets) > 0 and any(r.annotation is not None and not r.annotation.is_write for r in rets)
|
||||
is_inplace_view = 'inplace_view' in self.tags
|
||||
is_inplace_view = self.tag is not None and self.tag is Tag.inplace_view
|
||||
is_wildcard_view = any(inp.annotation is not None and
|
||||
inp.annotation.alias_set_after != "" for inp in self.func.schema_order_arguments())
|
||||
return is_non_mutating_view or is_inplace_view or is_wildcard_view
|
||||
|
|
|
|||
|
|
@ -123,8 +123,7 @@ def main() -> None:
|
|||
default='benchmarks/static_runtime/test_generated_ops.cc')
|
||||
options = parser.parse_args()
|
||||
native_yaml_path = os.path.join(options.source_path, 'native/native_functions.yaml')
|
||||
tags_yaml_path = os.path.join(options.source_path, 'native/tags.yaml')
|
||||
parsed_yaml = gen.parse_native_yaml(native_yaml_path, tags_yaml_path)
|
||||
parsed_yaml = gen.parse_native_yaml(native_yaml_path)
|
||||
native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
|
||||
grouped_native_functions = gen.get_grouped_native_functions(native_functions)
|
||||
structured_native_functions = [g for g in grouped_native_functions
|
||||
|
|
|
|||
|
|
@ -206,8 +206,7 @@ def main() -> None:
|
|||
selector = SelectiveBuilder.get_nop_selector()
|
||||
|
||||
native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
|
||||
tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
|
||||
parsed_yaml = parse_native_yaml(native_yaml_path)
|
||||
native_functions, backend_indices = (
|
||||
parsed_yaml.native_functions,
|
||||
parsed_yaml.backend_indices,
|
||||
|
|
|
|||
|
|
@ -51,8 +51,6 @@ def run_autogen() -> None:
|
|||
"tools/setup_helpers/generate_code.py",
|
||||
"--native-functions-path",
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
"--tags-path",
|
||||
"aten/src/ATen/native/tags.yaml",
|
||||
"--nn-path",
|
||||
"aten/src",
|
||||
"--gen_lazy_ts_backend",
|
||||
|
|
|
|||
|
|
@ -259,7 +259,7 @@ def gen_nn_functional(fm: FileManager) -> None:
|
|||
})
|
||||
|
||||
|
||||
def gen_pyi(native_yaml_path: str, tags_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -> None:
|
||||
def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -> None:
|
||||
"""gen_pyi()
|
||||
|
||||
This function generates a pyi file for torch.
|
||||
|
|
@ -388,7 +388,7 @@ def gen_pyi(native_yaml_path: str, tags_yaml_path: str, deprecated_yaml_path: st
|
|||
' other: Union[Tensor, Number],'
|
||||
' *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
|
||||
|
||||
native_functions = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
|
||||
native_functions = parse_native_yaml(native_yaml_path).native_functions
|
||||
native_functions = list(filter(should_generate_py_binding, native_functions))
|
||||
|
||||
function_signatures = load_signatures(native_functions, deprecated_yaml_path, method=False, pyi=True)
|
||||
|
|
@ -625,9 +625,6 @@ def main() -> None:
|
|||
parser.add_argument('--native-functions-path', metavar='NATIVE',
|
||||
default='aten/src/ATen/native/native_functions.yaml',
|
||||
help='path to native_functions.yaml')
|
||||
parser.add_argument('--tags-path', metavar='TAGS',
|
||||
default='aten/src/ATen/native/tags.yaml',
|
||||
help='path to tags.yaml')
|
||||
parser.add_argument('--deprecated-functions-path', metavar='DEPRECATED',
|
||||
default='tools/autograd/deprecated.yaml',
|
||||
help='path to deprecated.yaml')
|
||||
|
|
@ -636,7 +633,7 @@ def main() -> None:
|
|||
help='path to output directory')
|
||||
args = parser.parse_args()
|
||||
fm = FileManager(install_dir=args.out, template_dir='.', dry_run=False)
|
||||
gen_pyi(args.native_functions_path, args.tags_path, args.deprecated_functions_path, fm)
|
||||
gen_pyi(args.native_functions_path, args.deprecated_functions_path, fm)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ except ImportError:
|
|||
source_files = {'.py', '.cpp', '.h'}
|
||||
|
||||
NATIVE_FUNCTIONS_PATH = 'aten/src/ATen/native/native_functions.yaml'
|
||||
TAGS_PATH = 'aten/src/ATen/native/tags.yaml'
|
||||
|
||||
# TODO: This is a little inaccurate, because it will also pick
|
||||
# up setup_helper scripts which don't affect code generation
|
||||
|
|
@ -30,7 +29,6 @@ def all_generator_source() -> List[str]:
|
|||
def generate_code(ninja_global: Optional[str] = None,
|
||||
nn_path: Optional[str] = None,
|
||||
native_functions_path: Optional[str] = None,
|
||||
tags_path: Optional[str] = None,
|
||||
install_dir: Optional[str] = None,
|
||||
subset: Optional[str] = None,
|
||||
disable_autograd: bool = False,
|
||||
|
|
@ -60,7 +58,6 @@ def generate_code(ninja_global: Optional[str] = None,
|
|||
if subset == "pybindings" or not subset:
|
||||
gen_autograd_python(
|
||||
native_functions_path or NATIVE_FUNCTIONS_PATH,
|
||||
tags_path or TAGS_PATH,
|
||||
autograd_gen_dir,
|
||||
autograd_dir)
|
||||
|
||||
|
|
@ -71,7 +68,6 @@ def generate_code(ninja_global: Optional[str] = None,
|
|||
|
||||
gen_autograd(
|
||||
native_functions_path or NATIVE_FUNCTIONS_PATH,
|
||||
tags_path or TAGS_PATH,
|
||||
autograd_gen_dir,
|
||||
autograd_dir,
|
||||
disable_autograd=disable_autograd,
|
||||
|
|
@ -81,7 +77,6 @@ def generate_code(ninja_global: Optional[str] = None,
|
|||
if subset == "python" or not subset:
|
||||
gen_annotated(
|
||||
native_functions_path or NATIVE_FUNCTIONS_PATH,
|
||||
tags_path or TAGS_PATH,
|
||||
python_install_dir,
|
||||
autograd_dir)
|
||||
|
||||
|
|
@ -140,7 +135,6 @@ def get_selector(
|
|||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description='Autogenerate code')
|
||||
parser.add_argument('--native-functions-path')
|
||||
parser.add_argument('--tags-path')
|
||||
parser.add_argument('--nn-path')
|
||||
parser.add_argument('--ninja-global')
|
||||
parser.add_argument('--install_dir')
|
||||
|
|
@ -184,7 +178,6 @@ def main() -> None:
|
|||
options.ninja_global,
|
||||
options.nn_path,
|
||||
options.native_functions_path,
|
||||
options.tags_path,
|
||||
options.install_dir,
|
||||
options.subset,
|
||||
options.disable_autograd,
|
||||
|
|
|
|||
|
|
@ -130,8 +130,7 @@ class TestGenAutogradFunctions(unittest.TestCase):
|
|||
# to edit for use.
|
||||
DEFAULT_NATIVE_FUNCTION, _ = tools.codegen.model.NativeFunction.from_yaml(
|
||||
{'func': 'func() -> bool'},
|
||||
tools.codegen.model.Location(__file__, 1),
|
||||
set())
|
||||
loc=tools.codegen.model.Location(__file__, 1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class TestCodegenModel(expecttest.TestCase):
|
|||
def assertParseErrorInline(self, yaml_str: str, expect: str) -> None:
|
||||
es = yaml.load(yaml_str, Loader=LineLoader)
|
||||
try:
|
||||
parse_native_yaml_struct(es, set())
|
||||
parse_native_yaml_struct(es)
|
||||
except AssertionError as e:
|
||||
# hack to strip out the context
|
||||
msg, _ = str(e).split(' in ', 2)
|
||||
|
|
@ -25,7 +25,7 @@ class TestCodegenModel(expecttest.TestCase):
|
|||
def assertUfuncErrorInline(self, yaml_str: str, expect: str) -> None:
|
||||
# parse a single structured group out of the yaml to g
|
||||
es = yaml.load(yaml_str, Loader=LineLoader)
|
||||
parsed_yaml = parse_native_yaml_struct(es, set())
|
||||
parsed_yaml = parse_native_yaml_struct(es)
|
||||
native_functions, backend_indices = parsed_yaml.native_functions, parsed_yaml.backend_indices
|
||||
grouped_native_functions = gen.get_grouped_native_functions(native_functions)
|
||||
assert len(grouped_native_functions) == 1
|
||||
|
|
|
|||
|
|
@ -203,7 +203,6 @@ add_custom_command(
|
|||
COMMAND
|
||||
"${PYTHON_EXECUTABLE}" -mtools.pyi.gen_pyi
|
||||
--native-functions-path "aten/src/ATen/native/native_functions.yaml"
|
||||
--tags-path "aten/src/ATen/native/tags.yaml"
|
||||
--deprecated-functions-path "tools/autograd/deprecated.yaml"
|
||||
DEPENDS
|
||||
"${TORCH_SRC_DIR}/_C/__init__.pyi.in"
|
||||
|
|
@ -211,7 +210,6 @@ add_custom_command(
|
|||
"${TORCH_SRC_DIR}/nn/functional.pyi.in"
|
||||
"${TOOLS_PATH}/pyi/gen_pyi.py"
|
||||
"${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml"
|
||||
"${TORCH_ROOT}/aten/src/ATen/native/tags.yaml"
|
||||
"${TORCH_ROOT}/tools/autograd/deprecated.yaml"
|
||||
WORKING_DIRECTORY
|
||||
"${TORCH_ROOT}"
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ def deindent(code: str) -> str:
|
|||
return '\n'.join(lines)
|
||||
|
||||
|
||||
def gen_external(native_functions_path, tags_path, external_path):
|
||||
native_functions = parse_native_yaml(native_functions_path, tags_path)
|
||||
def gen_external(native_functions_path, external_path):
|
||||
native_functions = parse_native_yaml(native_functions_path)
|
||||
func_decls = []
|
||||
func_registrations = []
|
||||
for func in native_functions:
|
||||
|
|
@ -83,14 +83,11 @@ def main() -> None:
|
|||
parser.add_argument('--native_functions',
|
||||
help='path to native_functions.yaml',
|
||||
default='../../../../aten/src/ATen/native/native_functions.yaml')
|
||||
parser.add_argument('--tags',
|
||||
help='path to tags.yaml',
|
||||
default='../../../../aten/src/ATen/native/tags.yaml')
|
||||
parser.add_argument('--template_path',
|
||||
help='path to external_functions_codegen_template.cpp',
|
||||
default='../../../../tools/jit/templates/external_functions_codegen_template.cpp')
|
||||
args = parser.parse_args()
|
||||
gen_external(args.native_functions, args.tags, args.template_path)
|
||||
gen_external(args.native_functions, args.template_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user