# Generates Python bindings for ATen functions # # The bindings are generated as methods on python_variable or functions on the # torch._C._nn. torch._C._fft, or torch._C._linalg objects. # # Code tries to stick to the following rules: # # - templates should be colocated with the functions that use them. # no templates are currently shared between functions, but if that # happens, maybe put the template with the first one # # - don't use environment dictionaries when calling template.substitute(). # pass named arguments directly for everything, otherwise it's much too # hard to track what's actually being used and by who # # - colocate any new hacks/adjustments with existing ones of the same kind. # ideally in a data structure rather than code if possible. See e.g. # SCHEMA_DEFAULT_CONVERSION_HACKS, etc. # # - similarly, conversions from one format to another should ideally happen # all at once in a single place. # # - no nontrivial nested functions. couple-liners are ok but please no more. # especially avoid functions that read/write outer variables defined far away. # # - raise RuntimeError instead of asserting, and put as much # information as is available into the message. I.e. no need to # plumb in new params whose only purpose is to fill out an error # message, but use what's there # from collections import defaultdict import re from .gen_variable_type import should_trace from .utils import write, is_tensor_method from tools.codegen.code_template import CodeTemplate # # declarations blocklist # We skip codegen for these functions, for various reasons. # Future PRs will categorize this list and eliminate or hoist # them out of eager-only codegen. # See https://github.com/pytorch/pytorch/issues/30788 # # These functions require manual Python bindings or are not exposed to Python SKIP_PYTHON_BINDINGS = [ 'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride', '.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward', '.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*', '_arange.*', '_range.*', '_linspace.*', '_logspace.*', '_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*', '_sparse_dense_add_out', 'index', 'unique_dim_consecutive', '_indexCopy_', '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_.*', '_thnn_.*', 'arange.*', 'range.*', '_solve.*', '_inverse.*', 'full(_out)?', '_cholesky.*', '_triangular_solve.*', '_qr.*', '_symeig.*', '_svd.*', 'slice', 'randint(_out)?', 'item', '_local_scalar_dense', 'to', 'copy_sparse_to_sparse_', 'copy_', 'numpy_T', # this needs to be an attribute in Python, not a function 'nonzero(_(out|numpy))?', 'set_quantizer_', # return types not supported yet 'set_data', '.*_overrideable', # overrideable functions for backend extension 'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad', 'set_' ] # These function signatures are not exposed to Python. Note that this signature # list does not support regex. SKIP_PYTHON_BINDINGS_SIGNATURES = [ 'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)', 'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)', 'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)', 'div(Tensor, Scalar)', 'div_(Tensor, Scalar)', ] NATIVE_NAMESPACE_MAPPING = { "torch": "THPVariableFunctionsModule", "torch.nn": "THPNNVariableFunctionsModule", "torch.fft": "THPFFTVariableFunctionsModule", "torch.linalg": "THPLinalgVariableFunctionsModule", } def should_generate_python_binding(declaration): name = declaration['name'] for pattern in SKIP_PYTHON_BINDINGS: if re.match('^' + pattern + '$', name): return False simple_types = [arg['simple_type'] for arg in declaration['arguments']] signature = '{}({})'.format(name, ', '.join(simple_types)) for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES: if pattern == signature: return False return True # # top-level codegen functions, called from gen_autograd # def get_py_variable_methods(declarations): """ Get declarations (grouped by name) which should be generated as methods on Tensor. """ def should_bind(declaration): return (should_generate_python_binding(declaration) and not is_nn_module_function(declaration) and is_tensor_method(declaration)) return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) def gen_py_variable_methods(out, declarations, template_path): """ Generate Tensor methods. """ PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp') py_variable_methods = get_py_variable_methods(declarations) env = create_python_bindings(py_variable_methods, is_python_method=True, module=None) write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env) def get_py_nn_functions(declarations): """ Get declarations (grouped by name) which should be generated as functions in the "nn" module. """ def should_bind(declaration): return (should_generate_python_binding(declaration) and is_nn_module_function(declaration)) return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) def gen_py_nn_functions(out, declarations, template_path): """ Generate functions in the "nn" module. """ PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp') py_nn_functions = get_py_nn_functions(declarations) env = create_python_bindings(py_nn_functions, is_python_method=False, module="torch.nn") write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env) def get_py_fft_functions(declarations): """ Get declarations (grouped by name) which should be generated as functions in the "fft" module. """ def should_bind(declaration): return (should_generate_python_binding(declaration) and is_fft_module_function(declaration)) return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) def gen_py_fft_functions(out, declarations, template_path): """ Generate functions in the "fft" module. """ PY_FFT_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_fft_functions.cpp') py_fft_functions = get_py_fft_functions(declarations) env = create_python_bindings(py_fft_functions, is_python_method=False, module="torch.fft") write(out, 'python_fft_functions.cpp', PY_FFT_FUNCTIONS_CPP, env) def get_py_linalg_functions(declarations): """ Get declarations (grouped by name) which should be generated as functions in the "linalg" module. """ def should_bind(declaration): return (should_generate_python_binding(declaration) and is_linalg_module_function(declaration)) return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) def gen_py_linalg_functions(out, declarations, template_path): """ Generate functions in the "linalg" module. """ PY_LINALG_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_linalg_functions.cpp') py_linalg_functions = get_py_linalg_functions(declarations) env = create_python_bindings(py_linalg_functions, is_python_method=False, module="torch.linalg") write(out, 'python_linalg_functions.cpp', PY_LINALG_FUNCTIONS_CPP, env) def get_py_torch_functions(declarations): """ Get declarations (grouped by name) which should be generated as functions in the "torch" module. """ def should_bind(declaration): return (should_generate_python_binding(declaration) and not is_nn_module_function(declaration) and not is_fft_module_function(declaration) and not is_linalg_module_function(declaration) and is_torch_function(declaration)) return group_declarations_by_op_name([d for d in declarations if should_bind(d)]) def gen_py_torch_functions(out, declarations, template_path): """ Generate functions in the "torch" module. """ PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp') py_torch_functions = get_py_torch_functions(declarations) env = create_python_bindings(py_torch_functions, is_python_method=False, module="torch") write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env) def group_declarations_by_op_name(declarations): groups = defaultdict(list) for d in declarations: groups[op_name(d)].append(d) return groups def create_python_bindings(python_functions, is_python_method, module): """Generates Python bindings to ATen functions""" py_methods = [] py_method_defs = [] py_forwards = [] for name in sorted(python_functions.keys()): overload_decls = python_functions[name] py_methods.append(method_impl(name, overload_decls, is_python_method, module)) py_method_defs.append(method_def(name, overload_decls, is_python_method, module)) py_forwards.extend(forward_decls(name, overload_decls, is_python_method, module)) return { 'py_forwards': py_forwards, 'py_methods': py_methods, 'py_method_defs': py_method_defs, } # # extracting and storing parsed args # UNPACK_METHODS = { 'const Tensor &': 'tensor', 'Tensor &': 'tensor', 'c10::optional': 'optionalTensor', 'const c10::optional&': 'optionalTensor', 'c10::optional': 'generator', 'Storage': 'storage', 'Storage &': 'storage', 'const ScalarType &': 'scalartype', 'const Device &': 'device', 'c10::optional': 'toDimnameListOptional', 'c10::optional': 'scalartypeOptional', 'c10::optional': 'layoutOptional', 'c10::optional': 'memoryformatOptional', 'c10::optional': 'scalarOptional', 'c10::optional': 'intlistOptional', 'c10::optional': 'toInt64Optional', 'c10::optional': 'toBoolOptional', 'c10::optional': 'toDoubleOptional', 'c10::optional>': 'doublelistOptional', 'ArrayRef': 'doublelist', 'IntArrayRef': 'intlist', 'Scalar': 'scalar', 'ScalarType': 'scalartype', 'Dimname': 'dimname', 'DimnameList': 'dimnamelist', 'TensorList': 'tensorlist', 'int64_t': 'toInt64', 'bool': 'toBool', 'double': 'toDouble', 'std::string': 'string', 'c10::optional': 'stringOptional', } UNPACK_WITH_SIZE_METHODS = { 'TensorList': 'tensorlist_n<{}>', 'DimnameList': 'dimnamelist', 'IntArrayRef': 'intlist', 'c10::optional': 'intlistOptional', } UNPACK_WITH_DEFAULT_METHODS = { 'const ScalarType &': 'scalartypeWithDefault', 'const Device &': 'deviceWithDefault', 'c10::optional': 'layoutWithDefault', } def parsed_arg_expr(arg, arg_index): # e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)' typename = arg['type'] default_init = arg.get('python_default_init') if default_init is not None: # Note: only introduced by make_python_binding_args default_init = arg['python_default_init'] if typename not in UNPACK_WITH_DEFAULT_METHODS: raise RuntimeError( 'type \'{}\' is not supported in python_default_init'. format(typename)) unpack_with_default = UNPACK_WITH_DEFAULT_METHODS[typename] return '_r.{}({}, {})'.format(unpack_with_default, arg_index, default_init) size = arg.get('size') if size is not None: if typename not in UNPACK_WITH_SIZE_METHODS: raise RuntimeError( 'type \'{}\' with definite size ({}) is not supported'. format(typename, size)) unpack_with_size = UNPACK_WITH_SIZE_METHODS[typename].format(size) return '_r.{}({})'.format(unpack_with_size, arg_index) unpack = UNPACK_METHODS.get(typename) if unpack is None: raise RuntimeError('type \'{}\' is not supported'.format(typename)) return '_r.{}({})'.format(unpack, arg_index) # TODO make this part of something more general, or get rid of it def unpack_optional_dimname_list_hack(name, expr): # optional> are special. The PythonArgParser returns an # optional>, which cannot be implicitly converted to # optional>. One needs to unwrap the optional and rewrap. result = """\ auto __{name} = {expr}; c10::optional<{typ}> {name} = __{name} ? c10::make_optional({typ}(__{name}.value())) : c10::nullopt; """.format(name=name, expr=expr, typ='DimnameList') return [line.strip() for line in result.split('\n')] def parse_arg(arg, arg_index, unpack_to_local=False): # get parsed rhs expr = parsed_arg_expr(arg, arg_index) # maybe unpack to local name = arg['name'] typename = arg['type'] if typename == 'c10::optional': inits = unpack_optional_dimname_list_hack(name, expr) expr = name elif unpack_to_local: inits = ['auto {} = {};'.format(name, expr)] expr = name else: inits = [] return expr, inits # # schema type to cpp type conversions # some of these are to prevent dangling refs to temps, others are more obscure # TODO don't know if these fold into more general conversions somehere, hope so # TEMP_SAFE_CPP_DECL_TYPE = { 'Tensor &': 'Tensor', } def get_cpp_decl_type(typename, ensure_temp_safe=True): if ensure_temp_safe: typename = TEMP_SAFE_CPP_DECL_TYPE.get(typename, typename) return typename def get_cpp_formal(arg, ensure_temp_safe=True): decl_type = get_cpp_decl_type(arg['type'], ensure_temp_safe) return '{} {}'.format(decl_type, arg['name']) # XXX: if you got here because of an assertion failure, it doesn't mean # it's enough to just extend the list here. Before you do this, make sure # to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h. SUPPORTED_RETURN_TYPES = { 'Tensor', 'std::tuple', 'std::tuple', 'std::tuple', 'std::tuple', 'std::tuple', 'std::tuple', 'std::tuple', 'std::tuple', 'std::tuple', 'std::vector', 'Scalar', 'bool', 'int64_t', 'void*', 'void', 'QScheme', 'double', 'IntArrayRef', 'ScalarType' } def get_simple_return_type(declaration): # Use the simple_return_type (Tensor) rather than the fancy return type # (Tensor &). This is important because the dispatch lambdas take # mutable arguments *by value*, not by reference. If you then return # a reference to such an argument, you will now have a pointer to a # dangling stack entry. Not good. # # You want: # # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); }; # ^^^^^^ # # *not* # # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); }; # ^^^^^^^ # # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a # mutable reference to temporary. Maybe we could assign it to a # variable itself.) # simple_return_type = declaration['return_type'].replace(' &', '') if simple_return_type not in SUPPORTED_RETURN_TYPES: raise RuntimeError(declaration['name'] + " returns unsupported type " + simple_return_type) return simple_return_type # # dispatch codegen # def get_dispatch_callee(declaration): # format the name of the receiving function or method if is_tensor_method(declaration): return 'self.{}'.format(declaration['name']) elif is_torch_function(declaration): namespace = function_namespace(declaration) return '{}::{}'.format(namespace, declaration['name']) else: raise RuntimeError('could not dispatch, neither namespace function nor Tensor method') def get_op_args(declaration, argmap): # returns a list of argmap values in op call order, with two wrinkles: # 1. 'self' is eliminated for methods, it's baked into the callee expression elsewhere # 2. declaration['call_args'] shims legacy overrides and may contain constant values, # not just names (see load_deprecated_signatures() in gen_autograd.py) call_args_override = declaration.get('call_args') if call_args_override: # names or constants keys = call_args_override else: # only names keys = [param['name'] for param in declaration['arguments']] if is_tensor_method(declaration): # exclude self for method calls keys = [k for k in keys if k != 'self'] if call_args_override: # assume missing keys are constants return [argmap.get(k, k) for k in keys] else: return [argmap[k] for k in keys] TENSOR_OPTIONS_DECL = CodeTemplate("""\ const auto ${name} = TensorOptions() .dtype(${dtype}) .device(${device}) .layout(${layout}) .requires_grad(${requires_grad}) .pinned_memory(${pin_memory}); """) # addition to output-variant handler in which tensor options params # (if present) are checked against properties of a tensor output param # TODO remove hardcoding, use unpack logic from emit_single_dispatch PY_VARIABLE_CHECK_OUT_TYPE_HACK = CodeTemplate("""\ check_out_type_matches(_r.tensor(${out_idx}), _r.scalartype(${type_idx}), _r.isNone(${type_idx}), _r.layoutOptional(${layout_idx}), _r.device(${device_idx}), _r.isNone(${device_idx})); """) # Unpack parsed args to locals, call the op, and wrap the result. # Lambda is so GIL is back on by wrap() time (wrap can allocate) PY_VARIABLE_WRAP = CodeTemplate("""\ ${inits} auto dispatch_${name} = [](${lambda_formals}) -> ${simple_return_type} { ${auto_no_gil} return ${dispatch_callee}(${dispatch_args}); }; return wrap(${namedtuple_typeref}dispatch_${name}(${lambda_args})${set_requires_grad}); """) # void return variant PY_VARIABLE_RETURN_VOID = CodeTemplate("""\ ${inits} auto dispatch_${name} = [](${lambda_formals}) -> ${simple_return_type} { ${auto_no_gil} ${dispatch_callee}(${dispatch_args}); }; dispatch_${name}(${lambda_args})${set_requires_grad}; Py_RETURN_NONE; """) def emit_single_dispatch(declaration, is_python_method, output_gap=0): """ Emit dispatch code for a single declared overload. """ deprecated = '[deprecated] ' if declaration.get('deprecated', False) else '' schema_comment = '// ' + deprecated + declaration['schema_string'] inits = [schema_comment] pa = declaration['python_arglists'] args = pa['input_args'] + pa['input_kwargs'] + pa['output_args'] has_options = has_tensor_options(declaration) argmap = {} if is_python_method: # self is passed directly to python binding, rather than parsed argmap['self'] = {'value': 'self', 'formal': 'Tensor & self'} for i, arg in enumerate(args): unpack = is_scatter(arg) or (has_options and is_tensor_self(arg)) arg_expr, unpack_stmts = parse_arg(arg, i, unpack_to_local=unpack) inits.extend(unpack_stmts) if is_scatter(arg): for j, elem in enumerate(arg['scatter_args']): argmap[elem['name']] = { 'value': '{}[{}]'.format(arg_expr, j), 'formal': get_cpp_formal(elem, ensure_temp_safe=False), } else: argmap[arg['name']] = {'value': arg_expr, 'formal': get_cpp_formal(arg)} # synthetic python binding args deliver op args binding_argmap, binding_inits, set_requires_grad = \ handle_python_binding_args(declaration, output_gap) argmap.update(binding_argmap) inits.extend(binding_inits) lambda_formals = [argmap[arg['name']]['formal'] for arg in declaration['arguments']] lambda_args = [argmap[arg['name']]['value'] for arg in declaration['arguments']] dispatch_callee = get_dispatch_callee(declaration) dispatch_args = get_op_args(declaration, {name: name for name, _ in argmap.items()}) auto_no_gil = [] if declaration['with_gil'] else ['pybind11::gil_scoped_release no_gil;'] simple_return_type = get_simple_return_type(declaration) if simple_return_type == 'void': template = PY_VARIABLE_RETURN_VOID else: template = PY_VARIABLE_WRAP return template.substitute( name=declaration['name'], inits=inits, lambda_formals=lambda_formals, lambda_args=lambda_args, dispatch_callee=dispatch_callee, dispatch_args=dispatch_args, auto_no_gil=auto_no_gil, set_requires_grad=set_requires_grad, simple_return_type=simple_return_type, namedtuple_typeref=declaration['namedtuple_typeref'], ) # arg['name'] to arg['simple_type'] for scattered tensor options fields TENSOR_OPTIONS_FIELDS = { 'dtype': 'ScalarType', 'device': 'Device', 'layout': 'Layout', 'pin_memory': 'bool', 'requires_grad': 'bool', } def handle_python_binding_args(declaration, output_gap): # map synthetic python binding args to op args and misc other stuff # note: this logic shares arcane knowledge with make_python_binding_args # and isn't completely airtight w.r.t. the possible contents of # python_binding_args. TODO argmap = {} inits = [] set_requires_grad = '' pa = declaration['python_arglists'] python_binding_args = pa['python_binding_args'] if len(python_binding_args) == 0: # nothing to see here return argmap, inits, set_requires_grad args = pa['input_args'] + pa['input_kwargs'] + pa['output_args'] binding_arg_base = len(args) + output_gap binding_arg_offsets = {arg['name']: i for i, arg in enumerate(python_binding_args)} def binding_arg_index(name): return binding_arg_base + binding_arg_offsets[name] def parse_binding_arg(name): binding_arg = python_binding_args[binding_arg_offsets[name]] expr, _ = parse_arg(binding_arg, binding_arg_index(name)) return expr has_output = len(pa['output_args']) == 1 tensor_options_arg = get_tensor_options(declaration) if tensor_options_arg is not None: # if our op has a tensor options arg, these are its scattered fields. # first some checks if has_output: raise RuntimeError('{}: tensor options with output arg'.format(declaration['name'])) for arg in python_binding_args: typename = TENSOR_OPTIONS_FIELDS.get(arg['name']) if typename is None: raise RuntimeError( '{}: unrecognized tensor options field \'{}\' in python binding arguments'. format(declaration['name'], arg['name'])) if typename != arg['simple_type']: raise RuntimeError( '{}: unrecognized type \'{}\' for tensor options field \'{}\' in python binding arguments'. format(declaration['name'], arg['type'], arg['name'])) python_binding_argnames = [arg['name'] for arg in python_binding_args] if not all([key in python_binding_argnames for key in TENSOR_OPTIONS_FIELDS.keys()]): raise RuntimeError( '{}: incomplete tensor options args: {}'. format(declaration['name'], [arg['name'] for arg in python_binding_args])) # generate a gathering initialization of options struct argname = tensor_options_arg['name'] inits.append(TENSOR_OPTIONS_DECL.substitute({ 'name': argname, 'dtype': parse_binding_arg('dtype'), 'layout': parse_binding_arg('layout'), 'device': parse_binding_arg('device'), 'requires_grad': parse_binding_arg('requires_grad'), 'pin_memory': parse_binding_arg('pin_memory'), })) inits.append('torch::utils::maybe_initialize_cuda({});'.format(argname)) # and add to op arg map argmap['options'] = { 'value': argname, 'formal': get_cpp_formal(tensor_options_arg), } else: # not the scattered fields of a tensor options - sort of a grab bag if 'dtype' in binding_arg_offsets: # we're an output-arg variant, check these args against output tensor if not has_output: raise RuntimeError( '{}: dtype in python_binding_args without output arg'. format(declaration['name'])) if not all([name in binding_arg_offsets for name in ['layout', 'device']]): raise RuntimeError( '{}: incomplete tensor options for output check'. format(declaration['name'])) check_type = PY_VARIABLE_CHECK_OUT_TYPE_HACK.substitute( out_idx=get_python_output_index(declaration), type_idx=binding_arg_index('dtype'), layout_idx=binding_arg_index('layout'), device_idx=binding_arg_index('device'), ) inits.append(check_type) # we'll set requires_grad on outgoing tensor if 'requires_grad' not in binding_arg_offsets: raise RuntimeError( '{}: expected "requires_grad" in python_binding_args absent tensor options arg but found [{}]'. format(declaration['name'], [arg['name'] for arg in python_binding_args])) requires_grad = parse_binding_arg('requires_grad') set_requires_grad = '.set_requires_grad({})'.format(requires_grad) return argmap, inits, set_requires_grad # handler for output/no-output overload pair # (plugged into PY_VARIABLE_CASE as ${call_dispatch}) PY_VARIABLE_OUT = CodeTemplate("""\ if (_r.isNone(${out_idx})) { ${call_dispatch} } else { ${call_dispatch_out} } """) # handler for a single parsed signature - may be a single overload or # a pair of overloads that whose signatures only differ in output params PY_VARIABLE_CASE = CodeTemplate("""\ case ${i}: { ${body} } """) def emit_dispatch_case(i, dictionary, is_python_method): """ Emit dispatch code for a single parsed signature. This corresponds to either a single overload, or a pair that differ only in output params. In the latter case, a single signature is used for both and dispatching switches on the presence/absence of passed output args. - i: this signature's position in generated binding's signature list if number of signatures > 1, otherwise None - dictionary: contains a no-output overload declaration under 'base', and optionally a second overload with outputs under 'out' - true if we're generating a python method, in which case self is not parsed but passed directly """ base_decl = dictionary['base'] if 'out' in dictionary: # dispatch to output or no-output variant based on arg test out_decl = dictionary['out'] out_idx = get_python_output_index(out_decl) output_gap = get_python_argc(out_decl) - get_python_argc(base_decl) call_dispatch = emit_single_dispatch(base_decl, is_python_method, output_gap) call_dispatch_out = emit_single_dispatch(out_decl, is_python_method) # dispatch output and no-output variants, branch on _r.isNone() body = PY_VARIABLE_OUT.substitute( out_idx=out_idx, call_dispatch=call_dispatch, call_dispatch_out=call_dispatch_out, ) else: # no-output version only body = emit_single_dispatch(base_decl, is_python_method) if i is not None: # generate case for ith overload return PY_VARIABLE_CASE.substitute(i=i, body=body) else: # only one overload, omit case wrapper return body # # named tuple codegen # def namedtuple_fieldnames(declaration): returns = declaration['returns'] if len(returns) <= 1 or all(['field_name' not in x for x in returns]): return [] else: def get_field_name(x): # See Note [field_name versus name] if 'field_name' not in x: # When building on Windows, `PyStructSequence_UnnamedField` could not be # resolved by the linker for some reason, which cause error in building: # # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol # PyStructSequence_UnnamedField # # Thus, at this point in time, we do not support unnamed # fields in namedtuple; you must either name all fields, # or none of them. raise ValueError("Unnamed field is not supported by codegen") else: return x['field_name'] return [get_field_name(x) for x in returns] PY_NAMEDTUPLE_FIELDSDEF = CodeTemplate("""\ static PyStructSequence_Field ${fieldsname}[] = { ${fields,} {nullptr} }; """) PY_NAMEDTUPLE_TYPEDEF = CodeTemplate("""\ static PyTypeObject ${typename}; static bool ${typename}_initialized = false; if (!${typename}_initialized) { ${typename}_initialized = true; static PyStructSequence_Desc desc = { "torch.return_types.${name}", nullptr, ${fieldsname}, ${size} }; PyStructSequence_InitType(&${typename}, &desc); ${typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; } """) def emit_namedtuple_typedefs(declarations): """ Generate block of named tuple type def inits, and add typeref snippets to declarations that use them """ flddefnames = {} # map from unique field name lists to field def name flddefs = [] # field def declarations typenames = {} # map from unique name + field name lists to typedef name typedefs = [] # typedef declarations and init code for decl in declarations: fieldnames = namedtuple_fieldnames(decl) if fieldnames == []: decl['namedtuple_typeref'] = '' continue fn_key = '_'.join(fieldnames) fieldsname = flddefnames.get(fn_key) if fieldsname is None: fieldsname = 'NamedTuple_fields{}'.format('' if flddefs == [] else len(fielddefs)) fields = ['{{"{}", ""}}'.format(fn) for fn in fieldnames] fieldsdef = PY_NAMEDTUPLE_FIELDSDEF.substitute( fieldsname=fieldsname, fields=fields ) flddefnames[fn_key] = fieldsname flddefs.append(fieldsdef) name = decl['name'] key = '{}_{}'.format(name, '_'.join(fieldnames)) typename = typenames.get(key) if typename is None: typename = 'NamedTuple{}'.format('' if typedefs == [] else len(typedefs)) typedef = PY_NAMEDTUPLE_TYPEDEF.substitute( name=name, typename=typename, size=len(fieldnames), fieldsname=fieldsname ) typenames[key] = typename typedefs.append(typedef) decl['namedtuple_typeref'] = '&{}, '.format(typename) return flddefs + typedefs # # method impl codegen # def get_pycname(name): return 'THPVariable_{}'.format(name) def is_noarg_binding(overloads): return len(overloads) == 1 and get_python_argc(overloads[0]) == 0 # python binding for all overloads of a particular function/method PY_VARIABLE_METHOD_VARARGS = CodeTemplate(r"""\ // ${name} static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) { ${method_header} static PythonArgParser parser({ ${signatures} }, /*traceable=*/${traceable}); ParsedArgs<${max_args}> parsed_args; auto _r = parser.parse(${self_}, args, kwargs, parsed_args); ${check_has_torch_function} switch (_r.idx) { ${dispatch} } ${method_footer} } """) # python binding for single-overload function/method PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate("""\ // ${name} static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) { ${method_header} static PythonArgParser parser({ ${signatures} }, /*traceable=*/${traceable}); ParsedArgs<${max_args}> parsed_args; auto _r = parser.parse(${self_}, args, kwargs, parsed_args); ${check_has_torch_function} ${dispatch} ${method_footer} } """) # python binding for a method with no args, shortcuts parsing PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\ // ${name} static PyObject * ${pycname}(PyObject* self_, PyObject* args) { ${method_header} ${check_has_torch_function} ${dispatch} ${method_footer} } """) TORCH_FUNCTION_CHECK = CodeTemplate("""\ if(_r.has_torch_function()) { return handle_torch_function(_r, ${self_}, args, kwargs, ${namespace}, ${modulename}); } """) TORCH_FUNCTION_CHECK_NOARGS = CodeTemplate("""\ if(check_has_torch_function(self_)) { return handle_torch_function(self_, ${name}); } """) # NOTE: we type the unpacked self as Tensor not Variable to avoid return type # discrepancies on method resolution (e.g. Variable::detach_ returns void # rather than Tensor &) UNPACK_SELF = "Tensor& self = reinterpret_cast(self_)->cdata;" def method_impl(name, declarations, is_python_method, module): """ Generate a python binding for all overloads of an op. """ for declaration in declarations: # formals for python binding signature declaration['python_arglists'] = make_python_arglists(declaration, is_python_method) pycname = get_pycname(name) method_header = ['HANDLE_TH_ERRORS'] method_header += emit_namedtuple_typedefs(declarations) method_header += [UNPACK_SELF] if is_python_method else [] method_footer = ['END_HANDLE_TH_ERRORS'] check_has_torch_function = TORCH_FUNCTION_CHECK_NOARGS.substitute( name='"' + name + '"', ) if is_python_method else '' # emit dispatch if is_noarg_binding(declarations): dispatch = emit_single_dispatch(declaration, is_python_method) return PY_VARIABLE_METHOD_NOARGS.substitute( name=name, pycname=pycname, method_header=method_header, dispatch=dispatch, method_footer=method_footer, check_has_torch_function=check_has_torch_function, ) method_footer = ['Py_RETURN_NONE;'] + method_footer grouped = group_overloads(declarations, is_python_method) is_singleton = len(grouped) == 1 signatures = [] dispatch = [] for i, dictionary in enumerate(grouped): signature = dictionary['signature'] signatures.append('"{}",'.format(signature)) overload_index = i if not is_singleton else None dispatch.append(emit_dispatch_case(overload_index, dictionary, is_python_method)) if is_singleton: template = PY_VARIABLE_METHOD_VARARGS_SINGLETON else: template = PY_VARIABLE_METHOD_VARARGS if module: check_has_torch_function = TORCH_FUNCTION_CHECK.substitute( namespace=NATIVE_NAMESPACE_MAPPING[module], modulename='"' + module + '"', self_="self_" if is_python_method else "nullptr", ) else: check_has_torch_function = TORCH_FUNCTION_CHECK.substitute( namespace="THPVariableClass", modulename='"torch.Tensor"', self_="self_" if is_python_method else "nullptr", ) max_args = max([get_python_argc(decl) for decl in declarations]) traceable = 'true' if all(should_trace(d) for d in declarations) else 'false' return template.substitute( name=name, pycname=pycname, method_header=method_header, max_args=max_args, signatures=signatures, traceable=traceable, check_has_torch_function=check_has_torch_function, dispatch=dispatch, method_footer=method_footer, self_="self_" if is_python_method else "nullptr", ) # # forward declarations # PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION = CodeTemplate("""\ static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs); """) PY_VARIABLE_FUNCTION_NOARGS_FORWARD_DECLARATION = CodeTemplate("""\ static PyObject * ${pycname}(PyObject* self_, PyObject* args); """) def forward_decls(name, declarations, is_python_method, module): if is_python_method: return [] if is_noarg_binding(declarations): template = PY_VARIABLE_FUNCTION_NOARGS_FORWARD_DECLARATION else: template = PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION pycname = get_pycname(name) return [template.substitute(pycname=pycname)] # # method def (binding table entry) codegen # # Python binary operator dunder methods BINARY_OP_NAMES = [ '__lt__', '__le__', '__gt__', '__ge__', '__eq__', '__ne__', '__add__', '__radd__', '__iadd__', '__sub__', '__rsub__', '__isub__', '__mul__', '__rmul__', '__imul__', '__matmul__', '__rmatmul__', '__imatmul__', '__truediv__', '__rtruediv__', '__itruediv__', '__floordiv__', '__rfloordiv__', '__ifloordiv__', '__mod__', '__rmod__', '__imod__', '__divmod__', '__rdivmod__', '__idivmod__', '__pow__', '__rpow__', '__ipow__', '__lshift__', '__rlshift__', '__ilshift__', '__rshift__', '__rrshift__', '__irshift__', '__and__', '__rand__', '__iand__', '__xor__', '__rxor__', '__ixor__', '__or__', '__ror__', '__ior__', ] # PyMethodDef entry for binary op, throws not implemented error PY_VARIABLE_METHOD_BINOP_DEF = CodeTemplate("""\ {"${name}", (PyCFunction)${pycfunc_voidcast}TypeError_to_NotImplemented_<${pycname}>, ${flags}, NULL},""") # PyMethodDef entry PY_VARIABLE_METHOD_DEF = CodeTemplate("""\ {"${name}", (PyCFunction)${pycfunc_voidcast}${pycname}, ${flags}, NULL},""") def method_def(name, declarations, is_python_method, module): """ Generate method def entry. """ pycname = get_pycname(name) if is_noarg_binding(declarations): pycfunc_voidcast = '' flags = 'METH_NOARGS' if is_python_method else 'METH_VARARGS | METH_KEYWORDS' else: pycfunc_voidcast = '(void(*)(void))' flags = 'METH_VARARGS | METH_KEYWORDS' if module == "torch": flags += ' | METH_STATIC' if name in BINARY_OP_NAMES: def_template = PY_VARIABLE_METHOD_BINOP_DEF else: def_template = PY_VARIABLE_METHOD_DEF return def_template.substitute( name=name, pycname=pycname, pycfunc_voidcast=pycfunc_voidcast, flags=flags, ) # # overload sorting and grouping # def group_overloads(declarations, is_python_method): """Returns a list of dictionaries containing the optional keys: "base": the regular ATen declaration (e.g. conv2d) "out": the out variant (e.g. conv2d_out) "signature": the signature used for Python argument parsing Note that we merge pairs of declarations with signatures that are equivalent mod output arguments, and use a single entry in the python_arg_parser sig list for both (output arguments become optional) """ grouped = defaultdict(dict) # first group by signature ignoring out arguments for declaration in declarations: signature = get_python_signature(declaration, is_python_method, skip_outputs=True) v = grouped[signature] if declaration['name'].endswith('_out'): v['out'] = declaration # prefer the signature with optional out=... arguments v['signature'] = get_python_signature(declaration, is_python_method) else: v['base'] = declaration if 'signature' not in v: v['signature'] = signature result = [] for x, dictionary in sorted(grouped.items()): if 'base' not in dictionary: raise RuntimeError( "'base' not in dictionary for {}. keys are {}".format( x, list(dictionary.keys()))) result.append(dictionary) return sort_declarations(result) # This function declares a partial order on declarations, and sorts them according # to its linear extension. This is necessary, because there's some ambiguity in the # choice of overload, and we want a different order. # # See Note[Order of overloads matters] def sort_declarations(grouped_decls): def dynamic_type(arg): return arg['dynamic_type'] def is_coord_smaller(arg1, arg2): return dynamic_type(arg1) == 'Scalar' and arg2['dynamic_type'] == 'Tensor' def is_smaller(d1, d2): """Returns True if d1 < d2 in the partial order.""" args1, args2 = d1['base']['arguments'], d2['base']['arguments'] if len(args1) != len(args2): return False any_smaller = any(is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2)) all_smaller_or_equal = all(dynamic_type(arg1) == dynamic_type(arg2) or is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2)) return any_smaller and all_smaller_or_equal # Construct the relation graph larger_than = defaultdict(set) for i1, decl1 in enumerate(grouped_decls): for i2, decl2 in enumerate(grouped_decls): if is_smaller(decl1, decl2): larger_than[i1].add(i2) if not larger_than: return grouped_decls # Use a topological sort to sort decls according to the partial order. sorted_deps = [(i, decl) for i, decl in enumerate(grouped_decls) if i not in larger_than] for i, decl in sorted_deps: for i2 in sorted(larger_than.keys()): larger = larger_than[i2] larger.discard(i) if not larger: del larger_than[i2] sorted_deps.append((i2, grouped_decls[i2])) return [decl for i, decl in sorted_deps] # # python signature codegen # SCHEMA_DEFAULT_CONVERSION_HACKS = { 'nullptr': 'None', 'c10::nullopt': 'None', '{}': 'None', } def get_schema_formal(arg, is_python_method): name = arg['name'] typename = arg['simple_type'] # TODO: remove this and make optional types in simple_type to be consistent across # tensor and other types after make Tensor? be optional instead of undefined if arg.get('is_nullable') and '?' not in typename: typename = '{}?'.format(typename) # s/self/input/ outside method bindings. # TODO remove this? doesn't rename in codegen, it's just for the parse string if name == 'self' and typename == 'Tensor' and not is_python_method: name = 'input' size = arg.get('size') if size is not None: if typename.endswith('?'): typename = '{}[{}]?'.format(typename[:-1], size) else: typename = '{}[{}]'.format(typename, size) # default default = arg.get('default') if default is not None: default = SCHEMA_DEFAULT_CONVERSION_HACKS.get(default, default) return '{} {}={}'.format(typename, name, default) else: return '{} {}'.format(typename, name) PYTHON_ARG_PARSER_SCHEMA = CodeTemplate("""\ ${name}(${schema_formals})${deprecated}""") def get_python_signature(declaration, is_python_method, skip_outputs=False): # Compute the Python function signature for argument parsing, # as specified in torch/csrc/utils/python_arg_parser.h. WARNING: # this is NOT the same type signature as specified by PEP 484 # as understood by mypy; our format was independently developed # and has some quirks to make it more suitable specifically # for error parsing. # # For a translation to mypy-valid type signatures, see # tools/gen_pyi.py. If you change any logic here, please # check that file too. python_args = get_python_args(declaration) if skip_outputs: python_args = [arg for arg in python_args if not is_output(arg)] schema_formals = [get_schema_formal(arg, is_python_method) for arg in python_args] positional_argc = len(declaration['python_arglists']['input_args']) if len(python_args) > positional_argc: schema_formals.insert(positional_argc, '*') # Python function signature. # This is the string that we give to FunctionParameter, which is # then parsed into the actual structure which we do parsing with. name = op_name(declaration) deprecated = '|deprecated' if declaration.get('deprecated', False) else '' return PYTHON_ARG_PARSER_SCHEMA.substitute( name=name, schema_formals=schema_formals, deprecated=deprecated, ) # # op args to python parsed args transform # def get_python_args(decl): arglists = decl['python_arglists'] return \ arglists['input_args'] + \ arglists['input_kwargs'] + \ arglists['output_args'] + \ arglists['python_binding_args'] def get_python_argc(decl): return sum([len(arglist) for arglist in decl['python_arglists'].values()]) def get_python_output_index(decl): arglists = decl['python_arglists'] return len(arglists['input_args'] + arglists['input_kwargs']) def make_python_arglists(declaration, is_python_method): # produces python-ready args converted from declaration['args'], # partitioned into sublists by category. subslists are order, so # the final python arglist can be recovered by simple flattening # (see get_python_args()) # partition args into sublists args = declaration['arguments'] input_args = [] input_kwargs = [] output_args = [] current_input_args = input_args for arg in args: if is_output(arg): output_args.append(arg) else: if arg.get('kwarg_only', False): current_input_args = input_kwargs current_input_args.append(arg) # adjustments # positional inputs: # - filter self when we're generating a method binding.else - there, it comes in as # a separate Python param, not in args array def include(arg): return not (is_tensor_self(arg) and is_python_method) input_args = [arg for arg in input_args if include(arg)] # keyword inputs: # - filter options. after loading the yaml, an upstream step has gathered dtype, # layout et al into a single tensor options arg. here we reintroduce the originals input_kwargs = [arg for arg in input_kwargs if not is_tensor_options(arg)] # outputs: # - coalesce multiple output args into a single 'out' arg w/type TensorList. # - force a default. This is so we can use this sig for both out and non-out variants num_outputs = len(output_args) if num_outputs > 1: for arg in output_args: if not arg['simple_type'] == 'Tensor': raise RuntimeError( '{}: unsupported output argument type {}'. format(declaration['name'], arg['type'])) typename = 'TensorList' output_args = [{ 'default': 'None', 'kwarg_only': True, 'name': 'out', 'output': True, 'scatter_args': output_args, 'simple_type': typename, 'size': num_outputs, 'type': typename, }] elif num_outputs == 1: output_arg = output_args[0].copy() output_arg['default'] = 'None' output_args = [output_arg] # make python binding args # these are the (re)scattered versions of the options arg omitted above. # TODO because these aren't guaranteed to be 100% faithful to the original # versions in the yaml, this recreation is a potential source of drift between # eager and JIT. Pull this logic out to a shared place. python_binding_args = make_python_binding_args(declaration) return { 'input_args': input_args, 'input_kwargs': input_kwargs, 'output_args': output_args, 'python_binding_args': python_binding_args, } # # python binding args # # TODO blowtorch def dtype_default_type_hack(name): if name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices': return 'torch.int64' else: return 'None' def make_python_binding_args(declaration): """ Given various properties of a declaration, build a set of scattered python binding args. """ name = declaration['name'] python_binding_arguments = [] has_tensor_input_arg = False has_options_arg = False for arg in declaration['arguments']: if is_output(arg): continue typename = arg['simple_type'] if typename in ['Tensor', 'TensorList']: has_tensor_input_arg = True elif typename == 'TensorOptions': has_options_arg = True if arg['name'] == 'requires_grad': raise ValueError("argument named requires_grad not supported") has_tensor_return = False for ret in declaration['returns']: if ret['dynamic_type'] in ['Tensor', 'TensorList']: # this probably won't work if one of the returns is not a tensor, but it will # produce a compile-time error that is obvious has_tensor_return = True category_override = declaration['category_override'] is_like_function = name.endswith('_like') or category_override == 'like' is_like_function_with_options = is_like_function and has_options_arg is_new_function = name.startswith('new_') or category_override == 'new' is_new_function_with_options = is_new_function and has_options_arg is_factory_function = has_tensor_return and not has_tensor_input_arg or category_override == 'factory' is_factory_or_like_or_new_function = has_tensor_return and (is_factory_function or is_like_function or is_new_function) is_like_or_new_function_with_options = is_like_function_with_options or is_new_function_with_options if is_factory_function or has_options_arg: default_type = dtype_default_type_hack(name) py_default_dtype = 'self.scalar_type()' if is_like_or_new_function_with_options else None dtype_arg = { 'default': default_type, 'dynamic_type': 'ScalarType', 'kwarg_only': True, 'name': 'dtype', 'type': 'const ScalarType &', 'simple_type': 'ScalarType', 'python_default_init': py_default_dtype, } python_binding_arguments.append(dtype_arg) if is_factory_function or is_like_or_new_function_with_options: py_default_layout = 'layout_from_backend(self.options().backend())' if is_like_or_new_function_with_options else None layout_arg = { 'default': 'torch.strided', 'dynamic_type': 'Layout', 'kwarg_only': True, 'name': 'layout', 'type': 'c10::optional', 'simple_type': 'Layout', 'python_default_init': py_default_layout, } python_binding_arguments.append(layout_arg) py_default_device = 'self.device()' if is_like_or_new_function_with_options else None device_arg = { 'default': 'None', 'dynamic_type': 'Device', 'kwarg_only': True, 'name': 'device', 'type': 'const Device &', 'simple_type': 'Device', 'python_default_init': py_default_device } python_binding_arguments.append(device_arg) pin_memory_arg = { 'default': False, 'dynamic_type': 'bool', 'kwarg_only': True, 'name': 'pin_memory', 'type': 'bool', 'simple_type': 'bool', } python_binding_arguments.append(pin_memory_arg) if is_factory_or_like_or_new_function: requires_grad_arg = { 'default': False, 'dynamic_type': 'bool', 'kwarg_only': True, 'name': 'requires_grad', 'type': 'bool', 'simple_type': 'bool', } python_binding_arguments.append(requires_grad_arg) return python_binding_arguments # # declaration derived props, utils, etc. # declarations are dicts loaded from Declarations.yaml, # passed to our codegen methods by callers in gen_autograd # def is_tensor_self(arg): return arg['name'] == 'self' and arg['simple_type'] == 'Tensor' def is_tensor_options(arg): return arg['simple_type'] == 'TensorOptions' def is_scatter(arg): return arg.get('scatter_args') is not None def is_output(arg): return arg.get('output', False) def has_outputs(declaration): return any([is_output(arg) for arg in declaration['arguments']]) def get_tensor_options(declaration): args = [arg for arg in declaration['arguments'] if is_tensor_options(arg)] if len(args) == 0: return None if len(args) != 1: raise RuntimeError( '{}: multiple tensor options arguments'. format(declaration['name'])) return args[0] def has_tensor_options(declaration): return get_tensor_options(declaration) is not None def is_torch_function(declaration): return 'namespace' in declaration['method_of'] def is_nn_module_function(declaration): return declaration.get('python_module') == 'nn' def is_fft_module_function(declaration): return declaration.get('python_module') == 'fft' def is_linalg_module_function(declaration): return declaration.get('python_module') == 'linalg' def function_namespace(declaration): # TODO look into why these can't all be 'torch' calls if has_tensor_options(declaration) or op_name(declaration).endswith('_like'): return 'torch' else: return 'at' def op_name(declaration): name = declaration['name'] if has_outputs(declaration): if not name.endswith("_out"): raise RuntimeError( '{} has output params, expecting name ending with \'_out\''. format(declaration['name'])) return name[:-4] else: if name.endswith("_out"): raise RuntimeError( '{}: name ends with \'_out\', expecting output params'. format(declaration['name'])) return name