pytorch/tools/autograd/utils.py
Edward Yang 6ea89166bd Rewrite of ATen code generator (#42629)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42629

How to approach reviewing this diff:

- The new codegen itself lives in `tools/codegen`. Start with `gen.py`, then read `model.py` and them the `api/` folder. The comments at the top of the files describe what is going on. The CLI interface of the new codegen is similar to the old one, but (1) it is no longer necessary to explicitly specify cwrap inputs (and now we will error if you do so) and (2) the default settings for source and install dir are much better; to the extent that if you run the codegen from the root source directory as just `python -m tools.codegen.gen`, something reasonable will happen.
- The old codegen is (nearly) entirely deleted; every Python file in `aten/src/ATen` was deleted except for `common_with_cwrap.py`, which now permanently finds its home in `tools/shared/cwrap_common.py` (previously cmake copied the file there), and `code_template.py`, which now lives in `tools/codegen/code_template.py`. We remove the copying logic for `common_with_cwrap.py`.
- All of the inputs to the old codegen are deleted.
- Build rules now have to be adjusted to not refer to files that no longer exist, and to abide by the (slightly modified) CLI.
- LegacyTHFunctions files have been generated and checked in. We expect these to be deleted as these final functions get ported to ATen. The deletion process is straightforward; just delete the functions of the ones you are porting. There are 39 more functions left to port.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: bhosmer

Differential Revision: D23183978

Pulled By: ezyang

fbshipit-source-id: 6073ba432ad182c7284a97147b05f0574a02f763
2020-08-31 09:00:22 -07:00

91 lines
2.9 KiB
Python

import re
import os
import yaml
from .nested_dict import nested_dict
__all__ = [
'CodeTemplate', 'IDENT_REGEX', 'YamlLoader', 'nested_dict',
'split_name_params', 'write',
]
from tools.codegen.code_template import CodeTemplate
# You should use these lines, rather than doing it manually.
# Especially if you see this error!
#
# File "/usr/local/lib/python2.7/dist-packages/yaml/__init__.py", line 69, in load
# loader = Loader(stream)
# TypeError: 'module' object is not callable
try:
# use faster C loader if available
from yaml import CLoader as YamlLoader
except ImportError:
from yaml import Loader as YamlLoader
GENERATED_COMMENT = CodeTemplate(
"@" + "generated from ${filename}")
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
# occurrence of a parameter in the derivative formula
IDENT_REGEX = r'(^|\W){}($|\W)'
# TODO: Use a real parser here; this will get bamboozled
# by signatures that contain things like std::array<bool, 2> (note the space)
def split_name_params(prototype):
name, overload_name, params = re.match(r'(\w+)(\.\w+)?\((.*)\)', prototype).groups()
return name, params.split(', ')
# When tracing, we record inplace operations as out-of-place operations,
# because we don't have a story for side effects in the IR yet.
#
# Doing this un-inplacing is a little delicate however; __and__ is NOT inplace!
# TODO: Do something more robust
def uninplace_api_name(api_name):
if api_name.endswith('_') and not api_name.endswith('__'):
api_name = api_name[:-1]
if api_name.endswith('_out'):
api_name = api_name[:-4]
return api_name
def write(dirname, name, template, env):
env['generated_comment'] = GENERATED_COMMENT.substitute(filename=template.filename)
path = os.path.join(dirname, name)
# See Note [Unchanging results for ninja]
try:
with open(path, 'r') as f:
old_val = f.read()
except IOError:
old_val = None
new_val = template.substitute(env)
if old_val != new_val:
with open(path, 'w') as f:
print("Writing {}".format(path))
f.write(new_val)
else:
print("Skipped writing {}".format(path))
def is_tensor_method(declaration):
return 'Tensor' in declaration['method_of']
def is_out_variant(decl):
return decl['name'].endswith('_out')
def op_name_without_overload(decl):
name = decl['name'] if not is_out_variant(decl) else decl['name'][:-4]
return 'aten::{}'.format(name)
def load_op_list_and_strip_overload(op_list, op_list_path):
if op_list is None and op_list_path is None:
return None
if op_list is None:
op_list = []
if op_list_path is not None:
with open(op_list_path, 'r') as f:
op_list += yaml.load(f, Loader=YamlLoader)
# strip out the overload part
return {opname.split('.', 1)[0] for opname in op_list}