mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: X-link: https://github.com/pytorch/torchrec/pull/674 Pull Request resolved: https://github.com/pytorch/pytorch/pull/84984 this is to allow named_buffers to return the same buffer objects with different names multiple times, needed by internal use cases ghstack-source-id: 168589597 Test Plan: python test/test_nn.py -k test_buffers_and_named_buffers Imported from OSS Reviewed By: albanD Differential Revision: D39493161 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85903 Approved by: https://github.com/albanD
This commit is contained in:
parent
693250ac85
commit
c12f829cce
|
|
@ -897,6 +897,19 @@ class TestNN(NNTestCase):
|
|||
names(s.named_buffers()),
|
||||
['0.dummy_buf', '0.l1.layer_dummy_buf'])
|
||||
|
||||
# test remove_duplicate
|
||||
class M(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("buffer1", torch.empty(3, 5))
|
||||
self.register_buffer("buffer2", self.buffer1)
|
||||
|
||||
m = M()
|
||||
self.assertEqual(names(m.named_buffers()),
|
||||
["buffer1"])
|
||||
self.assertEqual(names(m.named_buffers(remove_duplicate=False)),
|
||||
["buffer1", "buffer2"])
|
||||
|
||||
def test_call_supports_python_dict_output(self):
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -390,7 +390,10 @@ class _RemoteModule(nn.Module):
|
|||
_raise_not_supported(self.buffers.__name__)
|
||||
|
||||
def named_buffers( # type: ignore[return]
|
||||
self, prefix: str = "", recurse: bool = True
|
||||
self,
|
||||
prefix: str = "",
|
||||
recurse: bool = True,
|
||||
remove_duplicate: bool = True
|
||||
) -> Iterator[Tuple[str, Tensor]]:
|
||||
_raise_not_supported(self.named_buffers.__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -1668,16 +1668,17 @@ class Module:
|
|||
self.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
def _named_members(self, get_members_fn, prefix='', recurse=True):
|
||||
def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplicate: bool = True):
|
||||
r"""Helper method for yielding various names + members of modules."""
|
||||
memo = set()
|
||||
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
|
||||
modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)]
|
||||
for module_prefix, module in modules:
|
||||
members = get_members_fn(module)
|
||||
for k, v in members:
|
||||
if v is None or v in memo:
|
||||
continue
|
||||
memo.add(v)
|
||||
if remove_duplicate:
|
||||
memo.add(v)
|
||||
name = module_prefix + ('.' if module_prefix else '') + k
|
||||
yield name, v
|
||||
|
||||
|
|
@ -1756,15 +1757,16 @@ class Module:
|
|||
for _, buf in self.named_buffers(recurse=recurse):
|
||||
yield buf
|
||||
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
|
||||
def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]:
|
||||
r"""Returns an iterator over module buffers, yielding both the
|
||||
name of the buffer as well as the buffer itself.
|
||||
|
||||
Args:
|
||||
prefix (str): prefix to prepend to all buffer names.
|
||||
recurse (bool): if True, then yields buffers of this module
|
||||
recurse (bool, optional): if True, then yields buffers of this module
|
||||
and all submodules. Otherwise, yields only buffers that
|
||||
are direct members of this module.
|
||||
are direct members of this module. Defaults to True.
|
||||
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
|
||||
|
||||
Yields:
|
||||
(str, torch.Tensor): Tuple containing the name and buffer
|
||||
|
|
@ -1779,7 +1781,7 @@ class Module:
|
|||
"""
|
||||
gen = self._named_members(
|
||||
lambda module: module._buffers.items(),
|
||||
prefix=prefix, recurse=recurse)
|
||||
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
|
||||
for elem in gen:
|
||||
yield elem
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user