pytorch/tools/shared/cwrap_common.py
Edward Yang 6ea89166bd Rewrite of ATen code generator (#42629)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42629

How to approach reviewing this diff:

- The new codegen itself lives in `tools/codegen`. Start with `gen.py`, then read `model.py` and them the `api/` folder. The comments at the top of the files describe what is going on. The CLI interface of the new codegen is similar to the old one, but (1) it is no longer necessary to explicitly specify cwrap inputs (and now we will error if you do so) and (2) the default settings for source and install dir are much better; to the extent that if you run the codegen from the root source directory as just `python -m tools.codegen.gen`, something reasonable will happen.
- The old codegen is (nearly) entirely deleted; every Python file in `aten/src/ATen` was deleted except for `common_with_cwrap.py`, which now permanently finds its home in `tools/shared/cwrap_common.py` (previously cmake copied the file there), and `code_template.py`, which now lives in `tools/codegen/code_template.py`. We remove the copying logic for `common_with_cwrap.py`.
- All of the inputs to the old codegen are deleted.
- Build rules now have to be adjusted to not refer to files that no longer exist, and to abide by the (slightly modified) CLI.
- LegacyTHFunctions files have been generated and checked in. We expect these to be deleted as these final functions get ported to ATen. The deletion process is straightforward; just delete the functions of the ones you are porting. There are 39 more functions left to port.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: bhosmer

Differential Revision: D23183978

Pulled By: ezyang

fbshipit-source-id: 6073ba432ad182c7284a97147b05f0574a02f763
2020-08-31 09:00:22 -07:00

196 lines
7.2 KiB
Python

# this code should be common among cwrap and ATen preprocessing
# for now, I have put it in one place but right now is copied out of cwrap
import copy
def parse_arguments(args):
new_args = []
for arg in args:
# Simple arg declaration of form "<type> <name>"
if isinstance(arg, str):
t, _, name = arg.partition(' ')
new_args.append({'type': t, 'name': name})
elif isinstance(arg, dict):
if 'arg' in arg:
arg['type'], _, arg['name'] = arg['arg'].partition(' ')
del arg['arg']
new_args.append(arg)
else:
raise AssertionError()
return new_args
def set_declaration_defaults(declaration):
if 'schema_string' not in declaration:
# This happens for legacy TH bindings like
# _thnn_conv_depthwise2d_backward
declaration['schema_string'] = ''
if 'matches_jit_signature' not in declaration:
declaration['matches_jit_signature'] = False
declaration.setdefault('arguments', [])
declaration.setdefault('return', 'void')
if 'cname' not in declaration:
declaration['cname'] = declaration['name']
if 'backends' not in declaration:
declaration['backends'] = ['CPU', 'CUDA']
assert 'api_name' not in declaration
declaration['api_name'] = declaration['name']
# NB: keep this in sync with gen_autograd.py
if declaration.get('overload_name'):
declaration['type_wrapper_name'] = "{}_{}".format(
declaration['name'], declaration['overload_name'])
else:
declaration['type_wrapper_name'] = declaration['name']
# TODO: Uggggh, parsing the schema string here, really???
declaration['operator_name_with_overload'] = declaration['schema_string'].split('(')[0]
if declaration['schema_string']:
declaration['unqual_schema_string'] = declaration['schema_string'].split('::')[1]
declaration['unqual_operator_name_with_overload'] = declaration['operator_name_with_overload'].split('::')[1]
else:
declaration['unqual_schema_string'] = ''
declaration['unqual_operator_name_with_overload'] = ''
# Simulate multiple dispatch, even if it's not necessary
if 'options' not in declaration:
declaration['options'] = [{
'arguments': copy.deepcopy(declaration['arguments']),
'schema_order_arguments': copy.deepcopy(declaration['schema_order_arguments']),
}]
del declaration['arguments']
del declaration['schema_order_arguments']
# Parse arguments (some of them can be strings)
for option in declaration['options']:
option['arguments'] = parse_arguments(option['arguments'])
option['schema_order_arguments'] = parse_arguments(option['schema_order_arguments'])
# Propagate defaults from declaration to options
for option in declaration['options']:
for k, v in declaration.items():
# TODO(zach): why does cwrap not propagate 'name'? I need it
# propagaged for ATen
if k != 'options':
option.setdefault(k, v)
# TODO(zach): added option to remove keyword handling for C++ which cannot
# support it.
def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self):
def exclude_arg(arg):
return arg['type'] == 'CONSTANT'
def exclude_arg_with_self_check(arg):
return exclude_arg(arg) or (remove_self and arg['name'] == 'self')
def signature(option, kwarg_only_count):
if kwarg_only_count == 0:
kwarg_only_count = None
else:
kwarg_only_count = -kwarg_only_count
arg_signature = '#'.join(
type_to_signature.get(arg['type'], arg['type'])
for arg in option['arguments'][:kwarg_only_count]
if not exclude_arg_with_self_check(arg))
if kwarg_only_count is None:
return arg_signature
kwarg_only_signature = '#'.join(
arg['name'] + '#' + arg['type']
for arg in option['arguments'][kwarg_only_count:]
if not exclude_arg(arg))
return arg_signature + "#-#" + kwarg_only_signature
seen_signatures = set()
unique = []
for option in options:
# if only check num_kwarg_only == 0 if allow_kwarg == False
limit = len(option['arguments']) if allow_kwarg else 0
for num_kwarg_only in range(0, limit + 1):
sig = signature(option, num_kwarg_only)
if sig not in seen_signatures:
if num_kwarg_only > 0:
for arg in option['arguments'][-num_kwarg_only:]:
arg['kwarg_only'] = True
unique.append(option)
seen_signatures.add(sig)
break
return unique
def sort_by_number_of_args(declaration, reverse=True):
def num_args(option):
return len(option['arguments'])
declaration['options'].sort(key=num_args, reverse=reverse)
class Function(object):
def __init__(self, name):
self.name = name
self.arguments = []
def add_argument(self, arg):
assert isinstance(arg, Argument)
self.arguments.append(arg)
def __repr__(self):
return self.name + '(' + ', '.join(map(lambda a: a.__repr__(), self.arguments)) + ')'
class Argument(object):
def __init__(self, _type, name, is_optional):
self.type = _type
self.name = name
self.is_optional = is_optional
def __repr__(self):
return self.type + ' ' + self.name
def parse_header(path):
with open(path, 'r') as f:
lines = f.read().split('\n')
# Remove empty lines and prebackend directives
lines = filter(lambda l: l and not l.startswith('#'), lines)
# Remove line comments
lines = map(lambda l: l.partition('//'), lines)
# Select line and comment part
lines = map(lambda l: (l[0].strip(), l[2].strip()), lines)
# Remove trailing special signs
lines = map(lambda l: (l[0].rstrip(');').rstrip(','), l[1]), lines)
# Split arguments
lines = map(lambda l: (l[0].split(','), l[1]), lines)
# Flatten lines
new_lines = []
for l, c in lines:
for split in l:
new_lines.append((split, c))
lines = new_lines
del new_lines
# Remove unnecessary whitespace
lines = map(lambda l: (l[0].strip(), l[1]), lines)
# Remove empty lines
lines = filter(lambda l: l[0], lines)
generic_functions = []
for l, c in lines:
if l.startswith('TH_API void THNN_'):
fn_name = l[len('TH_API void THNN_'):]
if fn_name[0] == '(' and fn_name[-2] == ')':
fn_name = fn_name[1:-2]
else:
fn_name = fn_name[:-1]
generic_functions.append(Function(fn_name))
elif l.startswith('THC_API void THNN_'):
fn_name = l[len('THC_API void THNN_'):]
if fn_name[0] == '(' and fn_name[-2] == ')':
fn_name = fn_name[1:-2]
else:
fn_name = fn_name[:-1]
generic_functions.append(Function(fn_name))
elif l:
t, name = l.split()
if '*' in name:
t = t + '*'
name = name[1:]
generic_functions[-1].add_argument(
Argument(t, name, '[OPTIONAL]' in c))
return generic_functions