mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix type sharing on loaded ScriptModules (#29826)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29826 After save/load, we lose concrete type information. So if you tried to script something that contained a loaded ScriptModule as a submodule, the following sequence happened: 1. During ConcreteType inference, the loaded submodule got a new inferred type. 2. But it already has a type! So there was a type mismatch. To fix this, we should generate a ConcreteType directly from the loaded submodule type (similar to what we do for interfaces). This makes sense too--the ConcreteModuleType should be empty, since all the "sugaredness" was stripped out during the save/load process. Test Plan: Imported from OSS Differential Revision: D18575009 Pulled By: suo fbshipit-source-id: 4d329b7e9b7e7624f459e50092e35ab0ab813791
This commit is contained in:
parent
558a777615
commit
93db2b86d1
|
|
@ -629,3 +629,24 @@ class TestRecursiveScript(JitTestCase):
|
|||
dummies = nn.ModuleList([dummy])
|
||||
model = Model(dummies)
|
||||
self.checkModule(model, (torch.rand(5, 5), ))
|
||||
|
||||
def test_script_loaded_module(self):
|
||||
"""
|
||||
Test that we can hold a loaded ScriptModule as a submodule.
|
||||
"""
|
||||
class Dummy(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
dummy = torch.jit.script(Dummy())
|
||||
dummy = self.getExportImportCopy(dummy)
|
||||
|
||||
class ContainsLoaded(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(ContainsLoaded, self).__init__()
|
||||
self.encoder = dummy
|
||||
|
||||
def forward(self, input):
|
||||
return self.encoder(input)
|
||||
|
||||
self.checkModule(ContainsLoaded(), (torch.rand(2, 3), ))
|
||||
|
|
|
|||
|
|
@ -35,11 +35,16 @@ ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
|
|||
return cls;
|
||||
}
|
||||
|
||||
std::shared_ptr<ConcreteModuleType> ConcreteModuleType::fromInterface(
|
||||
InterfaceTypePtr interface) {
|
||||
TORCH_INTERNAL_ASSERT(interface->is_module());
|
||||
std::shared_ptr<ConcreteModuleType> ConcreteModuleType::fromJitType(
|
||||
TypePtr type) {
|
||||
// `type` should either be a module interface or a class type
|
||||
if (auto interface = type->cast<InterfaceType>()){
|
||||
TORCH_INTERNAL_ASSERT(interface->is_module());
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(type->cast<ClassType>());
|
||||
}
|
||||
auto ret = std::shared_ptr<ConcreteModuleType>(new ConcreteModuleType());
|
||||
ret->jitType_ = std::move(interface);
|
||||
ret->jitType_ = std::move(type);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -178,8 +178,7 @@ class VISIBILITY_HIDDEN ConcreteModuleType {
|
|||
public:
|
||||
explicit ConcreteModuleType(ConcreteModuleTypeBuilder data);
|
||||
|
||||
static std::shared_ptr<ConcreteModuleType> fromInterface(
|
||||
InterfaceTypePtr interface);
|
||||
static std::shared_ptr<ConcreteModuleType> fromJitType(TypePtr type);
|
||||
|
||||
TypePtr getJitType() const;
|
||||
py::object getPyClass() const;
|
||||
|
|
|
|||
|
|
@ -1082,7 +1082,7 @@ void initJitScriptBindings(PyObject* module) {
|
|||
m, "ConcreteModuleType")
|
||||
.def_property_readonly("py_class", &ConcreteModuleType::getPyClass)
|
||||
.def_property_readonly("jit_type", &ConcreteModuleType::getJitType)
|
||||
.def_static("from_interface", &ConcreteModuleType::fromInterface)
|
||||
.def_static("from_jit_type", &ConcreteModuleType::fromJitType)
|
||||
.def("get_constants", &ConcreteModuleType::getConstantsPy)
|
||||
.def("get_attributes", &ConcreteModuleType::getAttributesPy)
|
||||
.def("get_modules", &ConcreteModuleType::getModulesPy)
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ def infer_concrete_type_builder(nn_module):
|
|||
attr_type = infer_type(name, item)
|
||||
if attr_type is not None:
|
||||
# if the type can be inferred, it should be a module interface type
|
||||
sub_concrete_type = torch._C.ConcreteModuleType.from_interface(attr_type)
|
||||
sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(attr_type)
|
||||
else:
|
||||
# otherwise we get the concrete module type for item and add it to concrete_type
|
||||
sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
|
||||
|
|
@ -561,6 +561,7 @@ def wrap_cpp_module(cpp_module):
|
|||
def init_fn(script_module):
|
||||
for name, cpp_module in torch._C.ModuleDict(script_module._c).items():
|
||||
setattr(script_module, name, wrap_cpp_module(cpp_module))
|
||||
script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(script_module._c._type())
|
||||
return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
|
||||
|
||||
def compile_unbound_method(concrete_type, fn):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user