Split ConcreteModuleType into two types (#29824)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29824

We have two distinct phases/uses for ConcreteModuleType:
1. We are building it up and using it to check whether we can
reuse JIT types. (RawConcreteModuleType)
2. We are using it to satisfy ModuleValue::attr queries.
(ConcreteModuleType)

These types share an underlying `ConcreteModuleTypeData` which
actually stores the relevant info.

Previously they were the same type because I was lazy, but it's been the
source of a bug. So split them to formalize the differing invariants for
the two phases.

Test Plan: Imported from OSS

Differential Revision: D18575010

Pulled By: suo

fbshipit-source-id: 3e4ebcd36e78b947150d8f0dbb74ecccad23e7c4
This commit is contained in:
Michael Suo 2019-11-20 01:11:11 -08:00 committed by Facebook Github Bot
parent 7495c25440
commit 63e66fd267
4 changed files with 249 additions and 245 deletions

View File

@ -3,16 +3,7 @@
namespace torch {
namespace jit {
namespace script {
ClassTypePtr ConcreteModuleType::getJitType() const {
TORCH_INTERNAL_ASSERT(jitType_);
return jitType_;
}
ClassTypePtr ConcreteModuleType::createNewTypeFromThis() {
TORCH_INTERNAL_ASSERT(!jitType_);
TORCH_INTERNAL_ASSERT(pyClass_);
ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
auto cu = get_python_cu();
py::object pyQualName = py::module::import("torch._jit_internal")
.attr("_qualified_name")(pyClass_);
@ -41,21 +32,85 @@ ClassTypePtr ConcreteModuleType::createNewTypeFromThis() {
moduleInfo.name_, moduleInfo.getJitType(), /*is_parameter=*/false);
}
jitType_ = std::move(cls);
return cls;
}
ConcreteModuleType::ConcreteModuleType(ConcreteModuleTypeBuilder data)
: data_(std::move(data)) {
jitType_ = data_.createTypeFromThis();
}
TypePtr ConcreteModuleTypeBuilder::ModuleInfo::getJitType() const {
return meta_ == nullptr ? type_ : meta_->getJitType();
}
bool operator==(
const ConcreteModuleTypeBuilder::ModuleInfo& lhs,
const ConcreteModuleTypeBuilder::ModuleInfo& rhs) {
if (lhs.meta_ != nullptr && rhs.meta_ != nullptr) {
return lhs.meta_->equals(*rhs.meta_);
} else if (lhs.type_ != nullptr && rhs.type_ != nullptr) {
return *(lhs.type_) == *(rhs.type_);
} else {
return false;
}
}
bool ConcreteModuleTypeBuilder::equals(
const ConcreteModuleTypeBuilder& other) const {
if (isPoisoned_ || other.isPoisoned_) {
return false;
}
// clang-format off
// These are vaguely ordered so that cheap, discriminating checks happen first.
bool equal =
pyClass_.is(other.pyClass_) &&
iterableModuleKind_ == other.iterableModuleKind_ &&
constants_ == other.constants_ &&
attributes_ == other.attributes_ &&
overloads_ == other.overloads_ &&
functionAttributes_ == other.functionAttributes_;
// clang-format on
if (!equal) {
return false;
}
// We store modules in order of insertion (to make compilation
// deterministic). However, for the purposes of equality, insertion order
// should not matter, so sort them by name.
// We put this check last because it involves the most work.
auto thisSorted = modules_;
std::sort(
thisSorted.begin(),
thisSorted.end(),
[](const ModuleInfo& a, const ModuleInfo& b) {
return a.name_ < b.name_;
});
auto otherSorted = other.modules_;
std::sort(
otherSorted.begin(),
otherSorted.end(),
[](const ModuleInfo& a, const ModuleInfo& b) {
return a.name_ < b.name_;
});
return thisSorted == otherSorted;
}
ClassTypePtr ConcreteModuleType::getJitType() const {
return jitType_;
}
py::object ConcreteModuleType::getPyClass() const {
TORCH_INTERNAL_ASSERT(jitType_);
TORCH_INTERNAL_ASSERT(pyClass_);
return pyClass_;
return data_.pyClass_;
}
c10::optional<std::vector<std::string>> ConcreteModuleType::findOverloads(
const std::string& name) const {
TORCH_INTERNAL_ASSERT(jitType_);
const auto it = overloads_.find(name);
if (it != overloads_.end()) {
const auto it = data_.overloads_.find(name);
if (it != data_.overloads_.end()) {
return it->second;
}
return c10::nullopt;
@ -63,9 +118,8 @@ c10::optional<std::vector<std::string>> ConcreteModuleType::findOverloads(
c10::optional<Function*> ConcreteModuleType::findFunctionAttribute(
const std::string& name) const {
TORCH_INTERNAL_ASSERT(jitType_);
const auto it = functionAttributes_.find(name);
if (it != functionAttributes_.end()) {
const auto it = data_.functionAttributes_.find(name);
if (it != data_.functionAttributes_.end()) {
return it->second.function_->function();
}
return c10::nullopt;
@ -73,9 +127,8 @@ c10::optional<Function*> ConcreteModuleType::findFunctionAttribute(
c10::optional<std::string> ConcreteModuleType::findFailedAttribute(
const std::string& name) const {
TORCH_INTERNAL_ASSERT(jitType_);
const auto it = failedAttributes_.find(name);
if (it != failedAttributes_.end()) {
const auto it = data_.failedAttributes_.find(name);
if (it != data_.failedAttributes_.end()) {
return it->second;
}
return c10::nullopt;
@ -83,130 +136,114 @@ c10::optional<std::string> ConcreteModuleType::findFailedAttribute(
std::shared_ptr<ConcreteModuleType> ConcreteModuleType::
findSubmoduleConcreteType(const std::string& name) const {
TORCH_INTERNAL_ASSERT(jitType_);
const auto it = std::find_if(
modules_.cbegin(), modules_.cend(), [&](const ModuleInfo& info) {
data_.modules_.cbegin(),
data_.modules_.cend(),
[&](const ConcreteModuleTypeBuilder::ModuleInfo& info) {
return info.name_ == name;
});
if (it == modules_.end()) {
if (it == data_.modules_.end()) {
return nullptr;
}
return it->meta_;
}
void ConcreteModuleType::setIterableModuleKind(IterableModuleKind kind) {
TORCH_INTERNAL_ASSERT(!jitType_);
void ConcreteModuleTypeBuilder::setIterableModuleKind(IterableModuleKind kind) {
iterableModuleKind_ = kind;
}
IterableModuleKind ConcreteModuleType::getIterableModuleKind() const {
TORCH_INTERNAL_ASSERT(jitType_);
return iterableModuleKind_;
return data_.iterableModuleKind_;
}
void ConcreteModuleType::setPoisoned() {
TORCH_INTERNAL_ASSERT(!jitType_)
void ConcreteModuleTypeBuilder::setPoisoned() {
isPoisoned_ = true;
}
void ConcreteModuleType::addJitType(ClassTypePtr type) {
TORCH_INTERNAL_ASSERT(!jitType_)
jitType_ = std::move(type);
}
void ConcreteModuleType::addPyClass(py::object pyClass) {
TORCH_INTERNAL_ASSERT(!jitType_);
pyClass_ = std::move(pyClass);
}
void ConcreteModuleType::addConstant(std::string name, py::object value) {
TORCH_INTERNAL_ASSERT(!jitType_);
void ConcreteModuleTypeBuilder::addConstant(std::string name, py::object value) {
constants_.emplace(std::move(name), std::move(value));
}
void ConcreteModuleType::addAttribute(
void ConcreteModuleTypeBuilder::addAttribute(
std::string name,
TypePtr type,
bool isParameter) {
TORCH_INTERNAL_ASSERT(type);
TORCH_INTERNAL_ASSERT(!jitType_);
// Function attributes should be handled separately
TORCH_INTERNAL_ASSERT(type->cast<FunctionType>() == nullptr);
attributes_.emplace(
std::move(name), Attribute(unshapedType(type), isParameter));
std::move(name),
ConcreteModuleTypeBuilder::Attribute(unshapedType(type), isParameter));
}
void ConcreteModuleType::addFunctionAttribute(
void ConcreteModuleTypeBuilder::addFunctionAttribute(
std::string name,
const TypePtr& type,
py::object pyFunction) {
TORCH_INTERNAL_ASSERT(type);
TORCH_INTERNAL_ASSERT(!jitType_);
functionAttributes_.emplace(
std::move(name),
FunctionAttribute{type->expect<FunctionType>(), std::move(pyFunction)});
ConcreteModuleTypeBuilder::FunctionAttribute{type->expect<FunctionType>(),
std::move(pyFunction)});
}
void ConcreteModuleType::addModule(
void ConcreteModuleTypeBuilder::addModule(
std::string name,
std::shared_ptr<ConcreteModuleType> meta) {
TORCH_INTERNAL_ASSERT(!jitType_);
modules_.emplace_back(ModuleInfo{std::move(name), std::move(meta)});
modules_.emplace_back(
ConcreteModuleTypeBuilder::ModuleInfo{std::move(name), std::move(meta)});
}
void ConcreteModuleType::addModuleInterface(
void ConcreteModuleTypeBuilder::addModuleInterface(
std::string name,
const TypePtr& type) {
TORCH_INTERNAL_ASSERT(!jitType_);
TORCH_INTERNAL_ASSERT(type->cast<InterfaceType>() && type->is_module());
modules_.emplace_back(ModuleInfo{std::move(name), type});
modules_.emplace_back(
ConcreteModuleTypeBuilder::ModuleInfo{std::move(name), type});
}
void ConcreteModuleType::addOverload(
void ConcreteModuleTypeBuilder::addOverload(
std::string methodName,
std::vector<std::string> overloadedMethodNames) {
TORCH_INTERNAL_ASSERT(!jitType_);
overloads_.emplace(std::move(methodName), std::move(overloadedMethodNames));
}
void ConcreteModuleType::addFailedAttribute(
void ConcreteModuleTypeBuilder::addFailedAttribute(
std::string name,
std::string failureReason) {
TORCH_INTERNAL_ASSERT(!jitType_);
failedAttributes_.emplace(std::move(name), std::move(failureReason));
}
c10::optional<py::object> ConcreteModuleType::findConstant(
const std::string& name) const {
auto it = constants_.find(name);
if (it != constants_.end()) {
auto it = data_.constants_.find(name);
if (it != data_.constants_.end()) {
return it->second.v_;
}
return c10::nullopt;
}
void ConcreteModuleType::dump() const {
std::cout << "ConcreteModuleType for: " << py::getattr(pyClass_, "__name__") << "\n";
std::cout << "ConcreteModuleType for: " << py::getattr(data_.pyClass_, "__name__") << "\n";
std::cout << "Constants: \n";
for (const auto& pr : constants_) {
for (const auto& pr : data_.constants_) {
std::cout << "\t" << pr.first << ": " << pr.second.v_ << "\n";
}
std::cout << "\nAttributes: \n";
for (const auto& pr : attributes_) {
for (const auto& pr : data_.attributes_) {
std::cout << "\t" << pr.first << ": " << pr.second.type_->python_str()
<< "\n";
}
std::cout << "\nSubmodules: \n";
for (const auto& info : modules_) {
for (const auto& info : data_.modules_) {
std::cout << "\t" << info.name_ << ": "
<< info.getJitType()->python_str() << "\n";
}
std::cout << "\nOverloads: \n";
for (const auto& pr : overloads_) {
for (const auto& pr : data_.overloads_) {
std::cout << "\t" << pr.first << ": " << pr.second << "\n";
}
std::string isPoisoned = isPoisoned_ ? "true" : "false";
std::string isPoisoned = data_.isPoisoned_ ? "true" : "false";
std::cout << "isPoisoned: " << isPoisoned << "\n";
if (jitType_) {
std::cout << "jit type: " << jitType_->python_str() << "\n";
@ -215,11 +252,10 @@ void ConcreteModuleType::dump() const {
std::unordered_map<std::string, py::object> ConcreteModuleType::getConstantsPy()
const {
TORCH_INTERNAL_ASSERT(jitType_);
// Convert to a more pybind-friendly representation, so we don't
// need to bind ConcreteModuleType::Constant as well.
std::unordered_map<std::string, py::object> ret;
for (const auto& pr : constants_) {
for (const auto& pr : data_.constants_) {
ret.emplace(pr.first, pr.second.v_);
}
return ret;
@ -227,11 +263,10 @@ std::unordered_map<std::string, py::object> ConcreteModuleType::getConstantsPy()
std::unordered_map<std::string, std::pair<TypePtr, bool>> ConcreteModuleType::
getAttributesPy() const {
TORCH_INTERNAL_ASSERT(jitType_);
// Convert to a more pybind-friendly representation, so we don't
// need to bind ConcreteModuleType::Attribute as well.
std::unordered_map<std::string, std::pair<TypePtr, bool>> ret;
for (auto& pr : attributes_) {
for (auto& pr : data_.attributes_) {
ret.emplace(
pr.first,
std::pair<TypePtr, bool>(pr.second.type_, pr.second.isParam_));
@ -241,10 +276,9 @@ std::unordered_map<std::string, std::pair<TypePtr, bool>> ConcreteModuleType::
std::vector<std::pair<std::string, TypePtr>> ConcreteModuleType::getModulesPy()
const {
TORCH_INTERNAL_ASSERT(jitType_);
std::vector<std::pair<std::string, TypePtr>> ret;
for (const ModuleInfo& info: modules_) {
for (const auto& info : data_.modules_) {
ret.emplace_back(std::make_pair(info.name_, info.getJitType()));
}
return ret;

View File

@ -9,7 +9,9 @@
namespace torch {
namespace jit {
namespace script {
enum class IterableModuleKind { NONE, LIST, DICT };
class ConcreteModuleType;
// You can think of an nn.Module as a template that corresponds to a family of
// JIT types. The template "arguments" are things like the constant values.
@ -41,21 +43,22 @@ enum class IterableModuleKind { NONE, LIST, DICT };
// ModuleValue::attr calls. This is so we can guarantee that if two Module's
// share a JIT type (and thus a ConcreteModuleType), then they behave the same
// way when you access attributes on them.
class VISIBILITY_HIDDEN ConcreteModuleType {
public:
// ConcreteModuleType has two states.
// 1. Building: First we build it up, during the ScriptModule conversion
// process
// ... to transition, we freeze the type by associating with a JIT type.
// 2. Querying: Then we ask use it as a source of truth during method
// compilation.
// During this time the ModuleType is effectively const.
// Yes, it could be two different types. Not terribly worth the verbosity.
/**
* Builder methods (jitType must be null)
*/
void addPyClass(py::object pyClass);
// ConcreteModuleType has two phases.
// 1. Creation: First we build it up, during the ScriptModule conversion
// process. This is represented by ConcreteModuleTypeBuilder.
// ...then the converter calls ConcreteModuleTypeBuilder::build(), producing a
// ConcreteModuleType ready for querying.
// 2. Querying: We use ConcreteModuleType as a source of truth for
// ModuleValue::attr calls during method compilation.
// Represents a concrete type during in the process for construction. We use
// this to decide whether we can share types between modules.
class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
public:
explicit ConcreteModuleTypeBuilder(py::object pyClass) {
pyClass_ = std::move(pyClass);
}
void addConstant(std::string name, py::object value);
void addAttribute(std::string name, TypePtr type, bool isParameter);
void addFunctionAttribute(
@ -73,95 +76,20 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
std::vector<std::string> overloadedMethodNames);
void addFailedAttribute(std::string name, std::string failureReason);
void setIterableModuleKind(IterableModuleKind kind);
// If a ConcreteModuleType is poisoned, it will never compare equal to any other
// concrete type
void setPoisoned();
/**
* Freezing methods
*/
// Based on the data in this ConcreteType, create an equivalent JIT type and
// associate this module type with it.
ClassTypePtr createNewTypeFromThis();
// Associate the provided type with this ConcreteType
void addJitType(ClassTypePtr type);
/**
* Query methods (jitType must be non-null)
*/
ClassTypePtr getJitType() const;
py::object getPyClass() const;
IterableModuleKind getIterableModuleKind() const;
c10::optional<py::object> findConstant(const std::string& name) const;
c10::optional<std::vector<std::string>> findOverloads(
const std::string& name) const;
c10::optional<Function*> findFunctionAttribute(const std::string& name) const;
std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType(
const std::string& name) const;
c10::optional<std::string> findFailedAttribute(const std::string& name) const;
// These getters are only here to return things as types that can be
// automatically converted by pybind.
std::unordered_map<std::string, py::object> getConstantsPy() const;
std::unordered_map<std::string, std::pair<TypePtr, bool>> getAttributesPy()
const;
std::vector<std::pair<std::string, TypePtr>> getModulesPy() const;
std::shared_ptr<ConcreteModuleType> build() const {
return std::make_shared<ConcreteModuleType>(*this);
}
// This determines whether two modules can share a type. The container structs
// used by ConcreteModuleType have been defined such that operator==
// implements a meaningful comparison in that context.
friend bool operator==(
const ConcreteModuleType& lhs,
const ConcreteModuleType& rhs) {
if (lhs.jitType_ == rhs.jitType_) {
// If the computed types are the same, these modules can (obviously) share
// a type.
return true;
}
bool equals(const ConcreteModuleTypeBuilder& other) const;
if (lhs.isPoisoned_ || rhs.isPoisoned_) {
return false;
}
// clang-format off
// These are vaguely ordered so that cheap, discriminating checks happen first.
bool equal =
lhs.pyClass_.is(rhs.pyClass_) &&
lhs.iterableModuleKind_ == rhs.iterableModuleKind_ &&
lhs.constants_ == rhs.constants_ &&
lhs.attributes_ == rhs.attributes_ &&
lhs.overloads_ == rhs.overloads_ &&
lhs.functionAttributes_ == rhs.functionAttributes_;
// clang-format on
if (!equal) {
return false;
}
// We store modules in order of insertion (to make compilation
// deterministic). However, for the purposes of equality, insertion order
// should not matter, so sort them by name.
// We put this check last because it involves the most work.
auto lhsSorted = lhs.modules_;
std::sort(
lhsSorted.begin(),
lhsSorted.end(),
[](const ModuleInfo& a, const ModuleInfo& b) {
return a.name_ < b.name_;
});
auto rhsSorted = rhs.modules_;
std::sort(
rhsSorted.begin(),
rhsSorted.end(),
[](const ModuleInfo& a, const ModuleInfo& b) {
return a.name_ < b.name_;
});
return lhsSorted == rhsSorted;
}
void dump() const;
private:
struct Constant {
/* implicit */ Constant(py::object v) : v_(std::move(v)) {}
friend bool operator==(const Constant& lhs, const Constant& rhs) {
@ -206,28 +134,18 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
ModuleInfo(std::string name, const TypePtr& type)
: name_(std::move(name)), meta_(nullptr), type_(type) {}
friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs) {
if (lhs.meta_ != nullptr && rhs.meta_ != nullptr) {
return *(lhs.meta_) == *(rhs.meta_);
} else if (lhs.type_ != nullptr && rhs.type_ != nullptr) {
return *(lhs.type_) == *(rhs.type_);
} else {
return false;
}
}
TypePtr getJitType() const {
return meta_ == nullptr? type_ : meta_->getJitType();
}
TypePtr getJitType() const;
std::string name_;
friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs);
// Module Info contains either an ConcreateModuleType or a type (which is
// a Module Interface), these two are union relationship.
std::shared_ptr<ConcreteModuleType> meta_;
TypePtr type_;
};
private:
ClassTypePtr createTypeFromThis() const;
// If true, this type will never compare equally to anything else. This is
// used if we want to ensure that this type is not shared (for example, if it
// came from a traced module)
@ -256,11 +174,54 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
// The original `nn.Module` class that we derived this ScriptModule from.
py::object pyClass_;
// The JIT type derived from this ConcreteModuleType.
ClassTypePtr jitType_ = nullptr;
// NOTE: If you ever add any more state to this struct, you need to make sure
// operator== still makes sense! The only field that can be excluded from it
// is `jitType_`.
// operator== still makes sense!
friend ConcreteModuleType;
};
// Represents a finalized concrete type, used to service ModuleValue::attr calls
// during method compilation.
class VISIBILITY_HIDDEN ConcreteModuleType {
public:
explicit ConcreteModuleType(ConcreteModuleTypeBuilder data);
ClassTypePtr getJitType() const;
py::object getPyClass() const;
IterableModuleKind getIterableModuleKind() const;
c10::optional<py::object> findConstant(const std::string& name) const;
c10::optional<std::vector<std::string>> findOverloads(
const std::string& name) const;
c10::optional<Function*> findFunctionAttribute(const std::string& name) const;
std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType(
const std::string& name) const;
c10::optional<std::string> findFailedAttribute(const std::string& name) const;
// These getters are only here to return things as types that can be
// automatically converted by pybind.
std::unordered_map<std::string, py::object> getConstantsPy() const;
std::unordered_map<std::string, std::pair<TypePtr, bool>> getAttributesPy()
const;
std::vector<std::pair<std::string, TypePtr>> getModulesPy() const;
bool equals(const ConcreteModuleType& other) const {
if (jitType_ == other.jitType_) {
// If the computed types are the same, these modules can (obviously) share
// a type.
return true;
}
return data_.equals(other.data_);
}
bool equals(const ConcreteModuleTypeBuilder& other) const {
return data_.equals(other);
}
void dump() const;
private:
// The JIT type derived from this ConcreteModuleType.
ConcreteModuleTypeBuilder data_;
ClassTypePtr jitType_;
};
} // namespace script

View File

@ -1052,42 +1052,51 @@ void initJitScriptBindings(PyObject* module) {
return Module(get_python_cu(), type);
});
py::class_<ConcreteModuleTypeBuilder, std::shared_ptr<ConcreteModuleTypeBuilder>>(
m, "ConcreteModuleTypeBuilder")
.def(py::init<py::object>())
.def("add_constant", &ConcreteModuleTypeBuilder::addConstant)
.def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute)
.def(
"add_function_attribute",
&ConcreteModuleTypeBuilder::addFunctionAttribute)
.def("add_module", &ConcreteModuleTypeBuilder::addModule)
.def("add_module_interface", &ConcreteModuleTypeBuilder::addModuleInterface)
.def("add_overload", &ConcreteModuleTypeBuilder::addOverload)
.def("set_poisoned", &ConcreteModuleTypeBuilder::setPoisoned)
.def("add_failed_attribute", &ConcreteModuleTypeBuilder::addFailedAttribute)
.def(
"set_module_dict",
[](ConcreteModuleTypeBuilder& self) {
self.setIterableModuleKind(IterableModuleKind::DICT);
})
.def("build", &ConcreteModuleTypeBuilder::build)
.def(
"equals",
[](const ConcreteModuleTypeBuilder& self,
const ConcreteModuleTypeBuilder& other) { return self.equals(other); })
.def("set_module_list", [](ConcreteModuleTypeBuilder& self) {
self.setIterableModuleKind(IterableModuleKind::LIST);
});
py::class_<ConcreteModuleType, std::shared_ptr<ConcreteModuleType>>(
m, "ConcreteModuleType")
.def(py::init<>())
.def_property_readonly("py_class", &ConcreteModuleType::getPyClass)
.def_property_readonly("jit_type", &ConcreteModuleType::getJitType)
.def("get_constants", &ConcreteModuleType::getConstantsPy)
.def("get_attributes", &ConcreteModuleType::getAttributesPy)
.def("get_modules", &ConcreteModuleType::getModulesPy)
.def("add_constant", &ConcreteModuleType::addConstant)
.def("add_attribute", &ConcreteModuleType::addAttribute)
.def("add_function_attribute", &ConcreteModuleType::addFunctionAttribute)
.def("add_module", &ConcreteModuleType::addModule)
.def("add_module_interface", &ConcreteModuleType::addModuleInterface)
.def("add_pyclass", &ConcreteModuleType::addPyClass)
.def("add_overload", &ConcreteModuleType::addOverload)
.def("add_jit_type", &ConcreteModuleType::addJitType)
.def("set_poisoned", &ConcreteModuleType::setPoisoned)
.def(
"set_module_dict",
[](ConcreteModuleType& self) {
self.setIterableModuleKind(IterableModuleKind::DICT);
})
.def(
"set_module_list",
[](ConcreteModuleType& self) {
self.setIterableModuleKind(IterableModuleKind::LIST);
})
.def(
"create_new_type_from_this",
&ConcreteModuleType::createNewTypeFromThis)
.def("add_failed_attribute", &ConcreteModuleType::addFailedAttribute)
.def("dump", &ConcreteModuleType::dump)
.def(
"equals",
[](const ConcreteModuleType& self, const ConcreteModuleType& other) {
return self == other;
return self.equals(other);
})
.def(
"equals",
[](const ConcreteModuleType& self,
const ConcreteModuleTypeBuilder& other) {
return self.equals(other);
})
.def(
"_create_methods",

View File

@ -59,18 +59,17 @@ def _get_valid_constant(attr, v):
3. a list or tuple of (2)
""".format(type(v).__name__, attr, constants)))
def infer_raw_concrete_type(nn_module):
def infer_concrete_type_builder(nn_module):
"""
Build a ConcreteModuleType from an nn.Module. This ConcreteModuleType
doesn't have a JIT type associated with it yet, it must be filled in
by the caller.
Build a ConcreteModuleTypeBuilder from an nn.Module. This
ConcreteModuleType doesn't have a JIT type associated with it yet, it
must be filled in by the caller.
"""
concrete_type = torch._C.ConcreteModuleType()
concrete_type.add_pyclass(type(nn_module))
concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
if isinstance(nn_module, (torch.nn.ModuleDict)):
concrete_type.set_module_dict()
concrete_type_builder.set_module_dict()
if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
concrete_type.set_module_list()
concrete_type_builder.set_module_list()
class_annotations = getattr(nn_module, '__annotations__', {})
@ -96,7 +95,7 @@ def infer_raw_concrete_type(nn_module):
continue
assert isinstance(item, torch.Tensor)
attr_type = infer_type(name, item)
concrete_type.add_attribute(name, attr_type, True)
concrete_type_builder.add_attribute(name, attr_type, True)
added_names.add(name)
for name, item in nn_module._buffers.items():
@ -109,18 +108,18 @@ def infer_raw_concrete_type(nn_module):
continue
assert isinstance(item, torch.Tensor)
attr_type = infer_type(name, item)
concrete_type.add_attribute(name, attr_type, False)
concrete_type_builder.add_attribute(name, attr_type, False)
added_names.add(name)
for name, item in nn_module._modules.items():
attr_type = infer_type(name, item)
if attr_type is not None:
# if the type can be inferred, it should be a module interface type
concrete_type.add_module_interface(name, attr_type)
concrete_type_builder.add_module_interface(name, attr_type)
else:
# otherwise we get the concrete module type for item and add it to concrete_type
sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
concrete_type.add_module(name, sub_concrete_type)
concrete_type_builder.add_module(name, sub_concrete_type)
added_names.add(name)
@ -146,7 +145,7 @@ def infer_raw_concrete_type(nn_module):
"Consider removing it.".format(name))
continue
value = getattr(nn_module, name)
concrete_type.add_constant(name, _get_valid_constant(name, value))
concrete_type_builder.add_constant(name, _get_valid_constant(name, value))
added_names.add(name)
# populate overloads
@ -154,7 +153,7 @@ def infer_raw_concrete_type(nn_module):
# update with any annotated overloads
overloads.update(get_overload_name_mapping(get_overload_annotations(nn_module)))
for name, overloaded_names in overloads.items():
concrete_type.add_overload(name, overloaded_names)
concrete_type_builder.add_overload(name, overloaded_names)
for name, value in nn_module.__dict__.items():
@ -171,7 +170,7 @@ def infer_raw_concrete_type(nn_module):
if inspect.isfunction(value):
try:
scripted_fn = torch.jit.script(value)
concrete_type.add_function_attribute(
concrete_type_builder.add_function_attribute(
name,
torch._C._jit_try_infer_type(scripted_fn),
value)
@ -182,14 +181,14 @@ def infer_raw_concrete_type(nn_module):
hint = ("(This function exists as an attribute on the Python module, "
"but we failed to compile it to a TorchScript function. "
"\nThe error stack is reproduced here:\n{}").format(e)
concrete_type.add_failed_attribute(name, hint)
concrete_type_builder.add_failed_attribute(name, hint)
pass
continue
# Handle Script function attributes
if isinstance(value, torch.jit.ScriptFunction):
concrete_type.add_function_attribute(
concrete_type_builder.add_function_attribute(
name,
torch._C._jit_try_infer_type(value),
value)
@ -198,14 +197,14 @@ def infer_raw_concrete_type(nn_module):
# If we got here, this is a regular "data" attribute, Add it to the concrete type
attr_type = infer_type(name, value)
if attr_type is not None:
concrete_type.add_attribute(name, attr_type, False)
concrete_type_builder.add_attribute(name, attr_type, False)
else:
# TODO: could add more detail here. For example, what the user should do
# when the pytype is `list` or `NoneType`
hint = ("(This attribute exists on the Python module, "
"but we failed to convert Python type: '{}' "
"to a TorchScript type.)").format(type(value).__name__)
concrete_type.add_failed_attribute(name, hint)
concrete_type_builder.add_failed_attribute(name, hint)
# Add @property methods as failed attributes, to give a better error message.
for name, value in type(nn_module).__dict__.items():
@ -213,9 +212,9 @@ def infer_raw_concrete_type(nn_module):
hint = ("\n(This attribute exists on the Python module, but it's an @property "
"method. @property methods are not yet supported in TorchScript. "
"Please file a feature request on Github)")
concrete_type.add_failed_attribute(name, hint)
concrete_type_builder.add_failed_attribute(name, hint)
return concrete_type
return concrete_type_builder
class ConcreteTypeStore(object):
def __init__(self):
@ -234,7 +233,7 @@ class ConcreteTypeStore(object):
hasattr(nn_module, "_concrete_type"):
return nn_module._concrete_type
raw_concrete_type = infer_raw_concrete_type(nn_module)
concrete_type_builder = infer_concrete_type_builder(nn_module)
nn_module_type = type(nn_module)
if nn_module_type not in self.type_store:
@ -243,13 +242,13 @@ class ConcreteTypeStore(object):
# Search the type store for an already-available JIT type
known_types = self.type_store[nn_module_type]
for known_type in known_types:
if raw_concrete_type.equals(known_type):
if known_type.equals(concrete_type_builder):
return known_type
# We didn't find anything; generate a new JIT type from this concrete type
raw_concrete_type.create_new_type_from_this()
self.type_store[nn_module_type].append(raw_concrete_type)
return raw_concrete_type
concrete_type = concrete_type_builder.build()
self.type_store[nn_module_type].append(concrete_type)
return concrete_type
concrete_type_store = ConcreteTypeStore()
@ -272,11 +271,12 @@ def create_script_module_for_tracing(nn_module, stubs):
stubs: ScriptMethodStubs to compile as part of the conversion process.
"""
check_module_initialized(nn_module)
# Get a ConcreteType without a JIT type. We will generate one ourselves
# and fill it in.
concrete_type = infer_raw_concrete_type(nn_module)
concrete_type.set_poisoned()
concrete_type.create_new_type_from_this()
# Get a concrete type directly, without trying to re-use an existing JIT
# type from the type store.
concrete_type_builder = infer_concrete_type_builder(nn_module)
concrete_type_builder.set_poisoned()
concrete_type = concrete_type_builder.build()
cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)