mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Migrate return type void to () for native functions. (#28290)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28290 ghstack-source-id: 92368250 Test Plan: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28290 ghstack-source-id: 92368250 Differential Revision: D17565528 fbshipit-source-id: f4870bb9ee4f4e7c48df4d68508b512d25ed277c
This commit is contained in:
parent
f94b6cef43
commit
177c95e9bc
|
|
@ -887,7 +887,9 @@ def create_generic(top_env, declarations):
|
|||
|
||||
def format_return_type(return_types):
|
||||
# type: (List[ReturnType]) -> str
|
||||
if len(return_types) == 1:
|
||||
if len(return_types) == 0:
|
||||
return "void"
|
||||
elif len(return_types) == 1:
|
||||
return return_types[0]['type']
|
||||
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
|
||||
|
||||
|
|
@ -1116,9 +1118,6 @@ def create_generic(top_env, declarations):
|
|||
if isinstance(t_raw, string_type):
|
||||
t = t_raw
|
||||
name = None
|
||||
elif t_raw is None:
|
||||
t = 'void'
|
||||
name = None
|
||||
else:
|
||||
t = t_raw['type']
|
||||
name = t_raw['name']
|
||||
|
|
@ -1149,7 +1148,6 @@ def create_generic(top_env, declarations):
|
|||
assert option['python_module'] == '' or option['python_module'] == 'nn', \
|
||||
"Found python_module of {} for decl {}, but only \'\' string or \'nn\' are supported".format(
|
||||
option['python_module'], option['name'])
|
||||
|
||||
formals = native_get_formals(option)
|
||||
option['formals_list'] = formals
|
||||
option['formals'] = [format_formal(f) for f in formals]
|
||||
|
|
|
|||
|
|
@ -37,10 +37,11 @@
|
|||
use_c10_dispatcher: full
|
||||
variants: function
|
||||
|
||||
- func: backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> void
|
||||
- func: backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> ()
|
||||
variants: method
|
||||
|
||||
- func: set_data(Tensor(a!) self, Tensor new_data) -> void
|
||||
- func: set_data(Tensor(a!) self, Tensor new_data) -> ()
|
||||
use_c10_dispatcher: unboxed_only
|
||||
variants: method
|
||||
|
||||
- func: data(Tensor self) -> Tensor
|
||||
|
|
@ -1403,9 +1404,11 @@
|
|||
- func: _cufft_get_plan_cache_max_size(int device_index) -> int
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> void
|
||||
- func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> ()
|
||||
use_c10_dispatcher: unboxed_only
|
||||
|
||||
- func: _cufft_clear_plan_cache(int device_index) -> void
|
||||
- func: _cufft_clear_plan_cache(int device_index) -> ()
|
||||
use_c10_dispatcher: unboxed_only
|
||||
|
||||
- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
|
||||
variants: function, method
|
||||
|
|
@ -1459,7 +1462,7 @@
|
|||
variants: function
|
||||
device_guard: False
|
||||
supports_named_tensor: True
|
||||
|
||||
|
||||
- func: is_distributed(Tensor self) -> bool
|
||||
use_c10_dispatcher: full
|
||||
variants: function, method
|
||||
|
|
|
|||
|
|
@ -310,7 +310,7 @@ def parse_arguments(args, func_variants, declaration, func_return):
|
|||
|
||||
# TODO: Explicit checking for void is a hack and should disappear after a more
|
||||
# functionally complete implementation of Tensor aliases.
|
||||
if declaration['inplace'] and len(func_return) > 0 and func_return[0]['type'] != "void":
|
||||
if declaration['inplace'] and len(func_return) > 0:
|
||||
found_self = False
|
||||
for arg_idx, argument in enumerate(arguments):
|
||||
if argument['name'] == "self":
|
||||
|
|
@ -331,12 +331,15 @@ def parse_arguments(args, func_variants, declaration, func_return):
|
|||
|
||||
def parse_return_arguments(return_decl, inplace, func_decl):
|
||||
arguments = []
|
||||
if return_decl == '()':
|
||||
return arguments
|
||||
|
||||
# TODO: Use a real parser here; this will get bamboozled
|
||||
# by signatures that contain things like std::array<bool, 2> (note the space)
|
||||
if return_decl[0] == '(' and return_decl[-1] == ')':
|
||||
return_decl = return_decl[1:-1]
|
||||
multiple_args = len(return_decl.split(', ')) > 1
|
||||
|
||||
multiple_args = len(return_decl.split(', ')) > 1
|
||||
for arg_idx, arg in enumerate(return_decl.split(', ')):
|
||||
t, name, default, nullable, size, annotation = type_argument_translations(arg)
|
||||
# name of arguments and name of return sometimes have collision
|
||||
|
|
|
|||
|
|
@ -124,6 +124,10 @@ def supports(o, factory_methods):
|
|||
if "_out" in o['name']:
|
||||
return False
|
||||
|
||||
# skip if no return, previously it is 'void'
|
||||
if len(o['returns']) == 0:
|
||||
return False
|
||||
|
||||
# skip return types we cannot handle
|
||||
for ret in o['returns']:
|
||||
if not value_has_tensors(ret) and ret['type'] not in RETURN_MAP:
|
||||
|
|
|
|||
|
|
@ -33,6 +33,8 @@ class TestNamedTupleAPI(unittest.TestCase):
|
|||
continue
|
||||
if not ret.startswith('('):
|
||||
continue
|
||||
if ret == '()':
|
||||
continue
|
||||
ret = ret[1:-1].split(',')
|
||||
for r in ret:
|
||||
r = r.strip()
|
||||
|
|
|
|||
|
|
@ -509,7 +509,7 @@ def emit_body(declaration):
|
|||
inplace = declaration['inplace']
|
||||
is_out_fn = name.endswith('_out')
|
||||
modifies_arguments = inplace or is_out_fn
|
||||
returns_void = len(returns) == 1 and returns[0]['type'] == 'void'
|
||||
returns_void = len(returns) == 0
|
||||
|
||||
base_name = name[:-1] if inplace else name[:-4] if is_out_fn else name
|
||||
view_info = VIEW_FUNCTIONS.get(base_name, None)
|
||||
|
|
|
|||
|
|
@ -203,7 +203,8 @@ def is_jit_arg(i, arg):
|
|||
|
||||
def is_jit_op(decl):
|
||||
# We currently don't support functions that return nothing
|
||||
if all(r['type'] == 'void' for r in decl['returns']):
|
||||
assert all(r['type'] != 'void' for r in decl['returns'])
|
||||
if len(decl['returns']) == 0:
|
||||
return False
|
||||
|
||||
arguments = decl['arguments']
|
||||
|
|
|
|||
|
|
@ -284,8 +284,10 @@ def generate_type_hints(fname, decls, is_tensor=False):
|
|||
|
||||
if len(python_returns) > 1:
|
||||
python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']'
|
||||
else:
|
||||
elif len(python_returns) == 1:
|
||||
python_returns_s = python_returns[0]
|
||||
else:
|
||||
python_returns_s = 'None'
|
||||
|
||||
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
|
||||
numargs = len(decl['arguments'])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user