Revert D18864774: polish up overloads on free functions

Test Plan: revert-hammer

Differential Revision:
D18864774

Original commit changeset: 6c566738bd6f

fbshipit-source-id: 669192605a3bc1a6ba06bbb5cae54f61637a45ae
This commit is contained in:
Wanchao Liang 2019-12-09 15:39:15 -08:00 committed by Facebook Github Bot
parent 446488960a
commit 73dd8c005a
4 changed files with 65 additions and 305 deletions

View File

@ -176,9 +176,9 @@ struct FunctionSchema {
// this schema must provide default values.
bool isBackwardCompatibleWith(
const FunctionSchema& old,
std::ostream* why_not = nullptr) const;
std::ostream* why_not=nullptr) const;
private:
private:
OperatorName name_;
std::vector<Argument> arguments_;
std::vector<Argument> returns_;

View File

@ -15177,7 +15177,7 @@ a")
pass
def test_simple(x1): # noqa: F811
return x1
return x1 + 5
def invoke_function():
return test_simple(1.0), test_simple(.5)
@ -15188,34 +15188,7 @@ a")
compiled_fns_1 = torch.jit._get_overloads(test_simple)
compiled_fns_2 = torch.jit._get_overloads(test_simple)
for a, b in zip(compiled_fns_1, compiled_fns_2):
self.assertIs(a.graph, b.graph)
old_func = test_simple
# testing that new functions added work with caching
@torch.jit._overload # noqa: F811
def test_simple(x1): # noqa: F811
# type: (str) -> str
pass
@torch.jit.script
def my_func():
return old_func("hi")
# testing new function same qualified name
@torch.jit._overload # noqa: F811
def test_simple(a, b): # noqa: F811
# type: (int, int) -> int
pass
def test_simple(a, b):
return a + b
@torch.jit.script
def fn():
return test_simple(3, 4)
self.assertEqual(fn(), 7)
self.assertIs(a, b)
# currently we take the default values have to be specified in the
# overload as well - TODO take them from implementation and apply
@ -15224,9 +15197,9 @@ a")
def identity(x1): # noqa: F811
# type: (str) -> str
pass
#
@torch.jit._overload # noqa: F811
def identity(x1): # noqa: F811
def identity(x1=1.0): # noqa: F811
# type: (float) -> float
pass
@ -15272,87 +15245,6 @@ a")
with self.assertRaisesRegex(Exception, "Arguments for call are not valid"):
torch.jit.script(test)
@torch.jit._overload # noqa: F811
def good_overload(x=1): # noqa: F811
# type: (int) -> (int)
pass
def good_overload(x=1): # noqa: F811
return x
@torch.jit.script
def foo():
return good_overload()
self.assertEqual(foo(), 1)
with self.assertRaisesRegex(Exception, "must equal to the default parameter"):
@torch.jit._overload # noqa: F811
def bad_default_on_overload(x, y=2): # noqa: F811
# type: (int, int) -> (int)
pass
def bad_default_on_overload(x, y=1): # noqa: F811
# type: (int, int) -> (int)
pass
@torch.jit.script
def test():
return bad_default_on_overload(1, 2)
@torch.jit._overload # noqa: F811
def diff_default(x): # noqa: F811
# type: (int) -> int
pass
@torch.jit._overload # noqa: F811
def diff_default(x): # noqa: F811
# type: (str) -> str
pass
def diff_default(x="hi"): # noqa: F811
return x
def test():
return diff_default(), diff_default(2), diff_default("abc")
self.assertEqual(test(), torch.jit.script(test)())
@torch.jit._overload # noqa: F811
def diff_num_params(x): # noqa: F811
# type: (float) -> float
pass
@torch.jit._overload # noqa: F811
def diff_num_params(x, y): # noqa: F811
# type: (int, int) -> int
pass
def diff_num_params(x, y=2, z=3): # noqa: F811
# type: (Union[float, int], int, int)
return x + y + z
def test():
return diff_num_params(1.0), diff_num_params(1, 2), diff_num_params(1), diff_num_params(1, 2, 3)
self.assertEqual(test(), torch.jit.script(test)())
@torch.jit._overload # noqa: F811
def diff_num_params_no_annot():
# type: () -> int
pass
def diff_num_params_no_annot(x=1): # noqa: F811
return x
def test():
return diff_num_params_no_annot(1.0)
with self.assertRaisesRegex(Exception, "Parameters not specified on the overloaded declaration must have a type annotation"):
torch.jit.script(test)
def test_function_overloading_isinstance(self):
@torch.jit._overload # noqa: F811
def my_conv(x, y): # noqa: F811
@ -15360,7 +15252,7 @@ a")
pass
@torch.jit._overload # noqa: F811
def my_conv(x, y): # noqa: F811
def my_conv(x, y=2.0): # noqa: F811
# type: (float, float) -> (float)
pass

View File

@ -195,46 +195,6 @@ void checkOverloadDecl(const Decl& new_decl, const Decl& old_decl) {
old_params[i].ident());
}
}
c10::optional<IValue> tryCalculateDefaultParam(
const Argument& arg,
const py::object& def_value) {
auto n = arg.N();
auto list_type = arg.type()->cast<ListType>();
try {
if (n && *n > 0 && list_type) {
// BroadcastingList, allow default values T for arg types List[T]
return toIValue(def_value, list_type->getElementType());
} else {
return toIValue(def_value, arg.type());
}
} catch (py::cast_error& e) {
return c10::nullopt;
}
}
// An overloaded function may have a default that does not subtype all overloads
// @overload
// def foo(x: str)
// def foo(x=1)
FunctionDefaults calcOverloadedFunctionDefaults(
const FunctionSchema& schema,
const FunctionDefaults& defaults) {
FunctionDefaults updated_defaults;
for (const auto& arg : schema.arguments()) {
const std::string& arg_name = arg.name();
auto value = defaults.find(arg_name);
if (value == defaults.end()) {
continue;
}
auto maybe_ivalue = tryCalculateDefaultParam(arg, value->second);
if (maybe_ivalue) {
updated_defaults[arg_name] = value->second;
}
}
return updated_defaults;
}
} // namespace
bool checkMutableFunctionDefault(const py::object& def_arg) {
@ -276,18 +236,28 @@ FunctionSchema getSchemaWithNameAndDefaults(
auto it = default_args.find(arg.name());
if (it != default_args.end()) {
checkMutableFunctionDefault(range, arg, it->second);
c10::optional<IValue> value = tryCalculateDefaultParam(arg, it->second);
if (!value) {
try {
IValue value;
auto n = arg.N();
auto list_type = arg.type()->cast<ListType>();
if (n && *n > 0 && list_type) {
// BroadcastingList, allow default values T for arg types List[T]
value = toIValue(it->second, list_type->getElementType());
} else {
value = toIValue(it->second, arg.type());
}
new_args.emplace_back(
arg.name(), arg.type(), arg.N(), value, arg.kwarg_only());
} catch (py::cast_error& e) {
throw ErrorReport(range)
<< "Expected a default value of type " << arg.type()->python_str()
<< " on parameter \"" << arg.name() << "\"";
}
new_args.emplace_back(
arg.name(), arg.type(), arg.N(), *value, arg.kwarg_only());
} else {
new_args.push_back(arg);
}
}
return FunctionSchema(
new_name.value_or(schema.name()),
schema.overload_name(),
@ -297,104 +267,6 @@ FunctionSchema getSchemaWithNameAndDefaults(
schema.is_varret());
}
static Decl mergeDefaultsAndExtraParametersToOverloadDecl(
const Decl& overload_decl,
const Decl& impl_decl,
const FunctionDefaults& defaults) {
std::vector<Param> adjusted_params;
const auto& overload_params = overload_decl.params();
const auto& impl_params = impl_decl.params();
// following PEP specification that the following should work:
// @overload
// def mouse_event(x1: int, y1: int) -> ClickEvent: ...
// ...
// def mouse_event(x1: int, y1: int, x2: Optional[int] = None, y2:
// Optional[int] = None)
TORCH_CHECK(
overload_params.size() <= impl_params.size(),
"Overload should not have more parameters than implementation function",
overload_decl.range(),
impl_decl.range());
for (size_t i = 0; i < overload_params.size(); ++i) {
auto overload_name = overload_params[i].ident().name();
auto impl_name = impl_params[i].ident().name();
if (overload_name != impl_name) {
throw ErrorReport(overload_decl.range())
<< "Overload parameters must have the same names. "
<< "Found " << overload_name << " and " << impl_name
<< " on argument " << i;
}
adjusted_params.push_back(overload_params[i]);
}
for (size_t i = overload_params.size(); i < impl_params.size(); ++i) {
if (!defaults.count(impl_params[i].ident().name())) {
throw ErrorReport(impl_decl.range())
<< "Expected to find default parameter on argument"
<< impl_params[i].ident().name()
<< " because it is not defined on the overloaded declaration";
}
if (!impl_params[i].type().present()) {
throw ErrorReport(impl_decl.range())
<< "Parameters not specified on the overloaded declaration must have a type annotation in the implementation function."
<< " Did not find type for param " << impl_params[i].ident().name();
}
adjusted_params.push_back(impl_params[i]);
}
return Decl::create(
overload_decl.range(),
List<Param>::create(overload_decl.range(), adjusted_params),
overload_decl.return_type());
}
static StrongFunctionPtr script_compile_overloaded_function(
const c10::QualifiedName& name,
const Decl& overload_decl,
const Def& implementation_def,
ResolutionCallback rcb,
const FunctionDefaults& implementation_defaults,
const FunctionDefaults& overload_defaults,
const py::object& signature) {
if (signature.is(py::none())) {
throw ErrorReport(overload_decl.range())
<< "Must explicitly add type annotations to overloaded functions";
}
for (const auto& default_val : overload_defaults) {
auto impl_default = implementation_defaults.find(default_val.first);
if (impl_default == implementation_defaults.end() ||
!impl_default->second.equal(default_val.second)) {
throw ErrorReport(overload_decl.range())
<< "Default parameters on overloads do not "
<< "effect the runtime so they must equal to the default parameter on the implementation function."
<< " found on parameter " << impl_default->first;
}
}
auto adjusted_decl = mergeDefaultsAndExtraParametersToOverloadDecl(
overload_decl, implementation_def.decl(), implementation_defaults);
auto new_def = implementation_def.withDecl(adjusted_decl);
auto cu = get_python_cu();
auto defined_functions = cu->define(
QualifiedName(name.prefix()),
{new_def},
{pythonResolver(std::move(rcb))},
nullptr,
true);
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
auto& defined = defined_functions[0];
FunctionDefaults updated_defaults = calcOverloadedFunctionDefaults(
defined->getSchema(), implementation_defaults);
defined->setSchema(getSchemaWithNameAndDefaults(
new_def.range(),
defined->getSchema(),
new_def.name().name(),
updated_defaults));
StrongFunctionPtr ret(std::move(cu), defined);
didFinishEmitFunction(ret);
return ret;
}
static StrongFunctionPtr script_compile_function(
const c10::QualifiedName& name,
const Def& def,
@ -978,18 +850,11 @@ void initJitScriptBindings(PyObject* module) {
const Decl& overload_decl,
const Def& implementation_def,
ResolutionCallback rcb,
const FunctionDefaults& implementation_defaults,
const FunctionDefaults& overload_defaults,
const py::object& signature) {
const FunctionDefaults& defaults) {
const auto name = c10::QualifiedName(qualname);
return script_compile_overloaded_function(
name,
overload_decl,
implementation_def,
rcb,
implementation_defaults,
overload_defaults,
signature);
checkOverloadDecl(overload_decl, implementation_def.decl());
auto new_def = implementation_def.withDecl(overload_decl);
return script_compile_function(name, new_def, defaults, std::move(rcb));
});
m.def(
"_replace_overloaded_method_decl",

View File

@ -1276,7 +1276,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
return obj
else:
_check_directly_compile_overloaded(obj)
maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
maybe_already_compiled_fn = _try_get_jit_cached_key(obj)
if maybe_already_compiled_fn:
return maybe_already_compiled_fn
ast = get_jit_def(obj)
@ -1285,7 +1285,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
# Forward docstrings
fn.__doc__ = obj.__doc__
_set_jit_function_cache(obj, fn)
_set_jit_cache(obj, fn)
return fn
def interface(obj):
@ -1977,41 +1977,27 @@ _builtin_ops = [
# Caching: we currently cache compilation of free functions and overloaded functions.
# Caching: we currently cache compilation of free functions.
# To cache free functions we hold a weak ref to the function object and
# map to the compiled fn's qualified name.
# To cache overloaded functions we hold a weak ref to the function obj and
# map to all of its overloaded compiled fns.
# In the future we could consider caching more types of objects so that
# aliasing is preserved across separate compilations of the same object.
_jit_caching_layer = weakref.WeakKeyDictionary()
_jit_function_overload_caching = weakref.WeakKeyDictionary()
def _try_get_jit_cached_overloads(key):
qual_names = _jit_function_overload_caching.get(key, None)
if qual_names:
return [_python_cu.find_function(qual_name) for qual_name in qual_names]
else:
return None
def _set_jit_overload_cache(key, compiled_fns):
_jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns]
def _try_get_jit_cached_function(key):
def _try_get_jit_cached_key(key):
qual_name = _jit_caching_layer.get(key, None)
if qual_name:
return _python_cu.find_function(qual_name)
else:
return None
def _set_jit_function_cache(key, value):
def _set_jit_cache(key, value):
# only free functions currently supported
assert isinstance(value, torch.jit.ScriptFunction)
_jit_caching_layer[key] = value.qualified_name
# lazily built to ensure the correct initialization order
def _get_builtin_table():
global _builtin_table
@ -2065,40 +2051,57 @@ def _get_script_class(name):
# overloads are registered in _jit_internal and compiled here so that _overload
# can be used in nn/functional.py without an import cycle
def _compile_function_with_overload(overload_fn, qual_name, impl_fn):
overload_decl = torch.jit.get_jit_def(overload_fn).decl()
overload_signature = torch.jit.annotations.get_signature(overload_fn, None, None)
# qualified name => list[compiled fns]
_compiled_overloaded_fns = {}
def _compile_function_with_overload(qual_name, impl_fn, overload_decl, overload_defaults):
impl_ast = torch.jit.get_jit_def(impl_fn)
overload_defaults = get_default_args(overload_fn)
implementation_defaults = get_default_args(impl_fn)
_rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn)
fn = torch._C._jit_script_compile_overload(qual_name, overload_decl, impl_ast, _rcb,
implementation_defaults, overload_defaults, overload_signature)
fn = torch._C._jit_script_compile_overload(qual_name, overload_decl, impl_ast, _rcb, overload_defaults)
return fn
def _check_no_signature(func):
signature = torch.jit.annotations.get_signature(func, None, None)
if signature is None:
qual_name = _qualified_name(func)
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))
def _get_overload_decl_and_defaults(func):
_check_no_signature(func)
return (torch.jit.get_jit_def(func).decl(), get_default_args(func))
def _get_overloads(obj):
# check for cached compiled fns
existing_compiled_fns = _try_get_jit_cached_overloads(obj)
qual_name = _qualified_name(obj)
uncompiled_overloads = _jit_internal._get_fn_overloads(qual_name)
if uncompiled_overloads is None:
return existing_compiled_fns
global _compiled_overloaded_fns
compiled_overloads = _compiled_overloaded_fns.get(qual_name, None)
if compiled_overloads is not None:
return compiled_overloads
# check for not yet compiled overloads
overloads = _jit_internal._get_fn_overloads(qual_name)
if overloads is None:
return None
compiled_fns = []
for overload_fn in uncompiled_overloads:
compiled_fns.append(_compile_function_with_overload(overload_fn, qual_name, obj))
if existing_compiled_fns:
compiled_fns = existing_compiled_fns + compiled_fns
# TODO: use default args from the implementation, not from the overload
# This is more complicated because you can have a default arg with a type
# incompatible with a type of parameter in an overload, and other validation.
# This is still an internal api so for now use defaults from overload
for overload_fn in overloads:
overload_decl, overload_defaults = _get_overload_decl_and_defaults(overload_fn)
compiled_fn = _compile_function_with_overload(qual_name, obj, overload_decl, overload_defaults)
compiled_fns.append(compiled_fn)
# cache compilation, remove information stored to do compilation
_set_jit_overload_cache(obj, compiled_fns)
_compiled_overloaded_fns[qual_name] = compiled_fns
_jit_internal._clear_fn_overloads(qual_name)
return compiled_fns
def _check_directly_compile_overloaded(obj):
qual_name = _qualified_name(obj)
if _jit_internal._get_fn_overloads(qual_name) or _try_get_jit_cached_overloads(obj):
global _compiled_overloaded_fns
global _overloaded_fns
if qual_name in _compiled_overloaded_fns or _jit_internal._get_fn_overloads(qual_name):
raise RuntimeError("Function {} cannot be directly compiled because it"
" is overloaded. It must be used in a context of a function"
" where its inputs can determine which overload to call.".format(qual_name))