diff --git a/test/jit/test_concrete_module_type.py b/test/jit/test_concrete_module_type.py new file mode 100644 index 00000000000..7a7503f5721 --- /dev/null +++ b/test/jit/test_concrete_module_type.py @@ -0,0 +1,53 @@ +# Owner(s): ["oncall: jit"] + +import unittest + +import torch +from torch.testing._internal.common_utils import raise_on_run_directly + + +class TestConcreteModuleTypeFindSubmodule(unittest.TestCase): + def test_error_message_includes_submodule_name(self): + class ChildModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 3) + + def forward(self, x): + return self.linear(x) + + class ParentModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.existing_child = ChildModule() + + def forward(self, x): + return self.existing_child(x) + + module = ParentModule() + scripted_module = torch.jit.script(module) + + self.assertIsNotNone(scripted_module.existing_child) + + # Now try to trigger the error by accessing a non-existent submodule + # through the internal ConcreteModuleType mechanism. This happens + # when the TorchScript compiler tries to resolve submodule references. + class BrokenModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.parent = ParentModule() + + def forward(self, x): + return self.parent.missing_submodule(x) + + broken_module = BrokenModule() + + with self.assertRaises(RuntimeError) as context: + torch.jit.script(broken_module) + + error_msg = str(context.exception) + self.assertIn("missing_submodule", error_msg.lower()) + + +if __name__ == "__main__": + raise_on_run_directly("test/test_jit.py") diff --git a/torch/csrc/jit/frontend/concrete_module_type.cpp b/torch/csrc/jit/frontend/concrete_module_type.cpp index cfdef51afc3..91d41607f9d 100644 --- a/torch/csrc/jit/frontend/concrete_module_type.cpp +++ b/torch/csrc/jit/frontend/concrete_module_type.cpp @@ -204,7 +204,8 @@ std::shared_ptr ConcreteModuleType:: [&](const ConcreteModuleTypeBuilder::ModuleInfo& info) { return info.name_ == name; }); - TORCH_INTERNAL_ASSERT(it != data_.modules_.end()); + TORCH_INTERNAL_ASSERT( + it != data_.modules_.end(), "Cannot find submodule with name/key ", name); return it->meta_; }