mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36232 The purpose of this PR is to replace `at::Generator generator = nullptr` with `c10::optional<at::Generator> = c10::nullopt` all over the code * #36230 Replace std::shared_ptr with c10::intrusive_ptr in at::Generator Test Plan: Imported from OSS Differential Revision: D20943603 Pulled By: pbelevich fbshipit-source-id: 65d335990f01fcc706867d5344e73793fad68ae6
1483 lines
51 KiB
Python
1483 lines
51 KiB
Python
# Generates Python bindings for ATen functions
|
|
#
|
|
# The bindings are generated as methods on python_variable or functions on the
|
|
# torch._C._nn object.
|
|
#
|
|
|
|
# Code tries to stick to the following rules:
|
|
#
|
|
# - templates should be colocated with the functions that use them.
|
|
# no templates are currently shared between functions, but if that
|
|
# happens, maybe put the template with the first one
|
|
#
|
|
# - don't use environment dictionaries when calling template.substitute().
|
|
# pass named arguments directly for everything, otherwise it's much too
|
|
# hard to track what's actually being used and by who
|
|
#
|
|
# - colocate any new hacks/adjustments with existing ones of the same kind.
|
|
# ideally in a data structure rather than code if possible. See e.g.
|
|
# SCHEMA_DEFAULT_CONVERSION_HACKS, etc.
|
|
#
|
|
# - similarly, conversions from one format to another should ideally happen
|
|
# all at once in a single place.
|
|
#
|
|
# - no nontrivial nested functions. couple-liners are ok but please no more.
|
|
# especially avoid functions that read/write outer variables defined far away.
|
|
#
|
|
# - raise RuntimeError instead of asserting, and put as much
|
|
# information as is available into the message. I.e. no need to
|
|
# plumb in new params whose only purpose is to fill out an error
|
|
# message, but use what's there
|
|
#
|
|
|
|
from collections import defaultdict
|
|
import re
|
|
from .gen_variable_type import should_trace
|
|
from .utils import write
|
|
|
|
try:
|
|
from src.ATen.code_template import CodeTemplate
|
|
except ImportError:
|
|
from tools.shared.module_loader import import_module
|
|
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
|
|
|
|
#
|
|
# declarations blacklist
|
|
# We skip codegen for these functions, for various reasons.
|
|
# Future PRs will categorize this list and eliminate or hoist
|
|
# them out of eager-only codegen.
|
|
# See https://github.com/pytorch/pytorch/issues/30788
|
|
#
|
|
|
|
# These functions require manual Python bindings or are not exposed to Python
|
|
SKIP_PYTHON_BINDINGS = [
|
|
'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride',
|
|
'.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
|
|
'.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
|
|
'_arange.*', '_range.*', '_linspace.*', '_logspace.*',
|
|
'_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*', '_sparse_dense_add_out',
|
|
'index', 'unique_dim_consecutive',
|
|
'_indexCopy_', 'max_values', 'min_values',
|
|
'_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
|
|
'_th_.*', '_thnn_.*',
|
|
'arange.*', 'range.*', '_solve.*', '_inverse.*',
|
|
'full(_out)?',
|
|
'_cholesky.*', '_triangular_solve.*', '_qr.*', '_symeig.*', '_svd.*',
|
|
'slice', 'randint(_out)?',
|
|
'item', '_local_scalar_dense', 'to',
|
|
'copy_sparse_to_sparse_', 'copy_',
|
|
'numpy_T', # this needs to be an attribute in Python, not a function
|
|
'nonzero(_(out|numpy))?',
|
|
'set_quantizer_', # return types not supported yet
|
|
'set_data',
|
|
'.*_overrideable', # overrideable functions for backend extension
|
|
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retain_grad'
|
|
]
|
|
|
|
# These function signatures are not exposed to Python. Note that this signature
|
|
# list does not support regex.
|
|
SKIP_PYTHON_BINDINGS_SIGNATURES = [
|
|
'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)',
|
|
'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)',
|
|
'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)',
|
|
'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
|
|
]
|
|
|
|
NATIVE_NAMESPACE_MAPPING = {
|
|
"torch": "THPVariableFunctionsModule",
|
|
"torch.nn": "THPNNVariableFunctionsModule"
|
|
}
|
|
|
|
def should_generate_python_binding(declaration):
|
|
name = declaration['name']
|
|
for pattern in SKIP_PYTHON_BINDINGS:
|
|
if re.match('^' + pattern + '$', name):
|
|
return False
|
|
|
|
simple_types = [arg['simple_type'] for arg in declaration['arguments']]
|
|
signature = '{}({})'.format(name, ', '.join(simple_types))
|
|
for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
|
|
if pattern == signature:
|
|
return False
|
|
|
|
return True
|
|
|
|
#
|
|
# top-level codegen functions, called from gen_autograd
|
|
#
|
|
|
|
def get_py_variable_methods(declarations):
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as methods on Tensor.
|
|
"""
|
|
def should_bind(declaration):
|
|
return (should_generate_python_binding(declaration) and
|
|
not is_nn_module_function(declaration) and
|
|
is_tensor_method(declaration))
|
|
|
|
return group_declarations_by_op_name([d for d in declarations if should_bind(d)])
|
|
|
|
|
|
def gen_py_variable_methods(out, declarations, template_path):
|
|
"""
|
|
Generate Tensor methods.
|
|
"""
|
|
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
|
|
|
|
py_variable_methods = get_py_variable_methods(declarations)
|
|
|
|
env = create_python_bindings(py_variable_methods, is_python_method=True, module=None)
|
|
|
|
write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
|
|
|
|
|
|
def get_py_nn_functions(declarations):
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as functions in the "nn" module.
|
|
"""
|
|
def should_bind(declaration):
|
|
return (should_generate_python_binding(declaration) and
|
|
is_nn_module_function(declaration))
|
|
|
|
return group_declarations_by_op_name([d for d in declarations if should_bind(d)])
|
|
|
|
|
|
def gen_py_nn_functions(out, declarations, template_path):
|
|
"""
|
|
Generate functions in the "nn" module.
|
|
"""
|
|
PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
|
|
|
|
py_nn_functions = get_py_nn_functions(declarations)
|
|
|
|
env = create_python_bindings(py_nn_functions, is_python_method=False, module="torch.nn")
|
|
|
|
write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
|
|
|
|
|
|
def get_py_torch_functions(declarations):
|
|
"""
|
|
Get declarations (grouped by name) which should be generated
|
|
as functions in the "torch" module.
|
|
"""
|
|
def should_bind(declaration):
|
|
return (should_generate_python_binding(declaration) and
|
|
not is_nn_module_function(declaration) and
|
|
is_torch_function(declaration))
|
|
|
|
return group_declarations_by_op_name([d for d in declarations if should_bind(d)])
|
|
|
|
|
|
def gen_py_torch_functions(out, declarations, template_path):
|
|
"""
|
|
Generate functions in the "torch" module.
|
|
"""
|
|
PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
|
|
|
|
py_torch_functions = get_py_torch_functions(declarations)
|
|
|
|
env = create_python_bindings(py_torch_functions, is_python_method=False, module="torch")
|
|
|
|
write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
|
|
|
|
|
|
def group_declarations_by_op_name(declarations):
|
|
groups = defaultdict(list)
|
|
for d in declarations:
|
|
groups[op_name(d)].append(d)
|
|
return groups
|
|
|
|
|
|
def create_python_bindings(python_functions, is_python_method, module):
|
|
"""Generates Python bindings to ATen functions"""
|
|
py_methods = []
|
|
py_method_defs = []
|
|
py_forwards = []
|
|
|
|
for name in sorted(python_functions.keys()):
|
|
overload_decls = python_functions[name]
|
|
py_methods.append(method_impl(name, overload_decls, is_python_method, module))
|
|
py_method_defs.append(method_def(name, overload_decls, is_python_method, module))
|
|
py_forwards.extend(forward_decls(name, overload_decls, is_python_method, module))
|
|
|
|
return {
|
|
'py_forwards': py_forwards,
|
|
'py_methods': py_methods,
|
|
'py_method_defs': py_method_defs,
|
|
}
|
|
|
|
|
|
#
|
|
# extracting and storing parsed args
|
|
#
|
|
|
|
UNPACK_METHODS = {
|
|
'const Tensor &': 'tensor',
|
|
'Tensor &': 'tensor',
|
|
'c10::optional<Generator>': 'generator',
|
|
'Storage': 'storage',
|
|
'Storage &': 'storage',
|
|
'const ScalarType &': 'scalartype',
|
|
'const Device &': 'device',
|
|
'c10::optional<DimnameList>': 'toDimnameListOptional',
|
|
'c10::optional<ScalarType>': 'scalartypeOptional',
|
|
'c10::optional<Layout>': 'layoutOptional',
|
|
'c10::optional<MemoryFormat>': 'memoryformatOptional',
|
|
'c10::optional<Scalar>': 'scalarOptional',
|
|
'c10::optional<int64_t>': 'toInt64Optional',
|
|
'c10::optional<bool>': 'toBoolOptional',
|
|
'c10::optional<double>': 'toDoubleOptional',
|
|
'IntArrayRef': 'intlist',
|
|
'Scalar': 'scalar',
|
|
'ScalarType': 'scalartype',
|
|
'Dimname': 'dimname',
|
|
'DimnameList': 'dimnamelist',
|
|
'TensorList': 'tensorlist',
|
|
'int64_t': 'toInt64',
|
|
'bool': 'toBool',
|
|
'double': 'toDouble',
|
|
'std::string': 'string',
|
|
}
|
|
|
|
UNPACK_WITH_SIZE_METHODS = {
|
|
'TensorList': 'tensorlist_n<{}>',
|
|
'DimnameList': 'dimnamelist',
|
|
'IntArrayRef': 'intlist',
|
|
}
|
|
|
|
UNPACK_WITH_DEFAULT_METHODS = {
|
|
'const ScalarType &': 'scalartypeWithDefault',
|
|
'const Device &': 'deviceWithDefault',
|
|
'c10::optional<Layout>': 'layoutWithDefault',
|
|
}
|
|
|
|
def parsed_arg_expr(arg, arg_index):
|
|
# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
|
|
typename = arg['type']
|
|
|
|
default_init = arg.get('python_default_init')
|
|
if default_init is not None:
|
|
# Note: only introduced by make_python_binding_args
|
|
default_init = arg['python_default_init']
|
|
if typename not in UNPACK_WITH_DEFAULT_METHODS:
|
|
raise RuntimeError(
|
|
'type \'{}\' is not supported in python_default_init'.
|
|
format(typename))
|
|
unpack_with_default = UNPACK_WITH_DEFAULT_METHODS[typename]
|
|
return '_r.{}({}, {})'.format(unpack_with_default, arg_index, default_init)
|
|
|
|
size = arg.get('size')
|
|
if size is not None:
|
|
if typename not in UNPACK_WITH_SIZE_METHODS:
|
|
raise RuntimeError(
|
|
'type \'{}\' with definite size ({}) is not supported'.
|
|
format(typename, size))
|
|
unpack_with_size = UNPACK_WITH_SIZE_METHODS[typename].format(size)
|
|
return '_r.{}({})'.format(unpack_with_size, arg_index)
|
|
|
|
unpack = UNPACK_METHODS.get(typename)
|
|
if unpack is None:
|
|
raise RuntimeError('type \'{}\' is not supported'.format(typename))
|
|
|
|
return '_r.{}({})'.format(unpack, arg_index)
|
|
|
|
|
|
# TODO make this part of something more general, or get rid of it
|
|
def unpack_optional_dimname_list_hack(name, expr):
|
|
# optional<ArrayRef<T>> are special. The PythonArgParser returns an
|
|
# optional<vector<T>>, which cannot be implicitly converted to
|
|
# optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
|
|
result = """\
|
|
auto __{name} = {expr};
|
|
c10::optional<{typ}> {name} = __{name} ? c10::make_optional({typ}(__{name}.value())) : c10::nullopt;
|
|
""".format(name=name, expr=expr, typ='DimnameList')
|
|
return [line.strip() for line in result.split('\n')]
|
|
|
|
|
|
def parse_arg(arg, arg_index, unpack_to_local=False):
|
|
# get parsed rhs
|
|
expr = parsed_arg_expr(arg, arg_index)
|
|
|
|
# maybe unpack to local
|
|
name = arg['name']
|
|
typename = arg['type']
|
|
if typename == 'c10::optional<DimnameList>':
|
|
inits = unpack_optional_dimname_list_hack(name, expr)
|
|
expr = name
|
|
elif unpack_to_local:
|
|
inits = ['auto {} = {};'.format(name, expr)]
|
|
expr = name
|
|
else:
|
|
inits = []
|
|
|
|
return expr, inits
|
|
|
|
|
|
#
|
|
# schema type to cpp type conversions
|
|
# some of these are to prevent dangling refs to temps, others are more obscure
|
|
# TODO don't know if these fold into more general conversions somehere, hope so
|
|
#
|
|
|
|
TEMP_SAFE_CPP_DECL_TYPE = {
|
|
'Tensor &': 'Tensor',
|
|
}
|
|
|
|
def get_cpp_decl_type(typename, ensure_temp_safe=True):
|
|
if ensure_temp_safe:
|
|
typename = TEMP_SAFE_CPP_DECL_TYPE.get(typename, typename)
|
|
return typename
|
|
|
|
|
|
def get_cpp_formal(arg, ensure_temp_safe=True):
|
|
decl_type = get_cpp_decl_type(arg['type'], ensure_temp_safe)
|
|
return '{} {}'.format(decl_type, arg['name'])
|
|
|
|
|
|
# XXX: if you got here because of an assertion failure, it doesn't mean
|
|
# it's enough to just extend the list here. Before you do this, make sure
|
|
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
|
|
SUPPORTED_RETURN_TYPES = {
|
|
'Tensor',
|
|
'std::tuple<Tensor,Tensor>',
|
|
'std::tuple<Tensor,Tensor,Tensor>',
|
|
'std::tuple<Tensor,Tensor,Tensor,Tensor>',
|
|
'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
|
|
'std::tuple<Tensor,Tensor,Tensor,int64_t>',
|
|
'std::tuple<Tensor,Tensor,double,int64_t>',
|
|
'std::tuple<Tensor,Tensor,Tensor,Tensor,int64_t>',
|
|
'std::tuple<Tensor,Tensor,double,Tensor,int64_t>',
|
|
'std::tuple<double,int64_t>',
|
|
'std::vector<Tensor>',
|
|
'Scalar', 'bool', 'int64_t', 'void*', 'void',
|
|
'QScheme', 'double',
|
|
'IntArrayRef',
|
|
'ScalarType'
|
|
}
|
|
|
|
def get_simple_return_type(declaration):
|
|
# Use the simple_return_type (Tensor) rather than the fancy return type
|
|
# (Tensor &). This is important because the dispatch lambdas take
|
|
# mutable arguments *by value*, not by reference. If you then return
|
|
# a reference to such an argument, you will now have a pointer to a
|
|
# dangling stack entry. Not good.
|
|
#
|
|
# You want:
|
|
#
|
|
# auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
|
|
# ^^^^^^
|
|
#
|
|
# *not*
|
|
#
|
|
# auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
|
|
# ^^^^^^^
|
|
#
|
|
# (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
|
|
# codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
|
|
# mutable reference to temporary. Maybe we could assign it to a
|
|
# variable itself.)
|
|
#
|
|
simple_return_type = declaration['return_type'].replace(' &', '')
|
|
if simple_return_type not in SUPPORTED_RETURN_TYPES:
|
|
raise RuntimeError(declaration['name'] + " returns unsupported type " + simple_return_type)
|
|
return simple_return_type
|
|
|
|
#
|
|
# dispatch codegen
|
|
#
|
|
|
|
def get_dispatch_callee(declaration):
|
|
# format the name of the receiving function or method
|
|
if is_tensor_method(declaration):
|
|
return 'self.{}'.format(declaration['name'])
|
|
elif is_torch_function(declaration):
|
|
namespace = function_namespace(declaration)
|
|
return '{}::{}'.format(namespace, declaration['name'])
|
|
else:
|
|
raise RuntimeError('could not dispatch, neither namespace function nor Tensor method')
|
|
|
|
|
|
def get_op_args(declaration, argmap):
|
|
# returns a list of argmap values in op call order, with two wrinkles:
|
|
# 1. 'self' is eliminated for methods, it's baked into the callee expression elsewhere
|
|
# 2. declaration['call_args'] shims legacy overrides and may contain constant values,
|
|
# not just names (see load_deprecated_signatures() in gen_autograd.py)
|
|
call_args_override = declaration.get('call_args')
|
|
if call_args_override:
|
|
# names or constants
|
|
keys = call_args_override
|
|
else:
|
|
# only names
|
|
keys = [param['name'] for param in declaration['arguments']]
|
|
|
|
if is_tensor_method(declaration):
|
|
# exclude self for method calls
|
|
keys = [k for k in keys if k != 'self']
|
|
|
|
if call_args_override:
|
|
# assume missing keys are constants
|
|
return [argmap.get(k, k) for k in keys]
|
|
else:
|
|
return [argmap[k] for k in keys]
|
|
|
|
|
|
TENSOR_OPTIONS_DECL = CodeTemplate("""\
|
|
const auto ${name} = TensorOptions()
|
|
.dtype(${dtype})
|
|
.device(${device})
|
|
.layout(${layout})
|
|
.requires_grad(${requires_grad})
|
|
.pinned_memory(${pin_memory});
|
|
""")
|
|
|
|
# addition to output-variant handler in which tensor options params
|
|
# (if present) are checked against properties of a tensor output param
|
|
# TODO remove hardcoding, use unpack logic from emit_single_dispatch
|
|
PY_VARIABLE_CHECK_OUT_TYPE_HACK = CodeTemplate("""\
|
|
check_out_type_matches(_r.tensor(${out_idx}), _r.scalartype(${type_idx}),
|
|
_r.isNone(${type_idx}), _r.layoutOptional(${layout_idx}),
|
|
_r.device(${device_idx}), _r.isNone(${device_idx}));
|
|
""")
|
|
|
|
# Unpack parsed args to locals, call the op, and wrap the result.
|
|
# Lambda is so GIL is back on by wrap() time (wrap can allocate)
|
|
PY_VARIABLE_WRAP = CodeTemplate("""\
|
|
${inits}
|
|
auto dispatch_${name} = [](${lambda_formals}) -> ${simple_return_type} {
|
|
${auto_no_gil}
|
|
return ${dispatch_callee}(${dispatch_args});
|
|
};
|
|
return wrap(${namedtuple_typeref}dispatch_${name}(${lambda_args})${set_requires_grad});
|
|
""")
|
|
|
|
# void return variant
|
|
PY_VARIABLE_RETURN_VOID = CodeTemplate("""\
|
|
${inits}
|
|
auto dispatch_${name} = [](${lambda_formals}) -> ${simple_return_type} {
|
|
${auto_no_gil}
|
|
${dispatch_callee}(${dispatch_args});
|
|
};
|
|
dispatch_${name}(${lambda_args})${set_requires_grad};
|
|
Py_RETURN_NONE;
|
|
""")
|
|
|
|
|
|
def emit_single_dispatch(declaration, is_python_method, output_gap=0):
|
|
"""
|
|
Emit dispatch code for a single declared overload.
|
|
"""
|
|
deprecated = '[deprecated] ' if declaration.get('deprecated', False) else ''
|
|
schema_comment = '// ' + deprecated + declaration['schema_string']
|
|
inits = [schema_comment]
|
|
|
|
pa = declaration['python_arglists']
|
|
args = pa['input_args'] + pa['input_kwargs'] + pa['output_args']
|
|
has_options = has_tensor_options(declaration)
|
|
|
|
argmap = {}
|
|
|
|
if is_python_method:
|
|
# self is passed directly to python binding, rather than parsed
|
|
argmap['self'] = {'value': 'self', 'formal': 'Tensor & self'}
|
|
|
|
for i, arg in enumerate(args):
|
|
unpack = is_scatter(arg) or (has_options and is_tensor_self(arg))
|
|
arg_expr, unpack_stmts = parse_arg(arg, i, unpack_to_local=unpack)
|
|
inits.extend(unpack_stmts)
|
|
if is_scatter(arg):
|
|
for j, elem in enumerate(arg['scatter_args']):
|
|
argmap[elem['name']] = {
|
|
'value': '{}[{}]'.format(arg_expr, j),
|
|
'formal': get_cpp_formal(elem, ensure_temp_safe=False),
|
|
}
|
|
else:
|
|
argmap[arg['name']] = {'value': arg_expr, 'formal': get_cpp_formal(arg)}
|
|
|
|
# synthetic python binding args deliver op args
|
|
binding_argmap, binding_inits, set_requires_grad = \
|
|
handle_python_binding_args(declaration, output_gap)
|
|
argmap.update(binding_argmap)
|
|
inits.extend(binding_inits)
|
|
|
|
lambda_formals = [argmap[arg['name']]['formal'] for arg in declaration['arguments']]
|
|
lambda_args = [argmap[arg['name']]['value'] for arg in declaration['arguments']]
|
|
|
|
dispatch_callee = get_dispatch_callee(declaration)
|
|
dispatch_args = get_op_args(declaration, {name: name for name, _ in argmap.items()})
|
|
|
|
auto_no_gil = [] if declaration['with_gil'] else ['pybind11::gil_scoped_release no_gil;']
|
|
|
|
simple_return_type = get_simple_return_type(declaration)
|
|
if simple_return_type == 'void':
|
|
template = PY_VARIABLE_RETURN_VOID
|
|
else:
|
|
template = PY_VARIABLE_WRAP
|
|
|
|
return template.substitute(
|
|
name=declaration['name'],
|
|
inits=inits,
|
|
lambda_formals=lambda_formals,
|
|
lambda_args=lambda_args,
|
|
dispatch_callee=dispatch_callee,
|
|
dispatch_args=dispatch_args,
|
|
auto_no_gil=auto_no_gil,
|
|
set_requires_grad=set_requires_grad,
|
|
simple_return_type=simple_return_type,
|
|
namedtuple_typeref=declaration['namedtuple_typeref'],
|
|
)
|
|
|
|
|
|
# arg['name'] to arg['simple_type'] for scattered tensor options fields
|
|
TENSOR_OPTIONS_FIELDS = {
|
|
'dtype': 'ScalarType',
|
|
'device': 'Device',
|
|
'layout': 'Layout',
|
|
'pin_memory': 'bool',
|
|
'requires_grad': 'bool',
|
|
}
|
|
|
|
def handle_python_binding_args(declaration, output_gap):
|
|
# map synthetic python binding args to op args and misc other stuff
|
|
# note: this logic shares arcane knowledge with make_python_binding_args
|
|
# and isn't completely airtight w.r.t. the possible contents of
|
|
# python_binding_args. TODO
|
|
|
|
argmap = {}
|
|
inits = []
|
|
set_requires_grad = ''
|
|
|
|
pa = declaration['python_arglists']
|
|
python_binding_args = pa['python_binding_args']
|
|
|
|
if len(python_binding_args) == 0:
|
|
# nothing to see here
|
|
return argmap, inits, set_requires_grad
|
|
|
|
args = pa['input_args'] + pa['input_kwargs'] + pa['output_args']
|
|
binding_arg_base = len(args) + output_gap
|
|
binding_arg_offsets = {arg['name']: i for i, arg in enumerate(python_binding_args)}
|
|
|
|
def binding_arg_index(name):
|
|
return binding_arg_base + binding_arg_offsets[name]
|
|
|
|
def parse_binding_arg(name):
|
|
binding_arg = python_binding_args[binding_arg_offsets[name]]
|
|
expr, _ = parse_arg(binding_arg, binding_arg_index(name))
|
|
return expr
|
|
|
|
has_output = len(pa['output_args']) == 1
|
|
tensor_options_arg = get_tensor_options(declaration)
|
|
|
|
if tensor_options_arg is not None:
|
|
# if our op has a tensor options arg, these are its scattered fields.
|
|
# first some checks
|
|
if has_output:
|
|
raise RuntimeError('{}: tensor options with output arg'.format(declaration['name']))
|
|
for arg in python_binding_args:
|
|
typename = TENSOR_OPTIONS_FIELDS.get(arg['name'])
|
|
if typename is None:
|
|
raise RuntimeError(
|
|
'{}: unrecognized tensor options field \'{}\' in python binding arguments'.
|
|
format(declaration['name'], arg['name']))
|
|
if typename != arg['simple_type']:
|
|
raise RuntimeError(
|
|
'{}: unrecognized type \'{}\' for tensor options field \'{}\' in python binding arguments'.
|
|
format(declaration['name'], arg['type'], arg['name']))
|
|
python_binding_argnames = [arg['name'] for arg in python_binding_args]
|
|
if not all([key in python_binding_argnames for key in TENSOR_OPTIONS_FIELDS.keys()]):
|
|
raise RuntimeError(
|
|
'{}: incomplete tensor options args: {}'.
|
|
format(declaration['name'], [arg['name'] for arg in python_binding_args]))
|
|
# generate a gathering initialization of options struct
|
|
argname = tensor_options_arg['name']
|
|
inits.append(TENSOR_OPTIONS_DECL.substitute({
|
|
'name': argname,
|
|
'dtype': parse_binding_arg('dtype'),
|
|
'layout': parse_binding_arg('layout'),
|
|
'device': parse_binding_arg('device'),
|
|
'requires_grad': parse_binding_arg('requires_grad'),
|
|
'pin_memory': parse_binding_arg('pin_memory'),
|
|
}))
|
|
inits.append('torch::utils::maybe_initialize_cuda({});'.format(argname))
|
|
# and add to op arg map
|
|
argmap['options'] = {
|
|
'value': argname,
|
|
'formal': get_cpp_formal(tensor_options_arg),
|
|
}
|
|
|
|
else:
|
|
# not the scattered fields of a tensor options - sort of a grab bag
|
|
if 'dtype' in binding_arg_offsets:
|
|
# we're an output-arg variant, check these args against output tensor
|
|
if not has_output:
|
|
raise RuntimeError(
|
|
'{}: dtype in python_binding_args without output arg'.
|
|
format(declaration['name']))
|
|
if not all([name in binding_arg_offsets for name in ['layout', 'device']]):
|
|
raise RuntimeError(
|
|
'{}: incomplete tensor options for output check'.
|
|
format(declaration['name']))
|
|
check_type = PY_VARIABLE_CHECK_OUT_TYPE_HACK.substitute(
|
|
out_idx=get_python_output_index(declaration),
|
|
type_idx=binding_arg_index('dtype'),
|
|
layout_idx=binding_arg_index('layout'),
|
|
device_idx=binding_arg_index('device'),
|
|
)
|
|
inits.append(check_type)
|
|
# we'll set requires_grad on outgoing tensor
|
|
if 'requires_grad' not in binding_arg_offsets:
|
|
raise RuntimeError(
|
|
'{}: expected "requires_grad" in python_binding_args absent tensor options arg but found [{}]'.
|
|
format(declaration['name'], [arg['name'] for arg in python_binding_args]))
|
|
requires_grad = parse_binding_arg('requires_grad')
|
|
set_requires_grad = '.set_requires_grad({})'.format(requires_grad)
|
|
|
|
return argmap, inits, set_requires_grad
|
|
|
|
|
|
# handler for output/no-output overload pair
|
|
# (plugged into PY_VARIABLE_CASE as ${call_dispatch})
|
|
PY_VARIABLE_OUT = CodeTemplate("""\
|
|
if (_r.isNone(${out_idx})) {
|
|
${call_dispatch}
|
|
} else {
|
|
${call_dispatch_out}
|
|
}
|
|
""")
|
|
|
|
# handler for a single parsed signature - may be a single overload or
|
|
# a pair of overloads that whose signatures only differ in output params
|
|
PY_VARIABLE_CASE = CodeTemplate("""\
|
|
case ${i}: {
|
|
${body}
|
|
}
|
|
""")
|
|
|
|
|
|
def emit_dispatch_case(i, dictionary, is_python_method):
|
|
"""
|
|
Emit dispatch code for a single parsed signature. This corresponds to either
|
|
a single overload, or a pair that differ only in output params. In the latter
|
|
case, a single signature is used for both and dispatching switches on the
|
|
presence/absence of passed output args.
|
|
- i: this signature's position in generated binding's signature list if number of
|
|
signatures > 1, otherwise None
|
|
- dictionary: contains a no-output overload declaration under 'base', and optionally
|
|
a second overload with outputs under 'out'
|
|
- true if we're generating a python method, in which case self is not parsed but
|
|
passed directly
|
|
"""
|
|
base_decl = dictionary['base']
|
|
|
|
if 'out' in dictionary:
|
|
# dispatch to output or no-output variant based on arg test
|
|
out_decl = dictionary['out']
|
|
out_idx = get_python_output_index(out_decl)
|
|
output_gap = get_python_argc(out_decl) - get_python_argc(base_decl)
|
|
|
|
call_dispatch = emit_single_dispatch(base_decl, is_python_method, output_gap)
|
|
call_dispatch_out = emit_single_dispatch(out_decl, is_python_method)
|
|
|
|
# dispatch output and no-output variants, branch on _r.isNone(<out_idx>)
|
|
body = PY_VARIABLE_OUT.substitute(
|
|
out_idx=out_idx,
|
|
call_dispatch=call_dispatch,
|
|
call_dispatch_out=call_dispatch_out,
|
|
)
|
|
else:
|
|
# no-output version only
|
|
body = emit_single_dispatch(base_decl, is_python_method)
|
|
|
|
if i is not None:
|
|
# generate case for ith overload
|
|
return PY_VARIABLE_CASE.substitute(i=i, body=body)
|
|
else:
|
|
# only one overload, omit case wrapper
|
|
return body
|
|
|
|
#
|
|
# named tuple codegen
|
|
#
|
|
|
|
def namedtuple_fieldnames(declaration):
|
|
returns = declaration['returns']
|
|
if len(returns) <= 1 or all(['field_name' not in x for x in returns]):
|
|
return []
|
|
else:
|
|
def get_field_name(x):
|
|
# See Note [field_name versus name]
|
|
if 'field_name' not in x:
|
|
# When building on Windows, `PyStructSequence_UnnamedField` could not be
|
|
# resolved by the linker for some reason, which cause error in building:
|
|
#
|
|
# python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
|
|
# PyStructSequence_UnnamedField
|
|
#
|
|
# Thus, at this point in time, we do not support unnamed
|
|
# fields in namedtuple; you must either name all fields,
|
|
# or none of them.
|
|
raise ValueError("Unnamed field is not supported by codegen")
|
|
else:
|
|
return x['field_name']
|
|
return [get_field_name(x) for x in returns]
|
|
|
|
PY_NAMEDTUPLE_FIELDSDEF = CodeTemplate("""\
|
|
static PyStructSequence_Field ${fieldsname}[] = { ${fields,} {nullptr} };
|
|
""")
|
|
|
|
PY_NAMEDTUPLE_TYPEDEF = CodeTemplate("""\
|
|
static PyTypeObject ${typename};
|
|
static bool ${typename}_initialized = false;
|
|
if (!${typename}_initialized) {
|
|
${typename}_initialized = true;
|
|
static PyStructSequence_Desc desc = { "torch.return_types.${name}", nullptr, ${fieldsname}, ${size} };
|
|
PyStructSequence_InitType(&${typename}, &desc);
|
|
${typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
|
|
}
|
|
""")
|
|
|
|
|
|
def emit_namedtuple_typedefs(declarations):
|
|
"""
|
|
Generate block of named tuple type def inits, and add typeref snippets
|
|
to declarations that use them
|
|
"""
|
|
flddefnames = {} # map from unique field name lists to field def name
|
|
flddefs = [] # field def declarations
|
|
typenames = {} # map from unique name + field name lists to typedef name
|
|
typedefs = [] # typedef declarations and init code
|
|
|
|
for decl in declarations:
|
|
fieldnames = namedtuple_fieldnames(decl)
|
|
if fieldnames == []:
|
|
decl['namedtuple_typeref'] = ''
|
|
continue
|
|
|
|
fn_key = '_'.join(fieldnames)
|
|
fieldsname = flddefnames.get(fn_key)
|
|
if fieldsname is None:
|
|
fieldsname = 'NamedTuple_fields{}'.format('' if flddefs == [] else len(fielddefs))
|
|
fields = ['{{"{}", ""}}'.format(fn) for fn in fieldnames]
|
|
fieldsdef = PY_NAMEDTUPLE_FIELDSDEF.substitute(
|
|
fieldsname=fieldsname,
|
|
fields=fields
|
|
)
|
|
flddefnames[fn_key] = fieldsname
|
|
flddefs.append(fieldsdef)
|
|
|
|
name = decl['name']
|
|
key = '{}_{}'.format(name, '_'.join(fieldnames))
|
|
typename = typenames.get(key)
|
|
if typename is None:
|
|
typename = 'NamedTuple{}'.format('' if typedefs == [] else len(typedefs))
|
|
typedef = PY_NAMEDTUPLE_TYPEDEF.substitute(
|
|
name=name,
|
|
typename=typename,
|
|
size=len(fieldnames),
|
|
fieldsname=fieldsname
|
|
)
|
|
typenames[key] = typename
|
|
typedefs.append(typedef)
|
|
|
|
decl['namedtuple_typeref'] = '&{}, '.format(typename)
|
|
|
|
return flddefs + typedefs
|
|
|
|
#
|
|
# method impl codegen
|
|
#
|
|
|
|
def get_pycname(name):
|
|
return 'THPVariable_{}'.format(name)
|
|
|
|
|
|
def is_noarg_binding(overloads):
|
|
return len(overloads) == 1 and get_python_argc(overloads[0]) == 0
|
|
|
|
|
|
# python binding for all overloads of a particular function/method
|
|
PY_VARIABLE_METHOD_VARARGS = CodeTemplate(r"""\
|
|
// ${name}
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
|
|
{
|
|
${method_header}
|
|
static PythonArgParser parser({
|
|
${signatures}
|
|
}, /*traceable=*/${traceable});
|
|
|
|
ParsedArgs<${max_args}> parsed_args;
|
|
auto _r = parser.parse(args, kwargs, parsed_args);
|
|
${check_has_torch_function}
|
|
switch (_r.idx) {
|
|
${dispatch}
|
|
}
|
|
${method_footer}
|
|
}
|
|
|
|
""")
|
|
|
|
# python binding for single-overload function/method
|
|
PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate("""\
|
|
// ${name}
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
|
|
{
|
|
${method_header}
|
|
static PythonArgParser parser({
|
|
${signatures}
|
|
}, /*traceable=*/${traceable});
|
|
|
|
ParsedArgs<${max_args}> parsed_args;
|
|
auto _r = parser.parse(args, kwargs, parsed_args);
|
|
${check_has_torch_function}
|
|
${dispatch}
|
|
${method_footer}
|
|
}
|
|
|
|
""")
|
|
|
|
# python binding for a method with no args, shortcuts parsing
|
|
PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
|
|
// ${name}
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
|
|
{
|
|
${method_header}
|
|
${dispatch}
|
|
${method_footer}
|
|
}
|
|
|
|
""")
|
|
|
|
TORCH_FUNCTION_CHECK = CodeTemplate("""\
|
|
if(_r.has_torch_function()) {
|
|
return handle_torch_function(_r, args, kwargs, ${namespace}, ${modulename});
|
|
}
|
|
""")
|
|
|
|
# NOTE: we type the unpacked self as Tensor not Variable to avoid return type
|
|
# discrepancies on method resolution (e.g. Variable::detach_ returns void
|
|
# rather than Tensor &)
|
|
UNPACK_SELF = "Tensor& self = reinterpret_cast<THPVariable*>(self_)->cdata;"
|
|
|
|
|
|
def method_impl(name, declarations, is_python_method, module):
|
|
"""
|
|
Generate a python binding for all overloads of an op.
|
|
"""
|
|
for declaration in declarations:
|
|
# formals for python binding signature
|
|
declaration['python_arglists'] = make_python_arglists(declaration, is_python_method)
|
|
|
|
pycname = get_pycname(name)
|
|
|
|
method_header = ['HANDLE_TH_ERRORS']
|
|
method_header += emit_namedtuple_typedefs(declarations)
|
|
method_header += [UNPACK_SELF] if is_python_method else []
|
|
|
|
method_footer = ['END_HANDLE_TH_ERRORS']
|
|
|
|
# emit dispatch
|
|
if is_noarg_binding(declarations):
|
|
dispatch = emit_single_dispatch(declaration, is_python_method)
|
|
return PY_VARIABLE_METHOD_NOARGS.substitute(
|
|
name=name,
|
|
pycname=pycname,
|
|
method_header=method_header,
|
|
dispatch=dispatch,
|
|
method_footer=method_footer,
|
|
)
|
|
|
|
method_footer = ['Py_RETURN_NONE;'] + method_footer
|
|
|
|
grouped = group_overloads(declarations, is_python_method)
|
|
is_singleton = len(grouped) == 1
|
|
|
|
signatures = []
|
|
dispatch = []
|
|
for i, dictionary in enumerate(grouped):
|
|
signature = dictionary['signature']
|
|
signatures.append('"{}",'.format(signature))
|
|
overload_index = i if not is_singleton else None
|
|
dispatch.append(emit_dispatch_case(overload_index, dictionary, is_python_method))
|
|
|
|
if is_singleton:
|
|
template = PY_VARIABLE_METHOD_VARARGS_SINGLETON
|
|
else:
|
|
template = PY_VARIABLE_METHOD_VARARGS
|
|
|
|
if module:
|
|
check_has_torch_function = TORCH_FUNCTION_CHECK.substitute(
|
|
namespace=NATIVE_NAMESPACE_MAPPING[module],
|
|
modulename='"' + module + '"',
|
|
)
|
|
else:
|
|
check_has_torch_function = ''
|
|
|
|
max_args = max([get_python_argc(decl) for decl in declarations])
|
|
traceable = 'true' if all(should_trace(d) for d in declarations) else 'false'
|
|
|
|
return template.substitute(
|
|
name=name,
|
|
pycname=pycname,
|
|
method_header=method_header,
|
|
max_args=max_args,
|
|
signatures=signatures,
|
|
traceable=traceable,
|
|
check_has_torch_function=check_has_torch_function,
|
|
dispatch=dispatch,
|
|
method_footer=method_footer,
|
|
)
|
|
|
|
|
|
#
|
|
# forward declarations
|
|
#
|
|
|
|
PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION = CodeTemplate("""\
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
|
|
""")
|
|
|
|
PY_VARIABLE_FUNCTION_NOARGS_FORWARD_DECLARATION = CodeTemplate("""\
|
|
static PyObject * ${pycname}(PyObject* self_, PyObject* args);
|
|
""")
|
|
|
|
|
|
def forward_decls(name, declarations, is_python_method, module):
|
|
if is_python_method:
|
|
return []
|
|
|
|
if is_noarg_binding(declarations):
|
|
template = PY_VARIABLE_FUNCTION_NOARGS_FORWARD_DECLARATION
|
|
else:
|
|
template = PY_VARIABLE_FUNCTION_VARARGS_FORWARD_DECLARATION
|
|
|
|
pycname = get_pycname(name)
|
|
return [template.substitute(pycname=pycname)]
|
|
|
|
|
|
#
|
|
# method def (binding table entry) codegen
|
|
#
|
|
|
|
# Python binary operator dunder methods
|
|
BINARY_OP_NAMES = [
|
|
'__lt__', '__le__',
|
|
'__gt__', '__ge__',
|
|
'__eq__', '__ne__',
|
|
|
|
'__add__', '__radd__', '__iadd__',
|
|
'__sub__', '__rsub__', '__isub__',
|
|
'__mul__', '__rmul__', '__imul__',
|
|
'__matmul__', '__rmatmul__', '__imatmul__',
|
|
'__truediv__', '__rtruediv__', '__itruediv__',
|
|
'__floordiv__', '__rfloordiv__', '__ifloordiv__',
|
|
'__mod__', '__rmod__', '__imod__',
|
|
'__divmod__', '__rdivmod__', '__idivmod__',
|
|
'__pow__', '__rpow__', '__ipow__',
|
|
'__lshift__', '__rlshift__', '__ilshift__',
|
|
'__rshift__', '__rrshift__', '__irshift__',
|
|
'__and__', '__rand__', '__iand__',
|
|
'__xor__', '__rxor__', '__ixor__',
|
|
'__or__', '__ror__', '__ior__',
|
|
]
|
|
|
|
# PyMethodDef entry for binary op, throws not implemented error
|
|
PY_VARIABLE_METHOD_BINOP_DEF = CodeTemplate("""\
|
|
{"${name}", (PyCFunction)${pycfunc_voidcast}TypeError_to_NotImplemented_<${pycname}>, ${flags}, NULL},""")
|
|
|
|
# PyMethodDef entry
|
|
PY_VARIABLE_METHOD_DEF = CodeTemplate("""\
|
|
{"${name}", (PyCFunction)${pycfunc_voidcast}${pycname}, ${flags}, NULL},""")
|
|
|
|
|
|
def method_def(name, declarations, is_python_method, module):
|
|
"""
|
|
Generate method def entry.
|
|
"""
|
|
pycname = get_pycname(name)
|
|
|
|
if is_noarg_binding(declarations):
|
|
pycfunc_voidcast = ''
|
|
flags = 'METH_NOARGS' if is_python_method else 'METH_VARARGS | METH_KEYWORDS'
|
|
else:
|
|
pycfunc_voidcast = '(void(*)(void))'
|
|
flags = 'METH_VARARGS | METH_KEYWORDS'
|
|
|
|
if module == "torch":
|
|
flags += ' | METH_STATIC'
|
|
|
|
if name in BINARY_OP_NAMES:
|
|
def_template = PY_VARIABLE_METHOD_BINOP_DEF
|
|
else:
|
|
def_template = PY_VARIABLE_METHOD_DEF
|
|
|
|
return def_template.substitute(
|
|
name=name,
|
|
pycname=pycname,
|
|
pycfunc_voidcast=pycfunc_voidcast,
|
|
flags=flags,
|
|
)
|
|
|
|
#
|
|
# overload sorting and grouping
|
|
#
|
|
|
|
def group_overloads(declarations, is_python_method):
|
|
"""Returns a list of dictionaries containing the optional keys:
|
|
|
|
"base": the regular ATen declaration (e.g. conv2d)
|
|
"out": the out variant (e.g. conv2d_out)
|
|
"signature": the signature used for Python argument parsing
|
|
|
|
Note that we merge pairs of declarations with signatures that
|
|
are equivalent mod output arguments, and use a single entry in
|
|
the python_arg_parser sig list for both (output arguments become
|
|
optional)
|
|
"""
|
|
grouped = defaultdict(dict)
|
|
|
|
# first group by signature ignoring out arguments
|
|
for declaration in declarations:
|
|
signature = get_python_signature(declaration, is_python_method, skip_outputs=True)
|
|
v = grouped[signature]
|
|
if declaration['name'].endswith('_out'):
|
|
v['out'] = declaration
|
|
# prefer the signature with optional out=... arguments
|
|
v['signature'] = get_python_signature(declaration, is_python_method)
|
|
else:
|
|
v['base'] = declaration
|
|
if 'signature' not in v:
|
|
v['signature'] = signature
|
|
|
|
result = []
|
|
for x, dictionary in sorted(grouped.items()):
|
|
if 'base' not in dictionary:
|
|
raise RuntimeError(
|
|
"'base' not in dictionary for {}. keys are {}".format(
|
|
x, list(dictionary.keys())))
|
|
result.append(dictionary)
|
|
return sort_declarations(result)
|
|
|
|
|
|
# This function declares a partial order on declarations, and sorts them according
|
|
# to its linear extension. This is necessary, because there's some ambiguity in the
|
|
# choice of overload, and we want a different order.
|
|
#
|
|
# See Note[Order of overloads matters]
|
|
def sort_declarations(grouped_decls):
|
|
|
|
# TODO: This is a hack!
|
|
#
|
|
# For some reason, when you specify a Scalar argument in a native
|
|
# function, you get a Declarations.yaml entry that looks like this:
|
|
#
|
|
# - default: 1
|
|
# dynamic_type: Scalar
|
|
# is_nullable: false
|
|
# kwarg_only: true
|
|
# name: alpha
|
|
# type: Scalar
|
|
#
|
|
# This is contrast to when there is a 'real' argument in TH
|
|
# Declarations.cwrap; this gets (correctly?) translated into
|
|
# dynamic_type: real, and type: Scalar. I would like to fix this
|
|
# at the source but I have never understood what dynamic_type is
|
|
# supposed to be.
|
|
def normalized_dynamic_type(arg):
|
|
if arg['dynamic_type'] == 'real':
|
|
return 'Scalar'
|
|
return arg['dynamic_type']
|
|
|
|
def is_coord_smaller(arg1, arg2):
|
|
return normalized_dynamic_type(arg1) == 'Scalar' and arg2['dynamic_type'] == 'Tensor'
|
|
|
|
def is_smaller(d1, d2):
|
|
"""Returns True if d1 < d2 in the partial order."""
|
|
args1, args2 = d1['base']['arguments'], d2['base']['arguments']
|
|
if len(args1) != len(args2):
|
|
return False
|
|
any_smaller = any(is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2))
|
|
all_smaller_or_equal = all(normalized_dynamic_type(arg1) == normalized_dynamic_type(arg2) or
|
|
is_coord_smaller(arg1, arg2)
|
|
for arg1, arg2 in zip(args1, args2))
|
|
return any_smaller and all_smaller_or_equal
|
|
|
|
# Construct the relation graph
|
|
larger_than = defaultdict(set)
|
|
for i1, decl1 in enumerate(grouped_decls):
|
|
for i2, decl2 in enumerate(grouped_decls):
|
|
if is_smaller(decl1, decl2):
|
|
larger_than[i1].add(i2)
|
|
|
|
if not larger_than:
|
|
return grouped_decls
|
|
|
|
# Use a topological sort to sort decls according to the partial order.
|
|
sorted_deps = [(i, decl) for i, decl in enumerate(grouped_decls)
|
|
if i not in larger_than]
|
|
for i, decl in sorted_deps:
|
|
for i2 in sorted(larger_than.keys()):
|
|
larger = larger_than[i2]
|
|
larger.discard(i)
|
|
if not larger:
|
|
del larger_than[i2]
|
|
sorted_deps.append((i2, grouped_decls[i2]))
|
|
|
|
return [decl for i, decl in sorted_deps]
|
|
|
|
|
|
#
|
|
# python signature codegen
|
|
#
|
|
|
|
SCHEMA_DEFAULT_CONVERSION_HACKS = {
|
|
'nullptr': 'None',
|
|
'c10::nullopt': 'None',
|
|
'{}': 'None',
|
|
}
|
|
|
|
def get_schema_formal(arg, is_python_method):
|
|
name = arg['name']
|
|
typename = arg['simple_type']
|
|
|
|
# TODO: remove this and make optional types in simple_type to be consistent across
|
|
# tensor and other types after make Tensor? be optional instead of undefined
|
|
if arg.get('is_nullable') and '?' not in typename:
|
|
typename = '{}?'.format(typename)
|
|
|
|
# s/self/input/ outside method bindings.
|
|
# TODO remove this? doesn't rename in codegen, it's just for the parse string
|
|
if name == 'self' and typename == 'Tensor' and not is_python_method:
|
|
name = 'input'
|
|
|
|
size = arg.get('size')
|
|
if size is not None:
|
|
typename = '{}[{}]'.format(typename, size)
|
|
|
|
# default
|
|
default = arg.get('default')
|
|
if default is not None:
|
|
default = SCHEMA_DEFAULT_CONVERSION_HACKS.get(default, default)
|
|
return '{} {}={}'.format(typename, name, default)
|
|
else:
|
|
return '{} {}'.format(typename, name)
|
|
|
|
|
|
PYTHON_ARG_PARSER_SCHEMA = CodeTemplate("""\
|
|
${name}(${schema_formals})${deprecated}""")
|
|
|
|
|
|
def get_python_signature(declaration, is_python_method, skip_outputs=False):
|
|
# Compute the Python function signature for argument parsing,
|
|
# as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
|
|
# this is NOT the same type signature as specified by PEP 484
|
|
# as understood by mypy; our format was independently developed
|
|
# and has some quirks to make it more suitable specifically
|
|
# for error parsing.
|
|
#
|
|
# For a translation to mypy-valid type signatures, see
|
|
# tools/gen_pyi.py. If you change any logic here, please
|
|
# check that file too.
|
|
|
|
python_args = get_python_args(declaration)
|
|
if skip_outputs:
|
|
python_args = [arg for arg in python_args if not is_output(arg)]
|
|
|
|
schema_formals = [get_schema_formal(arg, is_python_method) for arg in python_args]
|
|
positional_argc = len(declaration['python_arglists']['input_args'])
|
|
if len(python_args) > positional_argc:
|
|
schema_formals.insert(positional_argc, '*')
|
|
|
|
# Python function signature.
|
|
# This is the string that we give to FunctionParameter, which is
|
|
# then parsed into the actual structure which we do parsing with.
|
|
name = op_name(declaration)
|
|
deprecated = '|deprecated' if declaration.get('deprecated', False) else ''
|
|
return PYTHON_ARG_PARSER_SCHEMA.substitute(
|
|
name=name,
|
|
schema_formals=schema_formals,
|
|
deprecated=deprecated,
|
|
)
|
|
|
|
#
|
|
# op args to python parsed args transform
|
|
#
|
|
|
|
def get_python_args(decl):
|
|
arglists = decl['python_arglists']
|
|
return \
|
|
arglists['input_args'] + \
|
|
arglists['input_kwargs'] + \
|
|
arglists['output_args'] + \
|
|
arglists['python_binding_args']
|
|
|
|
|
|
def get_python_argc(decl):
|
|
return sum([len(arglist) for arglist in decl['python_arglists'].values()])
|
|
|
|
|
|
def get_python_output_index(decl):
|
|
arglists = decl['python_arglists']
|
|
return len(arglists['input_args'] + arglists['input_kwargs'])
|
|
|
|
|
|
def make_python_arglists(declaration, is_python_method):
|
|
# produces python-ready args converted from declaration['args'],
|
|
# partitioned into sublists by category. subslists are order, so
|
|
# the final python arglist can be recovered by simple flattening
|
|
# (see get_python_args())
|
|
|
|
# partition args into sublists
|
|
|
|
args = declaration['arguments']
|
|
|
|
input_args = []
|
|
input_kwargs = []
|
|
output_args = []
|
|
|
|
current_input_args = input_args
|
|
for arg in args:
|
|
if is_output(arg):
|
|
output_args.append(arg)
|
|
else:
|
|
if arg.get('kwarg_only', False):
|
|
current_input_args = input_kwargs
|
|
current_input_args.append(arg)
|
|
|
|
# adjustments
|
|
|
|
# positional inputs:
|
|
# - filter self when we're generating a method binding.else - there, it comes in as
|
|
# a separate Python param, not in args array
|
|
def include(arg):
|
|
return not (is_tensor_self(arg) and is_python_method)
|
|
input_args = [arg for arg in input_args if include(arg)]
|
|
|
|
# keyword inputs:
|
|
# - filter options. after loading the yaml, an upstream step has gathered dtype,
|
|
# layout et al into a single tensor options arg. here we reintroduce the originals
|
|
input_kwargs = [arg for arg in input_kwargs if not is_tensor_options(arg)]
|
|
|
|
# outputs:
|
|
# - coalesce multiple output args into a single 'out' arg w/type TensorList.
|
|
# - force a default. This is so we can use this sig for both out and non-out variants
|
|
num_outputs = len(output_args)
|
|
if num_outputs > 1:
|
|
for arg in output_args:
|
|
if not arg['simple_type'] == 'Tensor':
|
|
raise RuntimeError(
|
|
'{}: unsupported output argument type {}'.
|
|
format(declaration['name'], arg['type']))
|
|
typename = 'TensorList'
|
|
output_args = [{
|
|
'default': 'None',
|
|
'kwarg_only': True,
|
|
'name': 'out',
|
|
'output': True,
|
|
'scatter_args': output_args,
|
|
'simple_type': typename,
|
|
'size': num_outputs,
|
|
'type': typename,
|
|
}]
|
|
elif num_outputs == 1:
|
|
output_arg = output_args[0].copy()
|
|
output_arg['default'] = 'None'
|
|
output_args = [output_arg]
|
|
|
|
# make python binding args
|
|
# these are the (re)scattered versions of the options arg omitted above.
|
|
# TODO because these aren't guaranteed to be 100% faithful to the original
|
|
# versions in the yaml, this recreation is a potential source of drift between
|
|
# eager and JIT. Pull this logic out to a shared place.
|
|
python_binding_args = make_python_binding_args(declaration)
|
|
|
|
return {
|
|
'input_args': input_args,
|
|
'input_kwargs': input_kwargs,
|
|
'output_args': output_args,
|
|
'python_binding_args': python_binding_args,
|
|
}
|
|
|
|
#
|
|
# python binding args
|
|
#
|
|
|
|
# TODO blowtorch
|
|
def dtype_default_type_hack(name):
|
|
if name.startswith('randperm') or name == 'tril_indices' or name == 'triu_indices':
|
|
return 'torch.int64'
|
|
else:
|
|
return 'None'
|
|
|
|
|
|
def make_python_binding_args(declaration):
|
|
"""
|
|
Given various properties of a declaration, build a set of scattered python binding args.
|
|
"""
|
|
name = declaration['name']
|
|
python_binding_arguments = []
|
|
has_tensor_input_arg = False
|
|
has_options_arg = False
|
|
for arg in declaration['arguments']:
|
|
if is_output(arg):
|
|
continue
|
|
typename = arg['simple_type']
|
|
if typename in ['Tensor', 'TensorList']:
|
|
has_tensor_input_arg = True
|
|
elif typename == 'TensorOptions':
|
|
has_options_arg = True
|
|
if arg['name'] == 'requires_grad':
|
|
raise ValueError("argument named requires_grad not supported")
|
|
|
|
has_tensor_return = False
|
|
for ret in declaration['returns']:
|
|
if ret['dynamic_type'] in ['Tensor', 'TensorList']:
|
|
# this probably won't work if one of the returns is not a tensor, but it will
|
|
# produce a compile-time error that is obvious
|
|
has_tensor_return = True
|
|
|
|
category_override = declaration['category_override']
|
|
is_like_function = name.endswith('_like') or category_override == 'like'
|
|
is_like_function_with_options = is_like_function and has_options_arg
|
|
is_new_function = name.startswith('new_') or category_override == 'new'
|
|
is_new_function_with_options = is_new_function and has_options_arg
|
|
is_factory_function = has_tensor_return and not has_tensor_input_arg or category_override == 'factory'
|
|
is_factory_or_like_or_new_function = has_tensor_return and (is_factory_function or is_like_function or is_new_function)
|
|
is_like_or_new_function_with_options = is_like_function_with_options or is_new_function_with_options
|
|
|
|
if is_factory_function or has_options_arg:
|
|
default_type = dtype_default_type_hack(name)
|
|
py_default_dtype = 'self.scalar_type()' if is_like_or_new_function_with_options else None
|
|
dtype_arg = {
|
|
'default': default_type,
|
|
'dynamic_type': 'ScalarType',
|
|
'kwarg_only': True,
|
|
'name': 'dtype',
|
|
'type': 'const ScalarType &',
|
|
'simple_type': 'ScalarType',
|
|
'python_default_init': py_default_dtype,
|
|
}
|
|
python_binding_arguments.append(dtype_arg)
|
|
|
|
if is_factory_function or is_like_or_new_function_with_options:
|
|
py_default_layout = 'layout_from_backend(self.options().backend())' if is_like_or_new_function_with_options else None
|
|
layout_arg = {
|
|
'default': 'torch.strided',
|
|
'dynamic_type': 'Layout',
|
|
'kwarg_only': True,
|
|
'name': 'layout',
|
|
'type': 'c10::optional<Layout>',
|
|
'simple_type': 'Layout',
|
|
'python_default_init': py_default_layout,
|
|
}
|
|
python_binding_arguments.append(layout_arg)
|
|
py_default_device = 'self.device()' if is_like_or_new_function_with_options else None
|
|
device_arg = {
|
|
'default': 'None',
|
|
'dynamic_type': 'Device',
|
|
'kwarg_only': True,
|
|
'name': 'device',
|
|
'type': 'const Device &',
|
|
'simple_type': 'Device',
|
|
'python_default_init': py_default_device
|
|
}
|
|
python_binding_arguments.append(device_arg)
|
|
pin_memory_arg = {
|
|
'default': False,
|
|
'dynamic_type': 'bool',
|
|
'kwarg_only': True,
|
|
'name': 'pin_memory',
|
|
'type': 'bool',
|
|
'simple_type': 'bool',
|
|
}
|
|
python_binding_arguments.append(pin_memory_arg)
|
|
|
|
if is_factory_or_like_or_new_function:
|
|
requires_grad_arg = {
|
|
'default': False,
|
|
'dynamic_type': 'bool',
|
|
'kwarg_only': True,
|
|
'name': 'requires_grad',
|
|
'type': 'bool',
|
|
'simple_type': 'bool',
|
|
}
|
|
python_binding_arguments.append(requires_grad_arg)
|
|
|
|
return python_binding_arguments
|
|
|
|
#
|
|
# declaration derived props, utils, etc.
|
|
# declarations are dicts loaded from Declarations.yaml,
|
|
# passed to our codegen methods by callers in gen_autograd
|
|
#
|
|
|
|
def is_tensor_self(arg):
|
|
return arg['name'] == 'self' and arg['simple_type'] == 'Tensor'
|
|
|
|
|
|
def is_tensor_options(arg):
|
|
return arg['simple_type'] == 'TensorOptions'
|
|
|
|
|
|
def is_scatter(arg):
|
|
return arg.get('scatter_args') is not None
|
|
|
|
def is_output(arg):
|
|
return arg.get('output', False)
|
|
|
|
|
|
def has_outputs(declaration):
|
|
return any([is_output(arg) for arg in declaration['arguments']])
|
|
|
|
|
|
def get_tensor_options(declaration):
|
|
args = [arg for arg in declaration['arguments'] if is_tensor_options(arg)]
|
|
if len(args) == 0:
|
|
return None
|
|
if len(args) != 1:
|
|
raise RuntimeError(
|
|
'{}: multiple tensor options arguments'.
|
|
format(declaration['name']))
|
|
return args[0]
|
|
|
|
|
|
def has_tensor_options(declaration):
|
|
return get_tensor_options(declaration) is not None
|
|
|
|
|
|
def is_tensor_method(declaration):
|
|
return 'Tensor' in declaration['method_of']
|
|
|
|
|
|
def is_torch_function(declaration):
|
|
return 'namespace' in declaration['method_of']
|
|
|
|
|
|
def is_nn_module_function(declaration):
|
|
return declaration.get('python_module') == 'nn'
|
|
|
|
|
|
def function_namespace(declaration):
|
|
# TODO look into why these can't all be 'torch' calls
|
|
if has_tensor_options(declaration) or op_name(declaration).endswith('_like'):
|
|
return 'torch'
|
|
else:
|
|
return 'at'
|
|
|
|
|
|
def op_name(declaration):
|
|
name = declaration['name']
|
|
if has_outputs(declaration):
|
|
if not name.endswith("_out"):
|
|
raise RuntimeError(
|
|
'{} has output params, expecting name ending with \'_out\''.
|
|
format(declaration['name']))
|
|
return name[:-4]
|
|
else:
|
|
if name.endswith("_out"):
|
|
raise RuntimeError(
|
|
'{}: name ends with \'_out\', expecting output params'.
|
|
format(declaration['name']))
|
|
return name
|