pytorch/torch/csrc/jit/script/module.cpp
Zachary DeVito 6ce799edd6 Tuples/Lists can now be inputs/outputs to script and other simple fixes. (#10812)
Summary:
* Fix the necessary pathways so that tuples and lists can be inputs to the script.

* prevent linear algebra functions from being run in shape prop because
they frequently will error out for nonsense data.

* favor schema-driven python input conversion where possible.
remaining cases where we directly create Stacks without schema are
only for debugging

* Make the error messages when calling script/trace functions more pythonic

* Simplify FlattenTuples -- now that tuples are supported we can choose to only flatten tuples when needed. This may have to be revisited pending onnx test results, but is necessary for making tuple io work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10812

Differential Revision: D9477982

Pulled By: zdevito

fbshipit-source-id: ed06fc426e6ef6deb404602a26c435a7fc40ea0c
2018-08-27 14:40:40 -07:00

79 lines
2.4 KiB
C++

#include "torch/csrc/jit/assertions.h"
#include "torch/csrc/jit/script/module.h"
#include "torch/csrc/jit/script/compiler.h"
#include "torch/csrc/jit/script/error_report.h"
#include "torch/csrc/jit/export.h"
#include "torch/csrc/jit/operator.h"
namespace torch { namespace jit { namespace script {
struct RecursiveMethodCallError : public std::exception {};
void placeholderCreator(Method&) {
throw RecursiveMethodCallError();
}
static FunctionSchema defaultSchemaFor(const Method& method) {
std::vector<Argument> args;
std::vector<Argument> returns;
Graph& g = *method.graph();
size_t num_inputs = method.num_inputs();
for(size_t i = 0; i < num_inputs; ++i) {
const Value* v = g.inputs().at(i);
std::string name = v->hasUniqueName() ? v->uniqueName() : ("argument_" + std::to_string(i));
args.push_back({std::move(name), unshapedType(g.inputs()[i]->type())});
}
for(size_t i = 0; i < g.outputs().size(); ++i) {
returns.push_back({"", unshapedType(g.outputs()[i]->type())});
}
return { method.name(), std::move(args), std::move(returns) };
}
const FunctionSchema& Method::getSchema() const {
if(schema == nullptr) {
schema.reset(new FunctionSchema(defaultSchemaFor(*this)));
}
return *schema;
}
std::vector<Value*> Method::emit_call_to(SourceRange loc, Method & callee, ArrayRef<NamedValue> args, ArrayRef<NamedValue> kwargs) {
JIT_ASSERT(!executor);
try {
callee.ensure_defined();
} catch (RecursiveMethodCallError&) {
throw ErrorReport(loc) << " method '" << callee.name()
<< "' is called recursively involving this call site. Recursive calls are not supported";
}
auto fn = callee.graph();
std::stringstream failure_messages;
auto all_inputs = tryMatchSchema(
callee.getSchema(),
loc, *graph(), args, kwargs, failure_messages, /*conv_tensors_to_nums*/true);
if(!all_inputs)
throw ErrorReport(loc) << failure_messages.str();
// parameters to callee method (which become parameters to _this_ method
// if they were not already)
for(at::Tensor* member : callee.member_inputs) {
all_inputs->push_back(get_or_add_parameter(member));
}
return inlineCallTo(*graph(), *callee.graph(), *all_inputs);
}
void Method::ensure_defined() {
if(method_creator) {
auto creator = method_creator;
method_creator = placeholderCreator;
creator(*this);
method_creator = nullptr;
}
}
void Module::save(const std::string& filename) {
ExportModule(*this, filename);
}
}}}