diff --git a/test/test_nn.py b/test/test_nn.py index 92357d9ce15..fa0321a697d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1418,20 +1418,22 @@ class TestNN(NNTestCase): self.assertEqual(m.param_name, param3) def test_add_module_raises_error_if_attr_exists(self): - m = nn.Module() - m.attribute_name = 5 - with self.assertRaises(KeyError): - m.add_module('attribute_name', nn.Module()) + methods_to_test = ['add_module', 'register_module'] + for fn in methods_to_test: + m = nn.Module() + m.attribute_name = 5 + with self.assertRaises(KeyError): + getattr(m, fn)('attribute_name', nn.Module()) - del m.attribute_name - m.register_buffer('attribute_name', torch.rand(5)) - with self.assertRaises(KeyError): - m.add_module('attribute_name', nn.Module()) + del m.attribute_name + m.register_buffer('attribute_name', torch.rand(5)) + with self.assertRaises(KeyError): + getattr(m, fn)('attribute_name', nn.Module()) - del m.attribute_name - m.register_parameter('attribute_name', nn.Parameter()) - with self.assertRaises(KeyError): - m.add_module('attribute_name', nn.Module()) + del m.attribute_name + m.register_parameter('attribute_name', nn.Parameter()) + with self.assertRaises(KeyError): + getattr(m, fn)('attribute_name', nn.Module()) @unittest.expectedFailure def test_getattr_with_property(self): @@ -1835,24 +1837,26 @@ class TestNN(NNTestCase): check() def test_add_module(self): - l = nn.Linear(10, 20) - net = nn.Module() - net.l = l - net.l2 = l - net.add_module('empty', None) - self.assertEqual(net.l, l) - self.assertEqual(net.l2, l) - self.assertEqual(net.empty, None) - net.add_module('l3', l) - self.assertEqual(net.l3, l) - l3 = nn.Linear(20, 10) - net.add_module('l', l3) - self.assertEqual(net.l, l3) - self.assertRaises(TypeError, lambda: net.add_module('x', 'non-module')) - self.assertRaisesRegex(TypeError, 'module name should be a string. Got int', - lambda: net.add_module(1, l)) - self.assertRaisesRegex(TypeError, 'module name should be a string. Got NoneType', - lambda: net.add_module(None, l)) + methods_to_test = ['add_module', 'register_module'] + for fn in methods_to_test: + l = nn.Linear(10, 20) + net = nn.Module() + net.l = l + net.l2 = l + getattr(net, fn)('empty', None) + self.assertEqual(net.l, l) + self.assertEqual(net.l2, l) + self.assertEqual(net.empty, None) + getattr(net, fn)('l3', l) + self.assertEqual(net.l3, l) + l3 = nn.Linear(20, 10) + getattr(net, fn)('l', l3) + self.assertEqual(net.l, l3) + self.assertRaises(TypeError, lambda: getattr(net, fn)('x', 'non-module')) + self.assertRaisesRegex(TypeError, 'module name should be a string. Got int', + lambda: getattr(net, fn)(1, l)) + self.assertRaisesRegex(TypeError, 'module name should be a string. Got NoneType', + lambda: getattr(net, fn)(None, l)) def test_module_to_argparse(self): net = nn.Sequential(nn.Linear(3, 3)) diff --git a/torch/jit/_script.py b/torch/jit/_script.py index acc9e7c44f5..b800871dbb4 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -874,6 +874,7 @@ if _enabled: "forward", "register_buffer", "register_parameter", + "register_module", "add_module", "_apply", "apply", diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 28b220e2403..3d579134e1b 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -387,6 +387,10 @@ class Module: raise KeyError("module name can't be empty string \"\"") self._modules[name] = module + def register_module(self, name: str, module: Optional['Module']) -> None: + r"""Alias for :func:`add_module`.""" + self.add_module(name, module) + def get_submodule(self, target: str) -> "Module": """ Returns the submodule given by ``target`` if it exists,