[Lite Interpreter] Support features from to_backend (#52870)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52870

Add the missing parts to support to_backend modules by lite interpreter.
1. Add ISINSTANCE instruction support, which is used in to_backend for output type check.
2. Bypass lite interpreter's type parser by checking the qualified name. If it starts with "torch.jit", use the same type resolver as nn module (starting with "__torch__").

Tests
Mobile module is serialized and loaded in ```BackendTest.TestCompiler```. The results are compared to those from original torchscript module.

Test Plan: Imported from OSS

Reviewed By: raziel

Differential Revision: D26715351

Pulled By: iseeyuan

fbshipit-source-id: ad9d74ee81c6aa692ab9e5dd7a9003bae5d4f01f
This commit is contained in:
Martin Yuan 2021-03-01 17:53:50 -08:00 committed by Facebook GitHub Bot
parent 8467e5cad3
commit b5ae8e69a7
5 changed files with 34 additions and 7 deletions

View File

@ -2,6 +2,7 @@
#include <test/cpp/jit/test_utils.h> #include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/api/module.h> #include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/backends/backend_detail.h> #include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/torch.h> #include <torch/torch.h>
// Tests go in torch::jit // Tests go in torch::jit
@ -100,6 +101,12 @@ TEST(BackendTest, TestCompiler) {
"backend_with_compiler_demo", m, compile_spec, any_dict_ty); "backend_with_compiler_demo", m, compile_spec, any_dict_ty);
auto res = lm.forward(inputs); auto res = lm.forward(inputs);
AT_ASSERT(res.toTensor().equal(ref.toTensor())); AT_ASSERT(res.toTensor().equal(ref.toTensor()));
std::stringstream ss;
lm._save_for_mobile(ss);
auto mlm = _load_for_mobile(ss);
auto mres = mlm.forward(inputs);
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
} }
TEST(BackendTest, TestCompilerNotSupport) { TEST(BackendTest, TestCompilerNotSupport) {

View File

@ -85,9 +85,9 @@ class BackendWithCompiler : public PyTorchBackendInterface {
auto sub = instruction.substr(15); auto sub = instruction.substr(15);
const_val = stod(sub); const_val = stod(sub);
} else if (token == "aten::add") { } else if (token == "aten::add") {
output_list.emplace_back(x.add_(h, const_val)); output_list.emplace_back(x.add(h, const_val));
} else if (token == "aten::sub") { } else if (token == "aten::sub") {
output_list.emplace_back(x.sub_(h, const_val)); output_list.emplace_back(x.sub(h, const_val));
} else { } else {
TORCH_CHECK( TORCH_CHECK(
false, false,

View File

@ -146,13 +146,16 @@ BytecodeDeserializer::BytecodeDeserializer(
reader_(std::move(reader)) {} reader_(std::move(reader)) {}
TypePtr BytecodeDeserializer::resolveTypeName(const c10::QualifiedName& qn) { TypePtr BytecodeDeserializer::resolveTypeName(const c10::QualifiedName& qn) {
static const c10::QualifiedName torchPrefix = "__torch__"; // HACK: first we check whether the name starts with special prefix to
// HACK: first we check whether the name starts with `__torch__` to // tell if it's a supported pytorch class type. There are two special
// tell if it's "supposed" to be a class type. This is a reliable // prefixes. "__torch__" for nn module, and "torch.jit" from to_backend.
// This is a reliable
// check today, but there is no guarantee that this is the case. The // check today, but there is no guarantee that this is the case. The
// real solution is to merge type parsers so we can share class // real solution is to merge type parsers so we can share class
// resolution logic. // resolution logic.
if (torchPrefix.isPrefixOf(qn)) { static const c10::QualifiedName torchPrefix = "__torch__";
static const c10::QualifiedName jitPrefix = "torch.jit";
if (torchPrefix.isPrefixOf(qn) || jitPrefix.isPrefixOf(qn)) {
if (compilation_unit_->get_class(qn) == nullptr) { if (compilation_unit_->get_class(qn) == nullptr) {
auto typeptr = ClassType::create(qn, compilation_unit_, true); auto typeptr = ClassType::create(qn, compilation_unit_, true);
compilation_unit_->register_type(typeptr); compilation_unit_->register_type(typeptr);

View File

@ -26,6 +26,17 @@ void createObject(Stack& stack, const at::ClassTypePtr& type) {
type->numAttributes()); type->numAttributes());
push(stack, std::move(userObj)); push(stack, std::move(userObj));
} }
void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
at::TypePtr ty = pop(stack).type();
for (const at::TypePtr& candidate : types) {
if (ty->isSubtypeOf(candidate)) {
push(stack, true);
return;
}
}
push(stack, false);
}
} // namespace } // namespace
using namespace at; using namespace at;
@ -189,6 +200,12 @@ bool InterpreterState::run(Stack& stack) {
createObject(stack, type); createObject(stack, type);
++pc; ++pc;
} break; } break;
case ISINSTANCE: {
at::ArrayRef<TypePtr> types(
&(code_->types_[inst.X]), &(code_->types_[inst.X + inst.N]));
isinstance(stack, types);
++pc;
} break;
case WARN: { case WARN: {
drop(stack, 1); drop(stack, 1);
TORCH_WARN(pop(stack).toStringRef()); TORCH_WARN(pop(stack).toStringRef());

View File

@ -75,7 +75,7 @@ bool isOpSupportedInMobile(OpCode op) {
OP, OPN, LOAD, MOVE, STOREN, STORE, DROP, DROPR, LOADC, JF, JMP, LOOP, OP, OPN, LOAD, MOVE, STOREN, STORE, DROP, DROPR, LOADC, JF, JMP, LOOP,
RET, GET_ATTR, SET_ATTR, LIST_CONSTRUCT, TUPLE_CONSTRUCT, WARN, RET, GET_ATTR, SET_ATTR, LIST_CONSTRUCT, TUPLE_CONSTRUCT, WARN,
INTERFACE_CALL, LIST_UNPACK, TUPLE_SLICE, DICT_CONSTRUCT, INTERFACE_CALL, LIST_UNPACK, TUPLE_SLICE, DICT_CONSTRUCT,
NAMED_TUPLE_CONSTRUCT, CREATE_OBJECT NAMED_TUPLE_CONSTRUCT, CREATE_OBJECT, ISINSTANCE
}; };
// clang-format on // clang-format on