diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py index 14a17c7cf12..4e96c8381be 100644 --- a/test/jit/test_recursive_script.py +++ b/test/jit/test_recursive_script.py @@ -629,3 +629,24 @@ class TestRecursiveScript(JitTestCase): dummies = nn.ModuleList([dummy]) model = Model(dummies) self.checkModule(model, (torch.rand(5, 5), )) + + def test_script_loaded_module(self): + """ + Test that we can hold a loaded ScriptModule as a submodule. + """ + class Dummy(nn.Module): + def forward(self, x): + return x + + dummy = torch.jit.script(Dummy()) + dummy = self.getExportImportCopy(dummy) + + class ContainsLoaded(torch.nn.Module): + def __init__(self): + super(ContainsLoaded, self).__init__() + self.encoder = dummy + + def forward(self, input): + return self.encoder(input) + + self.checkModule(ContainsLoaded(), (torch.rand(2, 3), )) diff --git a/torch/csrc/jit/script/concrete_module_type.cpp b/torch/csrc/jit/script/concrete_module_type.cpp index e688a645b8c..ff8e6c08d1c 100644 --- a/torch/csrc/jit/script/concrete_module_type.cpp +++ b/torch/csrc/jit/script/concrete_module_type.cpp @@ -35,11 +35,16 @@ ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const { return cls; } -std::shared_ptr ConcreteModuleType::fromInterface( - InterfaceTypePtr interface) { - TORCH_INTERNAL_ASSERT(interface->is_module()); +std::shared_ptr ConcreteModuleType::fromJitType( + TypePtr type) { + // `type` should either be a module interface or a class type + if (auto interface = type->cast()){ + TORCH_INTERNAL_ASSERT(interface->is_module()); + } else { + TORCH_INTERNAL_ASSERT(type->cast()); + } auto ret = std::shared_ptr(new ConcreteModuleType()); - ret->jitType_ = std::move(interface); + ret->jitType_ = std::move(type); return ret; } diff --git a/torch/csrc/jit/script/concrete_module_type.h b/torch/csrc/jit/script/concrete_module_type.h index 787bfdf3a55..883d1fba9e0 100644 --- a/torch/csrc/jit/script/concrete_module_type.h +++ b/torch/csrc/jit/script/concrete_module_type.h @@ -178,8 +178,7 @@ class VISIBILITY_HIDDEN ConcreteModuleType { public: explicit ConcreteModuleType(ConcreteModuleTypeBuilder data); - static std::shared_ptr fromInterface( - InterfaceTypePtr interface); + static std::shared_ptr fromJitType(TypePtr type); TypePtr getJitType() const; py::object getPyClass() const; diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index ff96c43c83b..ddfddfd4c3e 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -1082,7 +1082,7 @@ void initJitScriptBindings(PyObject* module) { m, "ConcreteModuleType") .def_property_readonly("py_class", &ConcreteModuleType::getPyClass) .def_property_readonly("jit_type", &ConcreteModuleType::getJitType) - .def_static("from_interface", &ConcreteModuleType::fromInterface) + .def_static("from_jit_type", &ConcreteModuleType::fromJitType) .def("get_constants", &ConcreteModuleType::getConstantsPy) .def("get_attributes", &ConcreteModuleType::getAttributesPy) .def("get_modules", &ConcreteModuleType::getModulesPy) diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index ca0cf635552..4677c1bf588 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -115,7 +115,7 @@ def infer_concrete_type_builder(nn_module): attr_type = infer_type(name, item) if attr_type is not None: # if the type can be inferred, it should be a module interface type - sub_concrete_type = torch._C.ConcreteModuleType.from_interface(attr_type) + sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(attr_type) else: # otherwise we get the concrete module type for item and add it to concrete_type sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item) @@ -561,6 +561,7 @@ def wrap_cpp_module(cpp_module): def init_fn(script_module): for name, cpp_module in torch._C.ModuleDict(script_module._c).items(): setattr(script_module, name, wrap_cpp_module(cpp_module)) + script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(script_module._c._type()) return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) def compile_unbound_method(concrete_type, fn):