pytorch/tools/autograd/gen_python_functions.py
Edward Z. Yang fca03eeec1 Make proxy tensor support item() calls on torch.tensor constants (#81192)
This PR is doing a few interrelated things, all of which are necessary to get correctness. Read the comment in torch/fx/experimental/proxy_tensor.py for the high level overview.

Let's break down the parts of this PR:

* Bug fix where `enable_torch_dispatch_mode` with `None` doesn't work. This make `enable_torch_dispatch_mode(current_mode.inner)` work which is the basis for how we temporarily disable fake tensor mode.
* Bug fix for when fake tensor mode is combined with a non-mode tensor subclass. This actually could be ablated from this PR but it affects where the logic for allowing non fake tensor inputs with lift goes, so it's all in here in one go. There are some relevant tests for the fix in fake tensor, but it turns out I didn't need this because I'm always using proxy tensors as a mode (which ensures the ordering is right.)
* New `lift_fresh` view operator.  Note that like lift, we have to manually write the functionalize kernel for these functions.
* The actual change, which is to save constants when we see them in the proxy tensor mode, and then propagate them as we go (because otherwise you'll handle mutations on constants incorrectly--see test.)

This is mildly BC-breaking if anyone was previously interposing on
at::lift, but this operator was relatively new and I checked
functorch which has no explicit reference to lift.  So I think it
should not be too disruptive.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81192
Approved by: https://github.com/samdow, https://github.com/bdhirsh
2022-07-15 03:53:40 +00:00

1263 lines
41 KiB
Python

# Generates Python bindings for ATen functions
#
# The bindings are generated as methods on python_variable or functions on the
# torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._sparse or torch._C._special objects.
#
# Code tries to stick to the following rules:
#
# - templates should be colocated with the functions that use them.
# no templates are currently shared between functions, but if that
# happens, maybe put the template with the first one
#
# - don't use environment dictionaries when calling template.substitute().
# pass named arguments directly for everything, otherwise it's much too
# hard to track what's actually being used and by who
#
# - colocate any new hacks/adjustments with existing ones of the same kind.
# ideally in a data structure rather than code if possible. See e.g.
# SCHEMA_DEFAULT_CONVERSION_HACKS, etc.
#
# - similarly, conversions from one format to another should ideally happen
# all at once in a single place.
#
# - no nontrivial nested functions. couple-liners are ok but please no more.
# especially avoid functions that read/write outer variables defined far away.
#
# - raise RuntimeError instead of asserting, and put as much
# information as is available into the message. I.e. no need to
# plumb in new params whose only purpose is to fill out an error
# message, but use what's there
#
import itertools
import re
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple
import yaml
from torchgen.api import cpp
from torchgen.api.python import (
arg_parser_output_exprs,
argument_type_str,
cpp_dispatch_exprs,
cpp_dispatch_target,
dispatch_lambda_args,
dispatch_lambda_exprs,
dispatch_lambda_return_str,
has_tensor_options,
namedtuple_fieldnames,
PythonArgument,
PythonSignature,
PythonSignatureDeprecated,
PythonSignatureGroup,
PythonSignatureNativeFunctionPair,
signature,
)
from torchgen.api.types import CppSignatureGroup
from torchgen.code_template import CodeTemplate
from torchgen.context import with_native_function
from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
from torchgen.model import Argument, BaseOperatorName, NativeFunction, Type, Variant
from torchgen.utils import FileManager, split_name_params, YamlLoader
from .gen_trace_type import should_trace
#
# declarations blocklist
# We skip codegen for these functions, for various reasons.
# Future PRs will categorize this list and eliminate or hoist
# them out of eager-only codegen.
# See https://github.com/pytorch/pytorch/issues/30788
#
# These functions require manual Python bindings or are not exposed to Python
_SKIP_PYTHON_BINDINGS = [
"alias",
"contiguous",
"is_cuda",
"is_sparse",
"is_sparse_csr",
"size",
"stride",
".*_backward",
".*_backward_(out|input|weight|bias)",
".*_forward",
".*_forward_out",
".*_jvp",
"_unsafe_view",
"tensor",
"_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*",
"_arange.*",
"_range.*",
"linspace.*",
"logspace.*",
"_sparse_add_out",
"_sparse_div.*",
"_sparse_mul.*",
"_sparse_sub.*",
"_sparse_dense_add_out",
"index",
"index_out",
"unique_dim_consecutive",
"_cumsum.*",
"_cumprod.*",
"_sum.*",
"_prod.*",
"_th_.*",
"_thnn_.*",
"arange.*",
"range.*",
"_solve.*",
"_inverse.*",
"full(_out)?",
"_cholesky.*",
"_triangular_solve.*",
"_qr.*",
"_symeig.*",
"_svd.*",
"slice",
"randint(_out)?",
"item",
"_local_scalar_dense",
"to",
"_to_copy",
"copy_sparse_to_sparse_",
"copy_",
"numpy_T",
"matrix_H",
"mT",
"mH", # these need to be an attributes in Python, not functions
"nonzero(_(out|numpy))?",
"set_data",
".*_overrideable", # overrideable functions for backend extension
"data",
"is_leaf",
"output_nr",
"_version",
"requires_grad_",
"retains_grad",
"set_",
"_fw_primal",
"fake_quantize_per_tensor_affine_cachemask",
"fake_quantize_per_channel_affine_cachemask",
"_new_zeros_with_same_feature_meta",
"_has_same_storage_numel", # used for forward AD internals
"_reshape_alias",
"replace_", # only used by the functionalization pass, doesn't need to be exposed to python
"copy", # only used by the functionalization pass
"fill.Tensor", # only used by the functionalization pass
"fill.Scalar", # only used by the functionalization pass
"lift.*",
"normal_functional", # only used by the functionalization pas
]
SKIP_PYTHON_BINDINGS = list(
map(lambda pattern: re.compile(rf"^{pattern}$"), _SKIP_PYTHON_BINDINGS)
)
# These function signatures are not exposed to Python. Note that this signature
# list does not support regex.
SKIP_PYTHON_BINDINGS_SIGNATURES = [
"add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
"add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
"sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
"sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
"mul.Scalar(Tensor self, Scalar other) -> Tensor",
"mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
"div.Scalar(Tensor self, Scalar other) -> Tensor",
"div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
]
@with_native_function
def should_generate_py_binding(f: NativeFunction) -> bool:
# So far, all NativeFunctions that are entirely code-generated do not get python bindings.
if "generated" in f.tags:
return False
name = cpp.name(f.func)
for skip_regex in SKIP_PYTHON_BINDINGS:
if skip_regex.match(name):
return False
signature = str(f.func)
for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
if pattern == signature:
return False
return True
def get_pycname(name: BaseOperatorName) -> str:
return f"THPVariable_{name}"
def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool:
return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0
def is_py_variable_method(f: NativeFunction) -> bool:
return f.python_module is None and Variant.method in f.variants
def is_py_torch_function(f: NativeFunction) -> bool:
return f.python_module is None and Variant.function in f.variants
def is_py_nn_function(f: NativeFunction) -> bool:
return f.python_module == "nn"
def is_py_fft_function(f: NativeFunction) -> bool:
return f.python_module == "fft"
def is_py_linalg_function(f: NativeFunction) -> bool:
return f.python_module == "linalg"
def is_py_sparse_function(f: NativeFunction) -> bool:
return f.python_module == "sparse"
def is_py_special_function(f: NativeFunction) -> bool:
return f.python_module == "special"
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Main Function
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def gen(
out: str,
native_yaml_path: str,
tags_yaml_path: str,
deprecated_yaml_path: str,
template_path: str,
) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
native_functions = parse_native_yaml(
native_yaml_path, tags_yaml_path
).native_functions
native_functions = list(filter(should_generate_py_binding, native_functions))
methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
create_python_bindings(
fm,
methods,
is_py_variable_method,
None,
"python_variable_methods.cpp",
method=True,
)
# NOTE: num_shards here must be synced with gatherTorchFunctions in
# torch/csrc/autograd/python_torch_functions_manual.cpp
functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
create_python_bindings_sharded(
fm,
functions,
is_py_torch_function,
"torch",
"python_torch_functions.cpp",
method=False,
num_shards=3,
)
create_python_bindings(
fm,
functions,
is_py_nn_function,
"torch.nn",
"python_nn_functions.cpp",
method=False,
)
create_python_bindings(
fm,
functions,
is_py_fft_function,
"torch.fft",
"python_fft_functions.cpp",
method=False,
)
create_python_bindings(
fm,
functions,
is_py_linalg_function,
"torch.linalg",
"python_linalg_functions.cpp",
method=False,
)
create_python_bindings(
fm,
functions,
is_py_sparse_function,
"torch.sparse",
"python_sparse_functions.cpp",
method=False,
)
create_python_bindings(
fm,
functions,
is_py_special_function,
"torch.special",
"python_special_functions.cpp",
method=False,
)
# Currently, we only use `functions` to generate `return_types` bindings.
# All methods which return namedtuple have function variant at this point.
# If any method only operator with namedtuple is added in the future,
# we will have to address that.
create_python_return_type_bindings(
fm, functions, lambda fn: True, "python_return_types.cpp"
)
valid_tags = parse_tags_yaml(tags_yaml_path)
def gen_tags_enum() -> Dict[str, str]:
return {
"enum_of_valid_tags": (
"".join([f'\n.value("{tag}", at::Tag::{tag})' for tag in valid_tags])
)
}
fm.write("python_enum_tag.cpp", gen_tags_enum)
def group_filter_overloads(
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]:
grouped: Dict[
BaseOperatorName, List[PythonSignatureNativeFunctionPair]
] = defaultdict(list)
for pair in pairs:
if pred(pair.function):
grouped[pair.function.func.name.name].append(pair)
return grouped
def create_python_bindings(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
module: Optional[str],
filename: str,
*,
method: bool,
) -> None:
"""Generates Python bindings to ATen functions"""
py_methods: List[str] = []
ops_headers: List[str] = []
py_method_defs: List[str] = []
py_forwards: List[str] = []
grouped = group_filter_overloads(pairs, pred)
for name in sorted(grouped.keys(), key=lambda x: str(x)):
overloads = grouped[name]
py_methods.append(method_impl(name, module, overloads, method=method))
py_method_defs.append(method_def(name, module, overloads, method=method))
py_forwards.extend(forward_decls(name, overloads, method=method))
ops_headers.append(f"#include <ATen/ops/{name.base}.h>")
fm.write_with_template(
filename,
filename,
lambda: {
"generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
"ops_headers": ops_headers,
"py_forwards": py_forwards,
"py_methods": py_methods,
"py_method_defs": py_method_defs,
},
)
def create_python_return_type_bindings(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
filename: str,
) -> None:
"""
Generate function to initialize and return named tuple for native functions
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
"""
py_return_types_definition: List[str] = []
py_return_types_map: List[str] = []
grouped = group_filter_overloads(pairs, pred)
for name in sorted(grouped.keys(), key=lambda x: str(x)):
overloads = grouped[name]
definitions, map_entries = generate_return_type_definition_and_map_entry(
overloads
)
py_return_types_definition.append(
"" if not definitions else "\n".join(definitions)
)
py_return_types_map.append("" if not map_entries else "\n".join(map_entries))
fm.write_with_template(
filename,
filename,
lambda: {
"generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
"py_return_types": py_return_types_definition,
"py_return_types_map": py_return_types_map,
},
)
def create_python_bindings_sharded(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
module: Optional[str],
filename: str,
*,
method: bool,
num_shards: int,
) -> None:
"""Generates Python bindings to ATen functions"""
grouped = group_filter_overloads(pairs, pred)
def key_func(
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
) -> str:
return kv[0].base
def env_func(
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
) -> Dict[str, List[str]]:
name, fn_pairs = kv
return {
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
"py_forwards": list(forward_decls(name, fn_pairs, method=method)),
"py_methods": [method_impl(name, module, fn_pairs, method=method)],
"py_method_defs": [method_def(name, module, fn_pairs, method=method)],
}
fm.write_sharded(
filename,
grouped.items(),
base_env={
"generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
},
key_fn=key_func,
env_callable=env_func,
num_shards=num_shards,
sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"},
)
def load_signatures(
native_functions: List[NativeFunction],
deprecated_yaml_path: str,
*,
method: bool,
skip_deprecated: bool = False,
pyi: bool = False,
) -> Sequence[PythonSignatureNativeFunctionPair]:
@with_native_function
def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair:
return PythonSignatureNativeFunctionPair(
signature=signature(f, method=method, pyi=pyi),
function=f,
)
pairs = list(map(gen_signature_pairs, native_functions))
deprecated = load_deprecated_signatures(
pairs, deprecated_yaml_path, method=method, pyi=pyi
)
return pairs if skip_deprecated else pairs + deprecated
def load_deprecated_signatures(
pairs: Sequence[PythonSignatureNativeFunctionPair],
deprecated_yaml_path: str,
*,
method: bool,
pyi: bool,
) -> List[PythonSignatureNativeFunctionPair]:
# The deprecated.yaml doesn't have complete type information, we need
# find and leverage the original ATen signature (to which it delegates
# the call) to generate the full python signature.
# We join the deprecated and the original signatures using type-only form.
# native function -> type-only signature
@with_native_function
def signature_original(f: NativeFunction) -> str:
# remove inplace suffix but keep outplace suffix
opname = str(f.func.name.name.base)
if f.func.is_out_fn():
opname += "_out"
if f.func.name.name.inplace and pyi:
opname += "_"
args = CppSignatureGroup.from_native_function(
f, method=False
).signature.arguments()
# Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml.
types = ", ".join(
argument_type_str(a.argument.type)
for a in args
if isinstance(a.argument, Argument)
)
return f"{opname}({types})"
# deprecated -> type-only native signature (according to the call order)
def signature_deprecated(
opname: str, params: List[str], call_args: List[str]
) -> str:
# create a mapping of parameter name to parameter type
types: Dict[str, str] = {}
for param in params:
if param == "*":
continue
type, name = param.split(" ")
types[name] = type
# if the name in the call is not in the parameter list, assume it's
# a literal Scalar
rearranged_types = ", ".join(types.get(arg, "Scalar") for arg in call_args)
return f"{opname}({rearranged_types})"
# group the original ATen signatures by type-only signature
grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
for pair in pairs:
grouped[signature_original(pair.function)].append(pair)
# find matching original signatures for each deprecated signature
results: List[PythonSignatureNativeFunctionPair] = []
with open(deprecated_yaml_path, "r") as f:
deprecated_defs = yaml.load(f, Loader=YamlLoader)
for deprecated in deprecated_defs:
_, params = split_name_params(deprecated["name"])
aten_name, call_args = split_name_params(deprecated["aten"])
for pair in grouped[signature_deprecated(aten_name, params, call_args)]:
# It uses the types from the original ATen declaration, but the
# ordering and parameter names from the deprecated overload. Any
# default parameter values from the original ATen declaration are
# ignored.
# Deprecated signature might reorder input_args and input_kwargs,
# but never changes output_args nor TensorOptions (if any?),
# so here we only look into these two types of args.
python_sig = pair.signature
src_args: Dict[str, PythonArgument] = {
a.name: PythonArgument(
name=a.name,
type=a.type,
default=None,
default_init=None,
)
for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs)
}
args: List[str] = []
input_args: List[PythonArgument] = []
input_kwargs: List[PythonArgument] = []
kwarg_only = False
for param in params:
if param == "*":
kwarg_only = True
continue
_, param_name = param.split(" ")
args.append(param_name)
if param_name not in src_args:
# output argument
continue
if not kwarg_only:
if not method or param_name != "self":
input_args.append(src_args[param_name])
else:
input_kwargs.append(src_args[param_name])
results.append(
PythonSignatureNativeFunctionPair(
signature=PythonSignatureDeprecated(
name=python_sig.name,
input_args=tuple(input_args),
input_kwargs=tuple(input_kwargs),
output_args=python_sig.output_args,
tensor_options_args=python_sig.tensor_options_args,
method=python_sig.method,
deprecated_args_names=tuple(args),
deprecated_args_exprs=tuple(call_args),
returns=python_sig.returns,
),
function=pair.function,
)
)
return results
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Named Tuple Codegen
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
@with_native_function
def gen_namedtuple_typename_key(f: NativeFunction) -> str:
name = cpp.name(f.func)
fieldnames = namedtuple_fieldnames(f.func.returns)
return "_".join([name] + fieldnames)
def emit_namedtuple_call(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], Dict[str, str]]:
"""
Generate block of named tuple type def inits, and add typeref snippets
to declarations that use them
"""
typenames: Dict[
str, str
] = {} # map from unique name + field name lists to typedef name
typedefs: List[str] = [] # typedef declarations and init code
for overload in overloads:
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
if not fieldnames:
continue
name = cpp.name(overload.function.func) # use @with_native_function?
tn_key = gen_namedtuple_typename_key(overload.function)
typename = typenames.get(tn_key)
if typename is None:
typename = f'NamedTuple{"" if not typedefs else len(typedefs)}'
typenames[tn_key] = typename
typedefs.append(
f"""\
static PyTypeObject* {typename} = get_namedtuple("{name}");"""
)
return typedefs, typenames
def generate_return_type_definition_and_map_entry(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], List[str]]:
"""
Generate block of function in `python_return_types.cpp` to initialize
and return named tuple for a native function which returns named tuple
and relevant entry for the map in same file.
"""
typenames: Dict[
str, str
] = {} # map from unique name + field name lists to typedef name
definitions: List[str] = [] # function defintion to register the typedef
map_entries: List[
str
] = [] # C++ map entry of <function_name, function creates it namedtuple>
for overload in overloads:
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
if not fieldnames:
continue
fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames)
name = cpp.name(overload.function.func) # use @with_native_function?
tn_key = gen_namedtuple_typename_key(overload.function)
typename = typenames.get(tn_key)
if typename is None:
typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}'
typenames[tn_key] = typename
definitions.append(
f"""\
PyTypeObject* get_{name}_namedtuple() {{
static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }};
static PyTypeObject {typename};
static bool is_initialized = false;
static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }};
if (!is_initialized) {{
PyStructSequence_InitType(&{typename}, &desc);
{typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
is_initialized = true;
}}
return &{typename};
}}
"""
)
map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ')
return definitions, map_entries
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Method Impl Codegen
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# python binding for all overloads of a particular function/method
PY_VARIABLE_METHOD_VARARGS = CodeTemplate(
r"""\
// ${name}
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
{
${method_header}
static PythonArgParser parser({
${signatures}
}, /*traceable=*/${traceable});
ParsedArgs<${max_args}> parsed_args;
auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
${check_has_torch_function}
switch (_r.idx) {
${dispatch}
}
${method_footer}
}
"""
)
# handler for a single parsed signature - may be a single overload or
# a pair of overloads that whose signatures only differ in output params
# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch})
PY_VARIABLE_CASE = CodeTemplate(
"""\
case ${overload_index}: {
${body}
}
"""
)
# python binding for single-overload function/method
PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate(
"""\
// ${name}
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
{
${method_header}
static PythonArgParser parser({
${signatures}
}, /*traceable=*/${traceable});
ParsedArgs<${max_args}> parsed_args;
auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
${check_has_torch_function}
${dispatch}
${method_footer}
}
"""
)
# python binding for a method with no args, shortcuts parsing
PY_VARIABLE_METHOD_NOARGS = CodeTemplate(
"""\
// ${name}
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
{
${method_header}
${check_has_torch_function}
${dispatch}
${method_footer}
}
"""
)
def method_impl(
name: BaseOperatorName,
module: Optional[str],
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
) -> str:
"""
Generate a python binding for all overloads of an op.
"""
pycname = get_pycname(name)
noarg = is_noarg(overloads)
namedtuple_inits, namedtuple_typenames = emit_namedtuple_call(overloads)
method_header = ["HANDLE_TH_ERRORS"]
method_header += namedtuple_inits
method_header += (
["const Tensor& self = THPVariable_Unpack(self_);"] if method else []
)
method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"]
traceable = "true" if all(should_trace(o.function) for o in overloads) else "false"
grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(overloads)
is_singleton = len(grouped_overloads) == 1
signatures: List[str] = []
dispatch: List[str] = []
for overload_index, overload in enumerate(grouped_overloads):
signature = overload.signature.signature_str()
signatures.append(f"{cpp_string(str(signature))},")
dispatch_body = emit_dispatch_case(overload, namedtuple_typenames)
dispatch.append(
PY_VARIABLE_CASE.substitute(
overload_index=overload_index, body=dispatch_body
)
if not is_singleton
else dispatch_body
)
if noarg:
template = PY_VARIABLE_METHOD_NOARGS
elif is_singleton:
template = PY_VARIABLE_METHOD_VARARGS_SINGLETON
else:
template = PY_VARIABLE_METHOD_VARARGS
return template.substitute(
name=name,
pycname=pycname,
method_header=method_header,
max_args=max(map(lambda o: o.signature.arguments_count(), overloads)),
signatures=signatures,
traceable=traceable,
check_has_torch_function=gen_has_torch_function_check(
name=name,
module=module,
noarg=noarg,
method=method,
),
dispatch=dispatch,
method_footer=method_footer,
self_="self_" if method else "nullptr",
)
def gen_has_torch_function_check(
name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool
) -> str:
if noarg:
if method:
return f"""\
if(check_has_torch_function(self_)) {{
return handle_torch_function(self_, "{name}");
}}
"""
else:
return ""
self_ = "self_" if method else "nullptr"
namespace = (
{
"torch": "THPVariableFunctionsModule",
"torch.nn": "THPNNVariableFunctionsModule",
"torch.fft": "THPFFTVariableFunctionsModule",
"torch.linalg": "THPLinalgVariableFunctionsModule",
"torch.sparse": "THPSparseVariableFunctionsModule",
"torch.special": "THPSpecialVariableFunctionsModule",
}[module]
if module
else "THPVariableClass"
)
return f"""\
if(_r.has_torch_function()) {{
return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}");
}}
"""
# handler for output/no-output overload pair
PY_VARIABLE_OUT = CodeTemplate(
"""\
if (_r.isNone(${out_idx})) {
${call_dispatch}
} else {
${call_dispatch_out}
}
"""
)
def emit_dispatch_case(
overload: PythonSignatureGroup,
namedtuple_typenames: Dict[str, str],
) -> str:
"""
Emit dispatch code for a single parsed signature. This corresponds to either
a single native function, or a pair that differ only in output params. In the
latter case, a single python signature is used for both and dispatching
switches on the presence/absence of passed output args.
"""
if overload.outplace is not None:
# dispatch output and no-output variants, branch on _r.isNone(<out_idx>)
return PY_VARIABLE_OUT.substitute(
out_idx=overload.signature.output_idx(),
call_dispatch=emit_single_dispatch(
overload.signature, overload.base, namedtuple_typenames
),
call_dispatch_out=emit_single_dispatch(
overload.signature, overload.outplace, namedtuple_typenames
),
)
else:
# no-output version only
return emit_single_dispatch(
overload.signature, overload.base, namedtuple_typenames
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Forward Declarations Codegen
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def forward_decls(
name: BaseOperatorName,
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
) -> Tuple[str, ...]:
if method:
return ()
pycname = get_pycname(name)
if is_noarg(overloads):
return (
f"""\
static PyObject * {pycname}(PyObject* self_, PyObject* args);
""",
)
else:
return (
f"""\
static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
""",
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Method Def (Binding Table Entry) Codegen
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def method_def(
name: BaseOperatorName,
module: Optional[str],
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
) -> str:
"""
Generate method def entry.
"""
pycname = get_pycname(name)
if is_noarg(overloads):
pyfunc_cast = ""
flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS"
else:
pyfunc_cast = "castPyCFunctionWithKeywords"
flags = "METH_VARARGS | METH_KEYWORDS"
if module == "torch":
flags += " | METH_STATIC"
if name.dunder_method:
# PyMethodDef entry for binary op, throws not implemented error
return f"""\
{{"{name}", {pyfunc_cast}(TypeError_to_NotImplemented_<{pycname}>), {flags}, NULL}},"""
else:
# PyMethodDef entry
return f"""\
{{"{name}", {pyfunc_cast}({pycname}), {flags}, NULL}},"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Overload Sorting and Grouping
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def group_overloads(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Sequence[PythonSignatureGroup]:
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
# first group by signature ignoring out arguments
for overload in overloads:
sig = overload.signature.signature_str(skip_outputs=True)
if overload.function.func.is_out_fn():
if sig in outplaces:
raise RuntimeError(
f"Found duplicated function definition:\n- {overload.function.func}.\n"
f"Existing definition:\n- {outplaces[sig].function.func}."
)
outplaces[sig] = overload
else:
if sig in bases:
raise RuntimeError(
f"Found duplicated function definition:\n- {overload.function.func}.\n"
f"Existing definition:\n- {bases[sig].function.func}."
)
bases[sig] = overload
for sig, out in outplaces.items():
if sig not in bases:
candidates: List[str] = []
for overload in overloads:
if (
str(overload.function.func.name.name)
== str(out.function.func.name.name)
and not overload.function.func.is_out_fn()
and not overload.signature.deprecated
):
candidates.append(
overload.signature.signature_str(skip_outputs=True)
)
out_sig = out.signature.signature_str()
raise RuntimeError(
f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. "
f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema "
"correctly in native_functions.yaml. We discovered the following candidate(s): \n"
+ "\n".join(f"- {candidate}" for candidate in candidates)
)
grouped: List[PythonSignatureGroup] = []
for sig, base in bases.items():
outplace = outplaces.get(sig)
grouped.append(
PythonSignatureGroup(
# prefer the signature with optional out=... arguments because it's the
# superset that can be used to parse input for both base and outplace.
signature=outplace.signature
if outplace is not None
else base.signature,
base=base.function,
outplace=outplace.function if outplace is not None else None,
)
)
return sort_overloads(grouped)
# This function declares a partial order on declarations, and sorts them according
# to its linear extension. This is necessary, because there's some ambiguity in the
# choice of overload, and we want a different order.
#
# See Note[Order of overloads matters]
#
# A few examples of ambiguous python signature pairs.
#
# All parameters have the same type, except one taking Tensor the other taking
# Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor
# object can be accepted as Scalar type parameter (see python_arg_parser.cpp).
# Therefore, same input arguments might be accepted by either python signature.
# We want to always parse the one taking Tensor first.
#
# bitwise_and(Tensor input, Tensor other, *, Tensor out=None)
# bitwise_and(Tensor input, Scalar other, *, Tensor out=None)
#
# If they have different number of parameters then they are not ambiguous - but
# the difference on output param can be ignored as it's optional.
#
# multiply(Tensor input, Tensor other, *, Tensor out=None)
# multiply(Tensor input, Scalar other)
#
# Both positional args and keyword-only args are considered together.
#
# subtract(Tensor other, *, Scalar alpha=1)
# subtract(Scalar other, Scalar alpha=1)
#
# A few ambiguous cases which it does NOT handle yet.
#
# If there is any difference in other parameters besides the Tensor/Scalar
# difference, then they are not considered ambiguous by this method anymore.
# However, the difference could be too trivial to disambiguate.
#
# foo(Tensor input, Scalar other, Scalar bar)
# foo(Tensor input, Tensor other, double bar)
#
# If they are taking different number of parameters then they are not considered
# ambiguous anymore, even if the difference is only on optional kwargs.
#
# foo(Scalar other, Scalar alpha=1)
# foo(Tensor other, *, Scalar alpha=1, Scalar beta=1)
#
def sort_overloads(
grouped_overloads: Sequence[PythonSignatureGroup],
) -> Sequence[PythonSignatureGroup]:
# NB: Smaller here means lower priority
def is_arg_smaller(t1: Type, t2: Type) -> bool:
return (
str(t1) == "Scalar"
and str(t2) == "Tensor"
or str(t1) == "Scalar?"
and str(t2) == "Tensor?"
or "Dimname" in str(t1)
and "Dimname" not in str(t2)
or
# In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been
# discussed why it is important to prioritize int/int? over int[]
str(t1) == "int[]"
and (str(t2) == "int" or str(t2) == "int?")
or
# TensorList currently throws an error during argument parsing, that's why it needs to be
# last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087
str(t1) == "Tensor[]"
and str(t2).find("[]") != -1
or
# Prioritize SymIntArrayRef overload over IntArrayRef
str(t1) == "int[]"
and str(t2) == "SymInt[]"
)
def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:
"""Returns True if s1 < s2 in the partial order."""
args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True)
if len(args1) != len(args2):
return False
# TODO: should use some canonical form instead of 'str(arg.type)' - see comments
# above. The old codegen used the deprecated 'dynamic_type(arg.type)', which
# ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'.
equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2))
smaller_or_equal = all(
str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type)
for arg1, arg2 in zip(args1, args2)
)
return smaller_or_equal and not equal
# First sort by signature
grouped_overloads = sorted(
grouped_overloads, key=lambda x: x.signature.signature_str()
)
# Construct the relation graph
larger_than: Dict[int, Set[int]] = defaultdict(set)
for i1, overload1 in enumerate(grouped_overloads):
for i2, overload2 in enumerate(grouped_overloads):
if is_smaller(overload1.signature, overload2.signature):
larger_than[i1].add(i2)
if not larger_than:
return list(grouped_overloads)
# Use a topological sort to sort overloads according to the partial order.
N = len(grouped_overloads)
sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N)))
for idx in range(N):
# The size of sorted_ids will grow to N eventually.
i = sorted_ids[idx]
for j in sorted(larger_than.keys()):
larger = larger_than[j]
larger.discard(i)
if not larger:
del larger_than[j]
sorted_ids.append(j)
return list(map(lambda x: grouped_overloads[x], sorted_ids))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Codegen API Integration
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def emit_single_dispatch(
ps: PythonSignature, f: NativeFunction, namedtuple_typenames: Dict[str, str]
) -> str:
"""
Emit dispatch code for a single native function.
"""
@with_native_function
def go(f: NativeFunction) -> str:
# header comments
deprecated = "[deprecated] " if ps.deprecated else ""
schema_comment = f"// {deprecated}aten::{f.func}"
# dispatch lambda signature
name = cpp.name(f.func)
lambda_formals = ", ".join(
map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))
)
lambda_return = dispatch_lambda_return_str(f)
# dispatch lambda body
dispatch_callee = cpp_dispatch_target(f)
dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))
# from arg parser outputs to dispatch lambda arguments
parser_outputs = arg_parser_output_exprs(ps, f)
lambda_arg_exprs = dispatch_lambda_exprs(ps, f)
inits = "\n".join(lambda_arg_exprs.inits)
lambda_args = ", ".join(lambda_arg_exprs.exprs)
# scatter fields
# TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky
# solution for enabling the 'requires_grad' argument for tensor methods
# new_full, new_empty, and new_zeros. A much better but more difficult to
# implement solution involves refactoring according to Ed's description here:
# https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589
need_set_requires_grad = ps.tensor_options_args and (
not has_tensor_options(f)
or (ps.method and ("requires_grad" in parser_outputs))
)
set_requires_grad = (
f'.set_requires_grad({parser_outputs["requires_grad"].expr})'
if need_set_requires_grad
else ""
)
if lambda_return == "void":
return f"""\
{schema_comment}
{inits}
auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
pybind11::gil_scoped_release no_gil;
{dispatch_callee}({dispatch_args});
}};
dispatch_{name}({lambda_args}){set_requires_grad};
Py_RETURN_NONE;
"""
else:
typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f))
namedtuple_typeref = f"{typename}, " if typename is not None else ""
return f"""\
{schema_comment}
{inits}
auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
pybind11::gil_scoped_release no_gil;
return {dispatch_callee}({dispatch_args});
}};
return wrap({namedtuple_typeref}dispatch_{name}({lambda_args}){set_requires_grad});
"""
return go(f)