[jit][edge] Implement torch::jit::Function for mobile funciton. (#65970)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65970

ghstack-source-id: 141842338

mobile::Function should inherit from jit::Function, because for interface call support, we need an abstract jit::Function type stored in corresponding ClassTypes, so that we can look up methods in there. Previously mobile::Function is implemented separately which prevents this. Since we get rid of all the unneeded virtual methods from jit::Function, we can inherit from torch::jit::Function without too much cost.

NOTE that torch::jit::Function is already in dependency because we need it to support custom class call. We should be able to use Function uniformly without looking into whether it's a builtin function or mobile::Function.

Test Plan: no behavior change.

Reviewed By: iseeyuan, mrshenli

Differential Revision: D31326148

fbshipit-source-id: 36caeaf3c8c5f54c23a1a7c8c9e2fd6e78b19622
This commit is contained in:
Zhengxu Chen 2021-10-28 13:30:52 -07:00 committed by Facebook GitHub Bot
parent 5ef62c88a9
commit 60472594e1
6 changed files with 54 additions and 30 deletions

View File

@ -23,7 +23,7 @@ struct BuiltinOpFunction : public Function {
TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
}
const std::string& doc_string() const override {
c10::string_view doc_string() const override {
return doc_string_;
}

View File

@ -36,8 +36,8 @@ TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
// execution of the function. Method is a wrapper around an
// underlying Function that also provides a `self` object.
struct TORCH_API Function {
virtual const std::string& doc_string() const {
static const std::string no_doc_string = "";
virtual c10::string_view doc_string() const {
static constexpr c10::string_view no_doc_string = "";
return no_doc_string;
}
@ -49,7 +49,10 @@ struct TORCH_API Function {
virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch) = 0;
TaskLauncher taskLauncher = at::launch) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
return {};
}
at::IValue operator()(
Stack stack,

View File

@ -17,10 +17,6 @@ const c10::QualifiedName& Function::qualname() const {
return name_;
}
const std::string& Function::name() const {
return name_.name();
}
void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) {
TORCH_CHECK(
isOpSupportedInMobile(op),
@ -157,29 +153,42 @@ int64_t Function::get_debug_handle(size_t pc) const {
return code_->debug_handles_[pc];
}
void Function::setSchema(c10::FunctionSchema schema) {
torch::jit::Function& Function::setSchema(c10::FunctionSchema schema) {
schema_ = std::move(schema);
return *this;
}
const at::optional<c10::FunctionSchema>& Function::getSchema() const {
return schema_;
bool Function::hasSchema() const {
return schema_.has_value();
}
bool Function::run(Stack& stack) const {
const auto& schema = getSchema();
if (schema) { // if we have a schema then resolve optional args if any
schema->checkAndNormalizeInputs(
const c10::FunctionSchema& Function::getSchema() const {
return *schema_;
}
void Function::run(Stack& stack) {
if (hasSchema()) { // if we have a schema then resolve optional args if any
getSchema().checkAndNormalizeInputs(
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
}
InterpreterState interp_state(*code_);
return interp_state.run(stack);
interp_state.run(stack);
}
c10::IValue Function::operator()(Stack& stack) const {
at::IValue Function::operator()(Stack& stack) {
run(stack);
return stack.front();
}
size_t Function::num_inputs() const {
return schema_->arguments().size();
}
bool Function::call(Stack&, c10::function_ref<void(const mobile::Code&)> f) {
f(*code_);
return true;
}
const std::shared_ptr<Code> Function::get_code() const {
return code_;
}

View File

@ -1,8 +1,10 @@
#pragma once
#include <vector>
#include <ATen/core/function.h>
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
#include <vector>
namespace torch {
namespace jit {
@ -12,13 +14,16 @@ enum OpCode : uint8_t;
namespace mobile {
struct Code;
class Function {
class TORCH_API Function : public torch::jit::Function {
public:
TORCH_API Function(c10::QualifiedName name);
TORCH_API bool run(Stack& stack) const;
c10::IValue operator()(Stack& stack) const;
const std::string& name() const;
TORCH_API const c10::QualifiedName& qualname() const;
explicit Function(c10::QualifiedName name);
void run(Stack& stack) override;
at::IValue operator()(Stack& stack);
void ensure_defined() override {}
size_t num_inputs() const override;
const c10::QualifiedName& qualname() const override;
bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) override;
void append_instruction(OpCode op, int X, int N, int64_t dbg_handle);
void append_instruction(OpCode op, int X, int N);
bool append_operator(
@ -29,15 +34,16 @@ class Function {
are removed */
void append_constant(const c10::IValue& constant);
void append_type(const c10::TypePtr& type);
TORCH_API void append_function(mobile::Function& func);
void append_function(mobile::Function& func);
void set_register_size(size_t size);
int64_t get_debug_handle(size_t pc) const;
const std::shared_ptr<Code> get_code() const;
void setSchema(c10::FunctionSchema schema);
const at::optional<c10::FunctionSchema>& getSchema() const;
torch::jit::Function& setSchema(c10::FunctionSchema schema) override;
bool hasSchema() const;
const c10::FunctionSchema& getSchema() const override;
// Returns the debug handle corresponding to where the execution
// is halted due to exception.

View File

@ -67,6 +67,12 @@ void InterpreterState::saveExceptionDebugHandle() {
}
}
void InterpreterState::callFunction(torch::jit::Function& f, Stack& stack) {
bool newFrame =
f.call(stack, [&](const mobile::Code& code) { enterFrame(code); });
(frames_.rbegin() + (newFrame ? 1 : 0))->step();
}
bool InterpreterState::run(Stack& stack) {
while (true) {
try {
@ -125,9 +131,8 @@ bool InterpreterState::run(Stack& stack) {
frame.step();
} break;
case CALL: {
auto& function = frame.getCode().functions_.at(inst.X);
frame.step();
enterFrame(*function->get_code());
auto& function = *frame.getCode().functions_.at(inst.X);
callFunction(function, stack);
} break;
case INTERFACE_CALL: {
torch::jit::Function& method =

View File

@ -17,6 +17,7 @@ struct InterpreterState {
void enterFrame(const Code&);
void leaveFrame();
void saveExceptionDebugHandle();
void callFunction(torch::jit::Function& f, Stack& stack);
c10::IValue& reg(size_t reg);
std::vector<c10::IValue> registers_;