mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[jit][edge] Implement torch::jit::Function for mobile funciton. (#65970)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65970 ghstack-source-id: 141842338 mobile::Function should inherit from jit::Function, because for interface call support, we need an abstract jit::Function type stored in corresponding ClassTypes, so that we can look up methods in there. Previously mobile::Function is implemented separately which prevents this. Since we get rid of all the unneeded virtual methods from jit::Function, we can inherit from torch::jit::Function without too much cost. NOTE that torch::jit::Function is already in dependency because we need it to support custom class call. We should be able to use Function uniformly without looking into whether it's a builtin function or mobile::Function. Test Plan: no behavior change. Reviewed By: iseeyuan, mrshenli Differential Revision: D31326148 fbshipit-source-id: 36caeaf3c8c5f54c23a1a7c8c9e2fd6e78b19622
This commit is contained in:
parent
5ef62c88a9
commit
60472594e1
|
|
@ -23,7 +23,7 @@ struct BuiltinOpFunction : public Function {
|
|||
TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
|
||||
}
|
||||
|
||||
const std::string& doc_string() const override {
|
||||
c10::string_view doc_string() const override {
|
||||
return doc_string_;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -36,8 +36,8 @@ TORCH_API void preoptimizeGraph(std::shared_ptr<Graph>& graph);
|
|||
// execution of the function. Method is a wrapper around an
|
||||
// underlying Function that also provides a `self` object.
|
||||
struct TORCH_API Function {
|
||||
virtual const std::string& doc_string() const {
|
||||
static const std::string no_doc_string = "";
|
||||
virtual c10::string_view doc_string() const {
|
||||
static constexpr c10::string_view no_doc_string = "";
|
||||
return no_doc_string;
|
||||
}
|
||||
|
||||
|
|
@ -49,7 +49,10 @@ struct TORCH_API Function {
|
|||
|
||||
virtual c10::intrusive_ptr<c10::ivalue::Future> runAsync(
|
||||
Stack& stack,
|
||||
TaskLauncher taskLauncher = at::launch) = 0;
|
||||
TaskLauncher taskLauncher = at::launch) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
|
||||
return {};
|
||||
}
|
||||
|
||||
at::IValue operator()(
|
||||
Stack stack,
|
||||
|
|
|
|||
|
|
@ -17,10 +17,6 @@ const c10::QualifiedName& Function::qualname() const {
|
|||
return name_;
|
||||
}
|
||||
|
||||
const std::string& Function::name() const {
|
||||
return name_.name();
|
||||
}
|
||||
|
||||
void Function::append_instruction(OpCode op, int X, int N, int64_t dbg_handle) {
|
||||
TORCH_CHECK(
|
||||
isOpSupportedInMobile(op),
|
||||
|
|
@ -157,29 +153,42 @@ int64_t Function::get_debug_handle(size_t pc) const {
|
|||
return code_->debug_handles_[pc];
|
||||
}
|
||||
|
||||
void Function::setSchema(c10::FunctionSchema schema) {
|
||||
torch::jit::Function& Function::setSchema(c10::FunctionSchema schema) {
|
||||
schema_ = std::move(schema);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const at::optional<c10::FunctionSchema>& Function::getSchema() const {
|
||||
return schema_;
|
||||
bool Function::hasSchema() const {
|
||||
return schema_.has_value();
|
||||
}
|
||||
|
||||
bool Function::run(Stack& stack) const {
|
||||
const auto& schema = getSchema();
|
||||
if (schema) { // if we have a schema then resolve optional args if any
|
||||
schema->checkAndNormalizeInputs(
|
||||
const c10::FunctionSchema& Function::getSchema() const {
|
||||
return *schema_;
|
||||
}
|
||||
|
||||
void Function::run(Stack& stack) {
|
||||
if (hasSchema()) { // if we have a schema then resolve optional args if any
|
||||
getSchema().checkAndNormalizeInputs(
|
||||
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
|
||||
}
|
||||
InterpreterState interp_state(*code_);
|
||||
return interp_state.run(stack);
|
||||
interp_state.run(stack);
|
||||
}
|
||||
|
||||
c10::IValue Function::operator()(Stack& stack) const {
|
||||
at::IValue Function::operator()(Stack& stack) {
|
||||
run(stack);
|
||||
return stack.front();
|
||||
}
|
||||
|
||||
size_t Function::num_inputs() const {
|
||||
return schema_->arguments().size();
|
||||
}
|
||||
|
||||
bool Function::call(Stack&, c10::function_ref<void(const mobile::Code&)> f) {
|
||||
f(*code_);
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Code> Function::get_code() const {
|
||||
return code_;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/core/function.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
@ -12,13 +14,16 @@ enum OpCode : uint8_t;
|
|||
namespace mobile {
|
||||
struct Code;
|
||||
|
||||
class Function {
|
||||
class TORCH_API Function : public torch::jit::Function {
|
||||
public:
|
||||
TORCH_API Function(c10::QualifiedName name);
|
||||
TORCH_API bool run(Stack& stack) const;
|
||||
c10::IValue operator()(Stack& stack) const;
|
||||
const std::string& name() const;
|
||||
TORCH_API const c10::QualifiedName& qualname() const;
|
||||
explicit Function(c10::QualifiedName name);
|
||||
void run(Stack& stack) override;
|
||||
at::IValue operator()(Stack& stack);
|
||||
void ensure_defined() override {}
|
||||
size_t num_inputs() const override;
|
||||
const c10::QualifiedName& qualname() const override;
|
||||
bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) override;
|
||||
|
||||
void append_instruction(OpCode op, int X, int N, int64_t dbg_handle);
|
||||
void append_instruction(OpCode op, int X, int N);
|
||||
bool append_operator(
|
||||
|
|
@ -29,15 +34,16 @@ class Function {
|
|||
are removed */
|
||||
void append_constant(const c10::IValue& constant);
|
||||
void append_type(const c10::TypePtr& type);
|
||||
TORCH_API void append_function(mobile::Function& func);
|
||||
void append_function(mobile::Function& func);
|
||||
|
||||
void set_register_size(size_t size);
|
||||
|
||||
int64_t get_debug_handle(size_t pc) const;
|
||||
const std::shared_ptr<Code> get_code() const;
|
||||
|
||||
void setSchema(c10::FunctionSchema schema);
|
||||
const at::optional<c10::FunctionSchema>& getSchema() const;
|
||||
torch::jit::Function& setSchema(c10::FunctionSchema schema) override;
|
||||
bool hasSchema() const;
|
||||
const c10::FunctionSchema& getSchema() const override;
|
||||
|
||||
// Returns the debug handle corresponding to where the execution
|
||||
// is halted due to exception.
|
||||
|
|
|
|||
|
|
@ -67,6 +67,12 @@ void InterpreterState::saveExceptionDebugHandle() {
|
|||
}
|
||||
}
|
||||
|
||||
void InterpreterState::callFunction(torch::jit::Function& f, Stack& stack) {
|
||||
bool newFrame =
|
||||
f.call(stack, [&](const mobile::Code& code) { enterFrame(code); });
|
||||
(frames_.rbegin() + (newFrame ? 1 : 0))->step();
|
||||
}
|
||||
|
||||
bool InterpreterState::run(Stack& stack) {
|
||||
while (true) {
|
||||
try {
|
||||
|
|
@ -125,9 +131,8 @@ bool InterpreterState::run(Stack& stack) {
|
|||
frame.step();
|
||||
} break;
|
||||
case CALL: {
|
||||
auto& function = frame.getCode().functions_.at(inst.X);
|
||||
frame.step();
|
||||
enterFrame(*function->get_code());
|
||||
auto& function = *frame.getCode().functions_.at(inst.X);
|
||||
callFunction(function, stack);
|
||||
} break;
|
||||
case INTERFACE_CALL: {
|
||||
torch::jit::Function& method =
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ struct InterpreterState {
|
|||
void enterFrame(const Code&);
|
||||
void leaveFrame();
|
||||
void saveExceptionDebugHandle();
|
||||
void callFunction(torch::jit::Function& f, Stack& stack);
|
||||
|
||||
c10::IValue& reg(size_t reg);
|
||||
std::vector<c10::IValue> registers_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user