mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Migrate return type void to () for native functions. (#28290)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28290 ghstack-source-id: 92368250 Test Plan: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28290 ghstack-source-id: 92368250 Differential Revision: D17565528 fbshipit-source-id: f4870bb9ee4f4e7c48df4d68508b512d25ed277c
This commit is contained in:
parent
f94b6cef43
commit
177c95e9bc
|
|
@ -887,7 +887,9 @@ def create_generic(top_env, declarations):
|
||||||
|
|
||||||
def format_return_type(return_types):
|
def format_return_type(return_types):
|
||||||
# type: (List[ReturnType]) -> str
|
# type: (List[ReturnType]) -> str
|
||||||
if len(return_types) == 1:
|
if len(return_types) == 0:
|
||||||
|
return "void"
|
||||||
|
elif len(return_types) == 1:
|
||||||
return return_types[0]['type']
|
return return_types[0]['type']
|
||||||
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
|
return "std::tuple<{}>".format(','.join(r['type'] for r in return_types))
|
||||||
|
|
||||||
|
|
@ -1116,9 +1118,6 @@ def create_generic(top_env, declarations):
|
||||||
if isinstance(t_raw, string_type):
|
if isinstance(t_raw, string_type):
|
||||||
t = t_raw
|
t = t_raw
|
||||||
name = None
|
name = None
|
||||||
elif t_raw is None:
|
|
||||||
t = 'void'
|
|
||||||
name = None
|
|
||||||
else:
|
else:
|
||||||
t = t_raw['type']
|
t = t_raw['type']
|
||||||
name = t_raw['name']
|
name = t_raw['name']
|
||||||
|
|
@ -1149,7 +1148,6 @@ def create_generic(top_env, declarations):
|
||||||
assert option['python_module'] == '' or option['python_module'] == 'nn', \
|
assert option['python_module'] == '' or option['python_module'] == 'nn', \
|
||||||
"Found python_module of {} for decl {}, but only \'\' string or \'nn\' are supported".format(
|
"Found python_module of {} for decl {}, but only \'\' string or \'nn\' are supported".format(
|
||||||
option['python_module'], option['name'])
|
option['python_module'], option['name'])
|
||||||
|
|
||||||
formals = native_get_formals(option)
|
formals = native_get_formals(option)
|
||||||
option['formals_list'] = formals
|
option['formals_list'] = formals
|
||||||
option['formals'] = [format_formal(f) for f in formals]
|
option['formals'] = [format_formal(f) for f in formals]
|
||||||
|
|
|
||||||
|
|
@ -37,10 +37,11 @@
|
||||||
use_c10_dispatcher: full
|
use_c10_dispatcher: full
|
||||||
variants: function
|
variants: function
|
||||||
|
|
||||||
- func: backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> void
|
- func: backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> ()
|
||||||
variants: method
|
variants: method
|
||||||
|
|
||||||
- func: set_data(Tensor(a!) self, Tensor new_data) -> void
|
- func: set_data(Tensor(a!) self, Tensor new_data) -> ()
|
||||||
|
use_c10_dispatcher: unboxed_only
|
||||||
variants: method
|
variants: method
|
||||||
|
|
||||||
- func: data(Tensor self) -> Tensor
|
- func: data(Tensor self) -> Tensor
|
||||||
|
|
@ -1403,9 +1404,11 @@
|
||||||
- func: _cufft_get_plan_cache_max_size(int device_index) -> int
|
- func: _cufft_get_plan_cache_max_size(int device_index) -> int
|
||||||
use_c10_dispatcher: full
|
use_c10_dispatcher: full
|
||||||
|
|
||||||
- func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> void
|
- func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> ()
|
||||||
|
use_c10_dispatcher: unboxed_only
|
||||||
|
|
||||||
- func: _cufft_clear_plan_cache(int device_index) -> void
|
- func: _cufft_clear_plan_cache(int device_index) -> ()
|
||||||
|
use_c10_dispatcher: unboxed_only
|
||||||
|
|
||||||
- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
|
- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
|
@ -1459,7 +1462,7 @@
|
||||||
variants: function
|
variants: function
|
||||||
device_guard: False
|
device_guard: False
|
||||||
supports_named_tensor: True
|
supports_named_tensor: True
|
||||||
|
|
||||||
- func: is_distributed(Tensor self) -> bool
|
- func: is_distributed(Tensor self) -> bool
|
||||||
use_c10_dispatcher: full
|
use_c10_dispatcher: full
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
|
|
||||||
|
|
@ -310,7 +310,7 @@ def parse_arguments(args, func_variants, declaration, func_return):
|
||||||
|
|
||||||
# TODO: Explicit checking for void is a hack and should disappear after a more
|
# TODO: Explicit checking for void is a hack and should disappear after a more
|
||||||
# functionally complete implementation of Tensor aliases.
|
# functionally complete implementation of Tensor aliases.
|
||||||
if declaration['inplace'] and len(func_return) > 0 and func_return[0]['type'] != "void":
|
if declaration['inplace'] and len(func_return) > 0:
|
||||||
found_self = False
|
found_self = False
|
||||||
for arg_idx, argument in enumerate(arguments):
|
for arg_idx, argument in enumerate(arguments):
|
||||||
if argument['name'] == "self":
|
if argument['name'] == "self":
|
||||||
|
|
@ -331,12 +331,15 @@ def parse_arguments(args, func_variants, declaration, func_return):
|
||||||
|
|
||||||
def parse_return_arguments(return_decl, inplace, func_decl):
|
def parse_return_arguments(return_decl, inplace, func_decl):
|
||||||
arguments = []
|
arguments = []
|
||||||
|
if return_decl == '()':
|
||||||
|
return arguments
|
||||||
|
|
||||||
# TODO: Use a real parser here; this will get bamboozled
|
# TODO: Use a real parser here; this will get bamboozled
|
||||||
# by signatures that contain things like std::array<bool, 2> (note the space)
|
# by signatures that contain things like std::array<bool, 2> (note the space)
|
||||||
if return_decl[0] == '(' and return_decl[-1] == ')':
|
if return_decl[0] == '(' and return_decl[-1] == ')':
|
||||||
return_decl = return_decl[1:-1]
|
return_decl = return_decl[1:-1]
|
||||||
multiple_args = len(return_decl.split(', ')) > 1
|
|
||||||
|
|
||||||
|
multiple_args = len(return_decl.split(', ')) > 1
|
||||||
for arg_idx, arg in enumerate(return_decl.split(', ')):
|
for arg_idx, arg in enumerate(return_decl.split(', ')):
|
||||||
t, name, default, nullable, size, annotation = type_argument_translations(arg)
|
t, name, default, nullable, size, annotation = type_argument_translations(arg)
|
||||||
# name of arguments and name of return sometimes have collision
|
# name of arguments and name of return sometimes have collision
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,10 @@ def supports(o, factory_methods):
|
||||||
if "_out" in o['name']:
|
if "_out" in o['name']:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# skip if no return, previously it is 'void'
|
||||||
|
if len(o['returns']) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
# skip return types we cannot handle
|
# skip return types we cannot handle
|
||||||
for ret in o['returns']:
|
for ret in o['returns']:
|
||||||
if not value_has_tensors(ret) and ret['type'] not in RETURN_MAP:
|
if not value_has_tensors(ret) and ret['type'] not in RETURN_MAP:
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,8 @@ class TestNamedTupleAPI(unittest.TestCase):
|
||||||
continue
|
continue
|
||||||
if not ret.startswith('('):
|
if not ret.startswith('('):
|
||||||
continue
|
continue
|
||||||
|
if ret == '()':
|
||||||
|
continue
|
||||||
ret = ret[1:-1].split(',')
|
ret = ret[1:-1].split(',')
|
||||||
for r in ret:
|
for r in ret:
|
||||||
r = r.strip()
|
r = r.strip()
|
||||||
|
|
|
||||||
|
|
@ -509,7 +509,7 @@ def emit_body(declaration):
|
||||||
inplace = declaration['inplace']
|
inplace = declaration['inplace']
|
||||||
is_out_fn = name.endswith('_out')
|
is_out_fn = name.endswith('_out')
|
||||||
modifies_arguments = inplace or is_out_fn
|
modifies_arguments = inplace or is_out_fn
|
||||||
returns_void = len(returns) == 1 and returns[0]['type'] == 'void'
|
returns_void = len(returns) == 0
|
||||||
|
|
||||||
base_name = name[:-1] if inplace else name[:-4] if is_out_fn else name
|
base_name = name[:-1] if inplace else name[:-4] if is_out_fn else name
|
||||||
view_info = VIEW_FUNCTIONS.get(base_name, None)
|
view_info = VIEW_FUNCTIONS.get(base_name, None)
|
||||||
|
|
|
||||||
|
|
@ -203,7 +203,8 @@ def is_jit_arg(i, arg):
|
||||||
|
|
||||||
def is_jit_op(decl):
|
def is_jit_op(decl):
|
||||||
# We currently don't support functions that return nothing
|
# We currently don't support functions that return nothing
|
||||||
if all(r['type'] == 'void' for r in decl['returns']):
|
assert all(r['type'] != 'void' for r in decl['returns'])
|
||||||
|
if len(decl['returns']) == 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
arguments = decl['arguments']
|
arguments = decl['arguments']
|
||||||
|
|
|
||||||
|
|
@ -284,8 +284,10 @@ def generate_type_hints(fname, decls, is_tensor=False):
|
||||||
|
|
||||||
if len(python_returns) > 1:
|
if len(python_returns) > 1:
|
||||||
python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']'
|
python_returns_s = 'Tuple[' + ', '.join(python_returns) + ']'
|
||||||
else:
|
elif len(python_returns) == 1:
|
||||||
python_returns_s = python_returns[0]
|
python_returns_s = python_returns[0]
|
||||||
|
else:
|
||||||
|
python_returns_s = 'None'
|
||||||
|
|
||||||
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
|
type_hint = "def {}({}) -> {}: ...".format(fname, python_args_s, python_returns_s)
|
||||||
numargs = len(decl['arguments'])
|
numargs = len(decl['arguments'])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user