mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert D18864774: polish up overloads on free functions
Test Plan: revert-hammer Differential Revision: D18864774 Original commit changeset: 6c566738bd6f fbshipit-source-id: 669192605a3bc1a6ba06bbb5cae54f61637a45ae
This commit is contained in:
parent
446488960a
commit
73dd8c005a
|
|
@ -176,9 +176,9 @@ struct FunctionSchema {
|
|||
// this schema must provide default values.
|
||||
bool isBackwardCompatibleWith(
|
||||
const FunctionSchema& old,
|
||||
std::ostream* why_not = nullptr) const;
|
||||
std::ostream* why_not=nullptr) const;
|
||||
|
||||
private:
|
||||
private:
|
||||
OperatorName name_;
|
||||
std::vector<Argument> arguments_;
|
||||
std::vector<Argument> returns_;
|
||||
|
|
|
|||
118
test/test_jit.py
118
test/test_jit.py
|
|
@ -15177,7 +15177,7 @@ a")
|
|||
pass
|
||||
|
||||
def test_simple(x1): # noqa: F811
|
||||
return x1
|
||||
return x1 + 5
|
||||
|
||||
def invoke_function():
|
||||
return test_simple(1.0), test_simple(.5)
|
||||
|
|
@ -15188,34 +15188,7 @@ a")
|
|||
compiled_fns_1 = torch.jit._get_overloads(test_simple)
|
||||
compiled_fns_2 = torch.jit._get_overloads(test_simple)
|
||||
for a, b in zip(compiled_fns_1, compiled_fns_2):
|
||||
self.assertIs(a.graph, b.graph)
|
||||
|
||||
old_func = test_simple
|
||||
|
||||
# testing that new functions added work with caching
|
||||
@torch.jit._overload # noqa: F811
|
||||
def test_simple(x1): # noqa: F811
|
||||
# type: (str) -> str
|
||||
pass
|
||||
|
||||
@torch.jit.script
|
||||
def my_func():
|
||||
return old_func("hi")
|
||||
|
||||
# testing new function same qualified name
|
||||
@torch.jit._overload # noqa: F811
|
||||
def test_simple(a, b): # noqa: F811
|
||||
# type: (int, int) -> int
|
||||
pass
|
||||
|
||||
def test_simple(a, b):
|
||||
return a + b
|
||||
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
return test_simple(3, 4)
|
||||
|
||||
self.assertEqual(fn(), 7)
|
||||
self.assertIs(a, b)
|
||||
|
||||
# currently we take the default values have to be specified in the
|
||||
# overload as well - TODO take them from implementation and apply
|
||||
|
|
@ -15224,9 +15197,9 @@ a")
|
|||
def identity(x1): # noqa: F811
|
||||
# type: (str) -> str
|
||||
pass
|
||||
#
|
||||
|
||||
@torch.jit._overload # noqa: F811
|
||||
def identity(x1): # noqa: F811
|
||||
def identity(x1=1.0): # noqa: F811
|
||||
# type: (float) -> float
|
||||
pass
|
||||
|
||||
|
|
@ -15272,87 +15245,6 @@ a")
|
|||
with self.assertRaisesRegex(Exception, "Arguments for call are not valid"):
|
||||
torch.jit.script(test)
|
||||
|
||||
@torch.jit._overload # noqa: F811
|
||||
def good_overload(x=1): # noqa: F811
|
||||
# type: (int) -> (int)
|
||||
pass
|
||||
|
||||
def good_overload(x=1): # noqa: F811
|
||||
return x
|
||||
|
||||
@torch.jit.script
|
||||
def foo():
|
||||
return good_overload()
|
||||
|
||||
self.assertEqual(foo(), 1)
|
||||
|
||||
|
||||
with self.assertRaisesRegex(Exception, "must equal to the default parameter"):
|
||||
@torch.jit._overload # noqa: F811
|
||||
def bad_default_on_overload(x, y=2): # noqa: F811
|
||||
# type: (int, int) -> (int)
|
||||
pass
|
||||
|
||||
def bad_default_on_overload(x, y=1): # noqa: F811
|
||||
# type: (int, int) -> (int)
|
||||
pass
|
||||
|
||||
@torch.jit.script
|
||||
def test():
|
||||
return bad_default_on_overload(1, 2)
|
||||
|
||||
@torch.jit._overload # noqa: F811
|
||||
def diff_default(x): # noqa: F811
|
||||
# type: (int) -> int
|
||||
pass
|
||||
|
||||
@torch.jit._overload # noqa: F811
|
||||
def diff_default(x): # noqa: F811
|
||||
# type: (str) -> str
|
||||
pass
|
||||
|
||||
def diff_default(x="hi"): # noqa: F811
|
||||
return x
|
||||
|
||||
def test():
|
||||
return diff_default(), diff_default(2), diff_default("abc")
|
||||
|
||||
self.assertEqual(test(), torch.jit.script(test)())
|
||||
|
||||
@torch.jit._overload # noqa: F811
|
||||
def diff_num_params(x): # noqa: F811
|
||||
# type: (float) -> float
|
||||
pass
|
||||
|
||||
@torch.jit._overload # noqa: F811
|
||||
def diff_num_params(x, y): # noqa: F811
|
||||
# type: (int, int) -> int
|
||||
pass
|
||||
|
||||
def diff_num_params(x, y=2, z=3): # noqa: F811
|
||||
# type: (Union[float, int], int, int)
|
||||
return x + y + z
|
||||
|
||||
def test():
|
||||
return diff_num_params(1.0), diff_num_params(1, 2), diff_num_params(1), diff_num_params(1, 2, 3)
|
||||
|
||||
self.assertEqual(test(), torch.jit.script(test)())
|
||||
|
||||
@torch.jit._overload # noqa: F811
|
||||
def diff_num_params_no_annot():
|
||||
# type: () -> int
|
||||
pass
|
||||
|
||||
def diff_num_params_no_annot(x=1): # noqa: F811
|
||||
return x
|
||||
|
||||
def test():
|
||||
return diff_num_params_no_annot(1.0)
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Parameters not specified on the overloaded declaration must have a type annotation"):
|
||||
torch.jit.script(test)
|
||||
|
||||
|
||||
def test_function_overloading_isinstance(self):
|
||||
@torch.jit._overload # noqa: F811
|
||||
def my_conv(x, y): # noqa: F811
|
||||
|
|
@ -15360,7 +15252,7 @@ a")
|
|||
pass
|
||||
|
||||
@torch.jit._overload # noqa: F811
|
||||
def my_conv(x, y): # noqa: F811
|
||||
def my_conv(x, y=2.0): # noqa: F811
|
||||
# type: (float, float) -> (float)
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -195,46 +195,6 @@ void checkOverloadDecl(const Decl& new_decl, const Decl& old_decl) {
|
|||
old_params[i].ident());
|
||||
}
|
||||
}
|
||||
|
||||
c10::optional<IValue> tryCalculateDefaultParam(
|
||||
const Argument& arg,
|
||||
const py::object& def_value) {
|
||||
auto n = arg.N();
|
||||
auto list_type = arg.type()->cast<ListType>();
|
||||
try {
|
||||
if (n && *n > 0 && list_type) {
|
||||
// BroadcastingList, allow default values T for arg types List[T]
|
||||
return toIValue(def_value, list_type->getElementType());
|
||||
} else {
|
||||
return toIValue(def_value, arg.type());
|
||||
}
|
||||
} catch (py::cast_error& e) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// An overloaded function may have a default that does not subtype all overloads
|
||||
// @overload
|
||||
// def foo(x: str)
|
||||
// def foo(x=1)
|
||||
FunctionDefaults calcOverloadedFunctionDefaults(
|
||||
const FunctionSchema& schema,
|
||||
const FunctionDefaults& defaults) {
|
||||
FunctionDefaults updated_defaults;
|
||||
for (const auto& arg : schema.arguments()) {
|
||||
const std::string& arg_name = arg.name();
|
||||
auto value = defaults.find(arg_name);
|
||||
if (value == defaults.end()) {
|
||||
continue;
|
||||
}
|
||||
auto maybe_ivalue = tryCalculateDefaultParam(arg, value->second);
|
||||
if (maybe_ivalue) {
|
||||
updated_defaults[arg_name] = value->second;
|
||||
}
|
||||
}
|
||||
return updated_defaults;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool checkMutableFunctionDefault(const py::object& def_arg) {
|
||||
|
|
@ -276,18 +236,28 @@ FunctionSchema getSchemaWithNameAndDefaults(
|
|||
auto it = default_args.find(arg.name());
|
||||
if (it != default_args.end()) {
|
||||
checkMutableFunctionDefault(range, arg, it->second);
|
||||
c10::optional<IValue> value = tryCalculateDefaultParam(arg, it->second);
|
||||
if (!value) {
|
||||
try {
|
||||
IValue value;
|
||||
auto n = arg.N();
|
||||
auto list_type = arg.type()->cast<ListType>();
|
||||
if (n && *n > 0 && list_type) {
|
||||
// BroadcastingList, allow default values T for arg types List[T]
|
||||
value = toIValue(it->second, list_type->getElementType());
|
||||
} else {
|
||||
value = toIValue(it->second, arg.type());
|
||||
}
|
||||
new_args.emplace_back(
|
||||
arg.name(), arg.type(), arg.N(), value, arg.kwarg_only());
|
||||
} catch (py::cast_error& e) {
|
||||
throw ErrorReport(range)
|
||||
<< "Expected a default value of type " << arg.type()->python_str()
|
||||
<< " on parameter \"" << arg.name() << "\"";
|
||||
}
|
||||
new_args.emplace_back(
|
||||
arg.name(), arg.type(), arg.N(), *value, arg.kwarg_only());
|
||||
} else {
|
||||
new_args.push_back(arg);
|
||||
}
|
||||
}
|
||||
|
||||
return FunctionSchema(
|
||||
new_name.value_or(schema.name()),
|
||||
schema.overload_name(),
|
||||
|
|
@ -297,104 +267,6 @@ FunctionSchema getSchemaWithNameAndDefaults(
|
|||
schema.is_varret());
|
||||
}
|
||||
|
||||
static Decl mergeDefaultsAndExtraParametersToOverloadDecl(
|
||||
const Decl& overload_decl,
|
||||
const Decl& impl_decl,
|
||||
const FunctionDefaults& defaults) {
|
||||
std::vector<Param> adjusted_params;
|
||||
const auto& overload_params = overload_decl.params();
|
||||
const auto& impl_params = impl_decl.params();
|
||||
|
||||
// following PEP specification that the following should work:
|
||||
// @overload
|
||||
// def mouse_event(x1: int, y1: int) -> ClickEvent: ...
|
||||
// ...
|
||||
// def mouse_event(x1: int, y1: int, x2: Optional[int] = None, y2:
|
||||
// Optional[int] = None)
|
||||
TORCH_CHECK(
|
||||
overload_params.size() <= impl_params.size(),
|
||||
"Overload should not have more parameters than implementation function",
|
||||
overload_decl.range(),
|
||||
impl_decl.range());
|
||||
|
||||
for (size_t i = 0; i < overload_params.size(); ++i) {
|
||||
auto overload_name = overload_params[i].ident().name();
|
||||
auto impl_name = impl_params[i].ident().name();
|
||||
if (overload_name != impl_name) {
|
||||
throw ErrorReport(overload_decl.range())
|
||||
<< "Overload parameters must have the same names. "
|
||||
<< "Found " << overload_name << " and " << impl_name
|
||||
<< " on argument " << i;
|
||||
}
|
||||
adjusted_params.push_back(overload_params[i]);
|
||||
}
|
||||
for (size_t i = overload_params.size(); i < impl_params.size(); ++i) {
|
||||
if (!defaults.count(impl_params[i].ident().name())) {
|
||||
throw ErrorReport(impl_decl.range())
|
||||
<< "Expected to find default parameter on argument"
|
||||
<< impl_params[i].ident().name()
|
||||
<< " because it is not defined on the overloaded declaration";
|
||||
}
|
||||
if (!impl_params[i].type().present()) {
|
||||
throw ErrorReport(impl_decl.range())
|
||||
<< "Parameters not specified on the overloaded declaration must have a type annotation in the implementation function."
|
||||
<< " Did not find type for param " << impl_params[i].ident().name();
|
||||
}
|
||||
adjusted_params.push_back(impl_params[i]);
|
||||
}
|
||||
return Decl::create(
|
||||
overload_decl.range(),
|
||||
List<Param>::create(overload_decl.range(), adjusted_params),
|
||||
overload_decl.return_type());
|
||||
}
|
||||
|
||||
static StrongFunctionPtr script_compile_overloaded_function(
|
||||
const c10::QualifiedName& name,
|
||||
const Decl& overload_decl,
|
||||
const Def& implementation_def,
|
||||
ResolutionCallback rcb,
|
||||
const FunctionDefaults& implementation_defaults,
|
||||
const FunctionDefaults& overload_defaults,
|
||||
const py::object& signature) {
|
||||
if (signature.is(py::none())) {
|
||||
throw ErrorReport(overload_decl.range())
|
||||
<< "Must explicitly add type annotations to overloaded functions";
|
||||
}
|
||||
for (const auto& default_val : overload_defaults) {
|
||||
auto impl_default = implementation_defaults.find(default_val.first);
|
||||
if (impl_default == implementation_defaults.end() ||
|
||||
!impl_default->second.equal(default_val.second)) {
|
||||
throw ErrorReport(overload_decl.range())
|
||||
<< "Default parameters on overloads do not "
|
||||
<< "effect the runtime so they must equal to the default parameter on the implementation function."
|
||||
<< " found on parameter " << impl_default->first;
|
||||
}
|
||||
}
|
||||
|
||||
auto adjusted_decl = mergeDefaultsAndExtraParametersToOverloadDecl(
|
||||
overload_decl, implementation_def.decl(), implementation_defaults);
|
||||
auto new_def = implementation_def.withDecl(adjusted_decl);
|
||||
auto cu = get_python_cu();
|
||||
auto defined_functions = cu->define(
|
||||
QualifiedName(name.prefix()),
|
||||
{new_def},
|
||||
{pythonResolver(std::move(rcb))},
|
||||
nullptr,
|
||||
true);
|
||||
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
|
||||
auto& defined = defined_functions[0];
|
||||
FunctionDefaults updated_defaults = calcOverloadedFunctionDefaults(
|
||||
defined->getSchema(), implementation_defaults);
|
||||
defined->setSchema(getSchemaWithNameAndDefaults(
|
||||
new_def.range(),
|
||||
defined->getSchema(),
|
||||
new_def.name().name(),
|
||||
updated_defaults));
|
||||
StrongFunctionPtr ret(std::move(cu), defined);
|
||||
didFinishEmitFunction(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static StrongFunctionPtr script_compile_function(
|
||||
const c10::QualifiedName& name,
|
||||
const Def& def,
|
||||
|
|
@ -978,18 +850,11 @@ void initJitScriptBindings(PyObject* module) {
|
|||
const Decl& overload_decl,
|
||||
const Def& implementation_def,
|
||||
ResolutionCallback rcb,
|
||||
const FunctionDefaults& implementation_defaults,
|
||||
const FunctionDefaults& overload_defaults,
|
||||
const py::object& signature) {
|
||||
const FunctionDefaults& defaults) {
|
||||
const auto name = c10::QualifiedName(qualname);
|
||||
return script_compile_overloaded_function(
|
||||
name,
|
||||
overload_decl,
|
||||
implementation_def,
|
||||
rcb,
|
||||
implementation_defaults,
|
||||
overload_defaults,
|
||||
signature);
|
||||
checkOverloadDecl(overload_decl, implementation_def.decl());
|
||||
auto new_def = implementation_def.withDecl(overload_decl);
|
||||
return script_compile_function(name, new_def, defaults, std::move(rcb));
|
||||
});
|
||||
m.def(
|
||||
"_replace_overloaded_method_decl",
|
||||
|
|
|
|||
|
|
@ -1276,7 +1276,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
|
|||
return obj
|
||||
else:
|
||||
_check_directly_compile_overloaded(obj)
|
||||
maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
|
||||
maybe_already_compiled_fn = _try_get_jit_cached_key(obj)
|
||||
if maybe_already_compiled_fn:
|
||||
return maybe_already_compiled_fn
|
||||
ast = get_jit_def(obj)
|
||||
|
|
@ -1285,7 +1285,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None):
|
|||
fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
|
||||
# Forward docstrings
|
||||
fn.__doc__ = obj.__doc__
|
||||
_set_jit_function_cache(obj, fn)
|
||||
_set_jit_cache(obj, fn)
|
||||
return fn
|
||||
|
||||
def interface(obj):
|
||||
|
|
@ -1977,41 +1977,27 @@ _builtin_ops = [
|
|||
|
||||
|
||||
|
||||
# Caching: we currently cache compilation of free functions and overloaded functions.
|
||||
# Caching: we currently cache compilation of free functions.
|
||||
# To cache free functions we hold a weak ref to the function object and
|
||||
# map to the compiled fn's qualified name.
|
||||
# To cache overloaded functions we hold a weak ref to the function obj and
|
||||
# map to all of its overloaded compiled fns.
|
||||
# In the future we could consider caching more types of objects so that
|
||||
# aliasing is preserved across separate compilations of the same object.
|
||||
|
||||
_jit_caching_layer = weakref.WeakKeyDictionary()
|
||||
_jit_function_overload_caching = weakref.WeakKeyDictionary()
|
||||
|
||||
def _try_get_jit_cached_overloads(key):
|
||||
qual_names = _jit_function_overload_caching.get(key, None)
|
||||
if qual_names:
|
||||
return [_python_cu.find_function(qual_name) for qual_name in qual_names]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _set_jit_overload_cache(key, compiled_fns):
|
||||
_jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns]
|
||||
|
||||
def _try_get_jit_cached_function(key):
|
||||
def _try_get_jit_cached_key(key):
|
||||
qual_name = _jit_caching_layer.get(key, None)
|
||||
if qual_name:
|
||||
return _python_cu.find_function(qual_name)
|
||||
else:
|
||||
return None
|
||||
|
||||
def _set_jit_function_cache(key, value):
|
||||
def _set_jit_cache(key, value):
|
||||
# only free functions currently supported
|
||||
assert isinstance(value, torch.jit.ScriptFunction)
|
||||
_jit_caching_layer[key] = value.qualified_name
|
||||
|
||||
|
||||
|
||||
# lazily built to ensure the correct initialization order
|
||||
def _get_builtin_table():
|
||||
global _builtin_table
|
||||
|
|
@ -2065,40 +2051,57 @@ def _get_script_class(name):
|
|||
# overloads are registered in _jit_internal and compiled here so that _overload
|
||||
# can be used in nn/functional.py without an import cycle
|
||||
|
||||
def _compile_function_with_overload(overload_fn, qual_name, impl_fn):
|
||||
overload_decl = torch.jit.get_jit_def(overload_fn).decl()
|
||||
overload_signature = torch.jit.annotations.get_signature(overload_fn, None, None)
|
||||
# qualified name => list[compiled fns]
|
||||
_compiled_overloaded_fns = {}
|
||||
|
||||
def _compile_function_with_overload(qual_name, impl_fn, overload_decl, overload_defaults):
|
||||
impl_ast = torch.jit.get_jit_def(impl_fn)
|
||||
overload_defaults = get_default_args(overload_fn)
|
||||
implementation_defaults = get_default_args(impl_fn)
|
||||
_rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn)
|
||||
fn = torch._C._jit_script_compile_overload(qual_name, overload_decl, impl_ast, _rcb,
|
||||
implementation_defaults, overload_defaults, overload_signature)
|
||||
fn = torch._C._jit_script_compile_overload(qual_name, overload_decl, impl_ast, _rcb, overload_defaults)
|
||||
return fn
|
||||
|
||||
def _check_no_signature(func):
|
||||
signature = torch.jit.annotations.get_signature(func, None, None)
|
||||
if signature is None:
|
||||
qual_name = _qualified_name(func)
|
||||
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))
|
||||
|
||||
def _get_overload_decl_and_defaults(func):
|
||||
_check_no_signature(func)
|
||||
return (torch.jit.get_jit_def(func).decl(), get_default_args(func))
|
||||
|
||||
def _get_overloads(obj):
|
||||
# check for cached compiled fns
|
||||
existing_compiled_fns = _try_get_jit_cached_overloads(obj)
|
||||
qual_name = _qualified_name(obj)
|
||||
uncompiled_overloads = _jit_internal._get_fn_overloads(qual_name)
|
||||
if uncompiled_overloads is None:
|
||||
return existing_compiled_fns
|
||||
global _compiled_overloaded_fns
|
||||
compiled_overloads = _compiled_overloaded_fns.get(qual_name, None)
|
||||
if compiled_overloads is not None:
|
||||
return compiled_overloads
|
||||
|
||||
# check for not yet compiled overloads
|
||||
overloads = _jit_internal._get_fn_overloads(qual_name)
|
||||
if overloads is None:
|
||||
return None
|
||||
compiled_fns = []
|
||||
for overload_fn in uncompiled_overloads:
|
||||
compiled_fns.append(_compile_function_with_overload(overload_fn, qual_name, obj))
|
||||
|
||||
if existing_compiled_fns:
|
||||
compiled_fns = existing_compiled_fns + compiled_fns
|
||||
# TODO: use default args from the implementation, not from the overload
|
||||
# This is more complicated because you can have a default arg with a type
|
||||
# incompatible with a type of parameter in an overload, and other validation.
|
||||
# This is still an internal api so for now use defaults from overload
|
||||
for overload_fn in overloads:
|
||||
overload_decl, overload_defaults = _get_overload_decl_and_defaults(overload_fn)
|
||||
compiled_fn = _compile_function_with_overload(qual_name, obj, overload_decl, overload_defaults)
|
||||
compiled_fns.append(compiled_fn)
|
||||
|
||||
# cache compilation, remove information stored to do compilation
|
||||
_set_jit_overload_cache(obj, compiled_fns)
|
||||
_compiled_overloaded_fns[qual_name] = compiled_fns
|
||||
_jit_internal._clear_fn_overloads(qual_name)
|
||||
return compiled_fns
|
||||
|
||||
def _check_directly_compile_overloaded(obj):
|
||||
qual_name = _qualified_name(obj)
|
||||
if _jit_internal._get_fn_overloads(qual_name) or _try_get_jit_cached_overloads(obj):
|
||||
global _compiled_overloaded_fns
|
||||
global _overloaded_fns
|
||||
if qual_name in _compiled_overloaded_fns or _jit_internal._get_fn_overloads(qual_name):
|
||||
raise RuntimeError("Function {} cannot be directly compiled because it"
|
||||
" is overloaded. It must be used in a context of a function"
|
||||
" where its inputs can determine which overload to call.".format(qual_name))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user