mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47087 Fixes #33547 Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D24633077 Pulled By: ejguan fbshipit-source-id: d1baca84cb3bc415cced9b696103f17131e1e4c7
969 lines
34 KiB
Python
969 lines
34 KiB
Python
# Generates Python bindings for ATen functions
|
|
#
|
|
# The bindings are generated as methods on python_variable or functions on the
|
|
# torch._C._nn. torch._C._fft, or torch._C._linalg objects.
|
|
#
|
|
|
|
# Code tries to stick to the following rules:
|
|
#
|
|
# - templates should be colocated with the functions that use them.
|
|
# no templates are currently shared between functions, but if that
|
|
# happens, maybe put the template with the first one
|
|
#
|
|
# - don't use environment dictionaries when calling template.substitute().
|
|
# pass named arguments directly for everything, otherwise it's much too
|
|
# hard to track what's actually being used and by who
|
|
#
|
|
# - colocate any new hacks/adjustments with existing ones of the same kind.
|
|
# ideally in a data structure rather than code if possible. See e.g.
|
|
# SCHEMA_DEFAULT_CONVERSION_HACKS, etc.
|
|
#
|
|
# - similarly, conversions from one format to another should ideally happen
|
|
# all at once in a single place.
|
|
#
|
|
# - no nontrivial nested functions. couple-liners are ok but please no more.
|
|
# especially avoid functions that read/write outer variables defined far away.
|
|
#
|
|
# - raise RuntimeError instead of asserting, and put as much
|
|
# information as is available into the message. I.e. no need to
|
|
# plumb in new params whose only purpose is to fill out an error
|
|
# message, but use what's there
|
|
#
|
|
|
|
from collections import defaultdict
|
|
import re
|
|
from .gen_variable_type import should_trace
|
|
from .utils import write, is_tensor_method
|
|
|
|
from tools.codegen.code_template import CodeTemplate
|
|
from tools.codegen.api.python import *
|
|
from tools.codegen.gen import cpp_string, with_native_function
|
|
from tools.codegen.model import *
|
|
|
|
from typing import Dict, Optional, List, Any
|
|
|
|
#
|
|
# declarations blocklist
|
|
# We skip codegen for these functions, for various reasons.
|
|
# Future PRs will categorize this list and eliminate or hoist
|
|
# them out of eager-only codegen.
|
|
# See https://github.com/pytorch/pytorch/issues/30788
|
|
#
|
|
|
|
# These functions require manual Python bindings or are not exposed to Python
|
|
SKIP_PYTHON_BINDINGS = [
|
|
'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride',
|
|
'.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
|
|
'.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
|
|
'_arange.*', '_range.*', '_linspace.*', '_logspace.*',
|
|
'_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*', '_sparse_dense_add_out',
|
|
'index', 'unique_dim_consecutive',
|
|
'_indexCopy_', '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
|
|
'_th_.*', '_thnn_.*',
|
|
'arange.*', 'range.*', '_solve.*', '_inverse.*',
|
|
'full(_out)?',
|
|
'_cholesky.*', '_triangular_solve.*', '_qr.*', '_symeig.*', '_svd.*',
|
|
'slice', 'randint(_out)?',
|
|
'item', '_local_scalar_dense', 'to',
|
|
'copy_sparse_to_sparse_', 'copy_',
|
|
'numpy_T', # this needs to be an attribute in Python, not a function
|
|
'nonzero(_(out|numpy))?',
|
|
'set_quantizer_', # return types not supported yet
|
|
'set_data',
|
|
'.*_overrideable', # overrideable functions for backend extension
|
|
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_'
|
|
]
|
|
|
|
# These function signatures are not exposed to Python. Note that this signature
|
|
# list does not support regex.
|
|
SKIP_PYTHON_BINDINGS_SIGNATURES = [
|
|
'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)',
|
|
'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)',
|
|
'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)',
|
|
'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
|
|
]
|
|
|
|
NATIVE_NAMESPACE_MAPPING = {
|
|
"torch": "THPVariableFunctionsModule",
|
|
"torch.nn": "THPNNVariableFunctionsModule",
|
|
"torch.fft": "THPFFTVariableFunctionsModule",
|
|
"torch.linalg": "THPLinalgVariableFunctionsModule",
|
|
}
|
|
|
|
def should_generate_python_binding(declaration):
|
|
name = declaration['name']
|
|
for pattern in SKIP_PYTHON_BINDINGS:
|
|
if re.match('^' + pattern + '$', name):
|
|
return False
|
|
|
|
simple_types = [arg['simple_type'] for arg in declaration['arguments']]
|
|
signature = '{}({})'.format(name, ', '.join(simple_types))
|
|
for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
|
|
if pattern == signature:
|
|
return False
|
|
|
|
return True
|
|
|
|
#
|
|
# top-level codegen functions, called from gen_autograd
|
|
#
|
|
|
|
def get_py_variable_methods(declarations):
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as methods on Tensor.
|
|
"""
|
|
def should_bind(declaration):
|
|
return (should_generate_python_binding(declaration) and
|
|
not is_nn_module_function(declaration) and
|
|
is_tensor_method(declaration))
|
|
|
|
return group_declarations_by_op_name([d for d in declarations if should_bind(d)])
|
|
|
|
|
|
def gen_py_variable_methods(out, declarations, template_path):
|
|
"""
|
|
Generate Tensor methods.
|
|
"""
|
|
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
|
|
|
|
py_variable_methods = get_py_variable_methods(declarations)
|
|
|
|
env = create_python_bindings(py_variable_methods, is_python_method=True, module=None)
|
|
|
|
write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
|
|
|
|
|
|
def get_py_nn_functions(declarations):
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as functions in the "nn" module.
|
|
"""
|
|
def should_bind(declaration):
|
|
return (should_generate_python_binding(declaration) and
|
|
is_nn_module_function(declaration))
|
|
|
|
return group_declarations_by_op_name([d for d in declarations if should_bind(d)])
|
|
|
|
|
|
def gen_py_nn_functions(out, declarations, template_path):
|
|
"""
|
|
Generate functions in the "nn" module.
|
|
"""
|
|
PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
|
|
|
|
py_nn_functions = get_py_nn_functions(declarations)
|
|
|
|
env = create_python_bindings(py_nn_functions, is_python_method=False, module="torch.nn")
|
|
|
|
write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
|
|
|
|
|
|
def get_py_fft_functions(declarations):
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as functions in the "fft" module.
|
|
"""
|
|
def should_bind(declaration):
|
|
return (should_generate_python_binding(declaration) and
|
|
is_fft_module_function(declaration))
|
|
|
|
return group_declarations_by_op_name([d for d in declarations if should_bind(d)])
|
|
|
|
|
|
def gen_py_fft_functions(out, declarations, template_path):
|
|
"""
|
|
Generate functions in the "fft" module.
|
|
"""
|
|
PY_FFT_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_fft_functions.cpp')
|
|
|
|
py_fft_functions = get_py_fft_functions(declarations)
|
|
|
|
env = create_python_bindings(py_fft_functions, is_python_method=False, module="torch.fft")
|
|
|
|
write(out, 'python_fft_functions.cpp', PY_FFT_FUNCTIONS_CPP, env)
|
|
|
|
def get_py_linalg_functions(declarations):
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as functions in the "linalg" module.
|
|
"""
|
|
def should_bind(declaration):
|
|
return (should_generate_python_binding(declaration) and
|
|
is_linalg_module_function(declaration))
|
|
|
|
return group_declarations_by_op_name([d for d in declarations if should_bind(d)])
|
|
|
|
|
|
def gen_py_linalg_functions(out, declarations, template_path):
|
|
"""
|
|
Generate functions in the "linalg" module.
|
|
"""
|
|
PY_LINALG_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_linalg_functions.cpp')
|
|
|
|
py_linalg_functions = get_py_linalg_functions(declarations)
|
|
|
|
env = create_python_bindings(py_linalg_functions, is_python_method=False, module="torch.linalg")
|
|
|
|
write(out, 'python_linalg_functions.cpp', PY_LINALG_FUNCTIONS_CPP, env)
|
|
|
|
|
|
def get_py_torch_functions(declarations):
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as functions in the "torch" module.
|
|
"""
|
|
def should_bind(declaration):
|
|
return (should_generate_python_binding(declaration) and
|
|
not is_nn_module_function(declaration) and
|
|
not is_fft_module_function(declaration) and
|
|
not is_linalg_module_function(declaration) and
|
|
is_torch_function(declaration))
|
|
|
|
return group_declarations_by_op_name([d for d in declarations if should_bind(d)])
|
|
|
|
|
|
def gen_py_torch_functions(out, declarations, template_path):
|
|
"""
|
|
Generate functions in the "torch" module.
|
|
"""
|
|
PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
|
|
|
|
py_torch_functions = get_py_torch_functions(declarations)
|
|
|
|
env = create_python_bindings(py_torch_functions, is_python_method=False, module="torch")
|
|
|
|
write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
|
|
|
|
|
|
def group_declarations_by_op_name(declarations):
|
|
groups = defaultdict(list)
|
|
for d in declarations:
|
|
groups[op_name(d)].append(d)
|
|
return groups
|
|
|
|
|
|
def create_python_bindings(python_functions, is_python_method, module):
|
|
"""Generates Python bindings to ATen functions"""
|
|
py_methods = []
|
|
py_method_defs = []
|
|
py_forwards = []
|
|
|
|
for name in sorted(python_functions.keys()):
|
|
overload_decls = python_functions[name]
|
|
|
|
for declaration in overload_decls:
|
|
# TODO: change all methods to directly process python signatures instead of decls.
|
|
declaration['python_signature'] = decl_to_python_signature(declaration, method=is_python_method)
|
|
declaration['native_function'] = decl_to_native_function(declaration)
|
|
|
|
py_methods.append(method_impl(name, overload_decls, is_python_method, module))
|
|
py_method_defs.append(method_def(name, overload_decls, is_python_method, module))
|
|
py_forwards.extend(forward_decls(name, overload_decls, is_python_method, module))
|
|
|
|
return {
|
|
'py_forwards': py_forwards,
|
|
'py_methods': py_methods,
|
|
'py_method_defs': py_method_defs,
|
|
}
|
|
|
|
|
|
# handler for output/no-output overload pair
|
|
# (plugged into PY_VARIABLE_CASE as ${call_dispatch})
|
|
PY_VARIABLE_OUT = CodeTemplate("""\
|
|
if (_r.isNone(${out_idx})) {
|
|
${call_dispatch}
|
|
} else {
|
|
${call_dispatch_out}
|
|
}
|
|
""")
|
|
|
|
# handler for a single parsed signature - may be a single overload or
|
|
# a pair of overloads that whose signatures only differ in output params
|
|
PY_VARIABLE_CASE = CodeTemplate("""\
|
|
case ${i}: {
|
|
${body}
|
|
}
|
|
""")
|
|
|
|
|
|
def emit_dispatch_case(i, dictionary, is_python_method):
|
|
"""
|
|
Emit dispatch code for a single parsed signature. This corresponds to either
|
|
a single overload, or a pair that differ only in output params. In the latter
|
|
case, a single signature is used for both and dispatching switches on the
|
|
presence/absence of passed output args.
|
|
- i: this signature's position in generated binding's signature list if number of
|
|
signatures > 1, otherwise None
|
|
- dictionary: contains a no-output overload declaration under 'base', and optionally
|
|
a second overload with outputs under 'out'
|
|
- true if we're generating a python method, in which case self is not parsed but
|
|
passed directly
|
|
"""
|
|
base_decl = dictionary['base']
|
|
python_sig = base_decl['python_signature']
|
|
|
|
if 'out' in dictionary:
|
|
# dispatch to output or no-output variant based on arg test
|
|
out_decl = dictionary['out']
|
|
python_sig = out_decl['python_signature'] # prefer output variant
|
|
|
|
out_idx = get_python_output_index(out_decl)
|
|
|
|
call_dispatch = emit_single_dispatch(python_sig, base_decl, is_python_method)
|
|
call_dispatch_out = emit_single_dispatch(python_sig, out_decl, is_python_method)
|
|
|
|
# dispatch output and no-output variants, branch on _r.isNone(<out_idx>)
|
|
body = PY_VARIABLE_OUT.substitute(
|
|
out_idx=out_idx,
|
|
call_dispatch=call_dispatch,
|
|
call_dispatch_out=call_dispatch_out,
|
|
)
|
|
else:
|
|
# no-output version only
|
|
body = emit_single_dispatch(python_sig, base_decl, is_python_method)
|
|
|
|
if i is not None:
|
|
# generate case for ith overload
|
|
return PY_VARIABLE_CASE.substitute(i=i, body=body)
|
|
else:
|
|
# only one overload, omit case wrapper
|
|
return body
|
|
|
|
#
|
|
# named tuple codegen
|
|
#
|
|
|
|
def namedtuple_fieldnames(declaration):
|
|
returns = declaration['returns']
|
|
if len(returns) <= 1 or all(['field_name' not in x for x in returns]):
|
|
return []
|
|
else:
|
|
def get_field_name(x):
|
|
# See Note [field_name versus name]
|
|
if 'field_name' not in x:
|
|
# When building on Windows, `PyStructSequence_UnnamedField` could not be
|
|
# resolved by the linker for some reason, which cause error in building:
|
|
#
|
|
# python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
|
|
# PyStructSequence_UnnamedField
|
|
#
|
|
# Thus, at this point in time, we do not support unnamed
|
|
# fields in namedtuple; you must either name all fields,
|
|
# or none of them.
|
|
raise ValueError("Unnamed field is not supported by codegen")
|
|
else:
|
|
return x['field_name']
|
|
return [get_field_name(x) for x in returns]
|
|
|
|
PY_NAMEDTUPLE_FIELDSDEF = CodeTemplate("""\
|
|
static PyStructSequence_Field ${fieldsname}[] = { ${fields,} {nullptr} };
|
|
""")
|
|
|
|
PY_NAMEDTUPLE_TYPEDEF = CodeTemplate("""\
|
|
static PyTypeObject ${typename};
|
|
static bool ${typename}_initialized = false;
|
|
if (!${typename}_initialized) {
|
|
${typename}_initialized = true;
|
|
static PyStructSequence_Desc desc = { "torch.return_types.${name}", nullptr, ${fieldsname}, ${size} };
|
|
PyStructSequence_InitType(&${typename}, &desc);
|
|
${typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
|
|
}
|
|
""")
|
|
|
|
|
|
def emit_namedtuple_typedefs(declarations):
|
|
"""
|
|
Generate block of named tuple type def inits, and add typeref snippets
|
|
to declarations that use them
|
|
"""
|
|
flddefnames = {} # map from unique field name lists to field def name
|
|
flddefs = [] # field def declarations
|
|
typenames = {} # map from unique name + field name lists to typedef name
|
|
typedefs = [] # typedef declarations and init code
|
|
|
|
for decl in declarations:
|
|
fieldnames = namedtuple_fieldnames(decl)
|
|
if fieldnames == []:
|
|
decl['namedtuple_typeref'] = ''
|
|
continue
|
|
|
|
fn_key = '_'.join(fieldnames)
|
|
fieldsname = flddefnames.get(fn_key)
|
|
if fieldsname is None:
|
|
fieldsname = 'NamedTuple_fields{}'.format('' if flddefs == [] else len(fielddefs))
|
|
fields = ['{{"{}", ""}}'.format(fn) for fn in fieldnames]
|
|
fieldsdef = PY_NAMEDTUPLE_FIELDSDEF.substitute(
|
|
fieldsname=fieldsname,
|
|
fields=fields
|
|
)
|
|
flddefnames[fn_key] = fieldsname
|
|
flddefs.append(fieldsdef)
|
|
|
|
name = decl['name']
|
|
key = '{}_{}'.format(name, '_'.join(fieldnames))
|
|
typename = typenames.get(key)
|
|
if typename is None:
|
|
typename = 'NamedTuple{}'.format('' if typedefs == [] else len(typedefs))
|
|
typedef = PY_NAMEDTUPLE_TYPEDEF.substitute(
|
|
name=name,
|
|
typename=typename,
|
|
size=len(fieldnames),
|
|
fieldsname=fieldsname
|
|
)
|
|
typenames[key] = typename
|
|
typedefs.append(typedef)
|
|
|
|
decl['namedtuple_typeref'] = '&{}, '.format(typename)
|
|
|
|
return flddefs + typedefs
|
|
|
|
#
|
|
# method impl codegen
|
|
#
|
|
|
|
def get_pycname(name):
|
|
return 'THPVariable_{}'.format(name)
|
|
|
|
|
|
def is_noarg_binding(overloads):
|
|
return len(overloads) == 1 and get_python_argc(overloads[0]) == 0
|
|
|
|
|
|
# python binding for all overloads of a particular function/method
|
|
PY_VARIABLE_METHOD_VARARGS = CodeTemplate(r"""\
|
|
// ${name}
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
|
|
{
|
|
${method_header}
|
|
static PythonArgParser parser({
|
|
${signatures}
|
|
}, /*traceable=*/${traceable});
|
|
|
|
ParsedArgs<${max_args}> parsed_args;
|
|
auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
|
|
${check_has_torch_function}
|
|
switch (_r.idx) {
|
|
${dispatch}
|
|
}
|
|
${method_footer}
|
|
}
|
|
|
|
""")
|
|
|
|
# python binding for single-overload function/method
|
|
PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate("""\
|
|
// ${name}
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
|
|
{
|
|
${method_header}
|
|
static PythonArgParser parser({
|
|
${signatures}
|
|
}, /*traceable=*/${traceable});
|
|
|
|
ParsedArgs<${max_args}> parsed_args;
|
|
auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
|
|
${check_has_torch_function}
|
|
${dispatch}
|
|
${method_footer}
|
|
}
|
|
|
|
""")
|
|
|
|
# python binding for a method with no args, shortcuts parsing
|
|
PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
|
|
// ${name}
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
|
|
{
|
|
${method_header}
|
|
${check_has_torch_function}
|
|
${dispatch}
|
|
${method_footer}
|
|
}
|
|
|
|
""")
|
|
|
|
TORCH_FUNCTION_CHECK = CodeTemplate("""\
|
|
if(_r.has_torch_function()) {
|
|
return handle_torch_function(_r, ${self_}, args, kwargs, ${namespace}, ${modulename});
|
|
}
|
|
""")
|
|
|
|
TORCH_FUNCTION_CHECK_NOARGS = CodeTemplate("""\
|
|
if(check_has_torch_function(self_)) {
|
|
return handle_torch_function(self_, ${name});
|
|
}
|
|
""")
|
|
|
|
# NOTE: we type the unpacked self as Tensor not Variable to avoid return type
|
|
# discrepancies on method resolution (e.g. Variable::detach_ returns void
|
|
# rather than Tensor &)
|
|
UNPACK_SELF = "Tensor& self = reinterpret_cast<THPVariable*>(self_)->cdata;"
|
|
|
|
|
|
def method_impl(name, declarations, is_python_method, module):
|
|
"""
|
|
Generate a python binding for all overloads of an op.
|
|
"""
|
|
pycname = get_pycname(name)
|
|
|
|
method_header = ['HANDLE_TH_ERRORS']
|
|
method_header += emit_namedtuple_typedefs(declarations)
|
|
method_header += [UNPACK_SELF] if is_python_method else []
|
|
|
|
method_footer = ['END_HANDLE_TH_ERRORS']
|
|
|
|
check_has_torch_function = TORCH_FUNCTION_CHECK_NOARGS.substitute(
|
|
name='"' + name + '"',
|
|
) if is_python_method else ''
|
|
|
|
# emit dispatch
|
|
if is_noarg_binding(declarations):
|
|
python_sig = declarations[0]['python_signature']
|
|
dispatch = emit_single_dispatch(python_sig, declarations[0], is_python_method)
|
|
return PY_VARIABLE_METHOD_NOARGS.substitute(
|
|
name=name,
|
|
pycname=pycname,
|
|
method_header=method_header,
|
|
dispatch=dispatch,
|
|
method_footer=method_footer,
|
|
check_has_torch_function=check_has_torch_function,
|
|
)
|
|
|
|
method_footer = ['Py_RETURN_NONE;'] + method_footer
|
|
|
|
grouped = group_overloads(declarations, is_python_method)
|
|
is_singleton = len(grouped) == 1
|
|
|
|
signatures = []
|
|
dispatch = []
|
|
for i, dictionary in enumerate(grouped):
|
|
signature = dictionary['signature']
|
|
signatures.append(f'{cpp_string(str(signature))},')
|
|
overload_index = i if not is_singleton else None
|
|
dispatch.append(emit_dispatch_case(overload_index, dictionary, is_python_method))
|
|
|
|
if is_singleton:
|
|
template = PY_VARIABLE_METHOD_VARARGS_SINGLETON
|
|
else:
|
|
template = PY_VARIABLE_METHOD_VARARGS
|
|
|
|
if module:
|
|
check_has_torch_function = TORCH_FUNCTION_CHECK.substitute(
|
|
namespace=NATIVE_NAMESPACE_MAPPING[module],
|
|
modulename='"' + module + '"',
|
|
self_="self_" if is_python_method else "nullptr",
|
|
)
|
|
else:
|
|
check_has_torch_function = TORCH_FUNCTION_CHECK.substitute(
|
|
namespace="THPVariableClass",
|
|
modulename='"torch.Tensor"',
|
|
self_="self_" if is_python_method else "nullptr",
|
|
)
|
|
|
|
max_args = max([get_python_argc(decl) for decl in declarations])
|
|
traceable = 'true' if all(should_trace(d) for d in declarations) else 'false'
|
|
|
|
return template.substitute(
|
|
name=name,
|
|
pycname=pycname,
|
|
method_header=method_header,
|
|
max_args=max_args,
|
|
signatures=signatures,
|
|
traceable=traceable,
|
|
check_has_torch_function=check_has_torch_function,
|
|
dispatch=dispatch,
|
|
method_footer=method_footer,
|
|
self_="self_" if is_python_method else "nullptr",
|
|
)
|
|
|
|
|
|
#
|
|
# forward declarations
|
|
#
|
|
|
|
PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION = CodeTemplate("""\
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
|
|
""")
|
|
|
|
PY_VARIABLE_FUNCTION_NOARGS_FORWARD_DECLARATION = CodeTemplate("""\
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args);
|
|
""")
|
|
|
|
|
|
def forward_decls(name, declarations, is_python_method, module):
|
|
if is_python_method:
|
|
return []
|
|
|
|
if is_noarg_binding(declarations):
|
|
template = PY_VARIABLE_FUNCTION_NOARGS_FORWARD_DECLARATION
|
|
else:
|
|
template = PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION
|
|
|
|
pycname = get_pycname(name)
|
|
return [template.substitute(pycname=pycname)]
|
|
|
|
|
|
#
|
|
# method def (binding table entry) codegen
|
|
#
|
|
|
|
# Python binary operator dunder methods
|
|
BINARY_OP_NAMES = [
|
|
'__lt__', '__le__',
|
|
'__gt__', '__ge__',
|
|
'__eq__', '__ne__',
|
|
|
|
'__add__', '__radd__', '__iadd__',
|
|
'__sub__', '__rsub__', '__isub__',
|
|
'__mul__', '__rmul__', '__imul__',
|
|
'__matmul__', '__rmatmul__', '__imatmul__',
|
|
'__truediv__', '__rtruediv__', '__itruediv__',
|
|
'__floordiv__', '__rfloordiv__', '__ifloordiv__',
|
|
'__mod__', '__rmod__', '__imod__',
|
|
'__divmod__', '__rdivmod__', '__idivmod__',
|
|
'__pow__', '__rpow__', '__ipow__',
|
|
'__lshift__', '__rlshift__', '__ilshift__',
|
|
'__rshift__', '__rrshift__', '__irshift__',
|
|
'__and__', '__rand__', '__iand__',
|
|
'__xor__', '__rxor__', '__ixor__',
|
|
'__or__', '__ror__', '__ior__',
|
|
]
|
|
|
|
# PyMethodDef entry for binary op, throws not implemented error
|
|
PY_VARIABLE_METHOD_BINOP_DEF = CodeTemplate("""\
|
|
{"${name}", ${pyfunc_cast}(TypeError_to_NotImplemented_<${pycname}>), ${flags}, NULL},""")
|
|
|
|
# PyMethodDef entry
|
|
PY_VARIABLE_METHOD_DEF = CodeTemplate("""\
|
|
{"${name}", ${pyfunc_cast}(${pycname}), ${flags}, NULL},""")
|
|
|
|
|
|
def method_def(name, declarations, is_python_method, module):
|
|
"""
|
|
Generate method def entry.
|
|
"""
|
|
pycname = get_pycname(name)
|
|
|
|
if is_noarg_binding(declarations):
|
|
pyfunc_cast = ''
|
|
flags = 'METH_NOARGS' if is_python_method else 'METH_VARARGS | METH_KEYWORDS'
|
|
else:
|
|
pyfunc_cast = 'castPyCFunctionWithKeywords'
|
|
flags = 'METH_VARARGS | METH_KEYWORDS'
|
|
|
|
if module == "torch":
|
|
flags += ' | METH_STATIC'
|
|
|
|
if name in BINARY_OP_NAMES:
|
|
def_template = PY_VARIABLE_METHOD_BINOP_DEF
|
|
else:
|
|
def_template = PY_VARIABLE_METHOD_DEF
|
|
|
|
return def_template.substitute(
|
|
name=name,
|
|
pycname=pycname,
|
|
pyfunc_cast=pyfunc_cast,
|
|
flags=flags,
|
|
)
|
|
|
|
#
|
|
# overload sorting and grouping
|
|
#
|
|
|
|
def group_overloads(declarations, is_python_method):
|
|
"""Returns a list of dictionaries containing the optional keys:
|
|
|
|
"base": the regular ATen declaration (e.g. conv2d)
|
|
"out": the out variant (e.g. conv2d_out)
|
|
"signature": the signature used for Python argument parsing
|
|
|
|
Note that we merge pairs of declarations with signatures that
|
|
are equivalent mod output arguments, and use a single entry in
|
|
the python_arg_parser sig list for both (output arguments become
|
|
optional)
|
|
"""
|
|
grouped = defaultdict(dict)
|
|
|
|
# first group by signature ignoring out arguments
|
|
for declaration in declarations:
|
|
signature = get_python_signature(declaration, is_python_method, skip_outputs=True)
|
|
v = grouped[signature]
|
|
if declaration['name'].endswith('_out'):
|
|
v['out'] = declaration
|
|
# prefer the signature with optional out=... arguments
|
|
v['signature'] = get_python_signature(declaration, is_python_method)
|
|
else:
|
|
v['base'] = declaration
|
|
if 'signature' not in v:
|
|
v['signature'] = signature
|
|
|
|
result = []
|
|
for x, dictionary in sorted(grouped.items()):
|
|
if 'base' not in dictionary:
|
|
candidates = []
|
|
non_out_name = dictionary['out']['operator_name']
|
|
for declaration in declarations:
|
|
if declaration['name'] == non_out_name and not declaration['deprecated']:
|
|
signature = get_python_signature(declaration, is_python_method, skip_outputs=True)
|
|
candidates.append(signature)
|
|
raise RuntimeError(
|
|
"While identifying overloads, we found an out schema {} without a corresponding non-out variant. "
|
|
"We expected the non-out variant to have schema: \n- {}\nPlease check that you spelled the schema "
|
|
"correctly in native_functions.yaml. We discovered the following candidate(s): \n"
|
|
.format(dictionary['signature'], x) + "\n".join("- {}".format(candidate) for candidate in candidates))
|
|
result.append(dictionary)
|
|
return sort_declarations(result)
|
|
|
|
|
|
# This function declares a partial order on declarations, and sorts them according
|
|
# to its linear extension. This is necessary, because there's some ambiguity in the
|
|
# choice of overload, and we want a different order.
|
|
#
|
|
# See Note[Order of overloads matters]
|
|
def sort_declarations(grouped_decls):
|
|
|
|
def dynamic_type(arg):
|
|
return arg['dynamic_type']
|
|
|
|
def is_coord_smaller(arg1, arg2):
|
|
return dynamic_type(arg1) == 'Scalar' and arg2['dynamic_type'] == 'Tensor'
|
|
|
|
def is_smaller(d1, d2):
|
|
"""Returns True if d1 < d2 in the partial order."""
|
|
args1, args2 = d1['base']['arguments'], d2['base']['arguments']
|
|
if len(args1) != len(args2):
|
|
return False
|
|
any_smaller = any(is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2))
|
|
all_smaller_or_equal = all(dynamic_type(arg1) == dynamic_type(arg2) or
|
|
is_coord_smaller(arg1, arg2)
|
|
for arg1, arg2 in zip(args1, args2))
|
|
return any_smaller and all_smaller_or_equal
|
|
|
|
# Construct the relation graph
|
|
larger_than = defaultdict(set)
|
|
for i1, decl1 in enumerate(grouped_decls):
|
|
for i2, decl2 in enumerate(grouped_decls):
|
|
if is_smaller(decl1, decl2):
|
|
larger_than[i1].add(i2)
|
|
|
|
if not larger_than:
|
|
return grouped_decls
|
|
|
|
# Use a topological sort to sort decls according to the partial order.
|
|
sorted_deps = [(i, decl) for i, decl in enumerate(grouped_decls)
|
|
if i not in larger_than]
|
|
for i, decl in sorted_deps:
|
|
for i2 in sorted(larger_than.keys()):
|
|
larger = larger_than[i2]
|
|
larger.discard(i)
|
|
if not larger:
|
|
del larger_than[i2]
|
|
sorted_deps.append((i2, grouped_decls[i2]))
|
|
|
|
return [decl for i, decl in sorted_deps]
|
|
|
|
|
|
#
|
|
# python signature codegen
|
|
#
|
|
|
|
def get_python_signature(declaration, is_python_method, skip_outputs=False):
|
|
return declaration['python_signature'].signature_str(skip_outputs=skip_outputs)
|
|
|
|
|
|
#
|
|
# op args to python parsed args transform
|
|
#
|
|
|
|
def get_python_argc(decl):
|
|
return len(decl['python_signature'].arguments())
|
|
|
|
|
|
def get_python_output_index(decl):
|
|
ps: PythonSignature = decl['python_signature']
|
|
return len(ps.input_args) + len(ps.input_kwargs)
|
|
|
|
|
|
#
|
|
# declaration derived props, utils, etc.
|
|
# declarations are dicts loaded from Declarations.yaml,
|
|
# passed to our codegen methods by callers in gen_autograd
|
|
#
|
|
|
|
def is_output(arg):
|
|
return arg.get('output', False)
|
|
|
|
|
|
def has_outputs(declaration):
|
|
return any([is_output(arg) for arg in declaration['arguments']])
|
|
|
|
|
|
def is_torch_function(declaration):
|
|
return 'namespace' in declaration['method_of']
|
|
|
|
|
|
def is_nn_module_function(declaration):
|
|
return declaration.get('python_module') == 'nn'
|
|
|
|
|
|
def is_fft_module_function(declaration):
|
|
return declaration.get('python_module') == 'fft'
|
|
|
|
|
|
def is_linalg_module_function(declaration):
|
|
return declaration.get('python_module') == 'linalg'
|
|
|
|
|
|
def op_name(declaration):
|
|
name = declaration['name']
|
|
if has_outputs(declaration):
|
|
if not name.endswith("_out"):
|
|
raise RuntimeError(
|
|
'{} has output params, expecting name ending with \'_out\''.
|
|
format(declaration['name']))
|
|
return name[:-4]
|
|
else:
|
|
if name.endswith("_out"):
|
|
raise RuntimeError(
|
|
'{}: name ends with \'_out\', expecting output params'.
|
|
format(declaration['name']))
|
|
return name
|
|
|
|
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
#
|
|
# Codegen API Integration
|
|
#
|
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
|
|
|
# These helper functions allow us to call the new codegen API from the
|
|
# old codegen script (which operates on Declarations.yaml).
|
|
|
|
# TODO: remove all these HACKs after migration is completed!
|
|
|
|
# function schema str -> NativeFunction
|
|
NF_TABLE: Optional[Dict[str, NativeFunction]] = None
|
|
|
|
def init(native_yaml_path: str) -> None:
|
|
from tools.codegen.gen import parse_native_yaml
|
|
global NF_TABLE
|
|
NF_TABLE = {str(f.func): f for f in parse_native_yaml(native_yaml_path)}
|
|
|
|
# Multiple decl entries can map to the same native function (because of deprecated decl).
|
|
def decl_to_native_function(decl: Dict[str, Any]) -> NativeFunction:
|
|
assert NF_TABLE is not None, 'need to initialize codegen.api.python with init()'
|
|
function_schema_str = decl['schema_string']
|
|
assert function_schema_str.startswith('aten::'), f'unknown namespace: {function_schema_str}'
|
|
function_schema_str = function_schema_str[len('aten::'):]
|
|
assert function_schema_str in NF_TABLE, f'cannot find func: {function_schema_str}'
|
|
return NF_TABLE[function_schema_str]
|
|
|
|
# Each decl entry has unique python signature.
|
|
def decl_to_python_signature(decl: Dict[str, Any], *, method: bool) -> PythonSignature:
|
|
f = decl_to_native_function(decl)
|
|
|
|
@with_native_function
|
|
def go(f: NativeFunction) -> PythonSignature:
|
|
return signature(f, method=method)
|
|
|
|
python_sig = go(f)
|
|
|
|
if decl.get('deprecated', False):
|
|
# TODO: directly load 'deprecated.yaml'.
|
|
# deprecated.yaml doesn't have complete type information, we need
|
|
# leverage the source signature (to which it delegates the call).
|
|
# Deprecated signature might reorder input_args and input_kwargs,
|
|
# but never changes output_args nor python_binding_args (if any?),
|
|
# so here we only look into these two types of args.
|
|
src_args: Dict[str, PythonArgument] = {a.name: PythonArgument(
|
|
name=a.name,
|
|
type=a.type,
|
|
default=None,
|
|
default_init=None,
|
|
) for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs)}
|
|
args: List[Dict[str, Any]] = decl['arguments']
|
|
input_arg_names: List[str] = \
|
|
list(str(a['name']) for a in args if not a['kwarg_only'] and not a['output'])
|
|
input_kwarg_names: List[str] = \
|
|
list(str(a['name']) for a in args if a['kwarg_only'] and not a['output'])
|
|
python_sig = PythonSignatureDeprecated(
|
|
name=python_sig.name,
|
|
input_args=tuple(src_args[n] for n in input_arg_names if not method or n != 'self'),
|
|
input_kwargs=tuple(src_args[n] for n in input_kwarg_names),
|
|
output_args=python_sig.output_args,
|
|
tensor_options_args=python_sig.tensor_options_args,
|
|
method=python_sig.method,
|
|
deprecated_args_names=tuple(str(a['name']) for a in args),
|
|
deprecated_args_exprs=tuple(decl.get('call_args')),
|
|
)
|
|
return python_sig
|
|
|
|
|
|
def emit_single_dispatch(ps: PythonSignature, decl: Dict[str, Any], method: bool) -> str:
|
|
"""
|
|
Emit dispatch code for a single declared overload.
|
|
"""
|
|
f = decl['native_function']
|
|
|
|
@with_native_function
|
|
def go(f: NativeFunction) -> str:
|
|
# header comments
|
|
deprecated = '[deprecated] ' if ps.deprecated else ''
|
|
schema_comment = f'// {deprecated}aten::{f.func}'
|
|
|
|
# dispatch lambda signature
|
|
name = decl['name']
|
|
lambda_formals = ', '.join(map(lambda a: f"{a.type_str} {a.name}",
|
|
dispatch_lambda_args(ps, f, method=method)))
|
|
lambda_return = dispatch_lambda_return_str(f)
|
|
|
|
# dispatch lambda body
|
|
dispatch_callee = cpp_dispatch_target(f)
|
|
dispatch_args = ', '.join(cpp_dispatch_exprs(f, method, python_signature=ps))
|
|
|
|
# from arg parser outputs to dispatch lambda arguments
|
|
parser_outputs = arg_parser_output_exprs(ps, f, method=method)
|
|
lambda_arg_exprs = dispatch_lambda_exprs(ps, f, method=method)
|
|
inits = '\n'.join(lambda_arg_exprs.inits)
|
|
lambda_args = ', '.join(lambda_arg_exprs.exprs)
|
|
|
|
# scatter fields
|
|
# TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky
|
|
# solution for enabling the 'requires_grad' argument for tensor methods
|
|
# new_full, new_empty, and new_zeros. A much better but more difficult to
|
|
# implement solution involves refactoring according to Ed's description here:
|
|
# https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589
|
|
need_set_requires_grad = ps.tensor_options_args and (not has_tensor_options(f) or (
|
|
ps.method and ('requires_grad' in parser_outputs)))
|
|
set_requires_grad = f'.set_requires_grad({parser_outputs["requires_grad"].expr})' \
|
|
if need_set_requires_grad else ''
|
|
|
|
auto_no_gil = '' if decl['with_gil'] else 'pybind11::gil_scoped_release no_gil;'
|
|
|
|
namedtuple_typeref = decl['namedtuple_typeref']
|
|
|
|
if lambda_return == 'void':
|
|
return f"""\
|
|
{schema_comment}
|
|
{inits}
|
|
auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
|
|
{auto_no_gil}
|
|
{dispatch_callee}({dispatch_args});
|
|
}};
|
|
dispatch_{name}({lambda_args}){set_requires_grad};
|
|
Py_RETURN_NONE;
|
|
"""
|
|
else:
|
|
return f"""\
|
|
{schema_comment}
|
|
{inits}
|
|
auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
|
|
{auto_no_gil}
|
|
return {dispatch_callee}({dispatch_args});
|
|
}};
|
|
return wrap({namedtuple_typeref}dispatch_{name}({lambda_args}){set_requires_grad});
|
|
"""
|
|
|
|
return go(f)
|