Fix type sharing on loaded ScriptModules (#29826)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29826

After save/load, we lose concrete type information. So if you tried to
script something that contained a loaded ScriptModule as a submodule,
the following sequence happened:
1. During ConcreteType inference, the loaded submodule got a new
inferred type.
2. But it already has a type! So there was a type mismatch.

To fix this, we should generate a ConcreteType directly from the loaded
submodule type (similar to what we do for interfaces). This makes sense
too--the ConcreteModuleType should be empty, since all the "sugaredness"
was stripped out during the save/load process.

Test Plan: Imported from OSS

Differential Revision: D18575009

Pulled By: suo

fbshipit-source-id: 4d329b7e9b7e7624f459e50092e35ab0ab813791
This commit is contained in:
Michael Suo 2019-11-20 01:11:11 -08:00 committed by Facebook Github Bot
parent 558a777615
commit 93db2b86d1
5 changed files with 34 additions and 8 deletions

View File

@ -629,3 +629,24 @@ class TestRecursiveScript(JitTestCase):
dummies = nn.ModuleList([dummy])
model = Model(dummies)
self.checkModule(model, (torch.rand(5, 5), ))
def test_script_loaded_module(self):
"""
Test that we can hold a loaded ScriptModule as a submodule.
"""
class Dummy(nn.Module):
def forward(self, x):
return x
dummy = torch.jit.script(Dummy())
dummy = self.getExportImportCopy(dummy)
class ContainsLoaded(torch.nn.Module):
def __init__(self):
super(ContainsLoaded, self).__init__()
self.encoder = dummy
def forward(self, input):
return self.encoder(input)
self.checkModule(ContainsLoaded(), (torch.rand(2, 3), ))

View File

@ -35,11 +35,16 @@ ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
return cls;
}
std::shared_ptr<ConcreteModuleType> ConcreteModuleType::fromInterface(
InterfaceTypePtr interface) {
TORCH_INTERNAL_ASSERT(interface->is_module());
std::shared_ptr<ConcreteModuleType> ConcreteModuleType::fromJitType(
TypePtr type) {
// `type` should either be a module interface or a class type
if (auto interface = type->cast<InterfaceType>()){
TORCH_INTERNAL_ASSERT(interface->is_module());
} else {
TORCH_INTERNAL_ASSERT(type->cast<ClassType>());
}
auto ret = std::shared_ptr<ConcreteModuleType>(new ConcreteModuleType());
ret->jitType_ = std::move(interface);
ret->jitType_ = std::move(type);
return ret;
}

View File

@ -178,8 +178,7 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
public:
explicit ConcreteModuleType(ConcreteModuleTypeBuilder data);
static std::shared_ptr<ConcreteModuleType> fromInterface(
InterfaceTypePtr interface);
static std::shared_ptr<ConcreteModuleType> fromJitType(TypePtr type);
TypePtr getJitType() const;
py::object getPyClass() const;

View File

@ -1082,7 +1082,7 @@ void initJitScriptBindings(PyObject* module) {
m, "ConcreteModuleType")
.def_property_readonly("py_class", &ConcreteModuleType::getPyClass)
.def_property_readonly("jit_type", &ConcreteModuleType::getJitType)
.def_static("from_interface", &ConcreteModuleType::fromInterface)
.def_static("from_jit_type", &ConcreteModuleType::fromJitType)
.def("get_constants", &ConcreteModuleType::getConstantsPy)
.def("get_attributes", &ConcreteModuleType::getAttributesPy)
.def("get_modules", &ConcreteModuleType::getModulesPy)

View File

@ -115,7 +115,7 @@ def infer_concrete_type_builder(nn_module):
attr_type = infer_type(name, item)
if attr_type is not None:
# if the type can be inferred, it should be a module interface type
sub_concrete_type = torch._C.ConcreteModuleType.from_interface(attr_type)
sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(attr_type)
else:
# otherwise we get the concrete module type for item and add it to concrete_type
sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
@ -561,6 +561,7 @@ def wrap_cpp_module(cpp_module):
def init_fn(script_module):
for name, cpp_module in torch._C.ModuleDict(script_module._c).items():
setattr(script_module, name, wrap_cpp_module(cpp_module))
script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(script_module._c._type())
return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
def compile_unbound_method(concrete_type, fn):