pytorch/torch/csrc/jit/script/module.cpp
Zachary DeVito ce69d3110b
Improve script builtin checking using schema (#7311)
Improve script builtin checking using schema

* This add aten_schema.h which provides a barebones amount of type and
  argument information about each builtin operator
* emitBuiltinCall is updated to use this information rather than
  aten_dispatch to ensure the operator is correct.
* handling of keyword and position arguments now matches python behavior
* There is no longer a requirement that kwargs be constant or that the
  attributes of an op must be entirely constant or non-constant
* compiler now constructs a non-attributed version of the op first and
  then turns it into the constant-attribute version if all attributes
  are constants.
* default arguments for builtins now work
* SugaredValue::call and similar functions now have SourceRange information
  for their arguments so that error reporting is more accurate

Notes:
* This does not try to merge the builtin checking with python arg parser.
  Given that we will eventually have C10 schema which will replace aten_schema,
  we will eventually have a C++ description of the schema and working of that
  description directly will be the easiest form to understand.
* python function calls and script method calls do not support keyword arguments yet.
  When we add this support we should refactor the handling in tryEmitSchema
  that resolves keywords into a common function.

* default arguments work
* keyword arguments to builtins work (still need to extend to calling python and other script methods)
* much better error reporting for incorrect builtins

Lift any constants to attributes on nodes when possible

* Schema  is usable internally in the compiler as
  the function signatures of script functions as well as for builtin
  operators.
* Adds a List[T] class to better represent the arguments to cat/stack
  as a type rather than with custom checking.
* Support kwargs for calls of script methods

A future commit will be needed to add support for:
* calls to script _functions_ which are currently are GraphExecutors without schema info.
* kwargs to python functions, which will require refactoring python op
2018-05-14 14:46:36 -07:00

64 lines
2.1 KiB
C++

#include "torch/csrc/jit/script/module.h"
#include "torch/csrc/jit/script/compiler.h"
#include "torch/csrc/jit/script/error_report.h"
namespace torch { namespace jit { namespace script {
struct RecursiveMethodCallError : public std::exception {};
void placeholderCreator(Method&) {
throw RecursiveMethodCallError();
}
static FunctionSchema defaultSchemaFor(Method& method) {
std::vector<Argument> args;
std::vector<Argument> returns;
Graph& g = *method.graph();
size_t num_inputs = method.num_inputs();
for(size_t i = 0; i < num_inputs; ++i) {
const Value* v = g.inputs().at(i);
std::string name = v->hasUniqueName() ? v->uniqueName() : ("argument_" + std::to_string(i));
args.push_back({std::move(name), DynamicType::get(), at::nullopt, at::nullopt});
}
for(size_t i = 0; i < g.outputs().size(); ++i) {
returns.push_back({"", DynamicType::get(), at::nullopt, at::nullopt});
}
return { method.name(), std::move(args), std::move(returns) };
}
std::vector<Value*> Method::emit_call_to(SourceRange loc, Method & callee, ArrayRef<NamedValue> args, ArrayRef<NamedValue> kwargs) {
JIT_ASSERT(!executor);
try {
callee.ensure_defined();
} catch (RecursiveMethodCallError&) {
throw ErrorReport(loc) << " method '" << callee.name()
<< "' is called recursively involving this call site. Recursive calls are not supported";
}
auto fn = callee.graph();
std::stringstream failure_messages;
auto all_inputs = tryMatchSchema(
callee.schema ? *callee.schema : defaultSchemaFor(callee),
loc, *graph(), args, kwargs, failure_messages);
if(!all_inputs)
throw ErrorReport(loc) << failure_messages.str();
// parameters to callee method (which become parameters to _this_ method
// if they were not already)
for(at::Tensor* member : callee.member_inputs) {
all_inputs->push_back(get_or_add_parameter(member));
}
return inlineCallTo(*graph(), *callee.graph(), *all_inputs);
}
void Method::ensure_defined() {
if(method_creator) {
auto creator = method_creator;
method_creator = placeholderCreator;
creator(*this);
method_creator = nullptr;
}
}
}}}