pytorch/torch/csrc/jit/script/module.cpp
Michael Suo 18f721fb9a support serialization of classes (#17856)
Summary:
Stack:
      **#17856 [jit] support serialization of classes**  [💛](https://our.intern.facebook.com/intern/diff/D14402599/)

Add support for saving/loading TorchScript modules that depend on user-defned classes.

We track class dependencies the same we track tensor constants, then write them
all out such that we can just compile them in order before compiling the module
hierarchy.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17856

Reviewed By: shannonzhu

Differential Revision: D14461599

Pulled By: suo

fbshipit-source-id: 7115f87e069fd00dc8381d7de9997864fef7ea9f
2019-03-15 12:06:23 -07:00

142 lines
3.9 KiB
C++

#include <torch/csrc/jit/script/module.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/jit/script/schema_matching.h>
namespace torch {
namespace jit {
namespace script {
struct RecursiveMethodCallError : public std::exception {};
void placeholderCreator(Method&) {
throw RecursiveMethodCallError();
}
Value* try_emit_call_to(
Graph& graph,
const SourceRange& loc,
Method& callee,
c10::optional<NamedValue> self,
ArrayRef<NamedValue> args,
ArrayRef<NamedValue> kwargs,
std::stringstream& failure_messages,
Method* caller,
bool conv_tensors_to_nums) {
try {
callee.ensure_defined();
} catch (RecursiveMethodCallError&) {
throw ErrorReport(loc)
<< " method '" << callee.name()
<< "' is called recursively involving this call site. "
<< "Recursive calls are not supported";
}
auto fn = callee.graph();
auto matched_schema = tryMatchSchema(
callee.getSchema(),
loc,
graph,
std::move(self),
args,
kwargs,
failure_messages,
conv_tensors_to_nums);
if (!matched_schema)
return nullptr;
// parameters to callee method (which become parameters to _this_ method
// if they were not already)
for (auto member : callee.initial_ivalues()) {
if (!caller) {
throw ErrorReport(loc)
<< " attempting to call a method with parameters/attributes"
" from a raw graph. File a bug report";
}
matched_schema->inputs.push_back(caller->get_or_add_parameter(member));
}
callee.check_single_output();
return inlineCallTo(graph, *callee.graph(), matched_schema->inputs).at(0);
}
Value* Method::emit_call_to(
const SourceRange& loc,
Method& callee,
ArrayRef<NamedValue> args,
ArrayRef<NamedValue> kwargs) {
AT_ASSERT(!executor);
std::stringstream failure_messages;
if (auto result = try_emit_call_to(
*graph(),
loc,
callee,
c10::nullopt,
args,
kwargs,
failure_messages,
this,
/*conv_tensors_to_nums=*/true)) {
return result;
}
throw ErrorReport(loc) << failure_messages.str();
}
void Method::ensure_defined() {
if (method_creator) {
auto creator = method_creator;
method_creator = placeholderCreator;
creator(*this);
method_creator = nullptr;
}
}
void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) {
to_impl(device, dtype, non_blocking);
}
void Module::to(at::ScalarType dtype, bool non_blocking) {
to_impl(/*device=*/c10::nullopt, dtype, non_blocking);
}
void Module::to(at::Device device, bool non_blocking) {
to_impl(device, /*dtype=*/c10::nullopt, non_blocking);
}
void Module::save(std::ostream& out, const ExtraFilesMap& extra_files) {
ExportModule(*this, out, extra_files);
}
void Module::save(
const std::string& filename,
const ExtraFilesMap& extra_files) {
ExportModule(*this, filename, extra_files);
}
void Module::to_impl(
const c10::optional<at::Device>& device,
const c10::optional<at::ScalarType>& dtype,
bool non_blocking) {
// First call `to()` on every child module.
for (auto& child : modules) {
child->module->to_impl(device, dtype, non_blocking);
}
// Then convert every of our parameters.
for (auto& parameter : parameters) {
// Need to access the `at::Tensor` as a `Variable` here.
autograd::Variable variable = parameter.value().slot()->toTensor();
at::Tensor data = variable.data();
// Use the data's original device or dtype if not supplied here.
auto new_data = data.to(
device.value_or(data.device()),
dtype.value_or(data.scalar_type()),
non_blocking);
variable.set_data(new_data);
}
}
} // namespace script
} // namespace jit
} // namespace torch