mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add deprecated add_out overload (#5088)
We have a few calls that use this signature on Tensors. This also updates the binding code to support deprecated xxx_out signatures.
This commit is contained in:
parent
36bbaf0d85
commit
c1b98f0841
|
|
@ -4,6 +4,9 @@
|
|||
- name: add(Tensor self, Scalar alpha, Tensor other)
|
||||
aten: add(self, other, alpha)
|
||||
|
||||
- name: add(Tensor self, Scalar alpha, Tensor other, *, Tensor out)
|
||||
aten: add_out(out, self, other, alpha)
|
||||
|
||||
- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2)
|
||||
aten: addbmm(self, batch1, batch2, beta, alpha)
|
||||
|
||||
|
|
|
|||
|
|
@ -88,32 +88,46 @@ def load_deprecated_signatures(aten_decls):
|
|||
|
||||
def get_signature(name, params, call_args):
|
||||
# create a mapping of parameter name to parameter type
|
||||
types = dict([param.split(' ')[::-1] for param in params])
|
||||
types = dict([param.split(' ')[::-1] for param in params if param != '*'])
|
||||
# if the name in the call is not in the parameter list, assume it's
|
||||
# a literal Scalar
|
||||
rearranged_types = [types.get(arg, 'Scalar') for arg in call_args]
|
||||
return '{}({})'.format(name, ', '.join(rearranged_types))
|
||||
|
||||
for deprecated in deprecated_defs:
|
||||
python_signature = deprecated['name']
|
||||
call_args = split_name_params(deprecated['aten'])[1]
|
||||
name, params = split_name_params(python_signature)
|
||||
signature = get_signature(name, params, call_args)
|
||||
aten_name, call_args = split_name_params(deprecated['aten'])
|
||||
name, params = split_name_params(deprecated['name'])
|
||||
signature = get_signature(aten_name, params, call_args)
|
||||
|
||||
for declaration in declarations_by_signature[signature]:
|
||||
declaration = copy.deepcopy(declaration)
|
||||
declaration['deprecated'] = True
|
||||
declaration['call_args'] = call_args
|
||||
if declaration['inplace']:
|
||||
declaration['python_signature'] = python_signature.replace(name, name + '_')
|
||||
else:
|
||||
declaration['python_signature'] = python_signature
|
||||
|
||||
args_by_name = {arg['name']: arg for arg in declaration['arguments']}
|
||||
declaration['arguments'] = []
|
||||
for arg in params:
|
||||
_, arg_name = arg.split(' ')
|
||||
declaration['arguments'].append(args_by_name[arg_name])
|
||||
call_arg_to_idx = {arg: i for i, arg in enumerate(call_args)}
|
||||
original_args = declaration['arguments']
|
||||
|
||||
# Create an arguments list that uses the types from the original
|
||||
# ATen declaration, but the ordering and parameter names from
|
||||
# the deprecated overload. Any default parameter values from the
|
||||
# original ATen declaration are ignored.
|
||||
arguments = []
|
||||
kwarg_only = False
|
||||
for param in params:
|
||||
if param == '*':
|
||||
kwarg_only = True
|
||||
continue
|
||||
_, param_name = param.split(' ')
|
||||
original = original_args[call_arg_to_idx[param_name]]
|
||||
arguments.append({
|
||||
'name': param_name,
|
||||
'kwarg_only': kwarg_only,
|
||||
'type': original['type'],
|
||||
'simple_type': original['simple_type'],
|
||||
'dynamic_type': original['dynamic_type'],
|
||||
'output': original.get('output', False),
|
||||
})
|
||||
declaration['arguments'] = arguments
|
||||
declarations.append(declaration)
|
||||
return declarations
|
||||
|
||||
|
|
|
|||
|
|
@ -458,10 +458,6 @@ def group_declarations(declarations):
|
|||
|
||||
|
||||
def get_python_signature(declaration, include_out):
|
||||
# Use the saved signature for deprecated pseudo-declarations
|
||||
if 'python_signature' in declaration:
|
||||
return declaration['python_signature']
|
||||
|
||||
# Compute the Python function signature for argument parsing
|
||||
typed_args = []
|
||||
output_args = []
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ inputs = [
|
|||
'torch/csrc/generic/TensorMethods.cwrap',
|
||||
'torch/lib/tmp_install/share/ATen/Declarations.yaml',
|
||||
'tools/autograd/derivatives.yaml',
|
||||
'tools/autograd/deprecated.yaml',
|
||||
] + glob.glob('torch/csrc/generic/methods/*.cwrap')
|
||||
|
||||
outputs = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user