Migrate return type void to () for native functions. (#28290)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28290

ghstack-source-id: 92368250

Test Plan:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28290
ghstack-source-id: 92368250

Differential Revision: D17565528

fbshipit-source-id: f4870bb9ee4f4e7c48df4d68508b512d25ed277c
This commit is contained in:
Xingying Cheng 2019-10-22 15:21:23 -07:00 committed by Facebook Github Bot
parent f94b6cef43
commit 177c95e9bc
8 changed files with 28 additions and 15 deletions

View File

@ -887,7 +887,9 @@ def create_generic(top_env, declarations):
def format_return_type(return_types): def format_return_type(return_types):
# type: (List[ReturnType]) -> str # type: (List[ReturnType]) -> str
if len(return_types) == 1: if len(return_types) == 0:
return "void"
elif len(return_types) == 1:
return return_types[0]['type'] return return_types[0]['type']
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types)) return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
@ -1116,9 +1118,6 @@ def create_generic(top_env, declarations):
if isinstance(t_raw, string_type): if isinstance(t_raw, string_type):
t = t_raw t = t_raw
name = None name = None
elif t_raw is None:
t = 'void'
name = None
else: else:
t = t_raw['type'] t = t_raw['type']
name = t_raw['name'] name = t_raw['name']
@ -1149,7 +1148,6 @@ def create_generic(top_env, declarations):
assert option['python_module'] == '' or option['python_module'] == 'nn', \ assert option['python_module'] == '' or option['python_module'] == 'nn', \
"Found python_module of {} for decl {}, but only \'\' string or \'nn\' are supported".format( "Found python_module of {} for decl {}, but only \'\' string or \'nn\' are supported".format(
option['python_module'], option['name']) option['python_module'], option['name'])
formals = native_get_formals(option) formals = native_get_formals(option)
option['formals_list'] = formals option['formals_list'] = formals
option['formals'] = [format_formal(f) for f in formals] option['formals'] = [format_formal(f) for f in formals]

View File

@ -37,10 +37,11 @@
use_c10_dispatcher: full use_c10_dispatcher: full
variants: function variants: function
- func: backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> void - func: backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> ()
variants: method variants: method
- func: set_data(Tensor(a!) self, Tensor new_data) -> void - func: set_data(Tensor(a!) self, Tensor new_data) -> ()
use_c10_dispatcher: unboxed_only
variants: method variants: method
- func: data(Tensor self) -> Tensor - func: data(Tensor self) -> Tensor
@ -1403,9 +1404,11 @@
- func: _cufft_get_plan_cache_max_size(int device_index) -> int - func: _cufft_get_plan_cache_max_size(int device_index) -> int
use_c10_dispatcher: full use_c10_dispatcher: full
- func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> void - func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> ()
use_c10_dispatcher: unboxed_only
- func: _cufft_clear_plan_cache(int device_index) -> void - func: _cufft_clear_plan_cache(int device_index) -> ()
use_c10_dispatcher: unboxed_only
- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor - func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
variants: function, method variants: function, method
@ -1459,7 +1462,7 @@
variants: function variants: function
device_guard: False device_guard: False
supports_named_tensor: True supports_named_tensor: True
- func: is_distributed(Tensor self) -> bool - func: is_distributed(Tensor self) -> bool
use_c10_dispatcher: full use_c10_dispatcher: full
variants: function, method variants: function, method

View File

@ -310,7 +310,7 @@ def parse_arguments(args, func_variants, declaration, func_return):
# TODO: Explicit checking for void is a hack and should disappear after a more # TODO: Explicit checking for void is a hack and should disappear after a more
# functionally complete implementation of Tensor aliases. # functionally complete implementation of Tensor aliases.
if declaration['inplace'] and len(func_return) > 0 and func_return[0]['type'] != "void": if declaration['inplace'] and len(func_return) > 0:
found_self = False found_self = False
for arg_idx, argument in enumerate(arguments): for arg_idx, argument in enumerate(arguments):
if argument['name'] == "self": if argument['name'] == "self":
@ -331,12 +331,15 @@ def parse_arguments(args, func_variants, declaration, func_return):
def parse_return_arguments(return_decl, inplace, func_decl): def parse_return_arguments(return_decl, inplace, func_decl):
arguments = [] arguments = []
if return_decl == '()':
return arguments
# TODO: Use a real parser here; this will get bamboozled # TODO: Use a real parser here; this will get bamboozled
# by signatures that contain things like std::array<bool, 2> (note the space) # by signatures that contain things like std::array<bool, 2> (note the space)
if return_decl[0] == '(' and return_decl[-1] == ')': if return_decl[0] == '(' and return_decl[-1] == ')':
return_decl = return_decl[1:-1] return_decl = return_decl[1:-1]
multiple_args = len(return_decl.split(', ')) > 1
multiple_args = len(return_decl.split(', ')) > 1
for arg_idx, arg in enumerate(return_decl.split(', ')): for arg_idx, arg in enumerate(return_decl.split(', ')):
t, name, default, nullable, size, annotation = type_argument_translations(arg) t, name, default, nullable, size, annotation = type_argument_translations(arg)
# name of arguments and name of return sometimes have collision # name of arguments and name of return sometimes have collision

View File

@ -124,6 +124,10 @@ def supports(o, factory_methods):
if "_out" in o['name']: if "_out" in o['name']:
return False return False
# skip if no return, previously it is 'void'
if len(o['returns']) == 0:
return False
# skip return types we cannot handle # skip return types we cannot handle
for ret in o['returns']: for ret in o['returns']:
if not value_has_tensors(ret) and ret['type'] not in RETURN_MAP: if not value_has_tensors(ret) and ret['type'] not in RETURN_MAP:

View File

@ -33,6 +33,8 @@ class TestNamedTupleAPI(unittest.TestCase):
continue continue
if not ret.startswith('('): if not ret.startswith('('):
continue continue
if ret == '()':
continue
ret = ret[1:-1].split(',') ret = ret[1:-1].split(',')
for r in ret: for r in ret:
r = r.strip() r = r.strip()

View File

@ -509,7 +509,7 @@ def emit_body(declaration):
inplace = declaration['inplace'] inplace = declaration['inplace']
is_out_fn = name.endswith('_out') is_out_fn = name.endswith('_out')
modifies_arguments = inplace or is_out_fn modifies_arguments = inplace or is_out_fn
returns_void = len(returns) == 1 and returns[0]['type'] == 'void' returns_void = len(returns) == 0
base_name = name[:-1] if inplace else name[:-4] if is_out_fn else name base_name = name[:-1] if inplace else name[:-4] if is_out_fn else name
view_info = VIEW_FUNCTIONS.get(base_name, None) view_info = VIEW_FUNCTIONS.get(base_name, None)

View File

@ -203,7 +203,8 @@ def is_jit_arg(i, arg):
def is_jit_op(decl): def is_jit_op(decl):
# We currently don't support functions that return nothing # We currently don't support functions that return nothing
if all(r['type'] == 'void' for r in decl['returns']): assert all(r['type'] != 'void' for r in decl['returns'])
if len(decl['returns']) == 0:
return False return False
arguments = decl['arguments'] arguments = decl['arguments']

View File

@ -284,8 +284,10 @@ def generate_type_hints(fname, decls, is_tensor=False):
if len(python_returns) > 1: if len(python_returns) > 1:
python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']' python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']'
else: elif len(python_returns) == 1:
python_returns_s = python_returns[0] python_returns_s = python_returns[0]
else:
python_returns_s = 'None'
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s) type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
numargs = len(decl['arguments']) numargs = len(decl['arguments'])