[nn] Add remove_duplicate flag to named_buffers (#674) (#85903)

Summary:
X-link: https://github.com/pytorch/torchrec/pull/674

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84984

this is to allow named_buffers to return the same buffer objects with different names multiple times, needed by internal use cases
ghstack-source-id: 168589597

Test Plan:
python test/test_nn.py -k test_buffers_and_named_buffers

Imported from OSS

Reviewed By: albanD

Differential Revision: D39493161

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85903
Approved by: https://github.com/albanD
This commit is contained in:
Jerry Zhang 2022-10-11 18:49:09 +00:00 committed by PyTorch MergeBot
parent 693250ac85
commit c12f829cce
3 changed files with 26 additions and 8 deletions

View File

@ -897,6 +897,19 @@ class TestNN(NNTestCase):
names(s.named_buffers()), names(s.named_buffers()),
['0.dummy_buf', '0.l1.layer_dummy_buf']) ['0.dummy_buf', '0.l1.layer_dummy_buf'])
# test remove_duplicate
class M(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.empty(3, 5))
self.register_buffer("buffer2", self.buffer1)
m = M()
self.assertEqual(names(m.named_buffers()),
["buffer1"])
self.assertEqual(names(m.named_buffers(remove_duplicate=False)),
["buffer1", "buffer2"])
def test_call_supports_python_dict_output(self): def test_call_supports_python_dict_output(self):
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):

View File

@ -390,7 +390,10 @@ class _RemoteModule(nn.Module):
_raise_not_supported(self.buffers.__name__) _raise_not_supported(self.buffers.__name__)
def named_buffers( # type: ignore[return] def named_buffers( # type: ignore[return]
self, prefix: str = "", recurse: bool = True self,
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True
) -> Iterator[Tuple[str, Tensor]]: ) -> Iterator[Tuple[str, Tensor]]:
_raise_not_supported(self.named_buffers.__name__) _raise_not_supported(self.named_buffers.__name__)

View File

@ -1668,15 +1668,16 @@ class Module:
self.__class__.__name__, "\n\t".join(error_msgs))) self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys) return _IncompatibleKeys(missing_keys, unexpected_keys)
def _named_members(self, get_members_fn, prefix='', recurse=True): def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplicate: bool = True):
r"""Helper method for yielding various names + members of modules.""" r"""Helper method for yielding various names + members of modules."""
memo = set() memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)]
for module_prefix, module in modules: for module_prefix, module in modules:
members = get_members_fn(module) members = get_members_fn(module)
for k, v in members: for k, v in members:
if v is None or v in memo: if v is None or v in memo:
continue continue
if remove_duplicate:
memo.add(v) memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k name = module_prefix + ('.' if module_prefix else '') + k
yield name, v yield name, v
@ -1756,15 +1757,16 @@ class Module:
for _, buf in self.named_buffers(recurse=recurse): for _, buf in self.named_buffers(recurse=recurse):
yield buf yield buf
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]: def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]:
r"""Returns an iterator over module buffers, yielding both the r"""Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself. name of the buffer as well as the buffer itself.
Args: Args:
prefix (str): prefix to prepend to all buffer names. prefix (str): prefix to prepend to all buffer names.
recurse (bool): if True, then yields buffers of this module recurse (bool, optional): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that and all submodules. Otherwise, yields only buffers that
are direct members of this module. are direct members of this module. Defaults to True.
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
Yields: Yields:
(str, torch.Tensor): Tuple containing the name and buffer (str, torch.Tensor): Tuple containing the name and buffer
@ -1779,7 +1781,7 @@ class Module:
""" """
gen = self._named_members( gen = self._named_members(
lambda module: module._buffers.items(), lambda module: module._buffers.items(),
prefix=prefix, recurse=recurse) prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
for elem in gen: for elem in gen:
yield elem yield elem