mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Lite Interpreter] Support features from to_backend (#52870)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52870 Add the missing parts to support to_backend modules by lite interpreter. 1. Add ISINSTANCE instruction support, which is used in to_backend for output type check. 2. Bypass lite interpreter's type parser by checking the qualified name. If it starts with "torch.jit", use the same type resolver as nn module (starting with "__torch__"). Tests Mobile module is serialized and loaded in ```BackendTest.TestCompiler```. The results are compared to those from original torchscript module. Test Plan: Imported from OSS Reviewed By: raziel Differential Revision: D26715351 Pulled By: iseeyuan fbshipit-source-id: ad9d74ee81c6aa692ab9e5dd7a9003bae5d4f01f
This commit is contained in:
parent
8467e5cad3
commit
b5ae8e69a7
|
|
@ -2,6 +2,7 @@
|
|||
#include <test/cpp/jit/test_utils.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/backends/backend_detail.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
// Tests go in torch::jit
|
||||
|
|
@ -100,6 +101,12 @@ TEST(BackendTest, TestCompiler) {
|
|||
"backend_with_compiler_demo", m, compile_spec, any_dict_ty);
|
||||
auto res = lm.forward(inputs);
|
||||
AT_ASSERT(res.toTensor().equal(ref.toTensor()));
|
||||
|
||||
std::stringstream ss;
|
||||
lm._save_for_mobile(ss);
|
||||
auto mlm = _load_for_mobile(ss);
|
||||
auto mres = mlm.forward(inputs);
|
||||
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
|
||||
}
|
||||
|
||||
TEST(BackendTest, TestCompilerNotSupport) {
|
||||
|
|
|
|||
|
|
@ -85,9 +85,9 @@ class BackendWithCompiler : public PyTorchBackendInterface {
|
|||
auto sub = instruction.substr(15);
|
||||
const_val = stod(sub);
|
||||
} else if (token == "aten::add") {
|
||||
output_list.emplace_back(x.add_(h, const_val));
|
||||
output_list.emplace_back(x.add(h, const_val));
|
||||
} else if (token == "aten::sub") {
|
||||
output_list.emplace_back(x.sub_(h, const_val));
|
||||
output_list.emplace_back(x.sub(h, const_val));
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
|
|
|||
|
|
@ -146,13 +146,16 @@ BytecodeDeserializer::BytecodeDeserializer(
|
|||
reader_(std::move(reader)) {}
|
||||
|
||||
TypePtr BytecodeDeserializer::resolveTypeName(const c10::QualifiedName& qn) {
|
||||
static const c10::QualifiedName torchPrefix = "__torch__";
|
||||
// HACK: first we check whether the name starts with `__torch__` to
|
||||
// tell if it's "supposed" to be a class type. This is a reliable
|
||||
// HACK: first we check whether the name starts with special prefix to
|
||||
// tell if it's a supported pytorch class type. There are two special
|
||||
// prefixes. "__torch__" for nn module, and "torch.jit" from to_backend.
|
||||
// This is a reliable
|
||||
// check today, but there is no guarantee that this is the case. The
|
||||
// real solution is to merge type parsers so we can share class
|
||||
// resolution logic.
|
||||
if (torchPrefix.isPrefixOf(qn)) {
|
||||
static const c10::QualifiedName torchPrefix = "__torch__";
|
||||
static const c10::QualifiedName jitPrefix = "torch.jit";
|
||||
if (torchPrefix.isPrefixOf(qn) || jitPrefix.isPrefixOf(qn)) {
|
||||
if (compilation_unit_->get_class(qn) == nullptr) {
|
||||
auto typeptr = ClassType::create(qn, compilation_unit_, true);
|
||||
compilation_unit_->register_type(typeptr);
|
||||
|
|
|
|||
|
|
@ -26,6 +26,17 @@ void createObject(Stack& stack, const at::ClassTypePtr& type) {
|
|||
type->numAttributes());
|
||||
push(stack, std::move(userObj));
|
||||
}
|
||||
|
||||
void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
|
||||
at::TypePtr ty = pop(stack).type();
|
||||
for (const at::TypePtr& candidate : types) {
|
||||
if (ty->isSubtypeOf(candidate)) {
|
||||
push(stack, true);
|
||||
return;
|
||||
}
|
||||
}
|
||||
push(stack, false);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
using namespace at;
|
||||
|
|
@ -189,6 +200,12 @@ bool InterpreterState::run(Stack& stack) {
|
|||
createObject(stack, type);
|
||||
++pc;
|
||||
} break;
|
||||
case ISINSTANCE: {
|
||||
at::ArrayRef<TypePtr> types(
|
||||
&(code_->types_[inst.X]), &(code_->types_[inst.X + inst.N]));
|
||||
isinstance(stack, types);
|
||||
++pc;
|
||||
} break;
|
||||
case WARN: {
|
||||
drop(stack, 1);
|
||||
TORCH_WARN(pop(stack).toStringRef());
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ bool isOpSupportedInMobile(OpCode op) {
|
|||
OP, OPN, LOAD, MOVE, STOREN, STORE, DROP, DROPR, LOADC, JF, JMP, LOOP,
|
||||
RET, GET_ATTR, SET_ATTR, LIST_CONSTRUCT, TUPLE_CONSTRUCT, WARN,
|
||||
INTERFACE_CALL, LIST_UNPACK, TUPLE_SLICE, DICT_CONSTRUCT,
|
||||
NAMED_TUPLE_CONSTRUCT, CREATE_OBJECT
|
||||
NAMED_TUPLE_CONSTRUCT, CREATE_OBJECT, ISINSTANCE
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user