mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D25152559: T66557700 Support default argument values of a method
Test Plan: revert-hammer
Differential Revision:
D25152559 (6bde0ca6d3)
Original commit changeset: bbf52f1fbdbf
fbshipit-source-id: 592fdb3078b1ac86cd394adc6c1bfd6b10d829e1
This commit is contained in:
parent
0d411c4216
commit
2b61e4d84c
|
|
@ -102,7 +102,7 @@ struct BuiltinOpFunction : public Function {
|
|||
|
||||
std::string pretty_print_schema() const override {
|
||||
TORCH_INTERNAL_ASSERT(false);
|
||||
return ""; // TODO: suppress unreachable code warning
|
||||
return "";
|
||||
}
|
||||
|
||||
Function& setSchema(c10::FunctionSchema schema) override {
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/alias_info.h>
|
||||
#include <ATen/core/dispatch/OperatorOptions.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/interned_strings.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/alias_info.h>
|
||||
#include <ATen/core/operator_name.h>
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <ATen/core/dispatch/OperatorOptions.h>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace c10 {
|
||||
|
|
@ -33,7 +33,8 @@ struct Argument {
|
|||
N_(std::move(N)),
|
||||
default_value_(std::move(default_value)),
|
||||
kwarg_only_(kwarg_only),
|
||||
alias_info_(std::move(alias_info)) {}
|
||||
alias_info_(std::move(alias_info)) {
|
||||
}
|
||||
const std::string& name() const {
|
||||
return name_;
|
||||
}
|
||||
|
|
@ -84,8 +85,7 @@ struct Argument {
|
|||
}
|
||||
|
||||
Argument cloneWithType(TypePtr new_type) const {
|
||||
return Argument(
|
||||
name_, std::move(new_type), N_, default_value_, kwarg_only_, alias_info_);
|
||||
return Argument(name_, new_type, N_, default_value_, kwarg_only_, alias_info_);
|
||||
}
|
||||
|
||||
// this function check whether this Argument is backward compatible with
|
||||
|
|
@ -95,9 +95,9 @@ struct Argument {
|
|||
// 3) this arg must provide the same default value if old arg has one,
|
||||
bool isBackwardCompatibleWith(
|
||||
const Argument& old,
|
||||
std::ostream* why_not = nullptr) const;
|
||||
std::ostream* why_not=nullptr) const;
|
||||
|
||||
private:
|
||||
private:
|
||||
std::string name_;
|
||||
TypePtr type_;
|
||||
// for list types, an optional statically known length for the list
|
||||
|
|
@ -113,10 +113,12 @@ struct Argument {
|
|||
};
|
||||
|
||||
inline bool operator==(const Argument& lhs, const Argument& rhs) {
|
||||
return lhs.name() == rhs.name() && *lhs.type() == *rhs.type() &&
|
||||
lhs.N() == rhs.N() && lhs.default_value() == rhs.default_value() &&
|
||||
lhs.kwarg_only() == rhs.kwarg_only() &&
|
||||
lhs.alias_info() == rhs.alias_info();
|
||||
return lhs.name() == rhs.name()
|
||||
&& *lhs.type() == *rhs.type()
|
||||
&& lhs.N() == rhs.N()
|
||||
&& lhs.default_value() == rhs.default_value()
|
||||
&& lhs.kwarg_only() == rhs.kwarg_only()
|
||||
&& lhs.alias_info() == rhs.alias_info();
|
||||
}
|
||||
|
||||
bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs);
|
||||
|
|
@ -198,10 +200,7 @@ struct FunctionSchema {
|
|||
// this should always be set no matter what
|
||||
c10::optional<AliasAnalysisKind> alias_kind_;
|
||||
|
||||
void checkArg(
|
||||
const IValue& value,
|
||||
const Argument& argument,
|
||||
optional<size_t> pos) const;
|
||||
void checkArg(const IValue& value, const Argument& argument, optional<size_t> pos) const;
|
||||
|
||||
void checkSchema() const {
|
||||
bool seen_default_arg = false;
|
||||
|
|
@ -224,7 +223,8 @@ struct FunctionSchema {
|
|||
}
|
||||
}
|
||||
|
||||
public:
|
||||
public:
|
||||
|
||||
void dump() const;
|
||||
|
||||
const OperatorName& operator_name() const {
|
||||
|
|
@ -257,22 +257,21 @@ struct FunctionSchema {
|
|||
}
|
||||
|
||||
c10::optional<int> argumentIndexWithName(const std::string& name) const {
|
||||
for (size_t i = 0; i < arguments().size(); ++i) {
|
||||
if (name == arguments()[i].name()) {
|
||||
for(size_t i = 0; i < arguments().size(); ++i) {
|
||||
if(name == arguments()[i].name())
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
FunctionSchema cloneWithName(std::string name, std::string overload_name)
|
||||
const {
|
||||
FunctionSchema cloneWithName(std::string name, std::string overload_name) const {
|
||||
return FunctionSchema(
|
||||
std::move(name),
|
||||
std::move(overload_name),
|
||||
arguments(),
|
||||
returns(),
|
||||
is_vararg(),
|
||||
is_varret());
|
||||
std::move(name),
|
||||
std::move(overload_name),
|
||||
arguments(),
|
||||
returns(),
|
||||
is_vararg(),
|
||||
is_varret()
|
||||
);
|
||||
}
|
||||
FunctionSchema cloneWithArguments(std::vector<Argument> new_arguments) const {
|
||||
return FunctionSchema(
|
||||
|
|
@ -306,8 +305,7 @@ struct FunctionSchema {
|
|||
// values.
|
||||
void checkAndNormalizeInputs(
|
||||
std::vector<IValue>& inputs,
|
||||
const std::unordered_map<std::string, IValue>& kwargs =
|
||||
std::unordered_map<std::string, IValue>{}) const;
|
||||
const std::unordered_map<std::string, IValue>& kwargs) const;
|
||||
|
||||
std::string findErrorInKwargs(const std::vector<std::string>& kwargs) const;
|
||||
|
||||
|
|
@ -325,6 +323,7 @@ struct FunctionSchema {
|
|||
return false;
|
||||
}
|
||||
|
||||
|
||||
// TODO remove the mutation here
|
||||
bool isDefaultAliasAnalysisKind() const {
|
||||
return !alias_kind_;
|
||||
|
|
@ -350,17 +349,16 @@ struct FunctionSchema {
|
|||
// schema and have the program typecheck?
|
||||
// as_method - if true, treat this schema as a method and ignore
|
||||
// the first argument, which will be the object in both cases
|
||||
bool isSubtypeOf(
|
||||
const FunctionSchema& rhs,
|
||||
bool as_method,
|
||||
std::ostream* why_not = nullptr) const;
|
||||
bool isSubtypeOf(const FunctionSchema& rhs, bool as_method, std::ostream* why_not=nullptr) const;
|
||||
};
|
||||
|
||||
inline bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs) {
|
||||
return lhs.name() == rhs.name() &&
|
||||
lhs.overload_name() == rhs.overload_name() &&
|
||||
lhs.arguments() == rhs.arguments() && lhs.returns() == rhs.returns() &&
|
||||
lhs.is_vararg() == rhs.is_vararg() && lhs.is_varret() == rhs.is_varret();
|
||||
return lhs.name() == rhs.name()
|
||||
&& lhs.overload_name() == rhs.overload_name()
|
||||
&& lhs.arguments() == rhs.arguments()
|
||||
&& lhs.returns() == rhs.returns()
|
||||
&& lhs.is_vararg() == rhs.is_vararg()
|
||||
&& lhs.is_varret() == rhs.is_varret();
|
||||
}
|
||||
|
||||
inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) {
|
||||
|
|
@ -370,14 +368,14 @@ inline bool operator!=(const FunctionSchema& lhs, const FunctionSchema& rhs) {
|
|||
// print out Argument, which is compatible with FunctionSchema parser
|
||||
// full format: Type(alias)? name=default_value
|
||||
inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
|
||||
|
||||
// for adjusting the ? position.
|
||||
// in schema, we have Tensor?(a!) input, and t(a!)?.
|
||||
// however, t?(a!) doesn't work with schema parser.
|
||||
// so we always use Type(alias)? format
|
||||
auto type = arg.type();
|
||||
bool is_opt = type->kind() == OptionalType::Kind;
|
||||
auto unopt_type =
|
||||
is_opt ? type->cast<OptionalType>()->getElementType() : type;
|
||||
auto unopt_type = is_opt ? type->cast<OptionalType>()->getElementType() : type;
|
||||
|
||||
if (unopt_type->kind() == ListType::Kind && arg.N()) {
|
||||
// sized lists get size N from arg, not type
|
||||
|
|
@ -411,9 +409,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
|
|||
return out;
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const FunctionSchema& schema);
|
||||
inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema);
|
||||
|
||||
inline std::string toString(const FunctionSchema& schema) {
|
||||
std::ostringstream str;
|
||||
|
|
|
|||
|
|
@ -133,20 +133,14 @@ constexpr uint64_t kMaxSupportedFileFormatVersion = 0x5L;
|
|||
// when given bool or integer fill values.
|
||||
constexpr uint64_t kProducedFileFormatVersion = 0x3L;
|
||||
|
||||
// The version we write when the archive contains bytecode.
|
||||
// the version we write when the archive contains bytecode.
|
||||
// It must be higher or eq to kProducedFileFormatVersion.
|
||||
// Because torchscript changes is likely introduce bytecode change.
|
||||
// If kProducedFileFormatVersion is increased, kProducedBytecodeVersion
|
||||
// should be increased too. The relationship is:
|
||||
// kMaxSupportedFileFormatVersion >= (most likely ==) kProducedBytecodeVersion
|
||||
// >= kProducedFileFormatVersion
|
||||
// Versions:
|
||||
// 0x1L: Initial version
|
||||
// 0x2L: (Comment missing)
|
||||
// 0x3L: (Comment missing)
|
||||
// 0x4L: (Comment missing)
|
||||
// 0x5L: Added schema to function tuple
|
||||
constexpr uint64_t kProducedBytecodeVersion = 0x5L;
|
||||
constexpr uint64_t kProducedBytecodeVersion = 0x4L;
|
||||
|
||||
static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion,
|
||||
"kProducedBytecodeVersion must be higher or equal to kProducedFileFormatVersion.");
|
||||
|
|
|
|||
|
|
@ -65,49 +65,38 @@ TEST(LiteInterpreterTest, CheckAttrAccess) {
|
|||
AT_ASSERT(!mobile_optimized);
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, MethodInvocation) { // NOLINT (use =delete in gtest)
|
||||
const std::vector<std::string> test_programs{
|
||||
// test invoking a method with default parameter
|
||||
R"(
|
||||
def test_func(self, x, b : int = 4):
|
||||
return self.foo + x + b
|
||||
)",
|
||||
// inner method call with default parameter (gets inlined)
|
||||
R"(
|
||||
def add_with_default_arg(self, x, b : int = 4):
|
||||
return self.foo + x + b
|
||||
def test_func(self, x):
|
||||
return self.add_with_default_arg(x) # invoke method w/ default arg
|
||||
)",
|
||||
// simple method call
|
||||
R"(
|
||||
def test_func(self, x):
|
||||
b = 4
|
||||
return self.foo + x + b
|
||||
)",
|
||||
};
|
||||
for (const auto& test_program : test_programs) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
m.define(test_program);
|
||||
TEST(LiteInterpreterTest, Add) {
|
||||
Module m("m");
|
||||
m.register_parameter("foo", torch::ones({}), false);
|
||||
// TODO: support default param val, which was pushed in
|
||||
// function schema's checkAndNormalizeInputs()
|
||||
// m.define(R"(
|
||||
// def add_it(self, x, b : int = 4):
|
||||
// return self.foo + x + b
|
||||
// )");
|
||||
m.define(R"(
|
||||
def add_it(self, x):
|
||||
b = 4
|
||||
return self.foo + x + b
|
||||
)");
|
||||
|
||||
const int fortyTwo = 42; // (keep linter happy)
|
||||
auto minput = fortyTwo * torch::ones({});
|
||||
auto ref = m.run_method("test_func", minput);
|
||||
std::vector<IValue> inputs;
|
||||
auto minput = 5 * torch::ones({});
|
||||
inputs.emplace_back(minput);
|
||||
auto ref = m.run_method("add_it", minput);
|
||||
|
||||
std::stringstream ss;
|
||||
m._save_for_mobile(ss);
|
||||
mobile::Module bc = _load_for_mobile(ss);
|
||||
const auto& test_func = bc.get_method("test_func");
|
||||
IValue res;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
res = test_func({minput});
|
||||
}
|
||||
|
||||
auto resd = res.toTensor().item<float>();
|
||||
auto refd = ref.toTensor().item<float>();
|
||||
AT_ASSERT(resd == refd);
|
||||
std::stringstream ss;
|
||||
m._save_for_mobile(ss);
|
||||
mobile::Module bc = _load_for_mobile(ss);
|
||||
IValue res;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
auto bcinputs = inputs;
|
||||
res = bc.get_method("add_it")(bcinputs);
|
||||
}
|
||||
|
||||
auto resd = res.toTensor().item<float>();
|
||||
auto refd = ref.toTensor().item<float>();
|
||||
AT_ASSERT(resd == refd);
|
||||
}
|
||||
|
||||
TEST(LiteInterpreterTest, Conv) {
|
||||
|
|
|
|||
|
|
@ -190,37 +190,6 @@ class TestLiteScriptModule(unittest.TestCase):
|
|||
mobile_module_result = mobile_module.forward(*bundled_inputs[0])
|
||||
torch.testing.assert_allclose(script_module_result, mobile_module_result)
|
||||
|
||||
def test_method_calls_with_optional_arg(self):
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(A, self).__init__()
|
||||
|
||||
def forward(self, x, two: int = 2): # opt arg in script-to-script invocation
|
||||
return x + two
|
||||
|
||||
class B(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(B, self).__init__()
|
||||
self.A0 = A()
|
||||
|
||||
def forward(self, x, one: int = 1): # opt arg in Python-to-script invocation
|
||||
return self.A0(x) + one
|
||||
|
||||
script_module = torch.jit.script(B())
|
||||
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
|
||||
mobile_module = _load_for_lite_interpreter(buffer)
|
||||
|
||||
input = torch.tensor([5])
|
||||
script_module_forward_result = script_module.forward(input)
|
||||
mobile_module_forward_result = mobile_module.forward(input)
|
||||
torch.testing.assert_allclose(script_module_forward_result, mobile_module_forward_result)
|
||||
|
||||
script_module_forward_result = script_module.forward(input, 2) # change ref only
|
||||
self.assertFalse((script_module_forward_result == mobile_module_forward_result).all().item())
|
||||
|
||||
mobile_module_forward_result = mobile_module.forward(input, 2) # now both match again
|
||||
torch.testing.assert_allclose(script_module_forward_result, mobile_module_forward_result)
|
||||
|
||||
def test_unsupported_createobject(self):
|
||||
class Foo():
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -107,13 +107,13 @@ Module Method::owner() const {
|
|||
return Module(owner_);
|
||||
}
|
||||
void Method::run(Stack& stack) {
|
||||
stack.insert(stack.begin(), owner()._ivalue()); // self
|
||||
stack.insert(stack.begin(), owner()._ivalue());
|
||||
RECORD_TORCHSCRIPT_FUNCTION(name(), stack);
|
||||
function_->run(stack);
|
||||
}
|
||||
|
||||
IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
|
||||
stack.insert(stack.begin(), owner()._ivalue()); // self
|
||||
stack.insert(stack.begin(), owner()._ivalue());
|
||||
RECORD_TORCHSCRIPT_FUNCTION(name(), stack);
|
||||
return (*function_)(std::move(stack), kwargs);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,12 +58,12 @@ bool Function::append_operator(
|
|||
}
|
||||
|
||||
if (model_version == 0x3L &&
|
||||
model_version < caffe2::serialize::kProducedBytecodeVersion &&
|
||||
opname == c10::OperatorName("aten::_convolution", "")) {
|
||||
// Since byte-code versions 0x4L, convolution has an additional
|
||||
// default-value argument (allow_tf32=True, see
|
||||
// https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
|
||||
// backward compatibility with models of byte-code version <= 0x3L, where
|
||||
// this bool argument does not yet exist.
|
||||
// A default-value argument will be added in
|
||||
// https://github.com/pytorch/pytorch/pull/40737. This wrapper is used to
|
||||
// handle backward compatibility, where there is no default bool value in
|
||||
// old models.
|
||||
fn = [fn](Stack& stack) {
|
||||
stack.push_back(true);
|
||||
fn(stack);
|
||||
|
|
@ -107,26 +107,14 @@ std::string Function::get_module_debug_info(size_t pc) const {
|
|||
return pc_to_module_debug_info_[pc];
|
||||
}
|
||||
|
||||
void Function::setSchema(c10::FunctionSchema schema) {
|
||||
schema_ = std::move(schema);
|
||||
}
|
||||
|
||||
const at::optional<c10::FunctionSchema>& Function::getSchema() const {
|
||||
return schema_;
|
||||
}
|
||||
|
||||
bool Function::run(Stack& stack) const {
|
||||
const auto& schema = getSchema();
|
||||
if (schema) { // if we have a schema then resolve optional args if any
|
||||
schema->checkAndNormalizeInputs(
|
||||
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
|
||||
}
|
||||
InterpreterState interp_state(code_);
|
||||
return interp_state.run(stack);
|
||||
}
|
||||
|
||||
c10::IValue Function::operator()(Stack& stack) const {
|
||||
run(stack);
|
||||
c10::IValue Function::operator()(Stack& stack) {
|
||||
InterpreterState interp_state(code_);
|
||||
interp_state.run(stack);
|
||||
return stack.front();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
#pragma once
|
||||
#include <ATen/core/ivalue.h>
|
||||
//#include <aten/src/Aten/core/operator_name.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
|
|
@ -16,7 +15,7 @@ class Function {
|
|||
public:
|
||||
Function(c10::QualifiedName name);
|
||||
bool run(Stack& stack) const;
|
||||
c10::IValue operator()(Stack& stack) const;
|
||||
c10::IValue operator()(Stack& stack);
|
||||
const std::string& name() const;
|
||||
const c10::QualifiedName& qualname() const;
|
||||
void append_instruction(OpCode op, int X, int N);
|
||||
|
|
@ -33,13 +32,9 @@ class Function {
|
|||
|
||||
std::string get_module_debug_info(size_t pc) const;
|
||||
|
||||
void setSchema(c10::FunctionSchema schema);
|
||||
const at::optional<c10::FunctionSchema>& getSchema() const;
|
||||
|
||||
private:
|
||||
c10::QualifiedName name_;
|
||||
std::shared_ptr<Code> code_;
|
||||
at::optional<c10::FunctionSchema> schema_; // (byte-code version 5+ only)
|
||||
std::vector<std::string> pc_to_module_debug_info_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -15,47 +15,29 @@
|
|||
|
||||
// The import process to serialize the bytecode package.
|
||||
// An example for bytecode.pkl of a small mobile_module looks like:
|
||||
// (5, # model version number (caffe2::serialize::kProducedBytecodeVersion)
|
||||
// # first method
|
||||
// (
|
||||
// # function name
|
||||
// '__torch__.m.forward',
|
||||
// # code
|
||||
// (('instructions',
|
||||
// (('STOREN', 1, 2),
|
||||
// ('DROPR', 1, 0),
|
||||
// ('MOVE', 2, 0),
|
||||
// ('OP', 0, 0),
|
||||
// ('RET', 0, 0))),
|
||||
// ('operators', (('aten::Int', 'Tensor'),)),
|
||||
// ('constants', ()),
|
||||
// ('types', ()),
|
||||
// ('register_size', 2)),
|
||||
// # schema
|
||||
// (('arguments',
|
||||
// ((('name', 'x'), ('type', 'Tensor'), ('default_value', 13)),
|
||||
// ...)), # more args follow here
|
||||
// ('returns',
|
||||
// ((('name', ''), ('type', 'Tensor'), ('default_value', None)),
|
||||
// ...)), # more return values follow here
|
||||
// )),
|
||||
// # more methods follow here
|
||||
// ...)
|
||||
// (3,
|
||||
// ('__torch__.m.forward',
|
||||
// (('instructions',
|
||||
// (('STOREN', 1, 2),
|
||||
// ('DROPR', 1, 0),
|
||||
// ('MOVE', 2, 0),
|
||||
// ('OP', 0, 0),
|
||||
// ('RET', 0, 0))),
|
||||
// ('operators', (('aten::Int', 'Tensor'),)),
|
||||
// ('constants', ()),
|
||||
// ('types', ()),
|
||||
// ('register_size', 2))))
|
||||
|
||||
// In addition, the module debugging information can be saved
|
||||
// in mobile_debug.pkl. An example for it looks like:
|
||||
// (5,
|
||||
// ('__torch__.m.forward',
|
||||
// (('module_debug_info', (top(A).foo(B).forward)))))
|
||||
// (3,
|
||||
// ('__torch__.m.forward',
|
||||
// (('module_debug_info', (top(A).foo(B).forward)))))
|
||||
|
||||
// Note that currently the backward compatibility is not supported by bytecode.
|
||||
// This format and process need to be revisited and redesigned if we want to
|
||||
// This format and process need to be revisted and redesigned if we want to
|
||||
// support backward compatibility in future.
|
||||
|
||||
// Note that the following function-schema fields are not supported:
|
||||
// - Argument::{known_length_,kwarg_only_}
|
||||
// - FunctionSchema::{overload_name_, is_vararg_, is_varret_}
|
||||
|
||||
namespace c10 {
|
||||
// std::string serializeType(const Type &t);
|
||||
TypePtr parseType(const std::string& pythonStr);
|
||||
|
|
@ -108,57 +90,7 @@ void print_unsupported_ops_and_throw(
|
|||
error_message);
|
||||
}
|
||||
|
||||
// The deserializer class which loads the bytecode package from bc files.
|
||||
class BytecodeDeserializer final {
|
||||
public:
|
||||
explicit BytecodeDeserializer(std::unique_ptr<PyTorchStreamReader> reader);
|
||||
mobile::Module deserialize(
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files);
|
||||
std::unordered_map<std::string, std::string> deserializeMetadata(
|
||||
c10::optional<at::Device> device);
|
||||
|
||||
private:
|
||||
TypePtr resolveTypeName(const c10::QualifiedName& qn);
|
||||
void parseMethods(
|
||||
const std::vector<IValue>& vals,
|
||||
const c10::optional<std::vector<IValue>>& debug_info_vals,
|
||||
mobile::CompilationUnit& mcu);
|
||||
c10::IValue readArchive(
|
||||
const std::string& archive_name,
|
||||
std::shared_ptr<mobile::CompilationUnit> mcu);
|
||||
std::unordered_map<std::string, std::string> readMobileMetadata(
|
||||
std::shared_ptr<mobile::CompilationUnit> mcu);
|
||||
std::shared_ptr<CompilationUnit> compilation_unit_;
|
||||
std::unordered_set<std::string> imported_libs_;
|
||||
std::unique_ptr<PyTorchStreamReader> reader_;
|
||||
c10::optional<at::Device> device_;
|
||||
};
|
||||
|
||||
BytecodeDeserializer::BytecodeDeserializer(
|
||||
std::unique_ptr<PyTorchStreamReader> reader)
|
||||
: compilation_unit_(std::make_shared<CompilationUnit>()),
|
||||
reader_(std::move(reader)) {}
|
||||
|
||||
TypePtr BytecodeDeserializer::resolveTypeName(const c10::QualifiedName& qn) {
|
||||
static const c10::QualifiedName torchPrefix = "__torch__";
|
||||
// HACK: first we check whether the name starts with `__torch__` to
|
||||
// tell if it's "supposed" to be a class type. This is a reliable
|
||||
// check today, but there is no guarantee that this is the case. The
|
||||
// real solution is to merge type parsers so we can share class
|
||||
// resolution logic.
|
||||
if (torchPrefix.isPrefixOf(qn)) {
|
||||
if (compilation_unit_->get_class(qn) == nullptr) {
|
||||
auto typeptr = ClassType::create(qn, compilation_unit_, true);
|
||||
compilation_unit_->register_type(typeptr);
|
||||
}
|
||||
return compilation_unit_->get_class(qn);
|
||||
} else {
|
||||
return c10::parseType(qn.qualifiedName());
|
||||
}
|
||||
}
|
||||
|
||||
void BytecodeDeserializer::parseMethods(
|
||||
void parseMethods(
|
||||
const std::vector<IValue>& vals,
|
||||
const c10::optional<std::vector<IValue>>& debug_info_vals,
|
||||
mobile::CompilationUnit& mcu) {
|
||||
|
|
@ -194,32 +126,27 @@ void BytecodeDeserializer::parseMethods(
|
|||
const auto& element = vals[i];
|
||||
const auto& m_tuple = element.toTuple()->elements();
|
||||
const std::string& function_name = m_tuple[0].toStringRef();
|
||||
IValue codeTable = m_tuple[1];
|
||||
auto schemaTable = // older files do not store function schema
|
||||
(model_version >= 0x5L) ? at::optional<IValue>{m_tuple[2]} // NOLINT
|
||||
: at::nullopt;
|
||||
IValue table = m_tuple[1];
|
||||
|
||||
auto function = std::unique_ptr<mobile::Function>(
|
||||
new mobile::Function(c10::QualifiedName(function_name)));
|
||||
|
||||
const auto& ins_list =
|
||||
expect_field(codeTable, "instructions", BYTECODE_INDEX_INSTRUCTION)
|
||||
expect_field(table, "instructions", BYTECODE_INDEX_INSTRUCTION)
|
||||
.toTuple()
|
||||
->elements();
|
||||
const auto& ops_list =
|
||||
expect_field(codeTable, "operators", BYTECODE_INDEX_OPERATOR)
|
||||
expect_field(table, "operators", BYTECODE_INDEX_OPERATOR)
|
||||
.toTuple()
|
||||
->elements();
|
||||
const auto& consts_list =
|
||||
expect_field(codeTable, "constants", BYTECODE_INDEX_CONSTANT)
|
||||
expect_field(table, "constants", BYTECODE_INDEX_CONSTANT)
|
||||
.toTuple()
|
||||
->elements();
|
||||
const auto& types_list =
|
||||
expect_field(codeTable, "types", BYTECODE_INDEX_TYPE)
|
||||
.toTuple()
|
||||
->elements();
|
||||
expect_field(table, "types", BYTECODE_INDEX_TYPE).toTuple()->elements();
|
||||
const auto& register_size =
|
||||
expect_field(codeTable, "register_size", BYTECODE_INDEX_REGISTER_SIZE)
|
||||
expect_field(table, "register_size", BYTECODE_INDEX_REGISTER_SIZE)
|
||||
.toInt();
|
||||
|
||||
std::vector<IValue> module_debug_info_list;
|
||||
|
|
@ -291,51 +218,37 @@ void BytecodeDeserializer::parseMethods(
|
|||
|
||||
function->set_register_size(register_size);
|
||||
|
||||
// function schema
|
||||
if (schemaTable) { // (schema is optional for back compat)
|
||||
auto parseArgList = [this](const std::vector<IValue>& argTables) {
|
||||
std::vector<c10::Argument> args;
|
||||
for (auto&& argTable : argTables) {
|
||||
auto name =
|
||||
expect_field(argTable, "name", BYTECODE_INDEX_ARGUMENT_NAME)
|
||||
.toStringRef();
|
||||
const auto& type = resolveTypeName(
|
||||
(expect_field(argTable, "type", BYTECODE_INDEX_ARGUMENT_TYPE))
|
||||
.toStringRef());
|
||||
auto default_value = expect_field(
|
||||
argTable,
|
||||
"default_value",
|
||||
BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE)
|
||||
.toIValue();
|
||||
auto arg =
|
||||
c10::Argument(name, type, c10::nullopt /*N*/, default_value);
|
||||
args.emplace_back(std::move(arg));
|
||||
}
|
||||
return args;
|
||||
};
|
||||
const auto& arg_list =
|
||||
expect_field(
|
||||
*schemaTable, "arguments", BYTECODE_INDEX_SCHEMA_ARGUMENTS)
|
||||
.toTuple()
|
||||
->elements();
|
||||
const auto& ret_list =
|
||||
expect_field(*schemaTable, "returns", BYTECODE_INDEX_SCHEMA_RETURNS)
|
||||
.toTuple()
|
||||
->elements();
|
||||
c10::FunctionSchema schema(
|
||||
function_name,
|
||||
"" /*overload_name*/,
|
||||
parseArgList(arg_list),
|
||||
parseArgList(ret_list),
|
||||
false /*is_varargs*/,
|
||||
false /*is_varret*/);
|
||||
function->setSchema(std::move(schema));
|
||||
}
|
||||
|
||||
mcu.register_function(std::move(function));
|
||||
}
|
||||
}
|
||||
|
||||
// The deserializer class which loads the bytecode package from bc files.
|
||||
class BytecodeDeserializer final {
|
||||
public:
|
||||
explicit BytecodeDeserializer(std::unique_ptr<PyTorchStreamReader> reader);
|
||||
mobile::Module deserialize(
|
||||
c10::optional<at::Device> device,
|
||||
ExtraFilesMap& extra_files);
|
||||
std::unordered_map<std::string, std::string> deserializeMetadata(
|
||||
c10::optional<at::Device> device);
|
||||
|
||||
private:
|
||||
c10::IValue readArchive(
|
||||
const std::string& archive_name,
|
||||
std::shared_ptr<mobile::CompilationUnit> mcu);
|
||||
std::unordered_map<std::string, std::string> readMobileMetadata(
|
||||
std::shared_ptr<mobile::CompilationUnit> mcu);
|
||||
std::shared_ptr<CompilationUnit> compilation_unit_;
|
||||
std::unordered_set<std::string> imported_libs_;
|
||||
std::unique_ptr<PyTorchStreamReader> reader_;
|
||||
c10::optional<at::Device> device_;
|
||||
};
|
||||
|
||||
BytecodeDeserializer::BytecodeDeserializer(
|
||||
std::unique_ptr<PyTorchStreamReader> reader)
|
||||
: compilation_unit_(std::make_shared<CompilationUnit>()),
|
||||
reader_(std::move(reader)) {}
|
||||
|
||||
std::unordered_map<std::string, std::string> BytecodeDeserializer::
|
||||
deserializeMetadata(c10::optional<at::Device> device) {
|
||||
device_ = device;
|
||||
|
|
@ -407,8 +320,23 @@ c10::IValue BytecodeDeserializer::readArchive(
|
|||
return len;
|
||||
};
|
||||
|
||||
auto type_resolver = [this](const c10::QualifiedName& qn) {
|
||||
return c10::StrongTypePtr(compilation_unit_, resolveTypeName(qn));
|
||||
static const c10::QualifiedName torchPrefix = "__torch__";
|
||||
auto type_resolver = [&](const c10::QualifiedName& qn) {
|
||||
TypePtr type;
|
||||
// HACK: first we check whether the name starts with `__torch__` to tell if
|
||||
// it's "supposed" to be a class type. This is a reliable check today, but
|
||||
// there is no guarantee that this is the case. The real solution is to
|
||||
// merge type parsers so we can share class resolution logic.
|
||||
if (torchPrefix.isPrefixOf(qn)) {
|
||||
if (compilation_unit_->get_class(qn) == nullptr) {
|
||||
auto typeptr = ClassType::create(qn, compilation_unit_, true);
|
||||
compilation_unit_->register_type(typeptr);
|
||||
}
|
||||
type = compilation_unit_->get_class(qn);
|
||||
} else {
|
||||
type = c10::parseType(qn.qualifiedName());
|
||||
}
|
||||
return c10::StrongTypePtr(compilation_unit_, type);
|
||||
};
|
||||
|
||||
auto obj_loader = [&](at::StrongTypePtr type, IValue input) {
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ class Module;
|
|||
struct TORCH_API Method {
|
||||
Method(const Module* owner, Function* function);
|
||||
|
||||
void run(Stack& stack) const;
|
||||
void run(Stack&& stack) const {
|
||||
void run(Stack& stack);
|
||||
void run(Stack&& stack) {
|
||||
run(stack);
|
||||
}
|
||||
|
||||
c10::IValue operator()(std::vector<c10::IValue> stack) const;
|
||||
c10::IValue operator()(std::vector<c10::IValue> stack);
|
||||
|
||||
const std::string& name() const {
|
||||
return function_->name();
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ bool Module::is_training() const {
|
|||
Method::Method(const Module* owner, Function* function)
|
||||
: owner_(owner), function_(function) {}
|
||||
|
||||
void Method::run(Stack& stack) const {
|
||||
void Method::run(Stack& stack) {
|
||||
auto observer = torch::observerConfig().getModuleObserver();
|
||||
auto instance_key = std::rand();
|
||||
/* if the metadata dict doesn't contain "model_name", copy the metadata and
|
||||
|
|
@ -141,7 +141,7 @@ void Method::run(Stack& stack) const {
|
|||
at::DebugInfoGuard guard(at::DebugInfoKind::MOBILE_RUNTIME_INFO, debug_info);
|
||||
|
||||
try {
|
||||
stack.insert(stack.begin(), owner_->_ivalue()); // self
|
||||
stack.insert(stack.begin(), owner_->_ivalue());
|
||||
function_->run(stack);
|
||||
if (observer) {
|
||||
observer->onExitRunMethod(instance_key);
|
||||
|
|
@ -172,7 +172,7 @@ void Method::run(Stack& stack) const {
|
|||
}
|
||||
}
|
||||
|
||||
c10::IValue Method::operator()(std::vector<IValue> stack) const {
|
||||
c10::IValue Method::operator()(std::vector<IValue> stack) {
|
||||
run(stack);
|
||||
TORCH_INTERNAL_ASSERT(!stack.empty());
|
||||
return stack.front();
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#pragma once
|
||||
//#include <ATen/core/function_schema.h>
|
||||
#include <torch/csrc/jit/mobile/function.h>
|
||||
#include <torch/csrc/jit/mobile/method.h>
|
||||
|
||||
|
|
|
|||
|
|
@ -222,47 +222,12 @@ std::pair<IValue, c10::optional<IValue>> getFunctionTuple(
|
|||
// register size
|
||||
auto register_size = static_cast<int>(code.register_size());
|
||||
|
||||
auto codeTable = Table({{"instructions", Tup(instructions)},
|
||||
{"operators", Tup(operators)},
|
||||
{"constants", Tup(constants)},
|
||||
{"types", Tup(types)},
|
||||
{"register_size", register_size}});
|
||||
|
||||
// schema
|
||||
const auto& schema = func.getSchema();
|
||||
TORCH_CHECK(
|
||||
schema.overload_name().empty(), // @TODO: is this check correct?
|
||||
"Overloads are not supported in mobile modules.");
|
||||
TORCH_CHECK(
|
||||
!schema.is_vararg(), "Python *args are not supported in mobile modules.");
|
||||
TORCH_CHECK(
|
||||
!schema.is_varret(),
|
||||
"A variable number of return values is not supported in mobile modules.");
|
||||
auto makeArgTuple = [](const std::vector<Argument>& args) {
|
||||
std::vector<IValue> argTables;
|
||||
for (auto&& arg : args) {
|
||||
TORCH_CHECK(
|
||||
!arg.N(),
|
||||
"Arguments with known list lengths are not supported in mobile modules.");
|
||||
TORCH_CHECK(
|
||||
!arg.kwarg_only(),
|
||||
"Keyword-only arguments are not supported in mobile modules.");
|
||||
argTables.emplace_back(Table({
|
||||
{"name", arg.name()},
|
||||
{"type", arg.type()->annotation_str()},
|
||||
{"default_value", arg.default_value()},
|
||||
}));
|
||||
}
|
||||
return Tup(argTables);
|
||||
};
|
||||
auto schemaTable = Table({
|
||||
{"arguments", makeArgTuple(schema.arguments())},
|
||||
{"returns", makeArgTuple(schema.returns())},
|
||||
});
|
||||
|
||||
// function tuple
|
||||
auto bytecode_vals =
|
||||
Tup({func.qualname().qualifiedName(), codeTable, schemaTable});
|
||||
auto table = Table({{"instructions", Tup(instructions)},
|
||||
{"operators", Tup(operators)},
|
||||
{"constants", Tup(constants)},
|
||||
{"types", Tup(types)},
|
||||
{"register_size", register_size}});
|
||||
auto bytecode_vals = Tup({func.qualname().qualifiedName(), table});
|
||||
|
||||
c10::optional<IValue> debug_info_vals;
|
||||
if (save_mobile_debug_info) {
|
||||
|
|
@ -313,7 +278,7 @@ void setstateTuple(
|
|||
|
||||
void moduleMethodsTuple(
|
||||
const Module& module,
|
||||
std::vector<c10::IValue>& elements, // note: appended to in-place
|
||||
std::vector<c10::IValue>& elements,
|
||||
c10::optional<std::vector<c10::IValue>>& debug_info_elements,
|
||||
bool save_mobile_debug_info) {
|
||||
auto methods = module.get_methods();
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
#pragma once
|
||||
#include <cstddef>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
|
@ -8,14 +7,6 @@ constexpr size_t BYTECODE_INDEX_OPERATOR = 1;
|
|||
constexpr size_t BYTECODE_INDEX_CONSTANT = 2;
|
||||
constexpr size_t BYTECODE_INDEX_TYPE = 3;
|
||||
constexpr size_t BYTECODE_INDEX_REGISTER_SIZE = 4;
|
||||
|
||||
constexpr size_t BYTECODE_INDEX_SCHEMA_ARGUMENTS = 0;
|
||||
constexpr size_t BYTECODE_INDEX_SCHEMA_RETURNS = 1;
|
||||
|
||||
constexpr size_t BYTECODE_INDEX_ARGUMENT_NAME = 0;
|
||||
constexpr size_t BYTECODE_INDEX_ARGUMENT_TYPE = 1;
|
||||
constexpr size_t BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE = 2;
|
||||
|
||||
constexpr size_t BYTECODE_INDEX_MODULE_DEBUG_INFO = 0;
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user