Revert D25152559: T66557700 Support default argument values of a method

Test Plan: revert-hammer

Differential Revision:
D25152559 (6bde0ca6d3)

Original commit changeset: bbf52f1fbdbf

fbshipit-source-id: 592fdb3078b1ac86cd394adc6c1bfd6b10d829e1
This commit is contained in:
Martin Yuan 2020-12-17 14:01:54 -08:00 committed by Facebook GitHub Bot
parent 0d411c4216
commit 2b61e4d84c
14 changed files with 164 additions and 348 deletions

View File

@ -102,7 +102,7 @@ struct BuiltinOpFunction : public Function {
std::string pretty_print_schema() const override {
TORCH_INTERNAL_ASSERT(false);
return ""; // TODO: suppress unreachable code warning
return "";
}
Function& setSchema(c10::FunctionSchema schema) override {

View File

@ -1,12 +1,12 @@
#pragma once
#include <ATen/core/alias_info.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <c10/util/StringUtil.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/alias_info.h>
#include <ATen/core/operator_name.h>
#include <c10/util/StringUtil.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <unordered_map>
namespace c10 {
@ -33,7 +33,8 @@ struct Argument {
N_(std::move(N)),
default_value_(std::move(default_value)),
kwarg_only_(kwarg_only),
alias_info_(std::move(alias_info)) {}
alias_info_(std::move(alias_info)) {
}
const std::string& name() const {
return name_;
}
@ -84,8 +85,7 @@ struct Argument {
}
Argument cloneWithType(TypePtr new_type) const {
return Argument(
name_, std::move(new_type), N_, default_value_, kwarg_only_, alias_info_);
return Argument(name_, new_type, N_, default_value_, kwarg_only_, alias_info_);
}
// this function check whether this Argument is backward compatible with
@ -95,9 +95,9 @@ struct Argument {
// 3) this arg must provide the same default value if old arg has one,
bool isBackwardCompatibleWith(
const Argument& old,
std::ostream* why_not = nullptr) const;
std::ostream* why_not=nullptr) const;
private:
private:
std::string name_;
TypePtr type_;
// for list types, an optional statically known length for the list
@ -113,10 +113,12 @@ struct Argument {
};
inline bool operator==(const Argument& lhs, const Argument& rhs) {
return lhs.name() == rhs.name() && *lhs.type() == *rhs.type() &&
lhs.N() == rhs.N() && lhs.default_value() == rhs.default_value() &&
lhs.kwarg_only() == rhs.kwarg_only() &&
lhs.alias_info() == rhs.alias_info();
return lhs.name() == rhs.name()
&& *lhs.type() == *rhs.type()
&& lhs.N() == rhs.N()
&& lhs.default_value() == rhs.default_value()
&& lhs.kwarg_only() == rhs.kwarg_only()
&& lhs.alias_info() == rhs.alias_info();
}
bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs);
@ -198,10 +200,7 @@ struct FunctionSchema {
// this should always be set no matter what
c10::optional<AliasAnalysisKind> alias_kind_;
void checkArg(
const IValue& value,
const Argument& argument,
optional<size_t> pos) const;
void checkArg(const IValue& value, const Argument& argument, optional<size_t> pos) const;
void checkSchema() const {
bool seen_default_arg = false;
@ -224,7 +223,8 @@ struct FunctionSchema {
}
}
public:
public:
void dump() const;
const OperatorName& operator_name() const {
@ -257,22 +257,21 @@ struct FunctionSchema {
}
c10::optional<int> argumentIndexWithName(const std::string& name) const {
for (size_t i = 0; i < arguments().size(); ++i) {
if (name == arguments()[i].name()) {
for(size_t i = 0; i < arguments().size(); ++i) {
if(name == arguments()[i].name())
return i;
}
}
return c10::nullopt;
}
FunctionSchema cloneWithName(std::string name, std::string overload_name)
const {
FunctionSchema cloneWithName(std::string name, std::string overload_name) const {
return FunctionSchema(
std::move(name),
std::move(overload_name),
arguments(),
returns(),
is_vararg(),
is_varret());
std::move(name),
std::move(overload_name),
arguments(),
returns(),
is_vararg(),
is_varret()
);
}
FunctionSchema cloneWithArguments(std::vector<Argument> new_arguments) const {
return FunctionSchema(
@ -306,8 +305,7 @@ struct FunctionSchema {
// values.
void checkAndNormalizeInputs(
std::vector<IValue>& inputs,
const std::unordered_map<std::string, IValue>& kwargs =
std::unordered_map<std::string, IValue>{}) const;
const std::unordered_map<std::string, IValue>& kwargs) const;
std::string findErrorInKwargs(const std::vector<std::string>& kwargs) const;
@ -325,6 +323,7 @@ struct FunctionSchema {
return false;
}
// TODO remove the mutation here
bool isDefaultAliasAnalysisKind() const {
return !alias_kind_;
@ -350,17 +349,16 @@ struct FunctionSchema {
// schema and have the program typecheck?
// as_method - if true, treat this schema as a method and ignore
// the first argument, which will be the object in both cases
bool isSubtypeOf(
const FunctionSchema& rhs,
bool as_method,
std::ostream* why_not = nullptr) const;
bool isSubtypeOf(const FunctionSchema& rhs, bool as_method, std::ostream* why_not=nullptr) const;
};
inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
return lhs.name() == rhs.name() &&
lhs.overload_name() == rhs.overload_name() &&
lhs.arguments() == rhs.arguments() && lhs.returns() == rhs.returns() &&
lhs.is_vararg() == rhs.is_vararg() && lhs.is_varret() == rhs.is_varret();
return lhs.name() == rhs.name()
&& lhs.overload_name() == rhs.overload_name()
&& lhs.arguments() == rhs.arguments()
&& lhs.returns() == rhs.returns()
&& lhs.is_vararg() == rhs.is_vararg()
&& lhs.is_varret() == rhs.is_varret();
}
inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) {
@ -370,14 +368,14 @@ inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) {
// print out Argument, which is compatible with FunctionSchema parser
// full format: Type(alias)? name=default_value
inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
// for adjusting the ? position.
// in schema, we have Tensor?(a!) input, and t(a!)?.
// however, t?(a!) doesn't work with schema parser.
// so we always use Type(alias)? format
auto type = arg.type();
bool is_opt = type->kind() == OptionalType::Kind;
auto unopt_type =
is_opt ? type->cast<OptionalType>()->getElementType() : type;
auto unopt_type = is_opt ? type->cast<OptionalType>()->getElementType() : type;
if (unopt_type->kind() == ListType::Kind && arg.N()) {
// sized lists get size N from arg, not type
@ -411,9 +409,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
return out;
}
inline std::ostream& operator<<(
std::ostream& out,
const FunctionSchema& schema);
inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema);
inline std::string toString(const FunctionSchema& schema) {
std::ostringstream str;

View File

@ -133,20 +133,14 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L;
// when given bool or integer fill values.
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
// The version we write when the archive contains bytecode.
// the version we write when the archive contains bytecode.
// It must be higher or eq to kProducedFileFormatVersion.
// Because torchscript changes is likely introduce bytecode change.
// If kProducedFileFormatVersion is increased, kProducedBytecodeVersion
// should be increased too. The relationship is:
// kMaxSupportedFileFormatVersion >= (most likely ==) kProducedBytecodeVersion
// >= kProducedFileFormatVersion
// Versions:
// 0x1L: Initial version
// 0x2L: (Comment missing)
// 0x3L: (Comment missing)
// 0x4L: (Comment missing)
// 0x5L: Added schema to function tuple
constexpr uint64_t kProducedBytecodeVersion = 0x5L;
constexpr uint64_t kProducedBytecodeVersion = 0x4L;
static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion,
"kProducedBytecodeVersion must be higher or equal to kProducedFileFormatVersion.");

View File

@ -65,49 +65,38 @@ TEST(LiteInterpreterTest, CheckAttrAccess) {
AT_ASSERT(!mobile_optimized);
}
TEST(LiteInterpreterTest, MethodInvocation) { // NOLINT (use =delete in gtest)
const std::vector<std::string> test_programs{
// test invoking a method with default parameter
R"(
def test_func(self, x, b : int = 4):
return self.foo + x + b
)",
// inner method call with default parameter (gets inlined)
R"(
def add_with_default_arg(self, x, b : int = 4):
return self.foo + x + b
def test_func(self, x):
return self.add_with_default_arg(x) # invoke method w/ default arg
)",
// simple method call
R"(
def test_func(self, x):
b = 4
return self.foo + x + b
)",
};
for (const auto& test_program : test_programs) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(test_program);
TEST(LiteInterpreterTest, Add) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
// TODO: support default param val, which was pushed in
// function schema's checkAndNormalizeInputs()
// m.define(R"(
// def add_it(self, x, b : int = 4):
// return self.foo + x + b
// )");
m.define(R"(
def add_it(self, x):
b = 4
return self.foo + x + b
)");
const int fortyTwo = 42; // (keep linter happy)
auto minput = fortyTwo * torch::ones({});
auto ref = m.run_method("test_func", minput);
std::vector<IValue> inputs;
auto minput = 5 * torch::ones({});
inputs.emplace_back(minput);
auto ref = m.run_method("add_it", minput);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
const auto& test_func = bc.get_method("test_func");
IValue res;
for (int i = 0; i < 3; ++i) {
res = test_func({minput});
}
auto resd = res.toTensor().item<float>();
auto refd = ref.toTensor().item<float>();
AT_ASSERT(resd == refd);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
IValue res;
for (int i = 0; i < 3; ++i) {
auto bcinputs = inputs;
res = bc.get_method("add_it")(bcinputs);
}
auto resd = res.toTensor().item<float>();
auto refd = ref.toTensor().item<float>();
AT_ASSERT(resd == refd);
}
TEST(LiteInterpreterTest, Conv) {

View File

@ -190,37 +190,6 @@ class TestLiteScriptModule(unittest.TestCase):
mobile_module_result = mobile_module.forward(*bundled_inputs[0])
torch.testing.assert_allclose(script_module_result, mobile_module_result)
def test_method_calls_with_optional_arg(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
def forward(self, x, two: int = 2): # opt arg in script-to-script invocation
return x + two
class B(torch.nn.Module):
def __init__(self):
super(B, self).__init__()
self.A0 = A()
def forward(self, x, one: int = 1): # opt arg in Python-to-script invocation
return self.A0(x) + one
script_module = torch.jit.script(B())
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
mobile_module = _load_for_lite_interpreter(buffer)
input = torch.tensor([5])
script_module_forward_result = script_module.forward(input)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_allclose(script_module_forward_result, mobile_module_forward_result)
script_module_forward_result = script_module.forward(input, 2) # change ref only
self.assertFalse((script_module_forward_result == mobile_module_forward_result).all().item())
mobile_module_forward_result = mobile_module.forward(input, 2) # now both match again
torch.testing.assert_allclose(script_module_forward_result, mobile_module_forward_result)
def test_unsupported_createobject(self):
class Foo():
def __init__(self):

View File

@ -107,13 +107,13 @@ Module Method::owner() const {
return Module(owner_);
}
void Method::run(Stack& stack) {
stack.insert(stack.begin(), owner()._ivalue()); // self
stack.insert(stack.begin(), owner()._ivalue());
RECORD_TORCHSCRIPT_FUNCTION(name(), stack);
function_->run(stack);
}
IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
stack.insert(stack.begin(), owner()._ivalue()); // self
stack.insert(stack.begin(), owner()._ivalue());
RECORD_TORCHSCRIPT_FUNCTION(name(), stack);
return (*function_)(std::move(stack), kwargs);
}

View File

@ -58,12 +58,12 @@ bool Function::append_operator(
}
if (model_version == 0x3L &&
model_version < caffe2::serialize::kProducedBytecodeVersion &&
opname == c10::OperatorName("aten::_convolution", "")) {
// Since byte-code versions 0x4L, convolution has an additional
// default-value argument (allow_tf32=True, see
// https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
// backward compatibility with models of byte-code version <= 0x3L, where
// this bool argument does not yet exist.
// A default-value argument will be added in
// https://github.com/pytorch/pytorch/pull/40737. This wrapper is used to
// handle backward compatibility, where there is no default bool value in
// old models.
fn = [fn](Stack& stack) {
stack.push_back(true);
fn(stack);
@ -107,26 +107,14 @@ std::string Function::get_module_debug_info(size_t pc) const {
return pc_to_module_debug_info_[pc];
}
void Function::setSchema(c10::FunctionSchema schema) {
schema_ = std::move(schema);
}
const at::optional<c10::FunctionSchema>& Function::getSchema() const {
return schema_;
}
bool Function::run(Stack& stack) const {
const auto& schema = getSchema();
if (schema) { // if we have a schema then resolve optional args if any
schema->checkAndNormalizeInputs(
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
}
InterpreterState interp_state(code_);
return interp_state.run(stack);
}
c10::IValue Function::operator()(Stack& stack) const {
run(stack);
c10::IValue Function::operator()(Stack& stack) {
InterpreterState interp_state(code_);
interp_state.run(stack);
return stack.front();
}

View File

@ -1,7 +1,6 @@
#pragma once
#include <ATen/core/ivalue.h>
//#include <aten/src/Aten/core/operator_name.h>
#include <ATen/core/function_schema.h>
#include <vector>
namespace torch {
@ -16,7 +15,7 @@ class Function {
public:
Function(c10::QualifiedName name);
bool run(Stack& stack) const;
c10::IValue operator()(Stack& stack) const;
c10::IValue operator()(Stack& stack);
const std::string& name() const;
const c10::QualifiedName& qualname() const;
void append_instruction(OpCode op, int X, int N);
@ -33,13 +32,9 @@ class Function {
std::string get_module_debug_info(size_t pc) const;
void setSchema(c10::FunctionSchema schema);
const at::optional<c10::FunctionSchema>& getSchema() const;
private:
c10::QualifiedName name_;
std::shared_ptr<Code> code_;
at::optional<c10::FunctionSchema> schema_; // (byte-code version 5+ only)
std::vector<std::string> pc_to_module_debug_info_;
};

View File

@ -15,47 +15,29 @@
// The import process to serialize the bytecode package.
// An example for bytecode.pkl of a small mobile_module looks like:
// (5, # model version number (caffe2::serialize::kProducedBytecodeVersion)
// # first method
// (
// # function name
// '__torch__.m.forward',
// # code
// (('instructions',
// (('STOREN', 1, 2),
// ('DROPR', 1, 0),
// ('MOVE', 2, 0),
// ('OP', 0, 0),
// ('RET', 0, 0))),
// ('operators', (('aten::Int', 'Tensor'),)),
// ('constants', ()),
// ('types', ()),
// ('register_size', 2)),
// # schema
// (('arguments',
// ((('name', 'x'), ('type', 'Tensor'), ('default_value', 13)),
// ...)), # more args follow here
// ('returns',
// ((('name', ''), ('type', 'Tensor'), ('default_value', None)),
// ...)), # more return values follow here
// )),
// # more methods follow here
// ...)
// (3,
// ('__torch__.m.forward',
// (('instructions',
// (('STOREN', 1, 2),
// ('DROPR', 1, 0),
// ('MOVE', 2, 0),
// ('OP', 0, 0),
// ('RET', 0, 0))),
// ('operators', (('aten::Int', 'Tensor'),)),
// ('constants', ()),
// ('types', ()),
// ('register_size', 2))))
// In addition, the module debugging information can be saved
// in mobile_debug.pkl. An example for it looks like:
// (5,
// ('__torch__.m.forward',
// (('module_debug_info', (top(A).foo(B).forward)))))
// (3,
// ('__torch__.m.forward',
// (('module_debug_info', (top(A).foo(B).forward)))))
// Note that currently the backward compatibility is not supported by bytecode.
// This format and process need to be revisited and redesigned if we want to
// This format and process need to be revisted and redesigned if we want to
// support backward compatibility in future.
// Note that the following function-schema fields are not supported:
// - Argument::{known_length_,kwarg_only_}
// - FunctionSchema::{overload_name_, is_vararg_, is_varret_}
namespace c10 {
// std::string serializeType(const Type &t);
TypePtr parseType(const std::string& pythonStr);
@ -108,57 +90,7 @@ void print_unsupported_ops_and_throw(
error_message);
}
// The deserializer class which loads the bytecode package from bc files.
class BytecodeDeserializer final {
public:
explicit BytecodeDeserializer(std::unique_ptr<PyTorchStreamReader> reader);
mobile::Module deserialize(
c10::optional<at::Device> device,
ExtraFilesMap& extra_files);
std::unordered_map<std::string, std::string> deserializeMetadata(
c10::optional<at::Device> device);
private:
TypePtr resolveTypeName(const c10::QualifiedName& qn);
void parseMethods(
const std::vector<IValue>& vals,
const c10::optional<std::vector<IValue>>& debug_info_vals,
mobile::CompilationUnit& mcu);
c10::IValue readArchive(
const std::string& archive_name,
std::shared_ptr<mobile::CompilationUnit> mcu);
std::unordered_map<std::string, std::string> readMobileMetadata(
std::shared_ptr<mobile::CompilationUnit> mcu);
std::shared_ptr<CompilationUnit> compilation_unit_;
std::unordered_set<std::string> imported_libs_;
std::unique_ptr<PyTorchStreamReader> reader_;
c10::optional<at::Device> device_;
};
BytecodeDeserializer::BytecodeDeserializer(
std::unique_ptr<PyTorchStreamReader> reader)
: compilation_unit_(std::make_shared<CompilationUnit>()),
reader_(std::move(reader)) {}
TypePtr BytecodeDeserializer::resolveTypeName(const c10::QualifiedName& qn) {
static const c10::QualifiedName torchPrefix = "__torch__";
// HACK: first we check whether the name starts with `__torch__` to
// tell if it's "supposed" to be a class type. This is a reliable
// check today, but there is no guarantee that this is the case. The
// real solution is to merge type parsers so we can share class
// resolution logic.
if (torchPrefix.isPrefixOf(qn)) {
if (compilation_unit_->get_class(qn) == nullptr) {
auto typeptr = ClassType::create(qn, compilation_unit_, true);
compilation_unit_->register_type(typeptr);
}
return compilation_unit_->get_class(qn);
} else {
return c10::parseType(qn.qualifiedName());
}
}
void BytecodeDeserializer::parseMethods(
void parseMethods(
const std::vector<IValue>& vals,
const c10::optional<std::vector<IValue>>& debug_info_vals,
mobile::CompilationUnit& mcu) {
@ -194,32 +126,27 @@ void BytecodeDeserializer::parseMethods(
const auto& element = vals[i];
const auto& m_tuple = element.toTuple()->elements();
const std::string& function_name = m_tuple[0].toStringRef();
IValue codeTable = m_tuple[1];
auto schemaTable = // older files do not store function schema
(model_version >= 0x5L) ? at::optional<IValue>{m_tuple[2]} // NOLINT
: at::nullopt;
IValue table = m_tuple[1];
auto function = std::unique_ptr<mobile::Function>(
new mobile::Function(c10::QualifiedName(function_name)));
const auto& ins_list =
expect_field(codeTable, "instructions", BYTECODE_INDEX_INSTRUCTION)
expect_field(table, "instructions", BYTECODE_INDEX_INSTRUCTION)
.toTuple()
->elements();
const auto& ops_list =
expect_field(codeTable, "operators", BYTECODE_INDEX_OPERATOR)
expect_field(table, "operators", BYTECODE_INDEX_OPERATOR)
.toTuple()
->elements();
const auto& consts_list =
expect_field(codeTable, "constants", BYTECODE_INDEX_CONSTANT)
expect_field(table, "constants", BYTECODE_INDEX_CONSTANT)
.toTuple()
->elements();
const auto& types_list =
expect_field(codeTable, "types", BYTECODE_INDEX_TYPE)
.toTuple()
->elements();
expect_field(table, "types", BYTECODE_INDEX_TYPE).toTuple()->elements();
const auto& register_size =
expect_field(codeTable, "register_size", BYTECODE_INDEX_REGISTER_SIZE)
expect_field(table, "register_size", BYTECODE_INDEX_REGISTER_SIZE)
.toInt();
std::vector<IValue> module_debug_info_list;
@ -291,51 +218,37 @@ void BytecodeDeserializer::parseMethods(
function->set_register_size(register_size);
// function schema
if (schemaTable) { // (schema is optional for back compat)
auto parseArgList = [this](const std::vector<IValue>& argTables) {
std::vector<c10::Argument> args;
for (auto&& argTable : argTables) {
auto name =
expect_field(argTable, "name", BYTECODE_INDEX_ARGUMENT_NAME)
.toStringRef();
const auto& type = resolveTypeName(
(expect_field(argTable, "type", BYTECODE_INDEX_ARGUMENT_TYPE))
.toStringRef());
auto default_value = expect_field(
argTable,
"default_value",
BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE)
.toIValue();
auto arg =
c10::Argument(name, type, c10::nullopt /*N*/, default_value);
args.emplace_back(std::move(arg));
}
return args;
};
const auto& arg_list =
expect_field(
*schemaTable, "arguments", BYTECODE_INDEX_SCHEMA_ARGUMENTS)
.toTuple()
->elements();
const auto& ret_list =
expect_field(*schemaTable, "returns", BYTECODE_INDEX_SCHEMA_RETURNS)
.toTuple()
->elements();
c10::FunctionSchema schema(
function_name,
"" /*overload_name*/,
parseArgList(arg_list),
parseArgList(ret_list),
false /*is_varargs*/,
false /*is_varret*/);
function->setSchema(std::move(schema));
}
mcu.register_function(std::move(function));
}
}
// The deserializer class which loads the bytecode package from bc files.
class BytecodeDeserializer final {
public:
explicit BytecodeDeserializer(std::unique_ptr<PyTorchStreamReader> reader);
mobile::Module deserialize(
c10::optional<at::Device> device,
ExtraFilesMap& extra_files);
std::unordered_map<std::string, std::string> deserializeMetadata(
c10::optional<at::Device> device);
private:
c10::IValue readArchive(
const std::string& archive_name,
std::shared_ptr<mobile::CompilationUnit> mcu);
std::unordered_map<std::string, std::string> readMobileMetadata(
std::shared_ptr<mobile::CompilationUnit> mcu);
std::shared_ptr<CompilationUnit> compilation_unit_;
std::unordered_set<std::string> imported_libs_;
std::unique_ptr<PyTorchStreamReader> reader_;
c10::optional<at::Device> device_;
};
BytecodeDeserializer::BytecodeDeserializer(
std::unique_ptr<PyTorchStreamReader> reader)
: compilation_unit_(std::make_shared<CompilationUnit>()),
reader_(std::move(reader)) {}
std::unordered_map<std::string, std::string> BytecodeDeserializer::
deserializeMetadata(c10::optional<at::Device> device) {
device_ = device;
@ -407,8 +320,23 @@ c10::IValue BytecodeDeserializer::readArchive(
return len;
};
auto type_resolver = [this](const c10::QualifiedName& qn) {
return c10::StrongTypePtr(compilation_unit_, resolveTypeName(qn));
static const c10::QualifiedName torchPrefix = "__torch__";
auto type_resolver = [&](const c10::QualifiedName& qn) {
TypePtr type;
// HACK: first we check whether the name starts with `__torch__` to tell if
// it's "supposed" to be a class type. This is a reliable check today, but
// there is no guarantee that this is the case. The real solution is to
// merge type parsers so we can share class resolution logic.
if (torchPrefix.isPrefixOf(qn)) {
if (compilation_unit_->get_class(qn) == nullptr) {
auto typeptr = ClassType::create(qn, compilation_unit_, true);
compilation_unit_->register_type(typeptr);
}
type = compilation_unit_->get_class(qn);
} else {
type = c10::parseType(qn.qualifiedName());
}
return c10::StrongTypePtr(compilation_unit_, type);
};
auto obj_loader = [&](at::StrongTypePtr type, IValue input) {

View File

@ -10,12 +10,12 @@ class Module;
struct TORCH_API Method {
Method(const Module* owner, Function* function);
void run(Stack& stack) const;
void run(Stack&& stack) const {
void run(Stack& stack);
void run(Stack&& stack) {
run(stack);
}
c10::IValue operator()(std::vector<c10::IValue> stack) const;
c10::IValue operator()(std::vector<c10::IValue> stack);
const std::string& name() const {
return function_->name();

View File

@ -119,7 +119,7 @@ bool Module::is_training() const {
Method::Method(const Module* owner, Function* function)
: owner_(owner), function_(function) {}
void Method::run(Stack& stack) const {
void Method::run(Stack& stack) {
auto observer = torch::observerConfig().getModuleObserver();
auto instance_key = std::rand();
/* if the metadata dict doesn't contain "model_name", copy the metadata and
@ -141,7 +141,7 @@ void Method::run(Stack& stack) const {
at::DebugInfoGuard guard(at::DebugInfoKind::MOBILE_RUNTIME_INFO, debug_info);
try {
stack.insert(stack.begin(), owner_->_ivalue()); // self
stack.insert(stack.begin(), owner_->_ivalue());
function_->run(stack);
if (observer) {
observer->onExitRunMethod(instance_key);
@ -172,7 +172,7 @@ void Method::run(Stack& stack) const {
}
}
c10::IValue Method::operator()(std::vector<IValue> stack) const {
c10::IValue Method::operator()(std::vector<IValue> stack) {
run(stack);
TORCH_INTERNAL_ASSERT(!stack.empty());
return stack.front();

View File

@ -1,4 +1,5 @@
#pragma once
//#include <ATen/core/function_schema.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/method.h>

View File

@ -222,47 +222,12 @@ std::pair<IValue, c10::optional<IValue>> getFunctionTuple(
// register size
auto register_size = static_cast<int>(code.register_size());
auto codeTable = Table({{"instructions", Tup(instructions)},
{"operators", Tup(operators)},
{"constants", Tup(constants)},
{"types", Tup(types)},
{"register_size", register_size}});
// schema
const auto& schema = func.getSchema();
TORCH_CHECK(
schema.overload_name().empty(), // @TODO: is this check correct?
"Overloads are not supported in mobile modules.");
TORCH_CHECK(
!schema.is_vararg(), "Python *args are not supported in mobile modules.");
TORCH_CHECK(
!schema.is_varret(),
"A variable number of return values is not supported in mobile modules.");
auto makeArgTuple = [](const std::vector<Argument>& args) {
std::vector<IValue> argTables;
for (auto&& arg : args) {
TORCH_CHECK(
!arg.N(),
"Arguments with known list lengths are not supported in mobile modules.");
TORCH_CHECK(
!arg.kwarg_only(),
"Keyword-only arguments are not supported in mobile modules.");
argTables.emplace_back(Table({
{"name", arg.name()},
{"type", arg.type()->annotation_str()},
{"default_value", arg.default_value()},
}));
}
return Tup(argTables);
};
auto schemaTable = Table({
{"arguments", makeArgTuple(schema.arguments())},
{"returns", makeArgTuple(schema.returns())},
});
// function tuple
auto bytecode_vals =
Tup({func.qualname().qualifiedName(), codeTable, schemaTable});
auto table = Table({{"instructions", Tup(instructions)},
{"operators", Tup(operators)},
{"constants", Tup(constants)},
{"types", Tup(types)},
{"register_size", register_size}});
auto bytecode_vals = Tup({func.qualname().qualifiedName(), table});
c10::optional<IValue> debug_info_vals;
if (save_mobile_debug_info) {
@ -313,7 +278,7 @@ void setstateTuple(
void moduleMethodsTuple(
const Module& module,
std::vector<c10::IValue>& elements, // note: appended to in-place
std::vector<c10::IValue>& elements,
c10::optional<std::vector<c10::IValue>>& debug_info_elements,
bool save_mobile_debug_info) {
auto methods = module.get_methods();

View File

@ -1,5 +1,4 @@
#pragma once
#include <cstddef>
namespace torch {
namespace jit {
@ -8,14 +7,6 @@ constexpr size_t BYTECODE_INDEX_OPERATOR = 1;
constexpr size_t BYTECODE_INDEX_CONSTANT = 2;
constexpr size_t BYTECODE_INDEX_TYPE = 3;
constexpr size_t BYTECODE_INDEX_REGISTER_SIZE = 4;
constexpr size_t BYTECODE_INDEX_SCHEMA_ARGUMENTS = 0;
constexpr size_t BYTECODE_INDEX_SCHEMA_RETURNS = 1;
constexpr size_t BYTECODE_INDEX_ARGUMENT_NAME = 0;
constexpr size_t BYTECODE_INDEX_ARGUMENT_TYPE = 1;
constexpr size_t BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE = 2;
constexpr size_t BYTECODE_INDEX_MODULE_DEBUG_INFO = 0;
} // namespace jit
} // namespace torch