Add deprecated add_out overload (#5088)

We have a few calls that use this signature on Tensors. This also
updates the binding code to support deprecated xxx_out signatures.
This commit is contained in:
Sam Gross 2018-02-06 17:08:23 -05:00 committed by GitHub
parent 36bbaf0d85
commit c1b98f0841
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 18 deletions

View File

@ -4,6 +4,9 @@
- name: add(Tensor self, Scalar alpha, Tensor other)
aten: add(self, other, alpha)
- name: add(Tensor self, Scalar alpha, Tensor other, *, Tensor out)
aten: add_out(out, self, other, alpha)
- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2)
aten: addbmm(self, batch1, batch2, beta, alpha)

View File

@ -88,32 +88,46 @@ def load_deprecated_signatures(aten_decls):
def get_signature(name, params, call_args):
# create a mapping of parameter name to parameter type
types = dict([param.split(' ')[::-1] for param in params])
types = dict([param.split(' ')[::-1] for param in params if param != '*'])
# if the name in the call is not in the parameter list, assume it's
# a literal Scalar
rearranged_types = [types.get(arg, 'Scalar') for arg in call_args]
return '{}({})'.format(name, ', '.join(rearranged_types))
for deprecated in deprecated_defs:
python_signature = deprecated['name']
call_args = split_name_params(deprecated['aten'])[1]
name, params = split_name_params(python_signature)
signature = get_signature(name, params, call_args)
aten_name, call_args = split_name_params(deprecated['aten'])
name, params = split_name_params(deprecated['name'])
signature = get_signature(aten_name, params, call_args)
for declaration in declarations_by_signature[signature]:
declaration = copy.deepcopy(declaration)
declaration['deprecated'] = True
declaration['call_args'] = call_args
if declaration['inplace']:
declaration['python_signature'] = python_signature.replace(name, name + '_')
else:
declaration['python_signature'] = python_signature
args_by_name = {arg['name']: arg for arg in declaration['arguments']}
declaration['arguments'] = []
for arg in params:
_, arg_name = arg.split(' ')
declaration['arguments'].append(args_by_name[arg_name])
call_arg_to_idx = {arg: i for i, arg in enumerate(call_args)}
original_args = declaration['arguments']
# Create an arguments list that uses the types from the original
# ATen declaration, but the ordering and parameter names from
# the deprecated overload. Any default parameter values from the
# original ATen declaration are ignored.
arguments = []
kwarg_only = False
for param in params:
if param == '*':
kwarg_only = True
continue
_, param_name = param.split(' ')
original = original_args[call_arg_to_idx[param_name]]
arguments.append({
'name': param_name,
'kwarg_only': kwarg_only,
'type': original['type'],
'simple_type': original['simple_type'],
'dynamic_type': original['dynamic_type'],
'output': original.get('output', False),
})
declaration['arguments'] = arguments
declarations.append(declaration)
return declarations

View File

@ -458,10 +458,6 @@ def group_declarations(declarations):
def get_python_signature(declaration, include_out):
# Use the saved signature for deprecated pseudo-declarations
if 'python_signature' in declaration:
return declaration['python_signature']
# Compute the Python function signature for argument parsing
typed_args = []
output_args = []

View File

@ -21,6 +21,7 @@ inputs = [
'torch/csrc/generic/TensorMethods.cwrap',
'torch/lib/tmp_install/share/ATen/Declarations.yaml',
'tools/autograd/derivatives.yaml',
'tools/autograd/deprecated.yaml',
] + glob.glob('torch/csrc/generic/methods/*.cwrap')
outputs = [