mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Add deprecated add_out overload (#5088)
We have a few calls that use this signature on Tensors. This also updates the binding code to support deprecated xxx_out signatures.
This commit is contained in:
parent
36bbaf0d85
commit
c1b98f0841
|
|
@ -4,6 +4,9 @@
|
||||||
- name: add(Tensor self, Scalar alpha, Tensor other)
|
- name: add(Tensor self, Scalar alpha, Tensor other)
|
||||||
aten: add(self, other, alpha)
|
aten: add(self, other, alpha)
|
||||||
|
|
||||||
|
- name: add(Tensor self, Scalar alpha, Tensor other, *, Tensor out)
|
||||||
|
aten: add_out(out, self, other, alpha)
|
||||||
|
|
||||||
- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2)
|
- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2)
|
||||||
aten: addbmm(self, batch1, batch2, beta, alpha)
|
aten: addbmm(self, batch1, batch2, beta, alpha)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -88,32 +88,46 @@ def load_deprecated_signatures(aten_decls):
|
||||||
|
|
||||||
def get_signature(name, params, call_args):
|
def get_signature(name, params, call_args):
|
||||||
# create a mapping of parameter name to parameter type
|
# create a mapping of parameter name to parameter type
|
||||||
types = dict([param.split(' ')[::-1] for param in params])
|
types = dict([param.split(' ')[::-1] for param in params if param != '*'])
|
||||||
# if the name in the call is not in the parameter list, assume it's
|
# if the name in the call is not in the parameter list, assume it's
|
||||||
# a literal Scalar
|
# a literal Scalar
|
||||||
rearranged_types = [types.get(arg, 'Scalar') for arg in call_args]
|
rearranged_types = [types.get(arg, 'Scalar') for arg in call_args]
|
||||||
return '{}({})'.format(name, ', '.join(rearranged_types))
|
return '{}({})'.format(name, ', '.join(rearranged_types))
|
||||||
|
|
||||||
for deprecated in deprecated_defs:
|
for deprecated in deprecated_defs:
|
||||||
python_signature = deprecated['name']
|
aten_name, call_args = split_name_params(deprecated['aten'])
|
||||||
call_args = split_name_params(deprecated['aten'])[1]
|
name, params = split_name_params(deprecated['name'])
|
||||||
name, params = split_name_params(python_signature)
|
signature = get_signature(aten_name, params, call_args)
|
||||||
signature = get_signature(name, params, call_args)
|
|
||||||
|
|
||||||
for declaration in declarations_by_signature[signature]:
|
for declaration in declarations_by_signature[signature]:
|
||||||
declaration = copy.deepcopy(declaration)
|
declaration = copy.deepcopy(declaration)
|
||||||
declaration['deprecated'] = True
|
declaration['deprecated'] = True
|
||||||
declaration['call_args'] = call_args
|
declaration['call_args'] = call_args
|
||||||
if declaration['inplace']:
|
|
||||||
declaration['python_signature'] = python_signature.replace(name, name + '_')
|
|
||||||
else:
|
|
||||||
declaration['python_signature'] = python_signature
|
|
||||||
|
|
||||||
args_by_name = {arg['name']: arg for arg in declaration['arguments']}
|
call_arg_to_idx = {arg: i for i, arg in enumerate(call_args)}
|
||||||
declaration['arguments'] = []
|
original_args = declaration['arguments']
|
||||||
for arg in params:
|
|
||||||
_, arg_name = arg.split(' ')
|
# Create an arguments list that uses the types from the original
|
||||||
declaration['arguments'].append(args_by_name[arg_name])
|
# ATen declaration, but the ordering and parameter names from
|
||||||
|
# the deprecated overload. Any default parameter values from the
|
||||||
|
# original ATen declaration are ignored.
|
||||||
|
arguments = []
|
||||||
|
kwarg_only = False
|
||||||
|
for param in params:
|
||||||
|
if param == '*':
|
||||||
|
kwarg_only = True
|
||||||
|
continue
|
||||||
|
_, param_name = param.split(' ')
|
||||||
|
original = original_args[call_arg_to_idx[param_name]]
|
||||||
|
arguments.append({
|
||||||
|
'name': param_name,
|
||||||
|
'kwarg_only': kwarg_only,
|
||||||
|
'type': original['type'],
|
||||||
|
'simple_type': original['simple_type'],
|
||||||
|
'dynamic_type': original['dynamic_type'],
|
||||||
|
'output': original.get('output', False),
|
||||||
|
})
|
||||||
|
declaration['arguments'] = arguments
|
||||||
declarations.append(declaration)
|
declarations.append(declaration)
|
||||||
return declarations
|
return declarations
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -458,10 +458,6 @@ def group_declarations(declarations):
|
||||||
|
|
||||||
|
|
||||||
def get_python_signature(declaration, include_out):
|
def get_python_signature(declaration, include_out):
|
||||||
# Use the saved signature for deprecated pseudo-declarations
|
|
||||||
if 'python_signature' in declaration:
|
|
||||||
return declaration['python_signature']
|
|
||||||
|
|
||||||
# Compute the Python function signature for argument parsing
|
# Compute the Python function signature for argument parsing
|
||||||
typed_args = []
|
typed_args = []
|
||||||
output_args = []
|
output_args = []
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ inputs = [
|
||||||
'torch/csrc/generic/TensorMethods.cwrap',
|
'torch/csrc/generic/TensorMethods.cwrap',
|
||||||
'torch/lib/tmp_install/share/ATen/Declarations.yaml',
|
'torch/lib/tmp_install/share/ATen/Declarations.yaml',
|
||||||
'tools/autograd/derivatives.yaml',
|
'tools/autograd/derivatives.yaml',
|
||||||
|
'tools/autograd/deprecated.yaml',
|
||||||
] + glob.glob('torch/csrc/generic/methods/*.cwrap')
|
] + glob.glob('torch/csrc/generic/methods/*.cwrap')
|
||||||
|
|
||||||
outputs = [
|
outputs = [
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user