mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42629 How to approach reviewing this diff: - The new codegen itself lives in `tools/codegen`. Start with `gen.py`, then read `model.py` and them the `api/` folder. The comments at the top of the files describe what is going on. The CLI interface of the new codegen is similar to the old one, but (1) it is no longer necessary to explicitly specify cwrap inputs (and now we will error if you do so) and (2) the default settings for source and install dir are much better; to the extent that if you run the codegen from the root source directory as just `python -m tools.codegen.gen`, something reasonable will happen. - The old codegen is (nearly) entirely deleted; every Python file in `aten/src/ATen` was deleted except for `common_with_cwrap.py`, which now permanently finds its home in `tools/shared/cwrap_common.py` (previously cmake copied the file there), and `code_template.py`, which now lives in `tools/codegen/code_template.py`. We remove the copying logic for `common_with_cwrap.py`. - All of the inputs to the old codegen are deleted. - Build rules now have to be adjusted to not refer to files that no longer exist, and to abide by the (slightly modified) CLI. - LegacyTHFunctions files have been generated and checked in. We expect these to be deleted as these final functions get ported to ATen. The deletion process is straightforward; just delete the functions of the ones you are porting. There are 39 more functions left to port. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D23183978 Pulled By: ezyang fbshipit-source-id: 6073ba432ad182c7284a97147b05f0574a02f763
91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
import re
|
|
import os
|
|
import yaml
|
|
from .nested_dict import nested_dict
|
|
|
|
|
|
__all__ = [
|
|
'CodeTemplate', 'IDENT_REGEX', 'YamlLoader', 'nested_dict',
|
|
'split_name_params', 'write',
|
|
]
|
|
|
|
from tools.codegen.code_template import CodeTemplate
|
|
|
|
# You should use these lines, rather than doing it manually.
|
|
# Especially if you see this error!
|
|
#
|
|
# File "/usr/local/lib/python2.7/dist-packages/yaml/__init__.py", line 69, in load
|
|
# loader = Loader(stream)
|
|
# TypeError: 'module' object is not callable
|
|
try:
|
|
# use faster C loader if available
|
|
from yaml import CLoader as YamlLoader
|
|
except ImportError:
|
|
from yaml import Loader as YamlLoader
|
|
|
|
GENERATED_COMMENT = CodeTemplate(
|
|
"@" + "generated from ${filename}")
|
|
|
|
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
|
|
# occurrence of a parameter in the derivative formula
|
|
IDENT_REGEX = r'(^|\W){}($|\W)'
|
|
|
|
|
|
# TODO: Use a real parser here; this will get bamboozled
|
|
# by signatures that contain things like std::array<bool, 2> (note the space)
|
|
def split_name_params(prototype):
|
|
name, overload_name, params = re.match(r'(\w+)(\.\w+)?\((.*)\)', prototype).groups()
|
|
return name, params.split(', ')
|
|
|
|
|
|
# When tracing, we record inplace operations as out-of-place operations,
|
|
# because we don't have a story for side effects in the IR yet.
|
|
#
|
|
# Doing this un-inplacing is a little delicate however; __and__ is NOT inplace!
|
|
# TODO: Do something more robust
|
|
def uninplace_api_name(api_name):
|
|
if api_name.endswith('_') and not api_name.endswith('__'):
|
|
api_name = api_name[:-1]
|
|
if api_name.endswith('_out'):
|
|
api_name = api_name[:-4]
|
|
return api_name
|
|
|
|
|
|
def write(dirname, name, template, env):
|
|
env['generated_comment'] = GENERATED_COMMENT.substitute(filename=template.filename)
|
|
path = os.path.join(dirname, name)
|
|
# See Note [Unchanging results for ninja]
|
|
try:
|
|
with open(path, 'r') as f:
|
|
old_val = f.read()
|
|
except IOError:
|
|
old_val = None
|
|
new_val = template.substitute(env)
|
|
if old_val != new_val:
|
|
with open(path, 'w') as f:
|
|
print("Writing {}".format(path))
|
|
f.write(new_val)
|
|
else:
|
|
print("Skipped writing {}".format(path))
|
|
|
|
def is_tensor_method(declaration):
|
|
return 'Tensor' in declaration['method_of']
|
|
|
|
def is_out_variant(decl):
|
|
return decl['name'].endswith('_out')
|
|
|
|
def op_name_without_overload(decl):
|
|
name = decl['name'] if not is_out_variant(decl) else decl['name'][:-4]
|
|
return 'aten::{}'.format(name)
|
|
|
|
def load_op_list_and_strip_overload(op_list, op_list_path):
|
|
if op_list is None and op_list_path is None:
|
|
return None
|
|
if op_list is None:
|
|
op_list = []
|
|
if op_list_path is not None:
|
|
with open(op_list_path, 'r') as f:
|
|
op_list += yaml.load(f, Loader=YamlLoader)
|
|
# strip out the overload part
|
|
return {opname.split('.', 1)[0] for opname in op_list}
|