mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42629 How to approach reviewing this diff: - The new codegen itself lives in `tools/codegen`. Start with `gen.py`, then read `model.py` and them the `api/` folder. The comments at the top of the files describe what is going on. The CLI interface of the new codegen is similar to the old one, but (1) it is no longer necessary to explicitly specify cwrap inputs (and now we will error if you do so) and (2) the default settings for source and install dir are much better; to the extent that if you run the codegen from the root source directory as just `python -m tools.codegen.gen`, something reasonable will happen. - The old codegen is (nearly) entirely deleted; every Python file in `aten/src/ATen` was deleted except for `common_with_cwrap.py`, which now permanently finds its home in `tools/shared/cwrap_common.py` (previously cmake copied the file there), and `code_template.py`, which now lives in `tools/codegen/code_template.py`. We remove the copying logic for `common_with_cwrap.py`. - All of the inputs to the old codegen are deleted. - Build rules now have to be adjusted to not refer to files that no longer exist, and to abide by the (slightly modified) CLI. - LegacyTHFunctions files have been generated and checked in. We expect these to be deleted as these final functions get ported to ATen. The deletion process is straightforward; just delete the functions of the ones you are porting. There are 39 more functions left to port. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D23183978 Pulled By: ezyang fbshipit-source-id: 6073ba432ad182c7284a97147b05f0574a02f763
196 lines
7.2 KiB
Python
196 lines
7.2 KiB
Python
# this code should be common among cwrap and ATen preprocessing
|
|
# for now, I have put it in one place but right now is copied out of cwrap
|
|
|
|
import copy
|
|
|
|
def parse_arguments(args):
|
|
new_args = []
|
|
for arg in args:
|
|
# Simple arg declaration of form "<type> <name>"
|
|
if isinstance(arg, str):
|
|
t, _, name = arg.partition(' ')
|
|
new_args.append({'type': t, 'name': name})
|
|
elif isinstance(arg, dict):
|
|
if 'arg' in arg:
|
|
arg['type'], _, arg['name'] = arg['arg'].partition(' ')
|
|
del arg['arg']
|
|
new_args.append(arg)
|
|
else:
|
|
raise AssertionError()
|
|
return new_args
|
|
|
|
|
|
def set_declaration_defaults(declaration):
|
|
if 'schema_string' not in declaration:
|
|
# This happens for legacy TH bindings like
|
|
# _thnn_conv_depthwise2d_backward
|
|
declaration['schema_string'] = ''
|
|
if 'matches_jit_signature' not in declaration:
|
|
declaration['matches_jit_signature'] = False
|
|
declaration.setdefault('arguments', [])
|
|
declaration.setdefault('return', 'void')
|
|
if 'cname' not in declaration:
|
|
declaration['cname'] = declaration['name']
|
|
if 'backends' not in declaration:
|
|
declaration['backends'] = ['CPU', 'CUDA']
|
|
assert 'api_name' not in declaration
|
|
declaration['api_name'] = declaration['name']
|
|
# NB: keep this in sync with gen_autograd.py
|
|
if declaration.get('overload_name'):
|
|
declaration['type_wrapper_name'] = "{}_{}".format(
|
|
declaration['name'], declaration['overload_name'])
|
|
else:
|
|
declaration['type_wrapper_name'] = declaration['name']
|
|
# TODO: Uggggh, parsing the schema string here, really???
|
|
declaration['operator_name_with_overload'] = declaration['schema_string'].split('(')[0]
|
|
if declaration['schema_string']:
|
|
declaration['unqual_schema_string'] = declaration['schema_string'].split('::')[1]
|
|
declaration['unqual_operator_name_with_overload'] = declaration['operator_name_with_overload'].split('::')[1]
|
|
else:
|
|
declaration['unqual_schema_string'] = ''
|
|
declaration['unqual_operator_name_with_overload'] = ''
|
|
# Simulate multiple dispatch, even if it's not necessary
|
|
if 'options' not in declaration:
|
|
declaration['options'] = [{
|
|
'arguments': copy.deepcopy(declaration['arguments']),
|
|
'schema_order_arguments': copy.deepcopy(declaration['schema_order_arguments']),
|
|
}]
|
|
del declaration['arguments']
|
|
del declaration['schema_order_arguments']
|
|
# Parse arguments (some of them can be strings)
|
|
for option in declaration['options']:
|
|
option['arguments'] = parse_arguments(option['arguments'])
|
|
option['schema_order_arguments'] = parse_arguments(option['schema_order_arguments'])
|
|
# Propagate defaults from declaration to options
|
|
for option in declaration['options']:
|
|
for k, v in declaration.items():
|
|
# TODO(zach): why does cwrap not propagate 'name'? I need it
|
|
# propagaged for ATen
|
|
if k != 'options':
|
|
option.setdefault(k, v)
|
|
|
|
# TODO(zach): added option to remove keyword handling for C++ which cannot
|
|
# support it.
|
|
|
|
|
|
def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self):
|
|
def exclude_arg(arg):
|
|
return arg['type'] == 'CONSTANT'
|
|
|
|
def exclude_arg_with_self_check(arg):
|
|
return exclude_arg(arg) or (remove_self and arg['name'] == 'self')
|
|
|
|
def signature(option, kwarg_only_count):
|
|
if kwarg_only_count == 0:
|
|
kwarg_only_count = None
|
|
else:
|
|
kwarg_only_count = -kwarg_only_count
|
|
arg_signature = '#'.join(
|
|
type_to_signature.get(arg['type'], arg['type'])
|
|
for arg in option['arguments'][:kwarg_only_count]
|
|
if not exclude_arg_with_self_check(arg))
|
|
if kwarg_only_count is None:
|
|
return arg_signature
|
|
kwarg_only_signature = '#'.join(
|
|
arg['name'] + '#' + arg['type']
|
|
for arg in option['arguments'][kwarg_only_count:]
|
|
if not exclude_arg(arg))
|
|
return arg_signature + "#-#" + kwarg_only_signature
|
|
seen_signatures = set()
|
|
unique = []
|
|
for option in options:
|
|
# if only check num_kwarg_only == 0 if allow_kwarg == False
|
|
limit = len(option['arguments']) if allow_kwarg else 0
|
|
for num_kwarg_only in range(0, limit + 1):
|
|
sig = signature(option, num_kwarg_only)
|
|
if sig not in seen_signatures:
|
|
if num_kwarg_only > 0:
|
|
for arg in option['arguments'][-num_kwarg_only:]:
|
|
arg['kwarg_only'] = True
|
|
unique.append(option)
|
|
seen_signatures.add(sig)
|
|
break
|
|
return unique
|
|
|
|
|
|
def sort_by_number_of_args(declaration, reverse=True):
|
|
def num_args(option):
|
|
return len(option['arguments'])
|
|
declaration['options'].sort(key=num_args, reverse=reverse)
|
|
|
|
|
|
class Function(object):
|
|
|
|
def __init__(self, name):
|
|
self.name = name
|
|
self.arguments = []
|
|
|
|
def add_argument(self, arg):
|
|
assert isinstance(arg, Argument)
|
|
self.arguments.append(arg)
|
|
|
|
def __repr__(self):
|
|
return self.name + '(' + ', '.join(map(lambda a: a.__repr__(), self.arguments)) + ')'
|
|
|
|
|
|
class Argument(object):
|
|
|
|
def __init__(self, _type, name, is_optional):
|
|
self.type = _type
|
|
self.name = name
|
|
self.is_optional = is_optional
|
|
|
|
def __repr__(self):
|
|
return self.type + ' ' + self.name
|
|
|
|
|
|
def parse_header(path):
|
|
with open(path, 'r') as f:
|
|
lines = f.read().split('\n')
|
|
|
|
# Remove empty lines and prebackend directives
|
|
lines = filter(lambda l: l and not l.startswith('#'), lines)
|
|
# Remove line comments
|
|
lines = map(lambda l: l.partition('//'), lines)
|
|
# Select line and comment part
|
|
lines = map(lambda l: (l[0].strip(), l[2].strip()), lines)
|
|
# Remove trailing special signs
|
|
lines = map(lambda l: (l[0].rstrip(');').rstrip(','), l[1]), lines)
|
|
# Split arguments
|
|
lines = map(lambda l: (l[0].split(','), l[1]), lines)
|
|
# Flatten lines
|
|
new_lines = []
|
|
for l, c in lines:
|
|
for split in l:
|
|
new_lines.append((split, c))
|
|
lines = new_lines
|
|
del new_lines
|
|
# Remove unnecessary whitespace
|
|
lines = map(lambda l: (l[0].strip(), l[1]), lines)
|
|
# Remove empty lines
|
|
lines = filter(lambda l: l[0], lines)
|
|
generic_functions = []
|
|
for l, c in lines:
|
|
if l.startswith('TH_API void THNN_'):
|
|
fn_name = l[len('TH_API void THNN_'):]
|
|
if fn_name[0] == '(' and fn_name[-2] == ')':
|
|
fn_name = fn_name[1:-2]
|
|
else:
|
|
fn_name = fn_name[:-1]
|
|
generic_functions.append(Function(fn_name))
|
|
elif l.startswith('THC_API void THNN_'):
|
|
fn_name = l[len('THC_API void THNN_'):]
|
|
if fn_name[0] == '(' and fn_name[-2] == ')':
|
|
fn_name = fn_name[1:-2]
|
|
else:
|
|
fn_name = fn_name[:-1]
|
|
generic_functions.append(Function(fn_name))
|
|
elif l:
|
|
t, name = l.split()
|
|
if '*' in name:
|
|
t = t + '*'
|
|
name = name[1:]
|
|
generic_functions[-1].add_argument(
|
|
Argument(t, name, '[OPTIONAL]' in c))
|
|
return generic_functions
|