#include #include #include #include #include #include // Tests go in torch::jit namespace torch { namespace jit { void testLiteInterpreter() { script::Module m("m"); m.register_parameter("foo", torch::ones({}), false); // TODO: support default param val, which was pushed in // function schema's checkAndNormalizeInputs() // m.define(R"( // def add_it(self, x, b : int = 4): // return self.foo + x + b // )"); m.define(R"( def add_it(self, x): b = 4 return self.foo + x + b )"); std::vector inputs; auto minput = 5 * torch::ones({}); inputs.emplace_back(minput); auto ref = m.run_method("add_it", minput); std::stringstream ss; m._save_for_mobile(ss); mobile::Module bc = _load_for_mobile(ss); IValue res; for (int i = 0; i < 3; ++i) { auto bcinputs = inputs; res = bc.run_method("add_it", bcinputs); } auto resd = res.toTensor().item(); auto refd = ref.toTensor().item(); AT_ASSERT(resd == refd); } } // namespace torch } // namespace jit