pytorch/torch/csrc/jit/import_source.cpp
Zachary DeVito e2ccccee9a Load tensors directly from pickle archive
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23281

Test Plan: Imported from OSS

Differential Revision: D16452815

Pulled By: zdevito

fbshipit-source-id: 918eef3ad444b598ab655c39037e4baafdcb51e1
2019-08-22 11:48:09 -07:00

467 lines
16 KiB
C++

#include "import_source.h"
#include <ATen/core/qualified_name.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/script/parser.h>
#include <torch/csrc/jit/script/resolver.h>
#include <torch/csrc/jit/script/script_type_parser.h>
namespace torch {
namespace jit {
namespace script {
struct OpsValue : public SugaredValue {
OpsValue(size_t version) : version_(version) {}
std::string kind() const override {
return "ops";
}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& field) override {
return std::make_shared<BuiltinModule>(field, version_);
}
size_t version_;
};
struct ConstantValue : public SugaredValue {
ConstantValue(IValue value) : value_(std::move(value)) {}
IValue value_;
std::string kind() const override {
return "constant";
}
Value* asValue(const SourceRange& loc, Function& m) override {
return m.graph()->insertConstant(value_);
}
};
// Represents nested namespaces, like `foo.bar.Baz`.
// Right now these namespaces can only contain other namespaces or NamedTypes
struct TORCH_API ClassNamespaceValue : public SugaredValue {
/**
* @param name The fully qualified path, which can resolve either to a
* namespace or a NamedType
* @param cu The compilation unit to search for classes in
*/
explicit ClassNamespaceValue(
c10::QualifiedName name,
const CompilationUnit& cu)
: basename_(std::move(name)), cu_(cu) {}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& name) override {
auto fullName = c10::QualifiedName(basename_, name);
// Could be a ClassType or NamedTuple constructor
if (auto serializable_type = cu_.get_type(fullName)) {
if (auto classType = serializable_type->cast<ClassType>()) {
return std::make_shared<ClassValue>(classType);
} else if (auto tupleType = serializable_type->cast<TupleType>()) {
return std::make_shared<NamedTupleConstructor>(tupleType);
}
}
// Or it could be a free function
if (auto fn = cu_.find_function(fullName)) {
return std::make_shared<FunctionValue>(fn);
}
// If it's none of those things, assume it's another namespace
return std::make_shared<ClassNamespaceValue>(std::move(fullName), cu_);
}
std::string kind() const override {
return "Class Namespace";
}
private:
c10::QualifiedName basename_;
const CompilationUnit& cu_;
};
// This value maps attributes CONSTANTS.c0 CONSTANTS.c1 to entries
// in the 'constants' vector. This table is will be stored in a container format
// and given to the import_method when restoring the code.
struct ConstantTableValue : public SugaredValue {
ConstantTableValue(ArrayRef<at::Tensor> constants) : constants_(constants) {}
std::string kind() const override {
return "CONSTANTS";
}
// select an attribute on it, e.g. `this.field`
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& field) override {
const char* field_s = field.c_str();
char* end;
int64_t offset = std::strtoll(field_s + 1, &end, 10);
if (field.size() < 2 || *end != 0)
throw ErrorReport(loc) << "invalid constant specifier: " << field;
if (offset < 0 || size_t(offset) >= constants_.size()) {
throw ErrorReport(loc) << "constant index " << offset
<< " is out of bounds (constant table has "
<< constants_.size() << " entries)";
}
Value* value =
m.graph()->insertConstant(constants_.at(offset), nullptr, loc);
// specializing tensor type on compilation messes up typing relations
value->setType(unshapedType(value->type()));
return std::make_shared<SimpleValue>(value);
}
private:
ArrayRef<at::Tensor> constants_;
};
// A resolver that doesn't rely on Python, and understands references to model
// constants.
struct SourceResolver : public Resolver {
explicit SourceResolver(
std::shared_ptr<CompilationUnit> cu,
size_t version,
const std::vector<at::Tensor>& tensor_table)
: cu_(std::move(cu)) {
env_ = {
{"torch", std::make_shared<BuiltinModule>("aten", version)},
{"ops", std::make_shared<OpsValue>(version)},
// Constants present in the model. Used to resolve "CONSTANTS.n" to the
// actual value
{"CONSTANTS", std::make_shared<ConstantTableValue>(tensor_table)},
{"fork", std::make_shared<ForkValue>()},
{"annotate", std::make_shared<AnnotateValue>()},
{"uninitialized", std::make_shared<UninitializedValue>()},
{"inf",
std::make_shared<ConstantValue>(
std::numeric_limits<double>::infinity())},
{"nan",
std::make_shared<ConstantValue>(
std::numeric_limits<double>::quiet_NaN())},
};
}
std::shared_ptr<SugaredValue> resolveValue(
const std::string& name,
Function& m,
const SourceRange& loc) const override {
auto it = env_.find(name);
if (it != env_.end()) {
return it->second;
}
if (name == "__torch__") {
return std::make_shared<ClassNamespaceValue>(
c10::QualifiedName(name), *cu_);
}
return nullptr;
}
TypePtr resolveType(const std::string& name, const SourceRange& loc) const override {
return cu_->get_type(c10::QualifiedName(name));
}
private:
// Compilation unit to look classes up in
std::shared_ptr<CompilationUnit> cu_;
std::unordered_map<std::string, std::shared_ptr<SugaredValue>> env_;
};
struct SourceImporter {
SourceImporter(
const std::shared_ptr<CompilationUnit> cu,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& tensor_table,
const std::function<void(const std::string&)>& import_callback)
: p_(src),
cu_(cu),
import_callback_(import_callback),
tensor_table_(tensor_table) {
version_ = parseVersionNumber();
resolver_ = std::make_shared<SourceResolver>(cu_, version_, tensor_table_);
}
void import(const std::string& qualifier) {
checkVersionNumber();
auto& L = p_.lexer();
while (L.cur().kind != TK_EOF) {
parseImportsAndDoCallback();
auto tk = L.cur();
auto kind = tk.kind;
switch (kind) {
case TK_CLASS_DEF: {
auto parsed_treeref = p_.parseClass();
importClass(qualifier, ClassDef(parsed_treeref));
} break;
case TK_DEF: {
auto parsed_treeref = p_.parseFunction(/*is_method=*/false);
importFunction(qualifier, Def(parsed_treeref));
} break;
default:
throw ErrorReport(L.cur().range)
<< "Unexpected token in code import: " << kindToString(kind);
}
}
}
void LEGACY_importFunctions(
const c10::optional<c10::QualifiedName>& prefix,
const Self* self) {
checkVersionNumber();
parseImportsAndDoCallback();
std::vector<Def> definitions;
std::vector<ResolverPtr> resolvers;
while (p_.lexer().cur().kind != TK_EOF) {
auto def = Def(p_.parseFunction(/*is_method=*/bool(self)));
definitions.emplace_back(def);
resolvers.emplace_back(resolver_);
}
cu_->define(prefix, definitions, resolvers, self);
}
private:
void importFunction(const std::string& qualifier, const Def& def) {
std::vector<Def> definitions{def};
std::vector<ResolverPtr> resolvers{resolver_};
cu_->define(qualifier, definitions, resolvers, nullptr);
}
void importClass(const std::string& qualifier, const ClassDef& class_def) {
bool is_module = false;
if (class_def.superclass().present()) {
const auto& superclass_name =
Var(class_def.superclass().get()).name().name();
if (superclass_name == "Module") {
is_module = true;
} else if (superclass_name == "NamedTuple") {
// NamedTuples have special rules (since they are TupleTypes and not ClassTypes)
return importNamedTuple(qualifier, class_def);
} else {
throw ErrorReport(class_def.range())
<< "Torchscript does not support class inheritance.";
}
}
const auto qualified_classname =
QualifiedName(QualifiedName(qualifier), class_def.name().name());
auto class_type = ClassType::create(
c10::QualifiedName(qualified_classname), cu_, is_module);
std::vector<Def> methods;
std::vector<ResolverPtr> resolvers;
std::vector<Assign> attributes;
// Module-specific: which attrs are parameters?
std::unordered_set<std::string> parameter_names;
// Process statements, splitting things into attribute and method
// definitions.
for (const auto& statement : class_def.body()) {
switch (statement.kind()) {
case TK_ASSIGN: {
const auto assign = Assign(statement);
switch (assign.lhs().kind()) {
case TK_VAR: {
const auto name = Var(assign.lhs()).name().name();
if (name == "__parameters__") {
// Populate the module parameter list. This is a field that
// looks like:
// __parameters__ = ["foo", "bar", "baz"]
// which tells us which attributes are module parameters.
TORCH_INTERNAL_ASSERT(
is_module,
"Assignments in class body only "
"supported on modules right now");
const auto param_list =
ListLiteral(assign.rhs().get()).inputs();
for (const auto& param : param_list) {
parameter_names.insert(StringLiteral(param).text());
}
} else if (name == "__annotations__") {
// This is to initialize the annotations dict, just ignore.
continue;
} else {
// This is a regular attribute assignment, of the form:
// foo : Tensor
if (assign.rhs().present()) {
throw ErrorReport(assign.rhs())
<< "Unexpected right-hand found in assignment in class body. "
"This is not yet supported.";
}
attributes.push_back(assign);
}
} break;
case TK_SUBSCRIPT: {
// This is a special attribute assignment where the attribute
// is not a valid python, identifier. Looks like:
// __annotations__["0"] = Tensor
const auto lhs = Subscript(assign.lhs());
TORCH_INTERNAL_ASSERT(
Var(lhs.value()).name().name() == "__annotations__");
TORCH_INTERNAL_ASSERT(lhs.subscript_exprs().size() == 1);
attributes.push_back(assign);
} break;
default: {
TORCH_INTERNAL_ASSERT(
false,
"Unexpected statement kind in module metadata: ",
kindToString(statement.kind()));
}
}
} break;
case TK_DEF: {
methods.emplace_back(Def(statement));
resolvers.push_back(resolver_);
} break;
default: {
TORCH_INTERNAL_ASSERT(
false,
"Unexpected statement kind in class body: ",
kindToString(statement.kind()));
}
}
}
// Populate class attributes
ScriptTypeParser type_parser(resolver_);
for (const auto& assign : attributes) {
switch (assign.lhs().kind()) {
case TK_VAR: {
const auto name = Var(assign.lhs()).name().name();
TORCH_INTERNAL_ASSERT(name != "__parameters__");
const auto type = type_parser.parseTypeFromExpr(assign.type().get());
const bool is_parameter = parameter_names.count(name);
class_type->addAttribute(name, type, is_parameter);
} break;
case TK_SUBSCRIPT: {
const auto name =
StringLiteral(Subscript(assign.lhs()).subscript_exprs()[0])
.text();
const auto type = type_parser.parseTypeFromExpr(assign.rhs().get());
const bool is_parameter = parameter_names.count(name);
class_type->addAttribute(name, type, is_parameter);
}
}
}
cu_->register_type(class_type);
const auto self = SimpleSelf(class_type);
cu_->define(qualified_classname, methods, resolvers, &self);
}
void importNamedTuple(
const std::string& qualifier,
const ClassDef& named_tuple_def) {
auto qualified_name =
c10::QualifiedName(qualifier + "." + named_tuple_def.name().name());
ScriptTypeParser type_parser(resolver_);
std::vector<std::string> field_names;
std::vector<TypePtr> field_types;
for (const auto& statement : named_tuple_def.body()) {
if (statement.kind() != TK_ASSIGN) {
throw ErrorReport(statement.range())
<< "Unexpected statement in NamedTuple body: "
"only attribute annotations are currently supported.";
}
const auto assign = Assign(statement);
auto name = Var(assign.lhs()).name().name();
field_names.emplace_back(std::move(name));
auto type = type_parser.parseTypeFromExpr(assign.type().get());
field_types.emplace_back(std::move(type));
}
auto tt = TupleType::create(
field_types,
qualified_name,
TupleType::namedTupleSchemaFromNamesAndTypes(
qualified_name, field_names, field_types));
cu_->register_type(tt);
}
void checkVersionNumber() {
// note: this cannot be called in the constructor because it may throw
if (version_ > CURRENT_OP_VERSION_SET) {
throw ErrorReport(p_.lexer().cur().range)
<< "Attempting to load a script generated from a newer version of "
<< "PyTorch. Maximum supported TorchScript version is "
<< CURRENT_OP_VERSION_SET
<< " but the script being loaded is version " << version_;
}
}
size_t parseVersionNumber() {
auto& L = p_.lexer();
auto range = L.cur().range;
auto name = L.expect(TK_IDENT).text();
L.expect('=');
std::string version_text = L.expect(TK_NUMBER).text();
L.expect(TK_NEWLINE);
auto version = Const::create(L.cur().range, version_text);
if (name != "op_version_set")
throw ErrorReport(range) << "expected an assignment to op_version_set";
if (!version.isIntegral())
throw ErrorReport(range)
<< "expected an integral version but found " << version.text();
return size_t(version.asIntegral());
}
void parseImportsAndDoCallback() {
// Gather all imports
auto& L = p_.lexer();
std::vector<std::string> imports;
while (L.nextIf(TK_IMPORT)) {
std::ostringstream s;
while (L.cur().kind != TK_NEWLINE) {
s << L.cur().text();
L.next();
}
L.expect(TK_NEWLINE);
const auto str = s.str();
AT_ASSERT(!str.empty());
imports.push_back(str);
}
// Call theregister_typectually compile them
for (const auto& import : imports) {
if (import_callback_) {
import_callback_(import);
}
}
}
Parser p_;
size_t version_;
std::shared_ptr<CompilationUnit> cu_;
const std::function<void(const std::string&)>& import_callback_;
const std::vector<at::Tensor>& tensor_table_;
std::shared_ptr<SourceResolver> resolver_;
};
void LEGACY_import_methods(
const Module& mod,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& constant_table,
const std::function<void(const std::string&)>& import_callback) {
SourceImporter importer(
mod.class_compilation_unit(), src, constant_table, import_callback);
auto self = SimpleSelf(mod.type());
importer.LEGACY_importFunctions(mod.name(), &self);
}
void import_libs(
std::shared_ptr<CompilationUnit> cu,
const std::string& qualifier,
const std::shared_ptr<Source>& src,
const std::vector<at::Tensor>& tensor_table,
const std::function<void(const std::string&)>& import_callback) {
SourceImporter importer(std::move(cu), src, tensor_table, import_callback);
importer.import(qualifier);
}
} // namespace script
} // namespace jit
} // namespace torch