From 12daa4f6632b20020f1f99e9ab8776ce018f0b2a Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Mon, 25 Oct 2021 14:43:08 -0700 Subject: [PATCH] [jit][edge] Enable CALL instruction in lite interpreter. (#65964) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65964 ghstack-source-id: 141425519 Test Plan: buck run xplat/caffe2:test_lite_interpreter Reviewed By: cccclai Differential Revision: D31326149 fbshipit-source-id: 8a599d92f3fa4e6c125100adb36d89592e71e547 --- test/cpp/jit/test_lite_interpreter.cpp | 87 ++++++++++++++++++++++++++ torch/csrc/jit/mobile/code.h | 7 +++ torch/csrc/jit/mobile/function.cpp | 4 ++ torch/csrc/jit/mobile/function.h | 1 + torch/csrc/jit/mobile/interpreter.cpp | 5 ++ torch/csrc/jit/runtime/instruction.cpp | 2 +- 6 files changed, 105 insertions(+), 1 deletion(-) diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index 82d3cba61ba..9a4d1750cd1 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -1371,5 +1372,91 @@ TEST(LiteInterpreterTest, OperatorCacheDifferentiatesDefaultArgs) { testLiteModuleCompareResultTensors(m, inputs, "forward3"); } +TEST(RunTimeTest, RuntimeCall) { + // def call(x): + // return x + x + // + // def forward(a): + // x = a + call(a) + // y = a + call(x) + // return y + + std::vector instructionsCall{ + to_tuple({"STORE", 1, 0}), + to_tuple({"LOAD", 1, 0}), + to_tuple({"MOVE", 1, 0}), + to_tuple({"LOADC", 0, 0}), + to_tuple({"OP", 0, 0}), + to_tuple({"RET", 0, 0}), + }; + std::vector instructionsFoo{ + to_tuple({"STORE", 1, 0}), + to_tuple({"LOAD", 1, 0}), + to_tuple({"LOAD", 1, 0}), + to_tuple({"MOVE", 1, 0}), + to_tuple({"CALL", 0, 0}), + to_tuple({"LOADC", 0, 0}), + to_tuple({"OP", 0, 0}), + to_tuple({"CALL", 0, 0}), + to_tuple({"LOADC", 0, 0}), + to_tuple({"OP", 0, 0}), + to_tuple({"RET", 0, 0}), + }; + std::vector operatorsFoo{ + to_tuple({"aten::add", "Tensor", 3}), + }; + std::vector constantsFoo{ + 1, + }; + std::vector operatorsCall{ + to_tuple({"aten::add", "Tensor", 3}), + }; + std::vector constantsCall{ + 1, + }; + int64_t model_version = caffe2::serialize::kProducedBytecodeVersion; + + auto foo = std::make_unique(c10::QualifiedName("foo")); + c10::ivalue::TupleElements debug_handles_m_tuple; + parseInstructions( + "foo", + std::move(*c10::ivalue::Tuple::create(instructionsFoo)).elements(), + debug_handles_m_tuple, + foo.get()); + parseOperators( + std::move(*c10::ivalue::Tuple::create(operatorsFoo)).elements(), + model_version, + 1, + foo.get()); + parseConstants( + std::move(*c10::ivalue::Tuple::create(constantsFoo)).elements(), + foo.get()); + const size_t rsize = 5; + parseRegisterSize(rsize, foo.get()); + + auto call = std::make_unique(c10::QualifiedName("call")); + parseInstructions( + "call", + std::move(*c10::ivalue::Tuple::create(instructionsCall)).elements(), + debug_handles_m_tuple, + call.get()); + parseOperators( + std::move(*c10::ivalue::Tuple::create(operatorsCall)).elements(), + model_version, + 1, + call.get()); + parseConstants( + std::move(*c10::ivalue::Tuple::create(constantsCall)).elements(), + call.get()); + parseRegisterSize(rsize, call.get()); + + foo->append_function(*call); + + std::vector inputs{at::tensor(1)}; + foo->run(inputs); + auto output = inputs[0]; + ASSERT_EQ(output, at::tensor(7)); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/code.h b/torch/csrc/jit/mobile/code.h index 9cf6e7c41df..5ea66176892 100644 --- a/torch/csrc/jit/mobile/code.h +++ b/torch/csrc/jit/mobile/code.h @@ -13,6 +13,8 @@ namespace mobile { using Stack = std::vector; using DebugHandle = int64_t; +class Function; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct Code { std::vector instructions_; @@ -21,6 +23,11 @@ struct Code { std::vector> operators_; std::vector constants_; std::vector types_; + // TODO After we actually export CALL instructions we can remove this. + // We may need a two-stage importing scheme, where we firstly construct all + // function objects, and then append referenced function pointers. This could + // be done in parseMethods(). + std::vector functions_; size_t register_size_; // Aggregated output size. }; diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index 138630c703f..92adf2eddbc 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -141,6 +141,10 @@ void Function::append_type(const at::TypePtr& type) { code_->types_.push_back(type); } +void Function::append_function(mobile::Function& function) { + code_->functions_.push_back(&function); +} + void Function::set_register_size(size_t size) { code_->register_size_ = size; } diff --git a/torch/csrc/jit/mobile/function.h b/torch/csrc/jit/mobile/function.h index 85daca1ffb6..a7ee58b071d 100644 --- a/torch/csrc/jit/mobile/function.h +++ b/torch/csrc/jit/mobile/function.h @@ -29,6 +29,7 @@ class Function { are removed */ void append_constant(const c10::IValue& constant); void append_type(const c10::TypePtr& type); + TORCH_API void append_function(mobile::Function& func); void set_register_size(size_t size); diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp index 6b4742f7e9b..310ab1eaf91 100644 --- a/torch/csrc/jit/mobile/interpreter.cpp +++ b/torch/csrc/jit/mobile/interpreter.cpp @@ -124,6 +124,11 @@ bool InterpreterState::run(Stack& stack) { code.operators_[inst.X](stack); frame.step(); } break; + case CALL: { + auto& function = frame.getCode().functions_.at(inst.X); + frame.step(); + enterFrame(*function->get_code()); + } break; case INTERFACE_CALL: { torch::jit::Function& method = peek(stack, 0, inst.N) diff --git a/torch/csrc/jit/runtime/instruction.cpp b/torch/csrc/jit/runtime/instruction.cpp index 297fe5ec68d..19c4d778c91 100644 --- a/torch/csrc/jit/runtime/instruction.cpp +++ b/torch/csrc/jit/runtime/instruction.cpp @@ -79,7 +79,7 @@ bool isOpSupportedInMobile(OpCode op) { OP, OPN, LOAD, MOVE, STOREN, STORE, DROP, DROPR, LOADC, JF, JMP, LOOP, RET, GET_ATTR, SET_ATTR, LIST_CONSTRUCT, TUPLE_CONSTRUCT, WARN, INTERFACE_CALL, LIST_UNPACK, TUPLE_SLICE, DICT_CONSTRUCT, - NAMED_TUPLE_CONSTRUCT, CREATE_OBJECT, ISINSTANCE + NAMED_TUPLE_CONSTRUCT, CREATE_OBJECT, ISINSTANCE, CALL }; // clang-format on