# Generates Python bindings for ATen functions # # The bindings are generated as methods on python_variable or functions on the # torch._C._nn. torch._C._fft, or torch._C._linalg objects. # # Code tries to stick to the following rules: # # - templates should be colocated with the functions that use them. # no templates are currently shared between functions, but if that # happens, maybe put the template with the first one # # - don't use environment dictionaries when calling template.substitute(). # pass named arguments directly for everything, otherwise it's much too # hard to track what's actually being used and by who # # - colocate any new hacks/adjustments with existing ones of the same kind. # ideally in a data structure rather than code if possible. See e.g. # SCHEMA_DEFAULT_CONVERSION_HACKS, etc. # # - similarly, conversions from one format to another should ideally happen # all at once in a single place. # # - no nontrivial nested functions. couple-liners are ok but please no more. # especially avoid functions that read/write outer variables defined far away. # # - raise RuntimeError instead of asserting, and put as much # information as is available into the message. I.e. no need to # plumb in new params whose only purpose is to fill out an error # message, but use what's there # from collections import defaultdict import re from .gen_variable_type import should_trace from .utils import write, is_tensor_method from tools.codegen.code_template import CodeTemplate from tools.codegen.api.python import * from tools.codegen.gen import cpp_string, with_native_function from tools.codegen.model import * from typing import Dict, Optional, List, Any # # declarations blocklist # We skip codegen for these functions, for various reasons. # Future PRs will categorize this list and eliminate or hoist # them out of eager-only codegen. # See https://github.com/pytorch/pytorch/issues/30788 # # These functions require manual Python bindings or are not exposed to Python SKIP_PYTHON_BINDINGS = [ 'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride', '.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward', '.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*', '_arange.*', '_range.*', '_linspace.*', '_logspace.*', '_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*', '_sparse_dense_add_out', 'index', 'unique_dim_consecutive', '_indexCopy_', '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_.*', '_thnn_.*', 'arange.*', 'range.*', '_solve.*', '_inverse.*', 'full(_out)?', '_cholesky.*', '_triangular_solve.*', '_qr.*', '_symeig.*', '_svd.*', 'slice', 'randint(_out)?', 'item', '_local_scalar_dense', 'to', 'copy_sparse_to_sparse_', 'copy_', 'numpy_T', # this needs to be an attribute in Python, not a function 'nonzero(_(out|numpy))?', 'set_quantizer_', # return types not supported yet 'set_data', '.*_overrideable', # overrideable functions for backend extension 'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_' ] # These function signatures are not exposed to Python. Note that this signature # list does not support regex. SKIP_PYTHON_BINDINGS_SIGNATURES = [ 'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)', 'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)', 'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)', 'div(Tensor, Scalar)', 'div_(Tensor, Scalar)', ] NATIVE_NAMESPACE_MAPPING = { "torch": "THPVariableFunctionsModule", "torch.nn": "THPNNVariableFunctionsModule", "torch.fft": "THPFFTVariableFunctionsModule", "torch.linalg": "THPLinalgVariableFunctionsModule", } def should_generate_python_binding(declaration): name = declaration['name'] for pattern in SKIP_PYTHON_BINDINGS: if re.match('^' + pattern + '$', name): return False simple_types = [arg['simple_type'] for arg in declaration['arguments']] signature = '{}({})'.format(name, ', '.join(simple_types)) for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: if pattern == signature: return False return True # # top-level codegen functions, called from gen_autograd # def get_py_variable_methods(declarations): """ Get declarations (grouped by name) which should be generated as methods on Tensor. """ def should_bind(declaration): return (should_generate_python_binding(declaration) and not is_nn_module_function(declaration) and is_tensor_method(declaration)) return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) def gen_py_variable_methods(out, declarations, template_path): """ Generate Tensor methods. """ PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp') py_variable_methods = get_py_variable_methods(declarations) env = create_python_bindings(py_variable_methods, is_python_method=True, module=None) write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env) def get_py_nn_functions(declarations): """ Get declarations (grouped by name) which should be generated as functions in the "nn" module. """ def should_bind(declaration): return (should_generate_python_binding(declaration) and is_nn_module_function(declaration)) return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) def gen_py_nn_functions(out, declarations, template_path): """ Generate functions in the "nn" module. """ PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp') py_nn_functions = get_py_nn_functions(declarations) env = create_python_bindings(py_nn_functions, is_python_method=False, module="torch.nn") write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env) def get_py_fft_functions(declarations): """ Get declarations (grouped by name) which should be generated as functions in the "fft" module. """ def should_bind(declaration): return (should_generate_python_binding(declaration) and is_fft_module_function(declaration)) return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) def gen_py_fft_functions(out, declarations, template_path): """ Generate functions in the "fft" module. """ PY_FFT_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_fft_functions.cpp') py_fft_functions = get_py_fft_functions(declarations) env = create_python_bindings(py_fft_functions, is_python_method=False, module="torch.fft") write(out, 'python_fft_functions.cpp', PY_FFT_FUNCTIONS_CPP, env) def get_py_linalg_functions(declarations): """ Get declarations (grouped by name) which should be generated as functions in the "linalg" module. """ def should_bind(declaration): return (should_generate_python_binding(declaration) and is_linalg_module_function(declaration)) return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) def gen_py_linalg_functions(out, declarations, template_path): """ Generate functions in the "linalg" module. """ PY_LINALG_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_linalg_functions.cpp') py_linalg_functions = get_py_linalg_functions(declarations) env = create_python_bindings(py_linalg_functions, is_python_method=False, module="torch.linalg") write(out, 'python_linalg_functions.cpp', PY_LINALG_FUNCTIONS_CPP, env) def get_py_torch_functions(declarations): """ Get declarations (grouped by name) which should be generated as functions in the "torch" module. """ def should_bind(declaration): return (should_generate_python_binding(declaration) and not is_nn_module_function(declaration) and not is_fft_module_function(declaration) and not is_linalg_module_function(declaration) and is_torch_function(declaration)) return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) def gen_py_torch_functions(out, declarations, template_path): """ Generate functions in the "torch" module. """ PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp') py_torch_functions = get_py_torch_functions(declarations) env = create_python_bindings(py_torch_functions, is_python_method=False, module="torch") write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env) def group_declarations_by_op_name(declarations): groups = defaultdict(list) for d in declarations: groups[op_name(d)].append(d) return groups def create_python_bindings(python_functions, is_python_method, module): """Generates Python bindings to ATen functions""" py_methods = [] py_method_defs = [] py_forwards = [] for name in sorted(python_functions.keys()): overload_decls = python_functions[name] for declaration in overload_decls: # TODO: change all methods to directly process python signatures instead of decls. declaration['python_signature'] = decl_to_python_signature(declaration, method=is_python_method) declaration['native_function'] = decl_to_native_function(declaration) py_methods.append(method_impl(name, overload_decls, is_python_method, module)) py_method_defs.append(method_def(name, overload_decls, is_python_method, module)) py_forwards.extend(forward_decls(name, overload_decls, is_python_method, module)) return { 'py_forwards': py_forwards, 'py_methods': py_methods, 'py_method_defs': py_method_defs, } # handler for output/no-output overload pair # (plugged into PY_VARIABLE_CASE as ${call_dispatch}) PY_VARIABLE_OUT = CodeTemplate("""\ if (_r.isNone(${out_idx})) { ${call_dispatch} } else { ${call_dispatch_out} } """) # handler for a single parsed signature - may be a single overload or # a pair of overloads that whose signatures only differ in output params PY_VARIABLE_CASE = CodeTemplate("""\ case ${i}: { ${body} } """) def emit_dispatch_case(i, dictionary, is_python_method): """ Emit dispatch code for a single parsed signature. This corresponds to either a single overload, or a pair that differ only in output params. In the latter case, a single signature is used for both and dispatching switches on the presence/absence of passed output args. - i: this signature's position in generated binding's signature list if number of signatures > 1, otherwise None - dictionary: contains a no-output overload declaration under 'base', and optionally a second overload with outputs under 'out' - true if we're generating a python method, in which case self is not parsed but passed directly """ base_decl = dictionary['base'] python_sig = base_decl['python_signature'] if 'out' in dictionary: # dispatch to output or no-output variant based on arg test out_decl = dictionary['out'] python_sig = out_decl['python_signature'] # prefer output variant out_idx = get_python_output_index(out_decl) call_dispatch = emit_single_dispatch(python_sig, base_decl, is_python_method) call_dispatch_out = emit_single_dispatch(python_sig, out_decl, is_python_method) # dispatch output and no-output variants, branch on _r.isNone() body = PY_VARIABLE_OUT.substitute( out_idx=out_idx, call_dispatch=call_dispatch, call_dispatch_out=call_dispatch_out, ) else: # no-output version only body = emit_single_dispatch(python_sig, base_decl, is_python_method) if i is not None: # generate case for ith overload return PY_VARIABLE_CASE.substitute(i=i, body=body) else: # only one overload, omit case wrapper return body # # named tuple codegen # def namedtuple_fieldnames(declaration): returns = declaration['returns'] if len(returns) <= 1 or all(['field_name' not in x for x in returns]): return [] else: def get_field_name(x): # See Note [field_name versus name] if 'field_name' not in x: # When building on Windows, `PyStructSequence_UnnamedField` could not be # resolved by the linker for some reason, which cause error in building: # # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol # PyStructSequence_UnnamedField # # Thus, at this point in time, we do not support unnamed # fields in namedtuple; you must either name all fields, # or none of them. raise ValueError("Unnamed field is not supported by codegen") else: return x['field_name'] return [get_field_name(x) for x in returns] PY_NAMEDTUPLE_FIELDSDEF = CodeTemplate("""\ static PyStructSequence_Field ${fieldsname}[] = { ${fields,} {nullptr} }; """) PY_NAMEDTUPLE_TYPEDEF = CodeTemplate("""\ static PyTypeObject ${typename}; static bool ${typename}_initialized = false; if (!${typename}_initialized) { ${typename}_initialized = true; static PyStructSequence_Desc desc = { "torch.return_types.${name}", nullptr, ${fieldsname}, ${size} }; PyStructSequence_InitType(&${typename}, &desc); ${typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; } """) def emit_namedtuple_typedefs(declarations): """ Generate block of named tuple type def inits, and add typeref snippets to declarations that use them """ flddefnames = {} # map from unique field name lists to field def name flddefs = [] # field def declarations typenames = {} # map from unique name + field name lists to typedef name typedefs = [] # typedef declarations and init code for decl in declarations: fieldnames = namedtuple_fieldnames(decl) if fieldnames == []: decl['namedtuple_typeref'] = '' continue fn_key = '_'.join(fieldnames) fieldsname = flddefnames.get(fn_key) if fieldsname is None: fieldsname = 'NamedTuple_fields{}'.format('' if flddefs == [] else len(flddefs)) fields = ['{{"{}", ""}}'.format(fn) for fn in fieldnames] fieldsdef = PY_NAMEDTUPLE_FIELDSDEF.substitute( fieldsname=fieldsname, fields=fields ) flddefnames[fn_key] = fieldsname flddefs.append(fieldsdef) name = decl['name'] key = '{}_{}'.format(name, '_'.join(fieldnames)) typename = typenames.get(key) if typename is None: typename = 'NamedTuple{}'.format('' if typedefs == [] else len(typedefs)) typedef = PY_NAMEDTUPLE_TYPEDEF.substitute( name=name, typename=typename, size=len(fieldnames), fieldsname=fieldsname ) typenames[key] = typename typedefs.append(typedef) decl['namedtuple_typeref'] = '&{}, '.format(typename) return flddefs + typedefs # # method impl codegen # def get_pycname(name): return 'THPVariable_{}'.format(name) def is_noarg_binding(overloads): return len(overloads) == 1 and get_python_argc(overloads[0]) == 0 # python binding for all overloads of a particular function/method PY_VARIABLE_METHOD_VARARGS = CodeTemplate(r"""\ // ${name} static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) { ${method_header} static PythonArgParser parser({ ${signatures} }, /*traceable=*/${traceable}); ParsedArgs<${max_args}> parsed_args; auto _r = parser.parse(${self_}, args, kwargs, parsed_args); ${check_has_torch_function} switch (_r.idx) { ${dispatch} } ${method_footer} } """) # python binding for single-overload function/method PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate("""\ // ${name} static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) { ${method_header} static PythonArgParser parser({ ${signatures} }, /*traceable=*/${traceable}); ParsedArgs<${max_args}> parsed_args; auto _r = parser.parse(${self_}, args, kwargs, parsed_args); ${check_has_torch_function} ${dispatch} ${method_footer} } """) # python binding for a method with no args, shortcuts parsing PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\ // ${name} static PyObject * ${pycname}(PyObject* self_, PyObject* args) { ${method_header} ${check_has_torch_function} ${dispatch} ${method_footer} } """) TORCH_FUNCTION_CHECK = CodeTemplate("""\ if(_r.has_torch_function()) { return handle_torch_function(_r, ${self_}, args, kwargs, ${namespace}, ${modulename}); } """) TORCH_FUNCTION_CHECK_NOARGS = CodeTemplate("""\ if(check_has_torch_function(self_)) { return handle_torch_function(self_, ${name}); } """) # NOTE: we type the unpacked self as Tensor not Variable to avoid return type # discrepancies on method resolution (e.g. Variable::detach_ returns void # rather than Tensor &) UNPACK_SELF = "Tensor& self = reinterpret_cast(self_)->cdata;" def method_impl(name, declarations, is_python_method, module): """ Generate a python binding for all overloads of an op. """ pycname = get_pycname(name) method_header = ['HANDLE_TH_ERRORS'] method_header += emit_namedtuple_typedefs(declarations) method_header += [UNPACK_SELF] if is_python_method else [] method_footer = ['END_HANDLE_TH_ERRORS'] check_has_torch_function = TORCH_FUNCTION_CHECK_NOARGS.substitute( name='"' + name + '"', ) if is_python_method else '' # emit dispatch if is_noarg_binding(declarations): python_sig = declarations[0]['python_signature'] dispatch = emit_single_dispatch(python_sig, declarations[0], is_python_method) return PY_VARIABLE_METHOD_NOARGS.substitute( name=name, pycname=pycname, method_header=method_header, dispatch=dispatch, method_footer=method_footer, check_has_torch_function=check_has_torch_function, ) method_footer = ['Py_RETURN_NONE;'] + method_footer grouped = group_overloads(declarations, is_python_method) is_singleton = len(grouped) == 1 signatures = [] dispatch = [] for i, dictionary in enumerate(grouped): signature = dictionary['signature'] signatures.append(f'{cpp_string(str(signature))},') overload_index = i if not is_singleton else None dispatch.append(emit_dispatch_case(overload_index, dictionary, is_python_method)) if is_singleton: template = PY_VARIABLE_METHOD_VARARGS_SINGLETON else: template = PY_VARIABLE_METHOD_VARARGS if module: check_has_torch_function = TORCH_FUNCTION_CHECK.substitute( namespace=NATIVE_NAMESPACE_MAPPING[module], modulename='"' + module + '"', self_="self_" if is_python_method else "nullptr", ) else: check_has_torch_function = TORCH_FUNCTION_CHECK.substitute( namespace="THPVariableClass", modulename='"torch.Tensor"', self_="self_" if is_python_method else "nullptr", ) max_args = max([get_python_argc(decl) for decl in declarations]) traceable = 'true' if all(should_trace(d) for d in declarations) else 'false' return template.substitute( name=name, pycname=pycname, method_header=method_header, max_args=max_args, signatures=signatures, traceable=traceable, check_has_torch_function=check_has_torch_function, dispatch=dispatch, method_footer=method_footer, self_="self_" if is_python_method else "nullptr", ) # # forward declarations # PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION = CodeTemplate("""\ static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs); """) PY_VARIABLE_FUNCTION_NOARGS_FORWARD_DECLARATION = CodeTemplate("""\ static PyObject * ${pycname}(PyObject* self_, PyObject* args); """) def forward_decls(name, declarations, is_python_method, module): if is_python_method: return [] if is_noarg_binding(declarations): template = PY_VARIABLE_FUNCTION_NOARGS_FORWARD_DECLARATION else: template = PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION pycname = get_pycname(name) return [template.substitute(pycname=pycname)] # # method def (binding table entry) codegen # # Python binary operator dunder methods BINARY_OP_NAMES = [ '__lt__', '__le__', '__gt__', '__ge__', '__eq__', '__ne__', '__add__', '__radd__', '__iadd__', '__sub__', '__rsub__', '__isub__', '__mul__', '__rmul__', '__imul__', '__matmul__', '__rmatmul__', '__imatmul__', '__truediv__', '__rtruediv__', '__itruediv__', '__floordiv__', '__rfloordiv__', '__ifloordiv__', '__mod__', '__rmod__', '__imod__', '__divmod__', '__rdivmod__', '__idivmod__', '__pow__', '__rpow__', '__ipow__', '__lshift__', '__rlshift__', '__ilshift__', '__rshift__', '__rrshift__', '__irshift__', '__and__', '__rand__', '__iand__', '__xor__', '__rxor__', '__ixor__', '__or__', '__ror__', '__ior__', ] # PyMethodDef entry for binary op, throws not implemented error PY_VARIABLE_METHOD_BINOP_DEF = CodeTemplate("""\ {"${name}", ${pyfunc_cast}(TypeError_to_NotImplemented_<${pycname}>), ${flags}, NULL},""") # PyMethodDef entry PY_VARIABLE_METHOD_DEF = CodeTemplate("""\ {"${name}", ${pyfunc_cast}(${pycname}), ${flags}, NULL},""") def method_def(name, declarations, is_python_method, module): """ Generate method def entry. """ pycname = get_pycname(name) if is_noarg_binding(declarations): pyfunc_cast = '' flags = 'METH_NOARGS' if is_python_method else 'METH_VARARGS | METH_KEYWORDS' else: pyfunc_cast = 'castPyCFunctionWithKeywords' flags = 'METH_VARARGS | METH_KEYWORDS' if module == "torch": flags += ' | METH_STATIC' if name in BINARY_OP_NAMES: def_template = PY_VARIABLE_METHOD_BINOP_DEF else: def_template = PY_VARIABLE_METHOD_DEF return def_template.substitute( name=name, pycname=pycname, pyfunc_cast=pyfunc_cast, flags=flags, ) # # overload sorting and grouping # def group_overloads(declarations, is_python_method): """Returns a list of dictionaries containing the optional keys: "base": the regular ATen declaration (e.g. conv2d) "out": the out variant (e.g. conv2d_out) "signature": the signature used for Python argument parsing Note that we merge pairs of declarations with signatures that are equivalent mod output arguments, and use a single entry in the python_arg_parser sig list for both (output arguments become optional) """ grouped = defaultdict(dict) # first group by signature ignoring out arguments for declaration in declarations: signature = get_python_signature(declaration, is_python_method, skip_outputs=True) v = grouped[signature] if declaration['name'].endswith('_out'): v['out'] = declaration # prefer the signature with optional out=... arguments v['signature'] = get_python_signature(declaration, is_python_method) else: v['base'] = declaration if 'signature' not in v: v['signature'] = signature result = [] for x, dictionary in sorted(grouped.items()): if 'base' not in dictionary: candidates = [] non_out_name = dictionary['out']['operator_name'] for declaration in declarations: if declaration['name'] == non_out_name and not declaration['deprecated']: signature = get_python_signature(declaration, is_python_method, skip_outputs=True) candidates.append(signature) raise RuntimeError( "While identifying overloads, we found an out schema {} without a corresponding non-out variant. " "We expected the non-out variant to have schema: \n- {}\nPlease check that you spelled the schema " "correctly in native_functions.yaml. We discovered the following candidate(s): \n" .format(dictionary['signature'], x) + "\n".join("- {}".format(candidate) for candidate in candidates)) result.append(dictionary) return sort_declarations(result) # This function declares a partial order on declarations, and sorts them according # to its linear extension. This is necessary, because there's some ambiguity in the # choice of overload, and we want a different order. # # See Note[Order of overloads matters] def sort_declarations(grouped_decls): def dynamic_type(arg): return arg['dynamic_type'] def is_coord_smaller(arg1, arg2): return dynamic_type(arg1) == 'Scalar' and arg2['dynamic_type'] == 'Tensor' def is_smaller(d1, d2): """Returns True if d1 < d2 in the partial order.""" args1, args2 = d1['base']['arguments'], d2['base']['arguments'] if len(args1) != len(args2): return False any_smaller = any(is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2)) all_smaller_or_equal = all(dynamic_type(arg1) == dynamic_type(arg2) or is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2)) return any_smaller and all_smaller_or_equal # Construct the relation graph larger_than = defaultdict(set) for i1, decl1 in enumerate(grouped_decls): for i2, decl2 in enumerate(grouped_decls): if is_smaller(decl1, decl2): larger_than[i1].add(i2) if not larger_than: return grouped_decls # Use a topological sort to sort decls according to the partial order. sorted_deps = [(i, decl) for i, decl in enumerate(grouped_decls) if i not in larger_than] for i, decl in sorted_deps: for i2 in sorted(larger_than.keys()): larger = larger_than[i2] larger.discard(i) if not larger: del larger_than[i2] sorted_deps.append((i2, grouped_decls[i2])) return [decl for i, decl in sorted_deps] # # python signature codegen # def get_python_signature(declaration, is_python_method, skip_outputs=False): return declaration['python_signature'].signature_str(skip_outputs=skip_outputs) # # op args to python parsed args transform # def get_python_argc(decl): return len(decl['python_signature'].arguments()) def get_python_output_index(decl): ps: PythonSignature = decl['python_signature'] return len(ps.input_args) + len(ps.input_kwargs) # # declaration derived props, utils, etc. # declarations are dicts loaded from Declarations.yaml, # passed to our codegen methods by callers in gen_autograd # def is_output(arg): return arg.get('output', False) def has_outputs(declaration): return any([is_output(arg) for arg in declaration['arguments']]) def is_torch_function(declaration): return 'namespace' in declaration['method_of'] def is_nn_module_function(declaration): return declaration.get('python_module') == 'nn' def is_fft_module_function(declaration): return declaration.get('python_module') == 'fft' def is_linalg_module_function(declaration): return declaration.get('python_module') == 'linalg' def op_name(declaration): name = declaration['name'] if has_outputs(declaration): if not name.endswith("_out"): raise RuntimeError( '{} has output params, expecting name ending with \'_out\''. format(declaration['name'])) return name[:-4] else: if name.endswith("_out"): raise RuntimeError( '{}: name ends with \'_out\', expecting output params'. format(declaration['name'])) return name # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # Codegen API Integration # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # These helper functions allow us to call the new codegen API from the # old codegen script (which operates on Declarations.yaml). # TODO: remove all these HACKs after migration is completed! # function schema str -> NativeFunction NF_TABLE: Optional[Dict[str, NativeFunction]] = None def init(native_yaml_path: str) -> None: from tools.codegen.gen import parse_native_yaml global NF_TABLE NF_TABLE = {str(f.func): f for f in parse_native_yaml(native_yaml_path)} # Multiple decl entries can map to the same native function (because of deprecated decl). def decl_to_native_function(decl: Dict[str, Any]) -> NativeFunction: assert NF_TABLE is not None, 'need to initialize codegen.api.python with init()' function_schema_str = decl['schema_string'] assert function_schema_str.startswith('aten::'), f'unknown namespace: {function_schema_str}' function_schema_str = function_schema_str[len('aten::'):] assert function_schema_str in NF_TABLE, f'cannot find func: {function_schema_str}' return NF_TABLE[function_schema_str] # Each decl entry has unique python signature. def decl_to_python_signature(decl: Dict[str, Any], *, method: bool) -> PythonSignature: f = decl_to_native_function(decl) @with_native_function def go(f: NativeFunction) -> PythonSignature: return signature(f, method=method) python_sig = go(f) if decl.get('deprecated', False): # TODO: directly load 'deprecated.yaml'. # deprecated.yaml doesn't have complete type information, we need # leverage the source signature (to which it delegates the call). # Deprecated signature might reorder input_args and input_kwargs, # but never changes output_args nor python_binding_args (if any?), # so here we only look into these two types of args. src_args: Dict[str, PythonArgument] = {a.name: PythonArgument( name=a.name, type=a.type, default=None, default_init=None, ) for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs)} args: List[Dict[str, Any]] = decl['arguments'] input_arg_names: List[str] = \ list(str(a['name']) for a in args if not a['kwarg_only'] and not a['output']) input_kwarg_names: List[str] = \ list(str(a['name']) for a in args if a['kwarg_only'] and not a['output']) python_sig = PythonSignatureDeprecated( name=python_sig.name, input_args=tuple(src_args[n] for n in input_arg_names if not method or n != 'self'), input_kwargs=tuple(src_args[n] for n in input_kwarg_names), output_args=python_sig.output_args, tensor_options_args=python_sig.tensor_options_args, method=python_sig.method, deprecated_args_names=tuple(str(a['name']) for a in args), deprecated_args_exprs=tuple(decl.get('call_args')), ) return python_sig def emit_single_dispatch(ps: PythonSignature, decl: Dict[str, Any], method: bool) -> str: """ Emit dispatch code for a single declared overload. """ f = decl['native_function'] @with_native_function def go(f: NativeFunction) -> str: # header comments deprecated = '[deprecated] ' if ps.deprecated else '' schema_comment = f'// {deprecated}aten::{f.func}' # dispatch lambda signature name = decl['name'] lambda_formals = ', '.join(map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f, method=method))) lambda_return = dispatch_lambda_return_str(f) # dispatch lambda body dispatch_callee = cpp_dispatch_target(f) dispatch_args = ', '.join(cpp_dispatch_exprs(f, method, python_signature=ps)) # from arg parser outputs to dispatch lambda arguments parser_outputs = arg_parser_output_exprs(ps, f, method=method) lambda_arg_exprs = dispatch_lambda_exprs(ps, f, method=method) inits = '\n'.join(lambda_arg_exprs.inits) lambda_args = ', '.join(lambda_arg_exprs.exprs) # scatter fields # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky # solution for enabling the 'requires_grad' argument for tensor methods # new_full, new_empty, and new_zeros. A much better but more difficult to # implement solution involves refactoring according to Ed's description here: # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589 need_set_requires_grad = ps.tensor_options_args and (not has_tensor_options(f) or ( ps.method and ('requires_grad' in parser_outputs))) set_requires_grad = f'.set_requires_grad({parser_outputs["requires_grad"].expr})' \ if need_set_requires_grad else '' auto_no_gil = '' if decl['with_gil'] else 'pybind11::gil_scoped_release no_gil;' namedtuple_typeref = decl['namedtuple_typeref'] if lambda_return == 'void': return f"""\ {schema_comment} {inits} auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ {auto_no_gil} {dispatch_callee}({dispatch_args}); }}; dispatch_{name}({lambda_args}){set_requires_grad}; Py_RETURN_NONE; """ else: return f"""\ {schema_comment} {inits} auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ {auto_no_gil} return {dispatch_callee}({dispatch_args}); }}; return wrap({namedtuple_typeref}dispatch_{name}({lambda_args}){set_requires_grad}); """ return go(f)