From 63e66fd26714c767be87bcd5f3cebda2bc567bdf Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Wed, 20 Nov 2019 01:11:11 -0800 Subject: [PATCH] Split ConcreteModuleType into two types (#29824) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29824 We have two distinct phases/uses for ConcreteModuleType: 1. We are building it up and using it to check whether we can reuse JIT types. (RawConcreteModuleType) 2. We are using it to satisfy ModuleValue::attr queries. (ConcreteModuleType) These types share an underlying `ConcreteModuleTypeData` which actually stores the relevant info. Previously they were the same type because I was lazy, but it's been the source of a bug. So split them to formalize the differing invariants for the two phases. Test Plan: Imported from OSS Differential Revision: D18575010 Pulled By: suo fbshipit-source-id: 3e4ebcd36e78b947150d8f0dbb74ecccad23e7c4 --- .../csrc/jit/script/concrete_module_type.cpp | 184 ++++++++++------- torch/csrc/jit/script/concrete_module_type.h | 189 +++++++----------- torch/csrc/jit/script/init.cpp | 59 +++--- torch/jit/_recursive.py | 62 +++--- 4 files changed, 249 insertions(+), 245 deletions(-) diff --git a/torch/csrc/jit/script/concrete_module_type.cpp b/torch/csrc/jit/script/concrete_module_type.cpp index 2ab9e902b94..1e0b96ba32a 100644 --- a/torch/csrc/jit/script/concrete_module_type.cpp +++ b/torch/csrc/jit/script/concrete_module_type.cpp @@ -3,16 +3,7 @@ namespace torch { namespace jit { namespace script { - -ClassTypePtr ConcreteModuleType::getJitType() const { - TORCH_INTERNAL_ASSERT(jitType_); - return jitType_; -} - -ClassTypePtr ConcreteModuleType::createNewTypeFromThis() { - TORCH_INTERNAL_ASSERT(!jitType_); - TORCH_INTERNAL_ASSERT(pyClass_); - +ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const { auto cu = get_python_cu(); py::object pyQualName = py::module::import("torch._jit_internal") .attr("_qualified_name")(pyClass_); @@ -41,21 +32,85 @@ ClassTypePtr ConcreteModuleType::createNewTypeFromThis() { moduleInfo.name_, moduleInfo.getJitType(), /*is_parameter=*/false); } - jitType_ = std::move(cls); + return cls; +} + +ConcreteModuleType::ConcreteModuleType(ConcreteModuleTypeBuilder data) + : data_(std::move(data)) { + jitType_ = data_.createTypeFromThis(); +} + +TypePtr ConcreteModuleTypeBuilder::ModuleInfo::getJitType() const { + return meta_ == nullptr ? type_ : meta_->getJitType(); +} + +bool operator==( + const ConcreteModuleTypeBuilder::ModuleInfo& lhs, + const ConcreteModuleTypeBuilder::ModuleInfo& rhs) { + if (lhs.meta_ != nullptr && rhs.meta_ != nullptr) { + return lhs.meta_->equals(*rhs.meta_); + } else if (lhs.type_ != nullptr && rhs.type_ != nullptr) { + return *(lhs.type_) == *(rhs.type_); + } else { + return false; + } +} + +bool ConcreteModuleTypeBuilder::equals( + const ConcreteModuleTypeBuilder& other) const { + if (isPoisoned_ || other.isPoisoned_) { + return false; + } + + // clang-format off + // These are vaguely ordered so that cheap, discriminating checks happen first. + bool equal = + pyClass_.is(other.pyClass_) && + iterableModuleKind_ == other.iterableModuleKind_ && + constants_ == other.constants_ && + attributes_ == other.attributes_ && + overloads_ == other.overloads_ && + functionAttributes_ == other.functionAttributes_; + // clang-format on + if (!equal) { + return false; + } + + // We store modules in order of insertion (to make compilation + // deterministic). However, for the purposes of equality, insertion order + // should not matter, so sort them by name. + // We put this check last because it involves the most work. + auto thisSorted = modules_; + std::sort( + thisSorted.begin(), + thisSorted.end(), + [](const ModuleInfo& a, const ModuleInfo& b) { + return a.name_ < b.name_; + }); + + auto otherSorted = other.modules_; + std::sort( + otherSorted.begin(), + otherSorted.end(), + [](const ModuleInfo& a, const ModuleInfo& b) { + return a.name_ < b.name_; + }); + + return thisSorted == otherSorted; +} + +ClassTypePtr ConcreteModuleType::getJitType() const { return jitType_; } py::object ConcreteModuleType::getPyClass() const { - TORCH_INTERNAL_ASSERT(jitType_); - TORCH_INTERNAL_ASSERT(pyClass_); - return pyClass_; + return data_.pyClass_; } c10::optional> ConcreteModuleType::findOverloads( const std::string& name) const { - TORCH_INTERNAL_ASSERT(jitType_); - const auto it = overloads_.find(name); - if (it != overloads_.end()) { + const auto it = data_.overloads_.find(name); + if (it != data_.overloads_.end()) { return it->second; } return c10::nullopt; @@ -63,9 +118,8 @@ c10::optional> ConcreteModuleType::findOverloads( c10::optional ConcreteModuleType::findFunctionAttribute( const std::string& name) const { - TORCH_INTERNAL_ASSERT(jitType_); - const auto it = functionAttributes_.find(name); - if (it != functionAttributes_.end()) { + const auto it = data_.functionAttributes_.find(name); + if (it != data_.functionAttributes_.end()) { return it->second.function_->function(); } return c10::nullopt; @@ -73,9 +127,8 @@ c10::optional ConcreteModuleType::findFunctionAttribute( c10::optional ConcreteModuleType::findFailedAttribute( const std::string& name) const { - TORCH_INTERNAL_ASSERT(jitType_); - const auto it = failedAttributes_.find(name); - if (it != failedAttributes_.end()) { + const auto it = data_.failedAttributes_.find(name); + if (it != data_.failedAttributes_.end()) { return it->second; } return c10::nullopt; @@ -83,130 +136,114 @@ c10::optional ConcreteModuleType::findFailedAttribute( std::shared_ptr ConcreteModuleType:: findSubmoduleConcreteType(const std::string& name) const { - TORCH_INTERNAL_ASSERT(jitType_); const auto it = std::find_if( - modules_.cbegin(), modules_.cend(), [&](const ModuleInfo& info) { + data_.modules_.cbegin(), + data_.modules_.cend(), + [&](const ConcreteModuleTypeBuilder::ModuleInfo& info) { return info.name_ == name; }); - if (it == modules_.end()) { + if (it == data_.modules_.end()) { return nullptr; } return it->meta_; } -void ConcreteModuleType::setIterableModuleKind(IterableModuleKind kind) { - TORCH_INTERNAL_ASSERT(!jitType_); +void ConcreteModuleTypeBuilder::setIterableModuleKind(IterableModuleKind kind) { iterableModuleKind_ = kind; } IterableModuleKind ConcreteModuleType::getIterableModuleKind() const { - TORCH_INTERNAL_ASSERT(jitType_); - return iterableModuleKind_; + return data_.iterableModuleKind_; } -void ConcreteModuleType::setPoisoned() { - TORCH_INTERNAL_ASSERT(!jitType_) +void ConcreteModuleTypeBuilder::setPoisoned() { isPoisoned_ = true; } -void ConcreteModuleType::addJitType(ClassTypePtr type) { - TORCH_INTERNAL_ASSERT(!jitType_) - jitType_ = std::move(type); -} - -void ConcreteModuleType::addPyClass(py::object pyClass) { - TORCH_INTERNAL_ASSERT(!jitType_); - pyClass_ = std::move(pyClass); -} - -void ConcreteModuleType::addConstant(std::string name, py::object value) { - TORCH_INTERNAL_ASSERT(!jitType_); +void ConcreteModuleTypeBuilder::addConstant(std::string name, py::object value) { constants_.emplace(std::move(name), std::move(value)); } -void ConcreteModuleType::addAttribute( +void ConcreteModuleTypeBuilder::addAttribute( std::string name, TypePtr type, bool isParameter) { TORCH_INTERNAL_ASSERT(type); - TORCH_INTERNAL_ASSERT(!jitType_); // Function attributes should be handled separately TORCH_INTERNAL_ASSERT(type->cast() == nullptr); attributes_.emplace( - std::move(name), Attribute(unshapedType(type), isParameter)); + std::move(name), + ConcreteModuleTypeBuilder::Attribute(unshapedType(type), isParameter)); } -void ConcreteModuleType::addFunctionAttribute( +void ConcreteModuleTypeBuilder::addFunctionAttribute( std::string name, const TypePtr& type, py::object pyFunction) { TORCH_INTERNAL_ASSERT(type); - TORCH_INTERNAL_ASSERT(!jitType_); functionAttributes_.emplace( std::move(name), - FunctionAttribute{type->expect(), std::move(pyFunction)}); + ConcreteModuleTypeBuilder::FunctionAttribute{type->expect(), + std::move(pyFunction)}); } -void ConcreteModuleType::addModule( +void ConcreteModuleTypeBuilder::addModule( std::string name, std::shared_ptr meta) { - TORCH_INTERNAL_ASSERT(!jitType_); - modules_.emplace_back(ModuleInfo{std::move(name), std::move(meta)}); + modules_.emplace_back( + ConcreteModuleTypeBuilder::ModuleInfo{std::move(name), std::move(meta)}); } -void ConcreteModuleType::addModuleInterface( +void ConcreteModuleTypeBuilder::addModuleInterface( std::string name, const TypePtr& type) { - TORCH_INTERNAL_ASSERT(!jitType_); TORCH_INTERNAL_ASSERT(type->cast() && type->is_module()); - modules_.emplace_back(ModuleInfo{std::move(name), type}); + modules_.emplace_back( + ConcreteModuleTypeBuilder::ModuleInfo{std::move(name), type}); } - -void ConcreteModuleType::addOverload( +void ConcreteModuleTypeBuilder::addOverload( std::string methodName, std::vector overloadedMethodNames) { - TORCH_INTERNAL_ASSERT(!jitType_); overloads_.emplace(std::move(methodName), std::move(overloadedMethodNames)); } -void ConcreteModuleType::addFailedAttribute( +void ConcreteModuleTypeBuilder::addFailedAttribute( std::string name, std::string failureReason) { - TORCH_INTERNAL_ASSERT(!jitType_); failedAttributes_.emplace(std::move(name), std::move(failureReason)); } c10::optional ConcreteModuleType::findConstant( const std::string& name) const { - auto it = constants_.find(name); - if (it != constants_.end()) { + auto it = data_.constants_.find(name); + if (it != data_.constants_.end()) { return it->second.v_; } return c10::nullopt; } void ConcreteModuleType::dump() const { - std::cout << "ConcreteModuleType for: " << py::getattr(pyClass_, "__name__") << "\n"; + std::cout << "ConcreteModuleType for: " << py::getattr(data_.pyClass_, "__name__") << "\n"; std::cout << "Constants: \n"; - for (const auto& pr : constants_) { + for (const auto& pr : data_.constants_) { std::cout << "\t" << pr.first << ": " << pr.second.v_ << "\n"; } std::cout << "\nAttributes: \n"; - for (const auto& pr : attributes_) { + for (const auto& pr : data_.attributes_) { std::cout << "\t" << pr.first << ": " << pr.second.type_->python_str() << "\n"; } std::cout << "\nSubmodules: \n"; - for (const auto& info : modules_) { + for (const auto& info : data_.modules_) { std::cout << "\t" << info.name_ << ": " << info.getJitType()->python_str() << "\n"; } std::cout << "\nOverloads: \n"; - for (const auto& pr : overloads_) { + for (const auto& pr : data_.overloads_) { std::cout << "\t" << pr.first << ": " << pr.second << "\n"; } - std::string isPoisoned = isPoisoned_ ? "true" : "false"; + std::string isPoisoned = data_.isPoisoned_ ? "true" : "false"; std::cout << "isPoisoned: " << isPoisoned << "\n"; if (jitType_) { std::cout << "jit type: " << jitType_->python_str() << "\n"; @@ -215,11 +252,10 @@ void ConcreteModuleType::dump() const { std::unordered_map ConcreteModuleType::getConstantsPy() const { - TORCH_INTERNAL_ASSERT(jitType_); // Convert to a more pybind-friendly representation, so we don't // need to bind ConcreteModuleType::Constant as well. std::unordered_map ret; - for (const auto& pr : constants_) { + for (const auto& pr : data_.constants_) { ret.emplace(pr.first, pr.second.v_); } return ret; @@ -227,11 +263,10 @@ std::unordered_map ConcreteModuleType::getConstantsPy() std::unordered_map> ConcreteModuleType:: getAttributesPy() const { - TORCH_INTERNAL_ASSERT(jitType_); // Convert to a more pybind-friendly representation, so we don't // need to bind ConcreteModuleType::Attribute as well. std::unordered_map> ret; - for (auto& pr : attributes_) { + for (auto& pr : data_.attributes_) { ret.emplace( pr.first, std::pair(pr.second.type_, pr.second.isParam_)); @@ -241,10 +276,9 @@ std::unordered_map> ConcreteModuleType:: std::vector> ConcreteModuleType::getModulesPy() const { - TORCH_INTERNAL_ASSERT(jitType_); std::vector> ret; - for (const ModuleInfo& info: modules_) { + for (const auto& info : data_.modules_) { ret.emplace_back(std::make_pair(info.name_, info.getJitType())); } return ret; diff --git a/torch/csrc/jit/script/concrete_module_type.h b/torch/csrc/jit/script/concrete_module_type.h index d098c402ec8..a5877726378 100644 --- a/torch/csrc/jit/script/concrete_module_type.h +++ b/torch/csrc/jit/script/concrete_module_type.h @@ -9,7 +9,9 @@ namespace torch { namespace jit { namespace script { + enum class IterableModuleKind { NONE, LIST, DICT }; +class ConcreteModuleType; // You can think of an nn.Module as a template that corresponds to a family of // JIT types. The template "arguments" are things like the constant values. @@ -41,21 +43,22 @@ enum class IterableModuleKind { NONE, LIST, DICT }; // ModuleValue::attr calls. This is so we can guarantee that if two Module's // share a JIT type (and thus a ConcreteModuleType), then they behave the same // way when you access attributes on them. -class VISIBILITY_HIDDEN ConcreteModuleType { - public: - // ConcreteModuleType has two states. - // 1. Building: First we build it up, during the ScriptModule conversion - // process - // ... to transition, we freeze the type by associating with a JIT type. - // 2. Querying: Then we ask use it as a source of truth during method - // compilation. - // During this time the ModuleType is effectively const. - // Yes, it could be two different types. Not terribly worth the verbosity. - /** - * Builder methods (jitType must be null) - */ - void addPyClass(py::object pyClass); +// ConcreteModuleType has two phases. +// 1. Creation: First we build it up, during the ScriptModule conversion +// process. This is represented by ConcreteModuleTypeBuilder. +// ...then the converter calls ConcreteModuleTypeBuilder::build(), producing a +// ConcreteModuleType ready for querying. +// 2. Querying: We use ConcreteModuleType as a source of truth for +// ModuleValue::attr calls during method compilation. + +// Represents a concrete type during in the process for construction. We use +// this to decide whether we can share types between modules. +class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder { + public: + explicit ConcreteModuleTypeBuilder(py::object pyClass) { + pyClass_ = std::move(pyClass); + } void addConstant(std::string name, py::object value); void addAttribute(std::string name, TypePtr type, bool isParameter); void addFunctionAttribute( @@ -73,95 +76,20 @@ class VISIBILITY_HIDDEN ConcreteModuleType { std::vector overloadedMethodNames); void addFailedAttribute(std::string name, std::string failureReason); void setIterableModuleKind(IterableModuleKind kind); + + // If a ConcreteModuleType is poisoned, it will never compare equal to any other + // concrete type void setPoisoned(); - /** - * Freezing methods - */ - - // Based on the data in this ConcreteType, create an equivalent JIT type and - // associate this module type with it. - ClassTypePtr createNewTypeFromThis(); - // Associate the provided type with this ConcreteType - void addJitType(ClassTypePtr type); - - /** - * Query methods (jitType must be non-null) - */ - ClassTypePtr getJitType() const; - py::object getPyClass() const; - IterableModuleKind getIterableModuleKind() const; - c10::optional findConstant(const std::string& name) const; - c10::optional> findOverloads( - const std::string& name) const; - c10::optional findFunctionAttribute(const std::string& name) const; - std::shared_ptr findSubmoduleConcreteType( - const std::string& name) const; - c10::optional findFailedAttribute(const std::string& name) const; - - // These getters are only here to return things as types that can be - // automatically converted by pybind. - std::unordered_map getConstantsPy() const; - std::unordered_map> getAttributesPy() - const; - std::vector> getModulesPy() const; + std::shared_ptr build() const { + return std::make_shared(*this); + } // This determines whether two modules can share a type. The container structs // used by ConcreteModuleType have been defined such that operator== // implements a meaningful comparison in that context. - friend bool operator==( - const ConcreteModuleType& lhs, - const ConcreteModuleType& rhs) { - if (lhs.jitType_ == rhs.jitType_) { - // If the computed types are the same, these modules can (obviously) share - // a type. - return true; - } + bool equals(const ConcreteModuleTypeBuilder& other) const; - if (lhs.isPoisoned_ || rhs.isPoisoned_) { - return false; - } - - // clang-format off - // These are vaguely ordered so that cheap, discriminating checks happen first. - bool equal = - lhs.pyClass_.is(rhs.pyClass_) && - lhs.iterableModuleKind_ == rhs.iterableModuleKind_ && - lhs.constants_ == rhs.constants_ && - lhs.attributes_ == rhs.attributes_ && - lhs.overloads_ == rhs.overloads_ && - lhs.functionAttributes_ == rhs.functionAttributes_; - // clang-format on - if (!equal) { - return false; - } - - // We store modules in order of insertion (to make compilation - // deterministic). However, for the purposes of equality, insertion order - // should not matter, so sort them by name. - // We put this check last because it involves the most work. - auto lhsSorted = lhs.modules_; - std::sort( - lhsSorted.begin(), - lhsSorted.end(), - [](const ModuleInfo& a, const ModuleInfo& b) { - return a.name_ < b.name_; - }); - - auto rhsSorted = rhs.modules_; - std::sort( - rhsSorted.begin(), - rhsSorted.end(), - [](const ModuleInfo& a, const ModuleInfo& b) { - return a.name_ < b.name_; - }); - - return lhsSorted == rhsSorted; - } - - void dump() const; - - private: struct Constant { /* implicit */ Constant(py::object v) : v_(std::move(v)) {} friend bool operator==(const Constant& lhs, const Constant& rhs) { @@ -206,28 +134,18 @@ class VISIBILITY_HIDDEN ConcreteModuleType { ModuleInfo(std::string name, const TypePtr& type) : name_(std::move(name)), meta_(nullptr), type_(type) {} - friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs) { - if (lhs.meta_ != nullptr && rhs.meta_ != nullptr) { - return *(lhs.meta_) == *(rhs.meta_); - } else if (lhs.type_ != nullptr && rhs.type_ != nullptr) { - return *(lhs.type_) == *(rhs.type_); - } else { - return false; - } - } - - TypePtr getJitType() const { - return meta_ == nullptr? type_ : meta_->getJitType(); - } + TypePtr getJitType() const; std::string name_; + friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs); // Module Info contains either an ConcreateModuleType or a type (which is // a Module Interface), these two are union relationship. std::shared_ptr meta_; TypePtr type_; - }; + private: + ClassTypePtr createTypeFromThis() const; // If true, this type will never compare equally to anything else. This is // used if we want to ensure that this type is not shared (for example, if it // came from a traced module) @@ -256,11 +174,54 @@ class VISIBILITY_HIDDEN ConcreteModuleType { // The original `nn.Module` class that we derived this ScriptModule from. py::object pyClass_; - // The JIT type derived from this ConcreteModuleType. - ClassTypePtr jitType_ = nullptr; // NOTE: If you ever add any more state to this struct, you need to make sure - // operator== still makes sense! The only field that can be excluded from it - // is `jitType_`. + // operator== still makes sense! + friend ConcreteModuleType; +}; + +// Represents a finalized concrete type, used to service ModuleValue::attr calls +// during method compilation. +class VISIBILITY_HIDDEN ConcreteModuleType { + public: + explicit ConcreteModuleType(ConcreteModuleTypeBuilder data); + + ClassTypePtr getJitType() const; + py::object getPyClass() const; + IterableModuleKind getIterableModuleKind() const; + c10::optional findConstant(const std::string& name) const; + c10::optional> findOverloads( + const std::string& name) const; + c10::optional findFunctionAttribute(const std::string& name) const; + std::shared_ptr findSubmoduleConcreteType( + const std::string& name) const; + c10::optional findFailedAttribute(const std::string& name) const; + + // These getters are only here to return things as types that can be + // automatically converted by pybind. + std::unordered_map getConstantsPy() const; + std::unordered_map> getAttributesPy() + const; + std::vector> getModulesPy() const; + + bool equals(const ConcreteModuleType& other) const { + if (jitType_ == other.jitType_) { + // If the computed types are the same, these modules can (obviously) share + // a type. + return true; + } + + return data_.equals(other.data_); + } + bool equals(const ConcreteModuleTypeBuilder& other) const { + return data_.equals(other); + } + + void dump() const; + + private: + // The JIT type derived from this ConcreteModuleType. + ConcreteModuleTypeBuilder data_; + ClassTypePtr jitType_; }; } // namespace script diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index beb7e798835..f495fd27719 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -1052,42 +1052,51 @@ void initJitScriptBindings(PyObject* module) { return Module(get_python_cu(), type); }); + py::class_>( + m, "ConcreteModuleTypeBuilder") + .def(py::init()) + .def("add_constant", &ConcreteModuleTypeBuilder::addConstant) + .def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute) + .def( + "add_function_attribute", + &ConcreteModuleTypeBuilder::addFunctionAttribute) + .def("add_module", &ConcreteModuleTypeBuilder::addModule) + .def("add_module_interface", &ConcreteModuleTypeBuilder::addModuleInterface) + .def("add_overload", &ConcreteModuleTypeBuilder::addOverload) + .def("set_poisoned", &ConcreteModuleTypeBuilder::setPoisoned) + .def("add_failed_attribute", &ConcreteModuleTypeBuilder::addFailedAttribute) + .def( + "set_module_dict", + [](ConcreteModuleTypeBuilder& self) { + self.setIterableModuleKind(IterableModuleKind::DICT); + }) + .def("build", &ConcreteModuleTypeBuilder::build) + .def( + "equals", + [](const ConcreteModuleTypeBuilder& self, + const ConcreteModuleTypeBuilder& other) { return self.equals(other); }) + .def("set_module_list", [](ConcreteModuleTypeBuilder& self) { + self.setIterableModuleKind(IterableModuleKind::LIST); + }); + py::class_>( m, "ConcreteModuleType") - .def(py::init<>()) .def_property_readonly("py_class", &ConcreteModuleType::getPyClass) .def_property_readonly("jit_type", &ConcreteModuleType::getJitType) .def("get_constants", &ConcreteModuleType::getConstantsPy) .def("get_attributes", &ConcreteModuleType::getAttributesPy) .def("get_modules", &ConcreteModuleType::getModulesPy) - .def("add_constant", &ConcreteModuleType::addConstant) - .def("add_attribute", &ConcreteModuleType::addAttribute) - .def("add_function_attribute", &ConcreteModuleType::addFunctionAttribute) - .def("add_module", &ConcreteModuleType::addModule) - .def("add_module_interface", &ConcreteModuleType::addModuleInterface) - .def("add_pyclass", &ConcreteModuleType::addPyClass) - .def("add_overload", &ConcreteModuleType::addOverload) - .def("add_jit_type", &ConcreteModuleType::addJitType) - .def("set_poisoned", &ConcreteModuleType::setPoisoned) - .def( - "set_module_dict", - [](ConcreteModuleType& self) { - self.setIterableModuleKind(IterableModuleKind::DICT); - }) - .def( - "set_module_list", - [](ConcreteModuleType& self) { - self.setIterableModuleKind(IterableModuleKind::LIST); - }) - .def( - "create_new_type_from_this", - &ConcreteModuleType::createNewTypeFromThis) - .def("add_failed_attribute", &ConcreteModuleType::addFailedAttribute) .def("dump", &ConcreteModuleType::dump) .def( "equals", [](const ConcreteModuleType& self, const ConcreteModuleType& other) { - return self == other; + return self.equals(other); + }) + .def( + "equals", + [](const ConcreteModuleType& self, + const ConcreteModuleTypeBuilder& other) { + return self.equals(other); }) .def( "_create_methods", diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 4b1b6b5bc87..277dab1d5e1 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -59,18 +59,17 @@ def _get_valid_constant(attr, v): 3. a list or tuple of (2) """.format(type(v).__name__, attr, constants))) -def infer_raw_concrete_type(nn_module): +def infer_concrete_type_builder(nn_module): """ - Build a ConcreteModuleType from an nn.Module. This ConcreteModuleType - doesn't have a JIT type associated with it yet, it must be filled in - by the caller. + Build a ConcreteModuleTypeBuilder from an nn.Module. This + ConcreteModuleType doesn't have a JIT type associated with it yet, it + must be filled in by the caller. """ - concrete_type = torch._C.ConcreteModuleType() - concrete_type.add_pyclass(type(nn_module)) + concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module)) if isinstance(nn_module, (torch.nn.ModuleDict)): - concrete_type.set_module_dict() + concrete_type_builder.set_module_dict() if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)): - concrete_type.set_module_list() + concrete_type_builder.set_module_list() class_annotations = getattr(nn_module, '__annotations__', {}) @@ -96,7 +95,7 @@ def infer_raw_concrete_type(nn_module): continue assert isinstance(item, torch.Tensor) attr_type = infer_type(name, item) - concrete_type.add_attribute(name, attr_type, True) + concrete_type_builder.add_attribute(name, attr_type, True) added_names.add(name) for name, item in nn_module._buffers.items(): @@ -109,18 +108,18 @@ def infer_raw_concrete_type(nn_module): continue assert isinstance(item, torch.Tensor) attr_type = infer_type(name, item) - concrete_type.add_attribute(name, attr_type, False) + concrete_type_builder.add_attribute(name, attr_type, False) added_names.add(name) for name, item in nn_module._modules.items(): attr_type = infer_type(name, item) if attr_type is not None: # if the type can be inferred, it should be a module interface type - concrete_type.add_module_interface(name, attr_type) + concrete_type_builder.add_module_interface(name, attr_type) else: # otherwise we get the concrete module type for item and add it to concrete_type sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item) - concrete_type.add_module(name, sub_concrete_type) + concrete_type_builder.add_module(name, sub_concrete_type) added_names.add(name) @@ -146,7 +145,7 @@ def infer_raw_concrete_type(nn_module): "Consider removing it.".format(name)) continue value = getattr(nn_module, name) - concrete_type.add_constant(name, _get_valid_constant(name, value)) + concrete_type_builder.add_constant(name, _get_valid_constant(name, value)) added_names.add(name) # populate overloads @@ -154,7 +153,7 @@ def infer_raw_concrete_type(nn_module): # update with any annotated overloads overloads.update(get_overload_name_mapping(get_overload_annotations(nn_module))) for name, overloaded_names in overloads.items(): - concrete_type.add_overload(name, overloaded_names) + concrete_type_builder.add_overload(name, overloaded_names) for name, value in nn_module.__dict__.items(): @@ -171,7 +170,7 @@ def infer_raw_concrete_type(nn_module): if inspect.isfunction(value): try: scripted_fn = torch.jit.script(value) - concrete_type.add_function_attribute( + concrete_type_builder.add_function_attribute( name, torch._C._jit_try_infer_type(scripted_fn), value) @@ -182,14 +181,14 @@ def infer_raw_concrete_type(nn_module): hint = ("(This function exists as an attribute on the Python module, " "but we failed to compile it to a TorchScript function. " "\nThe error stack is reproduced here:\n{}").format(e) - concrete_type.add_failed_attribute(name, hint) + concrete_type_builder.add_failed_attribute(name, hint) pass continue # Handle Script function attributes if isinstance(value, torch.jit.ScriptFunction): - concrete_type.add_function_attribute( + concrete_type_builder.add_function_attribute( name, torch._C._jit_try_infer_type(value), value) @@ -198,14 +197,14 @@ def infer_raw_concrete_type(nn_module): # If we got here, this is a regular "data" attribute, Add it to the concrete type attr_type = infer_type(name, value) if attr_type is not None: - concrete_type.add_attribute(name, attr_type, False) + concrete_type_builder.add_attribute(name, attr_type, False) else: # TODO: could add more detail here. For example, what the user should do # when the pytype is `list` or `NoneType` hint = ("(This attribute exists on the Python module, " "but we failed to convert Python type: '{}' " "to a TorchScript type.)").format(type(value).__name__) - concrete_type.add_failed_attribute(name, hint) + concrete_type_builder.add_failed_attribute(name, hint) # Add @property methods as failed attributes, to give a better error message. for name, value in type(nn_module).__dict__.items(): @@ -213,9 +212,9 @@ def infer_raw_concrete_type(nn_module): hint = ("\n(This attribute exists on the Python module, but it's an @property " "method. @property methods are not yet supported in TorchScript. " "Please file a feature request on Github)") - concrete_type.add_failed_attribute(name, hint) + concrete_type_builder.add_failed_attribute(name, hint) - return concrete_type + return concrete_type_builder class ConcreteTypeStore(object): def __init__(self): @@ -234,7 +233,7 @@ class ConcreteTypeStore(object): hasattr(nn_module, "_concrete_type"): return nn_module._concrete_type - raw_concrete_type = infer_raw_concrete_type(nn_module) + concrete_type_builder = infer_concrete_type_builder(nn_module) nn_module_type = type(nn_module) if nn_module_type not in self.type_store: @@ -243,13 +242,13 @@ class ConcreteTypeStore(object): # Search the type store for an already-available JIT type known_types = self.type_store[nn_module_type] for known_type in known_types: - if raw_concrete_type.equals(known_type): + if known_type.equals(concrete_type_builder): return known_type # We didn't find anything; generate a new JIT type from this concrete type - raw_concrete_type.create_new_type_from_this() - self.type_store[nn_module_type].append(raw_concrete_type) - return raw_concrete_type + concrete_type = concrete_type_builder.build() + self.type_store[nn_module_type].append(concrete_type) + return concrete_type concrete_type_store = ConcreteTypeStore() @@ -272,11 +271,12 @@ def create_script_module_for_tracing(nn_module, stubs): stubs: ScriptMethodStubs to compile as part of the conversion process. """ check_module_initialized(nn_module) - # Get a ConcreteType without a JIT type. We will generate one ourselves - # and fill it in. - concrete_type = infer_raw_concrete_type(nn_module) - concrete_type.set_poisoned() - concrete_type.create_new_type_from_this() + # Get a concrete type directly, without trying to re-use an existing JIT + # type from the type store. + concrete_type_builder = infer_concrete_type_builder(nn_module) + concrete_type_builder.set_poisoned() + concrete_type = concrete_type_builder.build() + cpp_module = torch._C._create_module_with_type(concrete_type.jit_type) return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)