diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 46161493cf6..00103753c62 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -3,6 +3,8 @@ #include #include +#include +#include #include #include #include @@ -171,6 +173,19 @@ bool isTensorInBytecodeArchive( namespace { +void tryRegisterMethod(const std::vector& args, Function& func) { + if (args.empty() || args[0].name() != "self") { + return; + } + + if (auto cls = args[0].type()->castRaw()) { + if (C10_UNLIKELY(cls->findMethod(func.name()))) { + return; + } + cls->addMethod(&func); + } +} + // The deserializer class which loads the bytecode package from bc files. class BytecodeDeserializer final { public: @@ -227,7 +242,8 @@ void BytecodeDeserializer::parseFunctionSchema( mobile::Function* function) { // function schema if (schemaTable) { // (schema is optional for back compat) - auto parseArgList = [this](c10::ivalue::TupleElements&& argTables) { + auto parseArgList = [this, + function](c10::ivalue::TupleElements&& argTables) { std::vector args; for (auto&& argTable : std::move(argTables)) { auto argTableElements = @@ -249,6 +265,7 @@ void BytecodeDeserializer::parseFunctionSchema( c10::nullopt /*N*/, std::move(default_value)); } + tryRegisterMethod(args, *function); return args; }; auto schemaTableElements = diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index adff37d2828..97780336732 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -36,7 +36,7 @@ Method Module::get_method(const std::string& name) const { } c10::optional Module::find_method(const std::string& basename) const { - for (auto& fn : cu_->methods()) { + for (const auto& fn : cu_->methods()) { if (fn->name() == basename) { return c10::make_optional(Method(this, fn.get())); }