pytorch/tools/autograd/gen_python_functions.py
Sam Gross 829d763c69 Implement add, sub, mul, div using TensorIterator (#8919)
Summary:
```
This adds TensorIterator, a helper class for computing element-wise
operations that's intended to replace the CPU and CUDA apply utils
functions.

CPU kernels are implemented as functions that operate on strided 1-d
tensors compared to CPUApplyUtils which operated individual elements. This
allows the kernels to handle vectorization, while TensorIterator handles
parallelization and non-coalesced dimensions.

GPU kernels continue to operate on elements, but the number of
specializations is reduced. The contiguous case remains the same. The
non-contiguous case uses a single (reduced) shape for all operands and
the fast integer division from THCIntegerDivider. To avoid extra
specializations for indexing with 64-bits, large operations are split
into smaller operations that can be indexed with 32-bits.

Major semantic changes:

 - No more s_add, s_mul, s_div, or s_sub. Broadcasting is handled by
   TensorIterator. The autograd engine performs the reduction assuming
   standard broadcasting if the gradient shape does not match the
   expected shape. Functions that do not use standard broadcasting rules
   should either continue to trace the expand calls or handle the
   reduction in their derivative formula.

 - Use ONNX v7, which supports broadcasting ops.

Performance impact:

 - Small increased fixed overhead (~0.5 us)
 - Larger overhead for wrapped numbers (~2.5 us)
 - No significant change for ops on contiguous tensors
 - Much faster worst-case performance for non-contiguous GPU tensors
 - Faster CPU bias addition (~2x)
 - Faster GPU bias addition (~30% faster)

Future work:

 - Decrease overhead, especially for wrapping numbers in Tensors
 - Handle general inter-type operations
 - Extend to unary ops and reductions
 - Use buffering for compute-bound operations on non-contiguous tensors
   (pull in from CPUApplyUtils)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/8919

Differential Revision: D8677600

Pulled By: colesbury

fbshipit-source-id: 61bc9cc2a36931dfd00eb7153501003fe0584afd
2018-07-27 14:43:24 -07:00

812 lines
33 KiB
Python

# Generates Python bindings for ATen functions
#
# The bindings are generated as methods on python_variable or functions on the
# torch._C._nn object.
#
from collections import defaultdict
import re
from .nested_dict import nested_dict
from .gen_variable_type import should_trace
from .utils import write
try:
from src.ATen.code_template import CodeTemplate
except ImportError:
from tools.shared.module_loader import import_module
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
# These functions require manual Python bindings or are not exposed to Python
SKIP_PYTHON_BINDINGS = [
'alias', 'contiguous', 'clamp.*', 'is_cuda', 'is_sparse', 'size', 'stride',
'.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
'.*_forward_out', 'sparse_raw_resize_', '_unsafe_view', 'tensor',
'sparse_coo_tensor', 'th_sparse_coo_tensor', 'native_sparse_coo_tensor',
'_arange.*', '_range.*', '_linspace.*', '_logspace.*',
'_sparse_add.*', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*',
'index',
'_indexCopy_', 'max_values', 'min_values', 'argmax', 'argmin',
'_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_.*',
'arange.*', 'range.*', '_gesv.*', '_getri.*', 'slice',
'_local_scalar', '_local_scalar_dense',
'max_pool1d', 'max_pool2d', 'max_pool3d'
]
# These function signatures are not exposed to Python. Note that this signature
# list does not support regex.
SKIP_PYTHON_BINDINGS_SIGNATURES = [
'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)',
'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)',
'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)',
'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
]
PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
${signatures}
}, /*traceable=*/${traceable});
${unpack_self}
ParsedArgs<${max_args}> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
${dispatch}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
""")
PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
{
HANDLE_TH_ERRORS
${unpack_self}
return wrap(${dispatch_name}(${actuals}));
END_HANDLE_TH_ERRORS
}
""")
PY_VARIABLE_CASE = CodeTemplate("""\
${cond} (r.idx == ${i}) {
${call_dispatch}
""")
PY_VARIABLE_OUT = CodeTemplate("""\
if (r.isNone(${out_idx})) {
${call_dispatch}
} else {
${call_dispatch_out}
}
""")
PY_VARIABLE_OUT_CHECK_TYPE = CodeTemplate("""\
if (r.isNone(${out_idx})) {
${call_dispatch}
} else {
check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.isNone(${type_idx}),
r.layout(${layout_idx}), r.isNone(${layout_idx}),
r.device(${device_idx}), r.isNone(${device_idx}));
${call_dispatch_out}
}
""")
PY_VARIABLE_CALL_DISPATCH = CodeTemplate("""\
${dispatch_name}(${actuals})""")
PY_VARIABLE_SET_REQUIRES_GRAD = CodeTemplate("""\
${call_dispatch}.set_requires_grad(${requires_grad})""")
PY_VARIABLE_WRAP = CodeTemplate("""\
return wrap(${call_dispatch});""")
PY_VARIABLE_DISPATCH = CodeTemplate("""\
inline ${return_type} ${dispatch_name}(${formal_args}) {
${initialize_cuda}
${AutoNoGIL}
return ${dispatch_call}(${dispatch_args});
}
""")
PY_VARIABLE_METHOD_DEF = CodeTemplate("""\
{"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""")
UNPACK_SELF = "auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;"
PYTHON_FUNCTION_SIGNATURE = CodeTemplate("""\
${name}(${py_formal_args})""")
# XXX: if you got here because of an assertion failure, it doesn't mean
# it's enough to just extend the list here. Before you do this, make sure
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
SUPPORTED_RETURN_TYPES = {
'Tensor', 'std::tuple<Tensor,Tensor>',
'std::tuple<Tensor,Tensor,Tensor>',
'std::tuple<Tensor,Tensor,Tensor,Tensor>',
'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
'std::vector<Tensor>',
'Scalar', 'bool', 'int64_t', 'void*', 'void'
}
TENSOR_OPTIONS = CodeTemplate("""\
const auto options = TensorOptions()
.dtype(${dtype})
.device(${device})
.layout(${layout}.layout)
.requires_grad(${requires_grad});
""")
def should_generate_python_binding(declaration):
name = declaration['name']
for pattern in SKIP_PYTHON_BINDINGS:
if re.match('^' + pattern + '$', name):
return False
simple_types = [arg['simple_type'] for arg in declaration['arguments']]
signature = '{}({})'.format(name, ', '.join(simple_types))
for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
if pattern == signature:
return False
# TODO: fix handling of SparseTensor. We don't want to generate Python
# bindings to SparseTensor overloads, such as add(Tensor, SparseTensorRef),
# since the Tensor-based signature already dynamically dispatches correctly.
# However, _sparse_mask only has a SparseTensor signature so we need to bind
# that function.
for arg in declaration['arguments']:
if arg['type'] == 'SparseTensorRef' and declaration['name'] != '_sparse_mask':
return False
return True
def gen_py_variable_methods(out, declarations, template_path):
PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
declaration['mode'] != 'NN' and
'Tensor' in declaration['method_of'])
py_variable_methods = group_declarations_by_name(declarations, should_bind)
env = create_python_bindings(py_variable_methods, True)
write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env)
def gen_py_nn_functions(out, declarations, template_path):
PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h')
PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h')
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
declaration['mode'] == 'NN')
py_nn_functions = group_declarations_by_name(declarations, should_bind)
env = create_python_bindings(py_nn_functions, has_self=False, is_module=True)
write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
write(out, 'python_nn_functions.h', PY_NN_FUNCTIONS_H, env)
write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env)
def gen_py_torch_functions(out, declarations, template_path):
PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')
def should_bind(declaration):
return (should_generate_python_binding(declaration) and
declaration['mode'] != 'NN' and
'namespace' in declaration['method_of'])
py_torch_functions = group_declarations_by_name(declarations, should_bind)
env = create_python_bindings(py_torch_functions, has_self=False)
write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
write(out, 'python_torch_functions_dispatch.h', PY_TORCH_DISPATCH_H, env)
def group_declarations_by_name(declarations, should_bind_fn):
"""Group declarations by name ignoring _out suffix"""
groups = defaultdict(list)
for declaration in declarations:
name = declaration['name']
if should_bind_fn(declaration):
if name.endswith('_out'):
groups[name[:-4]].append(declaration)
else:
groups[name].append(declaration)
return groups
def get_type_default(declaration):
if declaration['name'].startswith('randperm'):
return 'torch.int64'
else:
return 'None'
def create_python_bindings(python_functions, has_self, is_module=False):
"""Generates Python bindings to ATen functions"""
py_methods = []
py_method_defs = []
py_method_dispatch = []
unpack_methods = {
'const Tensor &': 'tensor',
'SparseTensorRef': 'tensor',
'Tensor &': 'tensor',
'Generator *': 'generator',
'Storage &': 'storage',
'const Type &': 'scalartype',
'const THPLayout &': 'layout',
'const Device &': 'device',
'optional<ScalarType>': 'scalartypeOptional',
'int64_t': 'toInt64',
'bool': 'toBool',
'double': 'toDouble',
'std::string': 'string',
}
unpack_with_default_methods = {
'IntList': 'setDefaultIntlist',
'Scalar': 'scalarWithDefault',
'int64_t': 'toInt64WithDefault',
'bool': 'setDefaultBool',
'double': 'setDefaultDouble',
'const Type &': 'scalartypeWithDefault',
'const THPLayout &': 'layoutWithDefault',
'const Device &': 'deviceWithDefault',
'ScalarType': 'scalartypeWithDefault',
}
def emit_single_dispatch(declaration, out_idx, base_env):
env = {}
simple_return_type = declaration['return_type'].replace(' &', '')
assert simple_return_type in SUPPORTED_RETURN_TYPES, \
declaration['name'] + ' returns unsupported type: ' + simple_return_type
body = []
actuals = []
formal_args = []
arg_idx = 0
def is_output(arg):
return arg.get('output', False)
inputs = [arg for arg in declaration['arguments'] if not is_output(arg)]
outputs = [arg for arg in declaration['arguments'] if is_output(arg)]
has_tensor_options = any(arg['simple_type'] == 'TensorOptions' for arg in declaration['arguments'])
def get_type_args(args):
return [arg for arg in args if arg['simple_type'] == 'Type']
type_actual_args = get_type_args(declaration['arguments'])
type_binding_args = get_type_args(declaration['python_binding_arguments'])
assert len(type_actual_args + type_binding_args) <= 1
if type_binding_args and len(outputs) == 0:
# out(s) determines the dtype if it is present, so only use this if there are no outputs.
type_args = type_binding_args
else:
type_args = type_actual_args
if type_args and len(outputs) > 1:
raise RuntimeError("Not supported: type dispatched parameter with multiple outputs")
def parse_arg(arg, arg_index, unpack_args=False):
name = arg['name']
typename = arg['type']
if typename.startswith('IntList['):
typename = 'IntList'
if typename.startswith('LongTensor'):
typename = 'Tensor'
if arg.get('python_default_init'):
assert typename in unpack_with_default_methods, \
'`{}` type is not supported in python_default_init'.format(typename)
unpack_with_default = unpack_with_default_methods.get(typename)
default_expr = arg.get('python_default_init')
# TODO: Type currently maps to ScalarType, figure out a cleaner solution
if typename == 'const Type &':
default_expr += '.scalarType()'
expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
else:
unpack = unpack_methods.get(typename, typename.lower())
expr = 'r.{}({})'.format(unpack, arg_index)
if unpack_args:
body.append('auto {} = {};'.format(name, expr))
expr = name
if typename == 'Storage &':
expr = '*' + expr
if typename == 'SparseTensorRef':
expr = 'SparseTensorRef({})'.format(expr)
dispatch_type = typename
if dispatch_type == 'Tensor':
dispatch_type = 'const Tensor &'
elif dispatch_type == 'Tensor &':
dispatch_type = 'Tensor'
elif dispatch_type == 'const Device &':
dispatch_type = 'at::optional<int32_t>'
formal = '{} {}'.format(dispatch_type, name)
return expr, formal
def append_actuals_formals(actual, formal):
actuals.append(actual)
formal_args.append(formal)
# We always want to unpack when we have TensorOptions.
unpack = any(arg.get('python_default_init') for arg in inputs) or has_tensor_options
for arg in inputs:
if arg['simple_type'] in ['Type', 'TensorOptions']:
continue
if has_self and arg['name'] == 'self':
formal_args.append('Tensor & self')
actuals.append('self')
continue
append_actuals_formals(*parse_arg(arg, arg_idx, unpack))
arg_idx += 1
if len(outputs) == 1:
append_actuals_formals(*parse_arg(outputs[0], arg_idx))
elif len(outputs) > 1:
N = len(outputs)
body.append('auto results = r.tensorlist_n<{}>({});'.format(N, arg_idx))
for i, arg in enumerate(outputs):
formal_args.append('Tensor & {}'.format(arg['name']))
actuals.append('results[{}]'.format(i))
layout = None
parsed_type_args = None
# type args go after the outputs to match the signature generation.
arg_idx = arg_idx if out_idx is None else out_idx + 1
for arg in type_args:
parsed_type_args = parse_arg(arg, arg_idx, unpack)
arg_idx += 1
# check python_binding_arguments
has_device_bind = False
requires_grad = None
python_binding_arguments = declaration.get('python_binding_arguments', [])
if 'dtype' in (a['name'] for a in python_binding_arguments):
if not has_tensor_options:
arg_idx += 1
if 'layout' in (a['name'] for a in python_binding_arguments):
layout_idx, device_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2)
else:
device_idx, requires_grad_idx = (arg_idx, arg_idx + 1)
device = None
for arg in python_binding_arguments:
if arg['name'] == 'dtype' and arg['simple_type'] == 'Type':
pass # already handled by type_dispatched_args
elif arg['name'] == 'layout' and arg['simple_type'] == 'Layout':
# out(s) determines the type and layout if it is present, so only use this if there are no outputs.
if len(outputs) == 0:
layout = parse_arg(arg, layout_idx, arg.get('python_default_init'))[0]
elif arg['name'] == 'device' and arg['simple_type'] == 'Device':
if len(outputs) == 0:
assert parsed_type_args
assert layout
device, device_type = parse_arg(arg, device_idx, True)
if not has_tensor_options:
# add type, device formals and corresponding actuals.
# The type actual isthe ATen type mapped from (ScalarType, Layout, Device)
# The device actual is the corresponding AutoGPU index for the Device.
formal_args.append(parsed_type_args[1])
formal_args.append(device_type)
actuals.append("torch::getType({}, {}, {})".format(parsed_type_args[0], layout, device))
actuals.append('{}.index()'.format(device))
has_device_bind = True
elif arg['name'] == 'requires_grad' and arg['simple_type'] == 'bool':
requires_grad = parse_arg(arg, requires_grad_idx)[0]
else:
raise RuntimeError(("found {} in python_binding_arguments but only "
"\"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", "
"\"Device device\" are supported".format(arg)))
dtype = parsed_type_args[0] if parsed_type_args else None
if has_tensor_options and all([dtype, device, layout, requires_grad]):
body.append(TENSOR_OPTIONS.substitute({
'dtype': dtype,
'layout': layout,
'device': device,
'requires_grad': requires_grad
}))
formal_args.append('const TensorOptions & options')
actuals.append('options')
env['unpack_args'] = []
env['formal_args'] = formal_args
env['actuals'] = actuals
if has_tensor_options:
env['initialize_cuda'] = 'maybe_initialize_cuda(options.type());'
else:
env['initialize_cuda'] = 'maybe_initialize_cuda({});'.format(type_args[0]['name']) if type_args else ''
if 'call_args' in declaration:
env['dispatch_args'] = declaration['call_args']
else:
env['dispatch_args'] = [arg['name'] for arg in declaration['arguments']]
if 'Tensor' in declaration['method_of']:
env['dispatch_args'] = [arg for arg in env['dispatch_args'] if arg != 'self']
env['dispatch_call'] = 'self.{}'.format(declaration['name'])
elif 'namespace' in declaration['method_of']:
namespace = 'torch' if (has_tensor_options or declaration['name'].endswith('_like')) else 'at'
env['dispatch_call'] = '{}::{}'.format(namespace, declaration['name'])
else:
raise RuntimeError('could not dispatch, neither namespace function nor Tensor method')
env['AutoNoGIL'] = 'AutoNoGIL no_gil;' if not declaration['with_gil'] else ''
env = nested_dict(env, nested_dict(base_env, declaration))
call_dispatch = PY_VARIABLE_CALL_DISPATCH.substitute(env)
if requires_grad and not has_tensor_options:
call_dispatch = PY_VARIABLE_SET_REQUIRES_GRAD.substitute(env, call_dispatch=call_dispatch,
requires_grad=requires_grad)
if simple_return_type == 'void':
body.append('{call_dispatch};'.format(call_dispatch=call_dispatch))
body.append('Py_RETURN_NONE;')
else:
body.append(PY_VARIABLE_WRAP.substitute(env, call_dispatch=call_dispatch))
py_method_dispatch.append(PY_VARIABLE_DISPATCH.substitute(env))
return body
def emit_dispatch(i, dictionary, base_env):
if 'out' in dictionary:
out_idx = len([arg for arg in dictionary['out']['arguments']
if not arg.get('output', False)])
env = {}
env['call_dispatch_out'] = emit_single_dispatch(dictionary['out'], out_idx, base_env)
env['call_dispatch'] = emit_single_dispatch(dictionary['base'], out_idx, base_env)
has_dtype_bind = 'dtype' in [d['name'] for d in dictionary['out'].get('python_binding_arguments', [])]
if has_dtype_bind:
body = PY_VARIABLE_OUT_CHECK_TYPE.substitute(env, out_idx=out_idx, type_idx=out_idx + 1,
layout_idx=out_idx + 2, device_idx=out_idx + 3).split('\n')
else:
body = PY_VARIABLE_OUT.substitute(env, out_idx=out_idx).split('\n')
else:
body = emit_single_dispatch(dictionary['base'], None, base_env)
cond = 'if' if i == 0 else '} else if'
return PY_VARIABLE_CASE.substitute(i=i, cond=cond, call_dispatch=body)
def get_python_binding_arguments(declaration):
python_binding_arguments = []
has_tensor_input_arg = False
has_type_input_arg = False
has_options_arg = False
for arg in declaration['arguments']:
if arg.get('output', False):
continue
typename = arg['simple_type']
if typename in ['Tensor', 'TensorList']:
has_tensor_input_arg = True
if arg['simple_type'] == 'Type':
has_type_input_arg = True
elif arg['simple_type'] == 'TensorOptions':
has_options_arg = True
if arg['name'] == 'requires_grad':
raise ValueError("argument named requires_grad not supported")
has_tensor_return = False
for ret in declaration['returns']:
if ret['dynamic_type'] in ['Tensor', 'TensorList']:
# this probably won't work if one of the returns is not a tensor, but it will
# produce a compile-time error that is obvious
has_tensor_return = True
is_like_function = name.endswith('_like')
is_like_function_with_options = is_like_function and has_options_arg
is_factory_function = has_tensor_return and not has_tensor_input_arg
is_factory_or_like_function = has_tensor_return and (not has_tensor_input_arg or is_like_function)
if (is_factory_function and not has_type_input_arg) or has_options_arg:
default_type = get_type_default(declaration)
py_default_dtype = 'self.type()' if is_like_function_with_options else None
dtype_arg = {
'default': default_type,
'dynamic_type': 'Type',
'kwarg_only': True,
'name': 'dtype',
'type': 'const Type &',
'simple_type': 'Type',
'is_type_dispatched': True,
'python_default_init': py_default_dtype,
}
python_binding_arguments.append(dtype_arg)
if is_factory_function or is_like_function_with_options:
py_default_layout = '*torch::getLayout(self.type().backend())' if is_like_function_with_options else None
layout_arg = {
'default': 'torch.strided',
'dynamic_type': 'Layout',
'kwarg_only': True,
'name': 'layout',
'type': 'const THPLayout &',
'simple_type': 'Layout',
'python_default_init': py_default_layout,
}
python_binding_arguments.append(layout_arg)
py_default_device = 'self.device()' if is_like_function_with_options else None
device_arg = {
'default': 'None',
'default_init': 'None',
'dynamic_type': 'Device',
'kwarg_only': True,
'name': 'device',
'type': 'const Device &',
'simple_type': 'Device',
'python_default_init': py_default_device
}
python_binding_arguments.append(device_arg)
if is_factory_or_like_function:
requires_grad_arg = {
'default': False,
'dynamic_type': 'bool',
'kwarg_only': True,
'name': 'requires_grad',
'type': 'bool',
'simple_type': 'bool',
}
python_binding_arguments.append(requires_grad_arg)
return python_binding_arguments
def process_function(name, declarations):
for declaration in declarations:
declaration['python_binding_arguments'] = get_python_binding_arguments(declaration)
env = {
'name': name,
'dispatch_name': 'dispatch_{}'.format(name),
'pycname': 'THPVariable_{}'.format(name),
'signatures': [],
'max_args': max(len(o['arguments']) + len(o['python_binding_arguments']) for o in declarations),
'unpack_self': [],
'dispatch': [],
}
if has_self:
env['unpack_self'] = [UNPACK_SELF]
grouped = group_declarations(declarations)
for i, dictionary in enumerate(grouped):
signature = dictionary['signature']
if has_self:
signature = signature.replace('Tensor self, ', '')
signature = signature.replace('Tensor self', '')
if not has_self:
# Use 'input' instead of 'self' for NN functions
signature = signature.replace('Tensor self', 'Tensor input')
signature = signature.replace('SparseTensorRef', 'Tensor')
if dictionary['base'].get('deprecated', False):
signature += '|deprecated'
env['signatures'].append('"{}",'.format(signature))
env['dispatch'].append(emit_dispatch(i, dictionary, env))
env['dispatch'].append('}')
env['traceable'] = 'true' if all(should_trace(d) for d in declarations) else 'false'
if len(declarations) == 1 and len(declarations[0]['args']) == 1 and has_self:
tmpl = PY_VARIABLE_METHOD_NOARGS
env['actuals'] = ['self']
env['flags'] = 'METH_NOARGS'
else:
tmpl = PY_VARIABLE_METHOD_VARARGS
env['flags'] = 'METH_VARARGS | METH_KEYWORDS'
if not is_module and not has_self:
env['flags'] += ' | METH_STATIC'
py_methods.append(tmpl.substitute(env))
py_method_defs.append(PY_VARIABLE_METHOD_DEF.substitute(env))
for name in sorted(python_functions.keys()):
process_function(name, python_functions[name])
return {
'py_methods': py_methods,
'py_method_defs': py_method_defs,
'py_method_dispatch': py_method_dispatch,
}
def group_declarations(declarations):
"""Returns a list of dictionaries containing the optional keys:
"base": the regular ATen declaration (e.g. conv2d)
"out": the out variant (e.g. conv2d_out)
"signature": the signature used for Python argument parsing
"""
grouped = defaultdict(dict)
# first group by signature ignoring out arguments
for declaration in declarations:
signature = get_python_signature(declaration, False)
v = grouped[signature]
if declaration['name'].endswith('_out'):
v['out'] = declaration
# prefer the signature with optional out=... arguments
v['signature'] = get_python_signature(declaration, True)
else:
v['base'] = declaration
if 'signature' not in v:
v['signature'] = signature
result = []
for _, dictionary in sorted(grouped.items()):
if 'base' not in dictionary:
raise RuntimeError("'base' not in dictionary", dictionary)
result.append(dictionary)
return sort_declarations(result)
# This function declares a partial order on declarations, and sorts them according
# to its linear extension. This is necessary, because there's some ambiguity in the
# choice of overload, and we want a different order.
#
# See Note[Order of overloads matters]
def sort_declarations(grouped_decls):
# TODO: This is a hack!
#
# For some reason, when you specify a Scalar argument in a native
# function, you get a Declarations.yaml entry that looks like this:
#
# - default: 1
# dynamic_type: Scalar
# is_nullable: false
# kwarg_only: true
# name: alpha
# type: Scalar
#
# This is contrast to when there is a 'real' argument in TH
# Declarations.cwrap; this gets (correctly?) translated into
# dynamic_type: real, and type: Scalar. I would like to fix this
# at the source but I have never understood what dynamic_type is
# supposed to be.
def normalized_dynamic_type(arg):
if arg['dynamic_type'] == 'real':
return 'Scalar'
return arg['dynamic_type']
def is_coord_smaller(arg1, arg2):
return normalized_dynamic_type(arg1) == 'Scalar' and arg2['dynamic_type'] == 'Tensor'
def is_smaller(d1, d2):
"""Returns True if d1 < d2 in the partial order."""
args1, args2 = d1['base']['arguments'], d2['base']['arguments']
if len(args1) != len(args2):
return False
any_smaller = any(is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2))
all_smaller_or_equal = all(normalized_dynamic_type(arg1) == normalized_dynamic_type(arg2) or
is_coord_smaller(arg1, arg2)
for arg1, arg2 in zip(args1, args2))
return any_smaller and all_smaller_or_equal
# Construct the relation graph
larger_than = defaultdict(set)
for i1, decl1 in enumerate(grouped_decls):
for i2, decl2 in enumerate(grouped_decls):
if is_smaller(decl1, decl2):
larger_than[i1].add(i2)
if not larger_than:
return grouped_decls
# Use a topological sort to sort decls according to the partial order.
sorted_deps = [(i, decl) for i, decl in enumerate(grouped_decls)
if i not in larger_than]
for i, decl in sorted_deps:
for i2 in sorted(larger_than.keys()):
larger = larger_than[i2]
larger.discard(i)
if not larger:
del larger_than[i2]
sorted_deps.append((i2, grouped_decls[i2]))
return [decl for i, decl in sorted_deps]
def get_python_signature(declaration, include_out):
# Compute the Python function signature for argument parsing
py_formal_args = []
output_args = []
type_args = []
positional = True
def get_py_formal_arg(arg):
typename = arg['simple_type']
opt_match = re.match(r'optional<(.+)>', typename)
if opt_match:
typename = opt_match.group(1)
typename = typename if typename != 'Type' else 'ScalarType'
if arg.get('is_nullable') or opt_match:
typename = '{}?'.format(typename)
if arg.get('size') is not None:
typename = '{}[{}]'.format(typename, arg['size'])
param = typename + ' ' + arg['name']
default = None
if arg.get('default') is not None:
default = arg['default']
if default == 'nullptr' or default == 'nullopt' or default == '{}':
default = 'None'
if arg.get('python_default_init') is not None:
default = 'None'
if default is None and arg.get('is_type_dispatched', False):
# this is necessary because ATen does not have default_types; in this case,
# the type exists in the public API (at:: namespace), but not in the type interface;
# to match the PyTorch default_type API, we set the default to None.
default = get_type_default(declaration)
if default is not None:
param += '=' + str(default)
return param
for arg in declaration['arguments']:
if arg.get('output', False):
output_args.append(arg)
continue
if arg['simple_type'] == 'Type':
type_args.append(arg)
continue
# Skip `TensorOptions` in Python, as it is only used on the C++ side.
if arg['simple_type'] == 'TensorOptions':
continue
if arg.get('kwarg_only', False) and positional:
py_formal_args.append('*')
positional = False
param = get_py_formal_arg(arg)
py_formal_args.append(param)
# add output arguments
name = declaration['name']
if name.endswith('_out'):
name = name[:-4]
if len(output_args) > 0 and include_out:
assert declaration['name'].endswith('_out')
if positional:
py_formal_args.append('*')
positional = False
typenames = [arg['simple_type'] for arg in output_args]
if len(typenames) > 1:
typename = 'TensorList[{}]'.format(len(typenames))
else:
typename = typenames[0]
py_formal_args.append(typename + ' out=None')
# we could put this in the loop above but we want to ensure both type dispatched args
# and python binding arguments are after the out argument; this matches the case
# where there is a python binding argument dtype, which is necessary to match
# the function signatures between the out and non-out variant.
assert len(type_args) <= 1
for arg in type_args:
if positional: # assume type_args should be kwarg_only.
py_formal_args.append('*')
positional = False
py_formal_args.append(get_py_formal_arg(arg))
if len(declaration['python_binding_arguments']) > 0:
for arg in declaration['python_binding_arguments']:
if arg.get('kwarg_only', False) and positional:
py_formal_args.append('*')
positional = False
py_formal_args.append(get_py_formal_arg(arg))
# Python function signature.
# This is the string that we give to FunctionParameter, which is
# then parsed into the actual structure which we do parsing
# with.
return PYTHON_FUNCTION_SIGNATURE.substitute(name=name, py_formal_args=py_formal_args)