mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25187 The bytecode export flow: dump the bytecode format for the light weighted interpreter. * The bytecode is generated without input spec optimization. It would be more generic (input independent) with no obvious performance degradation (to be tested). * Main API: torch::jit::script::Module::save(filename, extra_files, bool *bytecode_format* = false). * Both bytecode and module object are exported in pickle format. * The module object (in data.pkl) is the same as the original JIT model. * The serializer is dependent on pickle only (no protobuf or Json). * The major functionality is forked in ScriptModuleSerializer2::serialize(). * The test loader is test_bc_export.cpp. * Simple APIs are added in Code and its implementation to get necessary information (instructions, operators and constants). * Since there's no dependency on graph/node, GetAttr is promoted from an operator to first-class instruction (https://github.com/pytorch/pytorch/pull/25151) . * Some definitions (instructions, writeArchive, etc) that are shared by full JIT and bytecode are pulled out of the local namespace (https://github.com/pytorch/pytorch/pull/25148). The output layout looks like: * folders of methods. * In each method folder (for example, forward/): * bytecode.pkl: instructions and operators * constants{.pkl,/}: constant list in constants.pkl. If there are tensors in constants, the binary tensor files in constants/ folder. * data{.pkl,/}: the module object, with binary tensor files in data/ folder. The same as in torchscript. Test Plan: Imported from OSS Differential Revision: D17076411 fbshipit-source-id: 46eb298e7320d1e585b0101effc0fcfd09219046
48 lines
1.2 KiB
C++
48 lines
1.2 KiB
C++
#include <test/cpp/jit/test_base.h>
|
|
#include <torch/csrc/jit/script/module.h>
|
|
#include <torch/csrc/autograd/generated/variable_factories.h>
|
|
#include <torch/csrc/jit/mobile/import.h>
|
|
#include <torch/csrc/jit/mobile/module.h>
|
|
#include <torch/csrc/jit/import.h>
|
|
|
|
// Tests go in torch::jit
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
void testLiteInterpreter() {
|
|
script::Module m("m");
|
|
m.register_parameter("foo", torch::ones({}), false);
|
|
// TODO: support default param val, which was pushed in
|
|
// function schema's checkAndNormalizeInputs()
|
|
// m.define(R"(
|
|
// def add_it(self, x, b : int = 4):
|
|
// return self.foo + x + b
|
|
// )");
|
|
m.define(R"(
|
|
def add_it(self, x):
|
|
b = 4
|
|
return self.foo + x + b
|
|
)");
|
|
|
|
std::vector<IValue> inputs;
|
|
auto minput = 5 * torch::ones({});
|
|
inputs.emplace_back(minput);
|
|
auto ref = m.run_method("add_it", minput);
|
|
|
|
std::stringstream ss;
|
|
m._save_for_mobile(ss);
|
|
mobile::Module bc = _load_for_mobile(ss);
|
|
IValue res;
|
|
for (int i = 0; i < 3; ++i) {
|
|
auto bcinputs = inputs;
|
|
res = bc.run_method("add_it", bcinputs);
|
|
}
|
|
|
|
auto resd = res.toTensor().item<float>();
|
|
auto refd = ref.toTensor().item<float>();
|
|
AT_ASSERT(resd == refd);
|
|
}
|
|
|
|
} // namespace torch
|
|
} // namespace jit
|