mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch Mobile] Support torchbind custom classes in lite interpreter (#51432)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51432 ghstack-source-id: 120976584 torchbind is a convenient way to include custom class to both python and torchscript. CREATE_OBJECT is used to create an object of custom class. CREATE_OBJECT was not supported by lite interpreter. The major reason was that for custom class directly defined in Python, there's no language parser in lite interpreter. It's still the case. However, for torchbind classes that are defined in C++, a python/torchscript parser is not needed. This diff is to support the case of torchbind custom classes. 1. The class type can be resolved at import level. 2. If the class is not the supported torchbind class, an error message is provided at export stage. Workaround is also suggested. 3. Unit tests. C++: ```LiteInterpreterTest::BuiltinClass``` is added as an end-to-end test on supported class. Python: ```test_unsupported_createobject``` is changed to ```test_unsupported_classtype``` to test unsupported classes. Test Plan: CI Reviewed By: raziel Differential Revision: D26168913 fbshipit-source-id: 74e8b6a12682ad8e9c39afdfd2b605c5f8e65427
This commit is contained in:
parent
1ffd26f8d8
commit
23c50a4a50
|
|
@ -3,6 +3,7 @@
|
|||
#include <c10/core/TensorOptions.h>
|
||||
#include <torch/csrc/autograd/generated/variable_factories.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/frontend/resolver.h>
|
||||
#include <torch/csrc/jit/mobile/import.h>
|
||||
#include <torch/csrc/jit/mobile/module.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
|
|
@ -348,6 +349,87 @@ class TorchBindLiteInterpreterTestStruct
|
|||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct ClassNamespaceValue : public SugaredValue {
|
||||
explicit ClassNamespaceValue(c10::QualifiedName name)
|
||||
: basename_(std::move(name)) {}
|
||||
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
const std::string& name) override {
|
||||
const auto fullName = c10::QualifiedName(basename_, name);
|
||||
|
||||
// Check to see if it is a custom class.
|
||||
if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
|
||||
return std::make_shared<ClassValue>(custom_class);
|
||||
}
|
||||
|
||||
// If it's not a custom class, assume it's another namespace
|
||||
return std::make_shared<ClassNamespaceValue>(std::move(fullName));
|
||||
}
|
||||
|
||||
std::string kind() const override {
|
||||
return "Class Namespace";
|
||||
}
|
||||
|
||||
private:
|
||||
c10::QualifiedName basename_;
|
||||
};
|
||||
|
||||
struct TestModuleResolver : public Resolver {
|
||||
std::shared_ptr<SugaredValue> resolveValue(
|
||||
const std::string& name,
|
||||
Function& m,
|
||||
const SourceRange& loc) override {
|
||||
if (name == "torch") {
|
||||
return std::make_shared<BuiltinModule>("aten");
|
||||
} else if (name == "__torch__") {
|
||||
return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TypePtr resolveType(const std::string& name, const SourceRange& loc)
|
||||
override {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST(LiteInterpreterTest, BuiltinClass) {
|
||||
script::Module m("m");
|
||||
|
||||
auto cls = getCustomClass(
|
||||
"__torch__.torch.classes._TorchScriptTesting._LiteInterpreterTest");
|
||||
TORCH_INTERNAL_ASSERT(cls);
|
||||
c10::intrusive_ptr<torch::CustomClassHolder> obj_holder;
|
||||
m.register_attribute("my_obj", cls, IValue::make_capsule(obj_holder));
|
||||
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(
|
||||
R"(
|
||||
def __getstate__(self):
|
||||
return 1
|
||||
def __setstate__(self, a):
|
||||
self.my_obj = __torch__.torch.classes._TorchScriptTesting._LiteInterpreterTest()
|
||||
|
||||
def forward(self, x) -> str:
|
||||
return self.my_obj.get(x)
|
||||
)",
|
||||
std::make_shared<TestModuleResolver>());
|
||||
|
||||
std::stringstream ss;
|
||||
m._save_for_mobile(ss);
|
||||
mobile::Module bc = _load_for_mobile(ss);
|
||||
auto res =
|
||||
bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
|
||||
const auto& str = res.toStringRef();
|
||||
std::string expected = "Hello! Your tensor has 12 elements!";
|
||||
AT_ASSERT(str == expected);
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, BuiltinFunction) {
|
||||
script::Module m("m");
|
||||
auto custom_class_obj =
|
||||
|
|
@ -828,6 +910,7 @@ static auto reg =
|
|||
torch::class_<TorchBindLiteInterpreterTestStruct>(
|
||||
"_TorchScriptTesting",
|
||||
"_LiteInterpreterTest")
|
||||
.def(torch::init<>())
|
||||
.def("get", &TorchBindLiteInterpreterTestStruct::get)
|
||||
.def_pickle(
|
||||
// __getattr__
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ class TestLiteScriptModule(TestCase):
|
|||
mobile_module_forward_result
|
||||
)
|
||||
|
||||
def test_unsupported_createobject(self):
|
||||
def test_unsupported_classtype(self):
|
||||
class Foo():
|
||||
def __init__(self):
|
||||
return
|
||||
|
|
@ -252,7 +252,6 @@ class TestLiteScriptModule(TestCase):
|
|||
|
||||
script_module = torch.jit.script(MyTestModule())
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"^CREATE_OBJECT is not supported in mobile module\. "
|
||||
r"Workaround: instead of using arbitrary class type \(class Foo\(\)\), "
|
||||
r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\.$"):
|
||||
script_module._save_to_buffer_for_lite_interpreter()
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ namespace jit {
|
|||
char const* toString(OpCode op);
|
||||
namespace mobile {
|
||||
Function::Function(c10::QualifiedName name)
|
||||
: name_(name), code_(std::make_shared<Code>()) {}
|
||||
: name_(std::move(name)), code_(std::make_shared<Code>()) {}
|
||||
|
||||
const c10::QualifiedName& Function::qualname() const {
|
||||
return name_;
|
||||
|
|
@ -23,11 +23,6 @@ const std::string& Function::name() const {
|
|||
}
|
||||
|
||||
void Function::append_instruction(OpCode op, int X, int N) {
|
||||
TORCH_CHECK(
|
||||
op != CREATE_OBJECT,
|
||||
"CREATE_OBJECT is not supported in mobile module. ",
|
||||
"Workaround: instead of using arbitrary class type (class Foo()), ",
|
||||
"define a pytorch class (class Foo(torch.nn.Module)).");
|
||||
TORCH_CHECK(
|
||||
isOpSupportedInMobile(op),
|
||||
toString(op),
|
||||
|
|
@ -43,7 +38,7 @@ bool Function::append_operator(
|
|||
code_->op_names_.emplace_back(name, overload_name);
|
||||
auto opname = code_->op_names_.back();
|
||||
|
||||
auto opname_c10 = opname;
|
||||
const auto& opname_c10 = opname;
|
||||
std::function<void(Stack&)> fn;
|
||||
|
||||
auto jit_op = findOperatorFor(opname);
|
||||
|
|
@ -58,7 +53,7 @@ bool Function::append_operator(
|
|||
}
|
||||
}
|
||||
|
||||
if (model_version == 0x3L &&
|
||||
if (model_version == 0x3LL &&
|
||||
opname == c10::OperatorName("aten::_convolution", "")) {
|
||||
// Since byte-code versions 0x4L, convolution has an additional
|
||||
// default-value argument (allow_tf32=True, see
|
||||
|
|
|
|||
|
|
@ -294,8 +294,20 @@ void BytecodeDeserializer::parseMethods(
|
|||
function->append_constant(constant);
|
||||
}
|
||||
|
||||
static const c10::QualifiedName classPrefix = "__torch__.torch.classes";
|
||||
for (const auto& t : types_list) {
|
||||
function->append_type(c10::parseType(t.toStringRef()));
|
||||
c10::QualifiedName qn(t.toStringRef());
|
||||
if (classPrefix.isPrefixOf(qn)) {
|
||||
auto classType = getCustomClass(qn.qualifiedName());
|
||||
TORCH_CHECK(
|
||||
classType,
|
||||
"The implementation of class ",
|
||||
qn.qualifiedName(),
|
||||
" cannot be found.");
|
||||
function->append_type(classType);
|
||||
} else {
|
||||
function->append_type(c10::parseType(t.toStringRef()));
|
||||
}
|
||||
}
|
||||
|
||||
function->set_register_size(register_size);
|
||||
|
|
|
|||
|
|
@ -19,6 +19,15 @@ InterpreterState::InterpreterState(std::shared_ptr<Code> code)
|
|||
registers_.resize(code_->register_size_);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void createObject(Stack& stack, const at::ClassTypePtr& type) {
|
||||
auto userObj = c10::ivalue::Object::create(
|
||||
c10::StrongTypePtr(type->compilation_unit(), type),
|
||||
type->numAttributes());
|
||||
push(stack, std::move(userObj));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
using namespace at;
|
||||
|
||||
bool InterpreterState::run(Stack& stack) {
|
||||
|
|
@ -175,6 +184,11 @@ bool InterpreterState::run(Stack& stack) {
|
|||
namedTupleConstruct(stack, type, inst.N);
|
||||
++pc;
|
||||
} break;
|
||||
case CREATE_OBJECT: {
|
||||
auto type = code_->types_[inst.X]->expect<c10::ClassType>();
|
||||
createObject(stack, type);
|
||||
++pc;
|
||||
} break;
|
||||
case WARN: {
|
||||
drop(stack, 1);
|
||||
TORCH_WARN(pop(stack).toStringRef());
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ bool isOpSupportedInMobile(OpCode op) {
|
|||
OP, OPN, LOAD, MOVE, STOREN, STORE, DROP, DROPR, LOADC, JF, JMP, LOOP,
|
||||
RET, GET_ATTR, SET_ATTR, LIST_CONSTRUCT, TUPLE_CONSTRUCT, WARN,
|
||||
INTERFACE_CALL, LIST_UNPACK, TUPLE_SLICE, DICT_CONSTRUCT,
|
||||
NAMED_TUPLE_CONSTRUCT
|
||||
NAMED_TUPLE_CONSTRUCT, CREATE_OBJECT
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
|
|
|
|||
|
|
@ -181,11 +181,6 @@ std::pair<IValue, c10::optional<IValue>> getFunctionTuple(
|
|||
}
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
ins.op != CREATE_OBJECT,
|
||||
"CREATE_OBJECT is not supported in mobile module. ",
|
||||
"Workaround: instead of using arbitrary class type (class Foo()), ",
|
||||
"define a pytorch class (class Foo(torch.nn.Module)).");
|
||||
TORCH_CHECK(
|
||||
isOpSupportedInMobile(ins.op),
|
||||
toString(ins.op),
|
||||
|
|
@ -219,8 +214,19 @@ std::pair<IValue, c10::optional<IValue>> getFunctionTuple(
|
|||
// types
|
||||
std::vector<IValue> types;
|
||||
types.reserve(code.type_table().size());
|
||||
static const std::string torch_prefix("__torch__");
|
||||
static const std::string class_prefix("__torch__.torch.classes");
|
||||
for (const TypePtr& t : code.type_table()) {
|
||||
types.emplace_back(t->annotation_str());
|
||||
auto type_str = t->annotation_str();
|
||||
if (type_str.find(torch_prefix) == 0) {
|
||||
TORCH_CHECK(
|
||||
type_str.find(class_prefix) == 0,
|
||||
"__torch__ types other than torchbind (__torch__.torch.classes)"
|
||||
"are not supported in lite interpreter. ",
|
||||
"Workaround: instead of using arbitrary class type (class Foo()), ",
|
||||
"define a pytorch class (class Foo(torch.nn.Module)).");
|
||||
}
|
||||
types.emplace_back(type_str);
|
||||
}
|
||||
|
||||
// since the register location is embedded into the bytecode, pass the
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user