Add register_module alias to nn.Module (#65174)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/60397. I'm not sure how aliases are supposed to be implemented, but this is the most basic/direct way, IMO. As a side-effect, this implementation results in a "duplicate" doc entry, inheriting the one from `add_module`:

![monkey-patch](https://user-images.githubusercontent.com/7027770/133693137-8408d8e7-1f4f-436b-b176-57dda9bc3a32.png)

An alternative implementation could be:

```python
def register_module(self, name: str, module: Optional['Module']) -> None:
    r"""Alias for :func:`add_module`."""
    self.add_module(name, module)
```

which results in this documentation:

![image](https://user-images.githubusercontent.com/7027770/133693249-d969a71a-be44-489d-9633-4f38b44ab887.png)

Questions:
1. Should I replicate the tests? There are two for `add_module`: [test_add_module_raises_error_if_attr_exists](873255c6d9/test/test_nn.py (L1420-L1434)) and [test_add_module](873255c6d9/test/test_nn.py (L1837-L1855)).
2. This PR only adds `register_module` to `nn.Module`. There is an `add_module` in [`_RemoteModule`](https://github.com/pytorch/pytorch/blob/master/torch/distributed/nn/api/remote_module.py#L311-L312), which raises `NotSupported`, and there is another one in [`ConcreteModuleTypeBuilder`](873255c6d9/torch/_C/__init__.pyi.in (L468)), which means something else, I think. Should I do anything about them?

cc ngimel SsnL

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65174

Reviewed By: soulitzer

Differential Revision: D31089717

Pulled By: jbschlosser

fbshipit-source-id: abd8d14a434fd8c7efa0bd8c242df56da33491e9
This commit is contained in:
Rodrigo Berriel 2021-09-22 16:36:00 -07:00 committed by Facebook GitHub Bot
parent 31584d065e
commit b80bdcc73b
3 changed files with 39 additions and 30 deletions

View File

@ -1418,20 +1418,22 @@ class TestNN(NNTestCase):
self.assertEqual(m.param_name, param3)
def test_add_module_raises_error_if_attr_exists(self):
methods_to_test = ['add_module', 'register_module']
for fn in methods_to_test:
m = nn.Module()
m.attribute_name = 5
with self.assertRaises(KeyError):
m.add_module('attribute_name', nn.Module())
getattr(m, fn)('attribute_name', nn.Module())
del m.attribute_name
m.register_buffer('attribute_name', torch.rand(5))
with self.assertRaises(KeyError):
m.add_module('attribute_name', nn.Module())
getattr(m, fn)('attribute_name', nn.Module())
del m.attribute_name
m.register_parameter('attribute_name', nn.Parameter())
with self.assertRaises(KeyError):
m.add_module('attribute_name', nn.Module())
getattr(m, fn)('attribute_name', nn.Module())
@unittest.expectedFailure
def test_getattr_with_property(self):
@ -1835,24 +1837,26 @@ class TestNN(NNTestCase):
check()
def test_add_module(self):
methods_to_test = ['add_module', 'register_module']
for fn in methods_to_test:
l = nn.Linear(10, 20)
net = nn.Module()
net.l = l
net.l2 = l
net.add_module('empty', None)
getattr(net, fn)('empty', None)
self.assertEqual(net.l, l)
self.assertEqual(net.l2, l)
self.assertEqual(net.empty, None)
net.add_module('l3', l)
getattr(net, fn)('l3', l)
self.assertEqual(net.l3, l)
l3 = nn.Linear(20, 10)
net.add_module('l', l3)
getattr(net, fn)('l', l3)
self.assertEqual(net.l, l3)
self.assertRaises(TypeError, lambda: net.add_module('x', 'non-module'))
self.assertRaises(TypeError, lambda: getattr(net, fn)('x', 'non-module'))
self.assertRaisesRegex(TypeError, 'module name should be a string. Got int',
lambda: net.add_module(1, l))
lambda: getattr(net, fn)(1, l))
self.assertRaisesRegex(TypeError, 'module name should be a string. Got NoneType',
lambda: net.add_module(None, l))
lambda: getattr(net, fn)(None, l))
def test_module_to_argparse(self):
net = nn.Sequential(nn.Linear(3, 3))

View File

@ -874,6 +874,7 @@ if _enabled:
"forward",
"register_buffer",
"register_parameter",
"register_module",
"add_module",
"_apply",
"apply",

View File

@ -387,6 +387,10 @@ class Module:
raise KeyError("module name can't be empty string \"\"")
self._modules[name] = module
def register_module(self, name: str, module: Optional['Module']) -> None:
r"""Alias for :func:`add_module`."""
self.add_module(name, module)
def get_submodule(self, target: str) -> "Module":
"""
Returns the submodule given by ``target`` if it exists,