[nn] Add remove_duplicate flag to named_buffers (#674) (#85903)

Summary:
X-link: https://github.com/pytorch/torchrec/pull/674

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84984

this is to allow named_buffers to return the same buffer objects with different names multiple times, needed by internal use cases
ghstack-source-id: 168589597

Test Plan:
python test/test_nn.py -k test_buffers_and_named_buffers

Imported from OSS

Reviewed By: albanD

Differential Revision: D39493161

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85903
Approved by: https://github.com/albanD
This commit is contained in:
Jerry Zhang 2022-10-11 18:49:09 +00:00 committed by PyTorch MergeBot
parent 693250ac85
commit c12f829cce
3 changed files with 26 additions and 8 deletions

View File

@ -897,6 +897,19 @@ class TestNN(NNTestCase):
names(s.named_buffers()),
['0.dummy_buf', '0.l1.layer_dummy_buf'])
# test remove_duplicate
class M(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.empty(3, 5))
self.register_buffer("buffer2", self.buffer1)
m = M()
self.assertEqual(names(m.named_buffers()),
["buffer1"])
self.assertEqual(names(m.named_buffers(remove_duplicate=False)),
["buffer1", "buffer2"])
def test_call_supports_python_dict_output(self):
class Net(nn.Module):
def __init__(self):

View File

@ -390,7 +390,10 @@ class _RemoteModule(nn.Module):
_raise_not_supported(self.buffers.__name__)
def named_buffers( # type: ignore[return]
self, prefix: str = "", recurse: bool = True
self,
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True
) -> Iterator[Tuple[str, Tensor]]:
_raise_not_supported(self.named_buffers.__name__)

View File

@ -1668,16 +1668,17 @@ class Module:
self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys)
def _named_members(self, get_members_fn, prefix='', recurse=True):
def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplicate: bool = True):
r"""Helper method for yielding various names + members of modules."""
memo = set()
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)]
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
if remove_duplicate:
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield name, v
@ -1756,15 +1757,16 @@ class Module:
for _, buf in self.named_buffers(recurse=recurse):
yield buf
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]:
r"""Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
Args:
prefix (str): prefix to prepend to all buffer names.
recurse (bool): if True, then yields buffers of this module
recurse (bool, optional): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module.
are direct members of this module. Defaults to True.
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
Yields:
(str, torch.Tensor): Tuple containing the name and buffer
@ -1779,7 +1781,7 @@ class Module:
"""
gen = self._named_members(
lambda module: module._buffers.items(),
prefix=prefix, recurse=recurse)
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
for elem in gen:
yield elem