mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Split ConcreteModuleType into two types (#29824)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29824 We have two distinct phases/uses for ConcreteModuleType: 1. We are building it up and using it to check whether we can reuse JIT types. (RawConcreteModuleType) 2. We are using it to satisfy ModuleValue::attr queries. (ConcreteModuleType) These types share an underlying `ConcreteModuleTypeData` which actually stores the relevant info. Previously they were the same type because I was lazy, but it's been the source of a bug. So split them to formalize the differing invariants for the two phases. Test Plan: Imported from OSS Differential Revision: D18575010 Pulled By: suo fbshipit-source-id: 3e4ebcd36e78b947150d8f0dbb74ecccad23e7c4
This commit is contained in:
parent
7495c25440
commit
63e66fd267
|
|
@ -3,16 +3,7 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
ClassTypePtr ConcreteModuleType::getJitType() const {
|
||||
TORCH_INTERNAL_ASSERT(jitType_);
|
||||
return jitType_;
|
||||
}
|
||||
|
||||
ClassTypePtr ConcreteModuleType::createNewTypeFromThis() {
|
||||
TORCH_INTERNAL_ASSERT(!jitType_);
|
||||
TORCH_INTERNAL_ASSERT(pyClass_);
|
||||
|
||||
ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
|
||||
auto cu = get_python_cu();
|
||||
py::object pyQualName = py::module::import("torch._jit_internal")
|
||||
.attr("_qualified_name")(pyClass_);
|
||||
|
|
@ -41,21 +32,85 @@ ClassTypePtr ConcreteModuleType::createNewTypeFromThis() {
|
|||
moduleInfo.name_, moduleInfo.getJitType(), /*is_parameter=*/false);
|
||||
}
|
||||
|
||||
jitType_ = std::move(cls);
|
||||
return cls;
|
||||
}
|
||||
|
||||
ConcreteModuleType::ConcreteModuleType(ConcreteModuleTypeBuilder data)
|
||||
: data_(std::move(data)) {
|
||||
jitType_ = data_.createTypeFromThis();
|
||||
}
|
||||
|
||||
TypePtr ConcreteModuleTypeBuilder::ModuleInfo::getJitType() const {
|
||||
return meta_ == nullptr ? type_ : meta_->getJitType();
|
||||
}
|
||||
|
||||
bool operator==(
|
||||
const ConcreteModuleTypeBuilder::ModuleInfo& lhs,
|
||||
const ConcreteModuleTypeBuilder::ModuleInfo& rhs) {
|
||||
if (lhs.meta_ != nullptr && rhs.meta_ != nullptr) {
|
||||
return lhs.meta_->equals(*rhs.meta_);
|
||||
} else if (lhs.type_ != nullptr && rhs.type_ != nullptr) {
|
||||
return *(lhs.type_) == *(rhs.type_);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool ConcreteModuleTypeBuilder::equals(
|
||||
const ConcreteModuleTypeBuilder& other) const {
|
||||
if (isPoisoned_ || other.isPoisoned_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
// These are vaguely ordered so that cheap, discriminating checks happen first.
|
||||
bool equal =
|
||||
pyClass_.is(other.pyClass_) &&
|
||||
iterableModuleKind_ == other.iterableModuleKind_ &&
|
||||
constants_ == other.constants_ &&
|
||||
attributes_ == other.attributes_ &&
|
||||
overloads_ == other.overloads_ &&
|
||||
functionAttributes_ == other.functionAttributes_;
|
||||
// clang-format on
|
||||
if (!equal) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// We store modules in order of insertion (to make compilation
|
||||
// deterministic). However, for the purposes of equality, insertion order
|
||||
// should not matter, so sort them by name.
|
||||
// We put this check last because it involves the most work.
|
||||
auto thisSorted = modules_;
|
||||
std::sort(
|
||||
thisSorted.begin(),
|
||||
thisSorted.end(),
|
||||
[](const ModuleInfo& a, const ModuleInfo& b) {
|
||||
return a.name_ < b.name_;
|
||||
});
|
||||
|
||||
auto otherSorted = other.modules_;
|
||||
std::sort(
|
||||
otherSorted.begin(),
|
||||
otherSorted.end(),
|
||||
[](const ModuleInfo& a, const ModuleInfo& b) {
|
||||
return a.name_ < b.name_;
|
||||
});
|
||||
|
||||
return thisSorted == otherSorted;
|
||||
}
|
||||
|
||||
ClassTypePtr ConcreteModuleType::getJitType() const {
|
||||
return jitType_;
|
||||
}
|
||||
|
||||
py::object ConcreteModuleType::getPyClass() const {
|
||||
TORCH_INTERNAL_ASSERT(jitType_);
|
||||
TORCH_INTERNAL_ASSERT(pyClass_);
|
||||
return pyClass_;
|
||||
return data_.pyClass_;
|
||||
}
|
||||
|
||||
c10::optional<std::vector<std::string>> ConcreteModuleType::findOverloads(
|
||||
const std::string& name) const {
|
||||
TORCH_INTERNAL_ASSERT(jitType_);
|
||||
const auto it = overloads_.find(name);
|
||||
if (it != overloads_.end()) {
|
||||
const auto it = data_.overloads_.find(name);
|
||||
if (it != data_.overloads_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return c10::nullopt;
|
||||
|
|
@ -63,9 +118,8 @@ c10::optional<std::vector<std::string>> ConcreteModuleType::findOverloads(
|
|||
|
||||
c10::optional<Function*> ConcreteModuleType::findFunctionAttribute(
|
||||
const std::string& name) const {
|
||||
TORCH_INTERNAL_ASSERT(jitType_);
|
||||
const auto it = functionAttributes_.find(name);
|
||||
if (it != functionAttributes_.end()) {
|
||||
const auto it = data_.functionAttributes_.find(name);
|
||||
if (it != data_.functionAttributes_.end()) {
|
||||
return it->second.function_->function();
|
||||
}
|
||||
return c10::nullopt;
|
||||
|
|
@ -73,9 +127,8 @@ c10::optional<Function*> ConcreteModuleType::findFunctionAttribute(
|
|||
|
||||
c10::optional<std::string> ConcreteModuleType::findFailedAttribute(
|
||||
const std::string& name) const {
|
||||
TORCH_INTERNAL_ASSERT(jitType_);
|
||||
const auto it = failedAttributes_.find(name);
|
||||
if (it != failedAttributes_.end()) {
|
||||
const auto it = data_.failedAttributes_.find(name);
|
||||
if (it != data_.failedAttributes_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return c10::nullopt;
|
||||
|
|
@ -83,130 +136,114 @@ c10::optional<std::string> ConcreteModuleType::findFailedAttribute(
|
|||
|
||||
std::shared_ptr<ConcreteModuleType> ConcreteModuleType::
|
||||
findSubmoduleConcreteType(const std::string& name) const {
|
||||
TORCH_INTERNAL_ASSERT(jitType_);
|
||||
const auto it = std::find_if(
|
||||
modules_.cbegin(), modules_.cend(), [&](const ModuleInfo& info) {
|
||||
data_.modules_.cbegin(),
|
||||
data_.modules_.cend(),
|
||||
[&](const ConcreteModuleTypeBuilder::ModuleInfo& info) {
|
||||
return info.name_ == name;
|
||||
});
|
||||
if (it == modules_.end()) {
|
||||
if (it == data_.modules_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->meta_;
|
||||
}
|
||||
|
||||
void ConcreteModuleType::setIterableModuleKind(IterableModuleKind kind) {
|
||||
TORCH_INTERNAL_ASSERT(!jitType_);
|
||||
void ConcreteModuleTypeBuilder::setIterableModuleKind(IterableModuleKind kind) {
|
||||
iterableModuleKind_ = kind;
|
||||
}
|
||||
|
||||
IterableModuleKind ConcreteModuleType::getIterableModuleKind() const {
|
||||
TORCH_INTERNAL_ASSERT(jitType_);
|
||||
return iterableModuleKind_;
|
||||
return data_.iterableModuleKind_;
|
||||
}
|
||||
|
||||
void ConcreteModuleType::setPoisoned() {
|
||||
TORCH_INTERNAL_ASSERT(!jitType_)
|
||||
void ConcreteModuleTypeBuilder::setPoisoned() {
|
||||
isPoisoned_ = true;
|
||||
}
|
||||
|
||||
void ConcreteModuleType::addJitType(ClassTypePtr type) {
|
||||
TORCH_INTERNAL_ASSERT(!jitType_)
|
||||
jitType_ = std::move(type);
|
||||
}
|
||||
|
||||
void ConcreteModuleType::addPyClass(py::object pyClass) {
|
||||
TORCH_INTERNAL_ASSERT(!jitType_);
|
||||
pyClass_ = std::move(pyClass);
|
||||
}
|
||||
|
||||
void ConcreteModuleType::addConstant(std::string name, py::object value) {
|
||||
TORCH_INTERNAL_ASSERT(!jitType_);
|
||||
void ConcreteModuleTypeBuilder::addConstant(std::string name, py::object value) {
|
||||
constants_.emplace(std::move(name), std::move(value));
|
||||
}
|
||||
|
||||
void ConcreteModuleType::addAttribute(
|
||||
void ConcreteModuleTypeBuilder::addAttribute(
|
||||
std::string name,
|
||||
TypePtr type,
|
||||
bool isParameter) {
|
||||
TORCH_INTERNAL_ASSERT(type);
|
||||
TORCH_INTERNAL_ASSERT(!jitType_);
|
||||
// Function attributes should be handled separately
|
||||
TORCH_INTERNAL_ASSERT(type->cast<FunctionType>() == nullptr);
|
||||
attributes_.emplace(
|
||||
std::move(name), Attribute(unshapedType(type), isParameter));
|
||||
std::move(name),
|
||||
ConcreteModuleTypeBuilder::Attribute(unshapedType(type), isParameter));
|
||||
}
|
||||
|
||||
void ConcreteModuleType::addFunctionAttribute(
|
||||
void ConcreteModuleTypeBuilder::addFunctionAttribute(
|
||||
std::string name,
|
||||
const TypePtr& type,
|
||||
py::object pyFunction) {
|
||||
TORCH_INTERNAL_ASSERT(type);
|
||||
TORCH_INTERNAL_ASSERT(!jitType_);
|
||||
functionAttributes_.emplace(
|
||||
std::move(name),
|
||||
FunctionAttribute{type->expect<FunctionType>(), std::move(pyFunction)});
|
||||
ConcreteModuleTypeBuilder::FunctionAttribute{type->expect<FunctionType>(),
|
||||
std::move(pyFunction)});
|
||||
}
|
||||
|
||||
void ConcreteModuleType::addModule(
|
||||
void ConcreteModuleTypeBuilder::addModule(
|
||||
std::string name,
|
||||
std::shared_ptr<ConcreteModuleType> meta) {
|
||||
TORCH_INTERNAL_ASSERT(!jitType_);
|
||||
modules_.emplace_back(ModuleInfo{std::move(name), std::move(meta)});
|
||||
modules_.emplace_back(
|
||||
ConcreteModuleTypeBuilder::ModuleInfo{std::move(name), std::move(meta)});
|
||||
}
|
||||
|
||||
void ConcreteModuleType::addModuleInterface(
|
||||
void ConcreteModuleTypeBuilder::addModuleInterface(
|
||||
std::string name,
|
||||
const TypePtr& type) {
|
||||
TORCH_INTERNAL_ASSERT(!jitType_);
|
||||
TORCH_INTERNAL_ASSERT(type->cast<InterfaceType>() && type->is_module());
|
||||
modules_.emplace_back(ModuleInfo{std::move(name), type});
|
||||
modules_.emplace_back(
|
||||
ConcreteModuleTypeBuilder::ModuleInfo{std::move(name), type});
|
||||
}
|
||||
|
||||
|
||||
void ConcreteModuleType::addOverload(
|
||||
void ConcreteModuleTypeBuilder::addOverload(
|
||||
std::string methodName,
|
||||
std::vector<std::string> overloadedMethodNames) {
|
||||
TORCH_INTERNAL_ASSERT(!jitType_);
|
||||
overloads_.emplace(std::move(methodName), std::move(overloadedMethodNames));
|
||||
}
|
||||
|
||||
void ConcreteModuleType::addFailedAttribute(
|
||||
void ConcreteModuleTypeBuilder::addFailedAttribute(
|
||||
std::string name,
|
||||
std::string failureReason) {
|
||||
TORCH_INTERNAL_ASSERT(!jitType_);
|
||||
failedAttributes_.emplace(std::move(name), std::move(failureReason));
|
||||
}
|
||||
|
||||
c10::optional<py::object> ConcreteModuleType::findConstant(
|
||||
const std::string& name) const {
|
||||
auto it = constants_.find(name);
|
||||
if (it != constants_.end()) {
|
||||
auto it = data_.constants_.find(name);
|
||||
if (it != data_.constants_.end()) {
|
||||
return it->second.v_;
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
void ConcreteModuleType::dump() const {
|
||||
std::cout << "ConcreteModuleType for: " << py::getattr(pyClass_, "__name__") << "\n";
|
||||
std::cout << "ConcreteModuleType for: " << py::getattr(data_.pyClass_, "__name__") << "\n";
|
||||
std::cout << "Constants: \n";
|
||||
for (const auto& pr : constants_) {
|
||||
for (const auto& pr : data_.constants_) {
|
||||
std::cout << "\t" << pr.first << ": " << pr.second.v_ << "\n";
|
||||
}
|
||||
std::cout << "\nAttributes: \n";
|
||||
for (const auto& pr : attributes_) {
|
||||
for (const auto& pr : data_.attributes_) {
|
||||
std::cout << "\t" << pr.first << ": " << pr.second.type_->python_str()
|
||||
<< "\n";
|
||||
}
|
||||
std::cout << "\nSubmodules: \n";
|
||||
for (const auto& info : modules_) {
|
||||
for (const auto& info : data_.modules_) {
|
||||
std::cout << "\t" << info.name_ << ": "
|
||||
<< info.getJitType()->python_str() << "\n";
|
||||
}
|
||||
std::cout << "\nOverloads: \n";
|
||||
for (const auto& pr : overloads_) {
|
||||
for (const auto& pr : data_.overloads_) {
|
||||
std::cout << "\t" << pr.first << ": " << pr.second << "\n";
|
||||
}
|
||||
std::string isPoisoned = isPoisoned_ ? "true" : "false";
|
||||
std::string isPoisoned = data_.isPoisoned_ ? "true" : "false";
|
||||
std::cout << "isPoisoned: " << isPoisoned << "\n";
|
||||
if (jitType_) {
|
||||
std::cout << "jit type: " << jitType_->python_str() << "\n";
|
||||
|
|
@ -215,11 +252,10 @@ void ConcreteModuleType::dump() const {
|
|||
|
||||
std::unordered_map<std::string, py::object> ConcreteModuleType::getConstantsPy()
|
||||
const {
|
||||
TORCH_INTERNAL_ASSERT(jitType_);
|
||||
// Convert to a more pybind-friendly representation, so we don't
|
||||
// need to bind ConcreteModuleType::Constant as well.
|
||||
std::unordered_map<std::string, py::object> ret;
|
||||
for (const auto& pr : constants_) {
|
||||
for (const auto& pr : data_.constants_) {
|
||||
ret.emplace(pr.first, pr.second.v_);
|
||||
}
|
||||
return ret;
|
||||
|
|
@ -227,11 +263,10 @@ std::unordered_map<std::string, py::object> ConcreteModuleType::getConstantsPy()
|
|||
|
||||
std::unordered_map<std::string, std::pair<TypePtr, bool>> ConcreteModuleType::
|
||||
getAttributesPy() const {
|
||||
TORCH_INTERNAL_ASSERT(jitType_);
|
||||
// Convert to a more pybind-friendly representation, so we don't
|
||||
// need to bind ConcreteModuleType::Attribute as well.
|
||||
std::unordered_map<std::string, std::pair<TypePtr, bool>> ret;
|
||||
for (auto& pr : attributes_) {
|
||||
for (auto& pr : data_.attributes_) {
|
||||
ret.emplace(
|
||||
pr.first,
|
||||
std::pair<TypePtr, bool>(pr.second.type_, pr.second.isParam_));
|
||||
|
|
@ -241,10 +276,9 @@ std::unordered_map<std::string, std::pair<TypePtr, bool>> ConcreteModuleType::
|
|||
|
||||
std::vector<std::pair<std::string, TypePtr>> ConcreteModuleType::getModulesPy()
|
||||
const {
|
||||
TORCH_INTERNAL_ASSERT(jitType_);
|
||||
std::vector<std::pair<std::string, TypePtr>> ret;
|
||||
|
||||
for (const ModuleInfo& info: modules_) {
|
||||
for (const auto& info : data_.modules_) {
|
||||
ret.emplace_back(std::make_pair(info.name_, info.getJitType()));
|
||||
}
|
||||
return ret;
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
namespace script {
|
||||
|
||||
enum class IterableModuleKind { NONE, LIST, DICT };
|
||||
class ConcreteModuleType;
|
||||
|
||||
// You can think of an nn.Module as a template that corresponds to a family of
|
||||
// JIT types. The template "arguments" are things like the constant values.
|
||||
|
|
@ -41,21 +43,22 @@ enum class IterableModuleKind { NONE, LIST, DICT };
|
|||
// ModuleValue::attr calls. This is so we can guarantee that if two Module's
|
||||
// share a JIT type (and thus a ConcreteModuleType), then they behave the same
|
||||
// way when you access attributes on them.
|
||||
class VISIBILITY_HIDDEN ConcreteModuleType {
|
||||
public:
|
||||
// ConcreteModuleType has two states.
|
||||
// 1. Building: First we build it up, during the ScriptModule conversion
|
||||
// process
|
||||
// ... to transition, we freeze the type by associating with a JIT type.
|
||||
// 2. Querying: Then we ask use it as a source of truth during method
|
||||
// compilation.
|
||||
// During this time the ModuleType is effectively const.
|
||||
// Yes, it could be two different types. Not terribly worth the verbosity.
|
||||
|
||||
/**
|
||||
* Builder methods (jitType must be null)
|
||||
*/
|
||||
void addPyClass(py::object pyClass);
|
||||
// ConcreteModuleType has two phases.
|
||||
// 1. Creation: First we build it up, during the ScriptModule conversion
|
||||
// process. This is represented by ConcreteModuleTypeBuilder.
|
||||
// ...then the converter calls ConcreteModuleTypeBuilder::build(), producing a
|
||||
// ConcreteModuleType ready for querying.
|
||||
// 2. Querying: We use ConcreteModuleType as a source of truth for
|
||||
// ModuleValue::attr calls during method compilation.
|
||||
|
||||
// Represents a concrete type during in the process for construction. We use
|
||||
// this to decide whether we can share types between modules.
|
||||
class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
|
||||
public:
|
||||
explicit ConcreteModuleTypeBuilder(py::object pyClass) {
|
||||
pyClass_ = std::move(pyClass);
|
||||
}
|
||||
void addConstant(std::string name, py::object value);
|
||||
void addAttribute(std::string name, TypePtr type, bool isParameter);
|
||||
void addFunctionAttribute(
|
||||
|
|
@ -73,95 +76,20 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
|
|||
std::vector<std::string> overloadedMethodNames);
|
||||
void addFailedAttribute(std::string name, std::string failureReason);
|
||||
void setIterableModuleKind(IterableModuleKind kind);
|
||||
|
||||
// If a ConcreteModuleType is poisoned, it will never compare equal to any other
|
||||
// concrete type
|
||||
void setPoisoned();
|
||||
|
||||
/**
|
||||
* Freezing methods
|
||||
*/
|
||||
|
||||
// Based on the data in this ConcreteType, create an equivalent JIT type and
|
||||
// associate this module type with it.
|
||||
ClassTypePtr createNewTypeFromThis();
|
||||
// Associate the provided type with this ConcreteType
|
||||
void addJitType(ClassTypePtr type);
|
||||
|
||||
/**
|
||||
* Query methods (jitType must be non-null)
|
||||
*/
|
||||
ClassTypePtr getJitType() const;
|
||||
py::object getPyClass() const;
|
||||
IterableModuleKind getIterableModuleKind() const;
|
||||
c10::optional<py::object> findConstant(const std::string& name) const;
|
||||
c10::optional<std::vector<std::string>> findOverloads(
|
||||
const std::string& name) const;
|
||||
c10::optional<Function*> findFunctionAttribute(const std::string& name) const;
|
||||
std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType(
|
||||
const std::string& name) const;
|
||||
c10::optional<std::string> findFailedAttribute(const std::string& name) const;
|
||||
|
||||
// These getters are only here to return things as types that can be
|
||||
// automatically converted by pybind.
|
||||
std::unordered_map<std::string, py::object> getConstantsPy() const;
|
||||
std::unordered_map<std::string, std::pair<TypePtr, bool>> getAttributesPy()
|
||||
const;
|
||||
std::vector<std::pair<std::string, TypePtr>> getModulesPy() const;
|
||||
std::shared_ptr<ConcreteModuleType> build() const {
|
||||
return std::make_shared<ConcreteModuleType>(*this);
|
||||
}
|
||||
|
||||
// This determines whether two modules can share a type. The container structs
|
||||
// used by ConcreteModuleType have been defined such that operator==
|
||||
// implements a meaningful comparison in that context.
|
||||
friend bool operator==(
|
||||
const ConcreteModuleType& lhs,
|
||||
const ConcreteModuleType& rhs) {
|
||||
if (lhs.jitType_ == rhs.jitType_) {
|
||||
// If the computed types are the same, these modules can (obviously) share
|
||||
// a type.
|
||||
return true;
|
||||
}
|
||||
bool equals(const ConcreteModuleTypeBuilder& other) const;
|
||||
|
||||
if (lhs.isPoisoned_ || rhs.isPoisoned_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
// These are vaguely ordered so that cheap, discriminating checks happen first.
|
||||
bool equal =
|
||||
lhs.pyClass_.is(rhs.pyClass_) &&
|
||||
lhs.iterableModuleKind_ == rhs.iterableModuleKind_ &&
|
||||
lhs.constants_ == rhs.constants_ &&
|
||||
lhs.attributes_ == rhs.attributes_ &&
|
||||
lhs.overloads_ == rhs.overloads_ &&
|
||||
lhs.functionAttributes_ == rhs.functionAttributes_;
|
||||
// clang-format on
|
||||
if (!equal) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// We store modules in order of insertion (to make compilation
|
||||
// deterministic). However, for the purposes of equality, insertion order
|
||||
// should not matter, so sort them by name.
|
||||
// We put this check last because it involves the most work.
|
||||
auto lhsSorted = lhs.modules_;
|
||||
std::sort(
|
||||
lhsSorted.begin(),
|
||||
lhsSorted.end(),
|
||||
[](const ModuleInfo& a, const ModuleInfo& b) {
|
||||
return a.name_ < b.name_;
|
||||
});
|
||||
|
||||
auto rhsSorted = rhs.modules_;
|
||||
std::sort(
|
||||
rhsSorted.begin(),
|
||||
rhsSorted.end(),
|
||||
[](const ModuleInfo& a, const ModuleInfo& b) {
|
||||
return a.name_ < b.name_;
|
||||
});
|
||||
|
||||
return lhsSorted == rhsSorted;
|
||||
}
|
||||
|
||||
void dump() const;
|
||||
|
||||
private:
|
||||
struct Constant {
|
||||
/* implicit */ Constant(py::object v) : v_(std::move(v)) {}
|
||||
friend bool operator==(const Constant& lhs, const Constant& rhs) {
|
||||
|
|
@ -206,28 +134,18 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
|
|||
|
||||
ModuleInfo(std::string name, const TypePtr& type)
|
||||
: name_(std::move(name)), meta_(nullptr), type_(type) {}
|
||||
friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs) {
|
||||
if (lhs.meta_ != nullptr && rhs.meta_ != nullptr) {
|
||||
return *(lhs.meta_) == *(rhs.meta_);
|
||||
} else if (lhs.type_ != nullptr && rhs.type_ != nullptr) {
|
||||
return *(lhs.type_) == *(rhs.type_);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
TypePtr getJitType() const {
|
||||
return meta_ == nullptr? type_ : meta_->getJitType();
|
||||
}
|
||||
TypePtr getJitType() const;
|
||||
std::string name_;
|
||||
friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs);
|
||||
|
||||
// Module Info contains either an ConcreateModuleType or a type (which is
|
||||
// a Module Interface), these two are union relationship.
|
||||
std::shared_ptr<ConcreteModuleType> meta_;
|
||||
TypePtr type_;
|
||||
|
||||
};
|
||||
|
||||
private:
|
||||
ClassTypePtr createTypeFromThis() const;
|
||||
// If true, this type will never compare equally to anything else. This is
|
||||
// used if we want to ensure that this type is not shared (for example, if it
|
||||
// came from a traced module)
|
||||
|
|
@ -256,11 +174,54 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
|
|||
// The original `nn.Module` class that we derived this ScriptModule from.
|
||||
py::object pyClass_;
|
||||
|
||||
// The JIT type derived from this ConcreteModuleType.
|
||||
ClassTypePtr jitType_ = nullptr;
|
||||
// NOTE: If you ever add any more state to this struct, you need to make sure
|
||||
// operator== still makes sense! The only field that can be excluded from it
|
||||
// is `jitType_`.
|
||||
// operator== still makes sense!
|
||||
friend ConcreteModuleType;
|
||||
};
|
||||
|
||||
// Represents a finalized concrete type, used to service ModuleValue::attr calls
|
||||
// during method compilation.
|
||||
class VISIBILITY_HIDDEN ConcreteModuleType {
|
||||
public:
|
||||
explicit ConcreteModuleType(ConcreteModuleTypeBuilder data);
|
||||
|
||||
ClassTypePtr getJitType() const;
|
||||
py::object getPyClass() const;
|
||||
IterableModuleKind getIterableModuleKind() const;
|
||||
c10::optional<py::object> findConstant(const std::string& name) const;
|
||||
c10::optional<std::vector<std::string>> findOverloads(
|
||||
const std::string& name) const;
|
||||
c10::optional<Function*> findFunctionAttribute(const std::string& name) const;
|
||||
std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType(
|
||||
const std::string& name) const;
|
||||
c10::optional<std::string> findFailedAttribute(const std::string& name) const;
|
||||
|
||||
// These getters are only here to return things as types that can be
|
||||
// automatically converted by pybind.
|
||||
std::unordered_map<std::string, py::object> getConstantsPy() const;
|
||||
std::unordered_map<std::string, std::pair<TypePtr, bool>> getAttributesPy()
|
||||
const;
|
||||
std::vector<std::pair<std::string, TypePtr>> getModulesPy() const;
|
||||
|
||||
bool equals(const ConcreteModuleType& other) const {
|
||||
if (jitType_ == other.jitType_) {
|
||||
// If the computed types are the same, these modules can (obviously) share
|
||||
// a type.
|
||||
return true;
|
||||
}
|
||||
|
||||
return data_.equals(other.data_);
|
||||
}
|
||||
bool equals(const ConcreteModuleTypeBuilder& other) const {
|
||||
return data_.equals(other);
|
||||
}
|
||||
|
||||
void dump() const;
|
||||
|
||||
private:
|
||||
// The JIT type derived from this ConcreteModuleType.
|
||||
ConcreteModuleTypeBuilder data_;
|
||||
ClassTypePtr jitType_;
|
||||
};
|
||||
|
||||
} // namespace script
|
||||
|
|
|
|||
|
|
@ -1052,42 +1052,51 @@ void initJitScriptBindings(PyObject* module) {
|
|||
return Module(get_python_cu(), type);
|
||||
});
|
||||
|
||||
py::class_<ConcreteModuleTypeBuilder, std::shared_ptr<ConcreteModuleTypeBuilder>>(
|
||||
m, "ConcreteModuleTypeBuilder")
|
||||
.def(py::init<py::object>())
|
||||
.def("add_constant", &ConcreteModuleTypeBuilder::addConstant)
|
||||
.def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute)
|
||||
.def(
|
||||
"add_function_attribute",
|
||||
&ConcreteModuleTypeBuilder::addFunctionAttribute)
|
||||
.def("add_module", &ConcreteModuleTypeBuilder::addModule)
|
||||
.def("add_module_interface", &ConcreteModuleTypeBuilder::addModuleInterface)
|
||||
.def("add_overload", &ConcreteModuleTypeBuilder::addOverload)
|
||||
.def("set_poisoned", &ConcreteModuleTypeBuilder::setPoisoned)
|
||||
.def("add_failed_attribute", &ConcreteModuleTypeBuilder::addFailedAttribute)
|
||||
.def(
|
||||
"set_module_dict",
|
||||
[](ConcreteModuleTypeBuilder& self) {
|
||||
self.setIterableModuleKind(IterableModuleKind::DICT);
|
||||
})
|
||||
.def("build", &ConcreteModuleTypeBuilder::build)
|
||||
.def(
|
||||
"equals",
|
||||
[](const ConcreteModuleTypeBuilder& self,
|
||||
const ConcreteModuleTypeBuilder& other) { return self.equals(other); })
|
||||
.def("set_module_list", [](ConcreteModuleTypeBuilder& self) {
|
||||
self.setIterableModuleKind(IterableModuleKind::LIST);
|
||||
});
|
||||
|
||||
py::class_<ConcreteModuleType, std::shared_ptr<ConcreteModuleType>>(
|
||||
m, "ConcreteModuleType")
|
||||
.def(py::init<>())
|
||||
.def_property_readonly("py_class", &ConcreteModuleType::getPyClass)
|
||||
.def_property_readonly("jit_type", &ConcreteModuleType::getJitType)
|
||||
.def("get_constants", &ConcreteModuleType::getConstantsPy)
|
||||
.def("get_attributes", &ConcreteModuleType::getAttributesPy)
|
||||
.def("get_modules", &ConcreteModuleType::getModulesPy)
|
||||
.def("add_constant", &ConcreteModuleType::addConstant)
|
||||
.def("add_attribute", &ConcreteModuleType::addAttribute)
|
||||
.def("add_function_attribute", &ConcreteModuleType::addFunctionAttribute)
|
||||
.def("add_module", &ConcreteModuleType::addModule)
|
||||
.def("add_module_interface", &ConcreteModuleType::addModuleInterface)
|
||||
.def("add_pyclass", &ConcreteModuleType::addPyClass)
|
||||
.def("add_overload", &ConcreteModuleType::addOverload)
|
||||
.def("add_jit_type", &ConcreteModuleType::addJitType)
|
||||
.def("set_poisoned", &ConcreteModuleType::setPoisoned)
|
||||
.def(
|
||||
"set_module_dict",
|
||||
[](ConcreteModuleType& self) {
|
||||
self.setIterableModuleKind(IterableModuleKind::DICT);
|
||||
})
|
||||
.def(
|
||||
"set_module_list",
|
||||
[](ConcreteModuleType& self) {
|
||||
self.setIterableModuleKind(IterableModuleKind::LIST);
|
||||
})
|
||||
.def(
|
||||
"create_new_type_from_this",
|
||||
&ConcreteModuleType::createNewTypeFromThis)
|
||||
.def("add_failed_attribute", &ConcreteModuleType::addFailedAttribute)
|
||||
.def("dump", &ConcreteModuleType::dump)
|
||||
.def(
|
||||
"equals",
|
||||
[](const ConcreteModuleType& self, const ConcreteModuleType& other) {
|
||||
return self == other;
|
||||
return self.equals(other);
|
||||
})
|
||||
.def(
|
||||
"equals",
|
||||
[](const ConcreteModuleType& self,
|
||||
const ConcreteModuleTypeBuilder& other) {
|
||||
return self.equals(other);
|
||||
})
|
||||
.def(
|
||||
"_create_methods",
|
||||
|
|
|
|||
|
|
@ -59,18 +59,17 @@ def _get_valid_constant(attr, v):
|
|||
3. a list or tuple of (2)
|
||||
""".format(type(v).__name__, attr, constants)))
|
||||
|
||||
def infer_raw_concrete_type(nn_module):
|
||||
def infer_concrete_type_builder(nn_module):
|
||||
"""
|
||||
Build a ConcreteModuleType from an nn.Module. This ConcreteModuleType
|
||||
doesn't have a JIT type associated with it yet, it must be filled in
|
||||
by the caller.
|
||||
Build a ConcreteModuleTypeBuilder from an nn.Module. This
|
||||
ConcreteModuleType doesn't have a JIT type associated with it yet, it
|
||||
must be filled in by the caller.
|
||||
"""
|
||||
concrete_type = torch._C.ConcreteModuleType()
|
||||
concrete_type.add_pyclass(type(nn_module))
|
||||
concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
|
||||
if isinstance(nn_module, (torch.nn.ModuleDict)):
|
||||
concrete_type.set_module_dict()
|
||||
concrete_type_builder.set_module_dict()
|
||||
if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
|
||||
concrete_type.set_module_list()
|
||||
concrete_type_builder.set_module_list()
|
||||
|
||||
class_annotations = getattr(nn_module, '__annotations__', {})
|
||||
|
||||
|
|
@ -96,7 +95,7 @@ def infer_raw_concrete_type(nn_module):
|
|||
continue
|
||||
assert isinstance(item, torch.Tensor)
|
||||
attr_type = infer_type(name, item)
|
||||
concrete_type.add_attribute(name, attr_type, True)
|
||||
concrete_type_builder.add_attribute(name, attr_type, True)
|
||||
added_names.add(name)
|
||||
|
||||
for name, item in nn_module._buffers.items():
|
||||
|
|
@ -109,18 +108,18 @@ def infer_raw_concrete_type(nn_module):
|
|||
continue
|
||||
assert isinstance(item, torch.Tensor)
|
||||
attr_type = infer_type(name, item)
|
||||
concrete_type.add_attribute(name, attr_type, False)
|
||||
concrete_type_builder.add_attribute(name, attr_type, False)
|
||||
added_names.add(name)
|
||||
|
||||
for name, item in nn_module._modules.items():
|
||||
attr_type = infer_type(name, item)
|
||||
if attr_type is not None:
|
||||
# if the type can be inferred, it should be a module interface type
|
||||
concrete_type.add_module_interface(name, attr_type)
|
||||
concrete_type_builder.add_module_interface(name, attr_type)
|
||||
else:
|
||||
# otherwise we get the concrete module type for item and add it to concrete_type
|
||||
sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
|
||||
concrete_type.add_module(name, sub_concrete_type)
|
||||
concrete_type_builder.add_module(name, sub_concrete_type)
|
||||
|
||||
added_names.add(name)
|
||||
|
||||
|
|
@ -146,7 +145,7 @@ def infer_raw_concrete_type(nn_module):
|
|||
"Consider removing it.".format(name))
|
||||
continue
|
||||
value = getattr(nn_module, name)
|
||||
concrete_type.add_constant(name, _get_valid_constant(name, value))
|
||||
concrete_type_builder.add_constant(name, _get_valid_constant(name, value))
|
||||
added_names.add(name)
|
||||
|
||||
# populate overloads
|
||||
|
|
@ -154,7 +153,7 @@ def infer_raw_concrete_type(nn_module):
|
|||
# update with any annotated overloads
|
||||
overloads.update(get_overload_name_mapping(get_overload_annotations(nn_module)))
|
||||
for name, overloaded_names in overloads.items():
|
||||
concrete_type.add_overload(name, overloaded_names)
|
||||
concrete_type_builder.add_overload(name, overloaded_names)
|
||||
|
||||
|
||||
for name, value in nn_module.__dict__.items():
|
||||
|
|
@ -171,7 +170,7 @@ def infer_raw_concrete_type(nn_module):
|
|||
if inspect.isfunction(value):
|
||||
try:
|
||||
scripted_fn = torch.jit.script(value)
|
||||
concrete_type.add_function_attribute(
|
||||
concrete_type_builder.add_function_attribute(
|
||||
name,
|
||||
torch._C._jit_try_infer_type(scripted_fn),
|
||||
value)
|
||||
|
|
@ -182,14 +181,14 @@ def infer_raw_concrete_type(nn_module):
|
|||
hint = ("(This function exists as an attribute on the Python module, "
|
||||
"but we failed to compile it to a TorchScript function. "
|
||||
"\nThe error stack is reproduced here:\n{}").format(e)
|
||||
concrete_type.add_failed_attribute(name, hint)
|
||||
concrete_type_builder.add_failed_attribute(name, hint)
|
||||
pass
|
||||
|
||||
continue
|
||||
|
||||
# Handle Script function attributes
|
||||
if isinstance(value, torch.jit.ScriptFunction):
|
||||
concrete_type.add_function_attribute(
|
||||
concrete_type_builder.add_function_attribute(
|
||||
name,
|
||||
torch._C._jit_try_infer_type(value),
|
||||
value)
|
||||
|
|
@ -198,14 +197,14 @@ def infer_raw_concrete_type(nn_module):
|
|||
# If we got here, this is a regular "data" attribute, Add it to the concrete type
|
||||
attr_type = infer_type(name, value)
|
||||
if attr_type is not None:
|
||||
concrete_type.add_attribute(name, attr_type, False)
|
||||
concrete_type_builder.add_attribute(name, attr_type, False)
|
||||
else:
|
||||
# TODO: could add more detail here. For example, what the user should do
|
||||
# when the pytype is `list` or `NoneType`
|
||||
hint = ("(This attribute exists on the Python module, "
|
||||
"but we failed to convert Python type: '{}' "
|
||||
"to a TorchScript type.)").format(type(value).__name__)
|
||||
concrete_type.add_failed_attribute(name, hint)
|
||||
concrete_type_builder.add_failed_attribute(name, hint)
|
||||
|
||||
# Add @property methods as failed attributes, to give a better error message.
|
||||
for name, value in type(nn_module).__dict__.items():
|
||||
|
|
@ -213,9 +212,9 @@ def infer_raw_concrete_type(nn_module):
|
|||
hint = ("\n(This attribute exists on the Python module, but it's an @property "
|
||||
"method. @property methods are not yet supported in TorchScript. "
|
||||
"Please file a feature request on Github)")
|
||||
concrete_type.add_failed_attribute(name, hint)
|
||||
concrete_type_builder.add_failed_attribute(name, hint)
|
||||
|
||||
return concrete_type
|
||||
return concrete_type_builder
|
||||
|
||||
class ConcreteTypeStore(object):
|
||||
def __init__(self):
|
||||
|
|
@ -234,7 +233,7 @@ class ConcreteTypeStore(object):
|
|||
hasattr(nn_module, "_concrete_type"):
|
||||
return nn_module._concrete_type
|
||||
|
||||
raw_concrete_type = infer_raw_concrete_type(nn_module)
|
||||
concrete_type_builder = infer_concrete_type_builder(nn_module)
|
||||
|
||||
nn_module_type = type(nn_module)
|
||||
if nn_module_type not in self.type_store:
|
||||
|
|
@ -243,13 +242,13 @@ class ConcreteTypeStore(object):
|
|||
# Search the type store for an already-available JIT type
|
||||
known_types = self.type_store[nn_module_type]
|
||||
for known_type in known_types:
|
||||
if raw_concrete_type.equals(known_type):
|
||||
if known_type.equals(concrete_type_builder):
|
||||
return known_type
|
||||
|
||||
# We didn't find anything; generate a new JIT type from this concrete type
|
||||
raw_concrete_type.create_new_type_from_this()
|
||||
self.type_store[nn_module_type].append(raw_concrete_type)
|
||||
return raw_concrete_type
|
||||
concrete_type = concrete_type_builder.build()
|
||||
self.type_store[nn_module_type].append(concrete_type)
|
||||
return concrete_type
|
||||
|
||||
concrete_type_store = ConcreteTypeStore()
|
||||
|
||||
|
|
@ -272,11 +271,12 @@ def create_script_module_for_tracing(nn_module, stubs):
|
|||
stubs: ScriptMethodStubs to compile as part of the conversion process.
|
||||
"""
|
||||
check_module_initialized(nn_module)
|
||||
# Get a ConcreteType without a JIT type. We will generate one ourselves
|
||||
# and fill it in.
|
||||
concrete_type = infer_raw_concrete_type(nn_module)
|
||||
concrete_type.set_poisoned()
|
||||
concrete_type.create_new_type_from_this()
|
||||
# Get a concrete type directly, without trying to re-use an existing JIT
|
||||
# type from the type store.
|
||||
concrete_type_builder = infer_concrete_type_builder(nn_module)
|
||||
concrete_type_builder.set_poisoned()
|
||||
concrete_type = concrete_type_builder.build()
|
||||
|
||||
cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
|
||||
return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user