[nn] implement extend method to sequential class (#81179)

Follows #71329

cc @kshitij12345 :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81179
Approved by: https://github.com/albanD
This commit is contained in:
Khushi Agrawal 2022-07-20 05:33:41 +00:00 committed by PyTorch MergeBot
parent 0f164d342f
commit 2c0b11b43b
2 changed files with 18 additions and 0 deletions

View File

@ -1600,6 +1600,19 @@ class TestNN(NNTestCase):
self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4)) self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4))
self.assertEqual(nn.Sequential(l1).append(l2).append(l4), nn.Sequential(l1, l2, l4)) self.assertEqual(nn.Sequential(l1).append(l2).append(l4), nn.Sequential(l1, l2, l4))
def test_Sequential_extend(self):
l1 = nn.Linear(10, 20)
l2 = nn.Linear(20, 30)
l3 = nn.Linear(30, 40)
l4 = nn.Linear(40, 50)
n1 = nn.Sequential(l1, l2)
n2 = nn.Sequential(l3, l4)
n3 = nn.Sequential(l1, l2)
for l in n2:
n1.append(l)
n3.extend(n2)
self.assertEqual(n3, n1)
def test_ModuleList(self): def test_ModuleList(self):
modules = [nn.ReLU(), nn.Linear(5, 5)] modules = [nn.ReLU(), nn.Linear(5, 5)]
module_list = nn.ModuleList(modules) module_list = nn.ModuleList(modules)

View File

@ -163,6 +163,11 @@ class Sequential(Module):
self.add_module(str(len(self)), module) self.add_module(str(len(self)), module)
return self return self
def extend(self, sequential) -> 'Sequential':
for layer in sequential:
self.append(layer)
return self
class ModuleList(Module): class ModuleList(Module):
r"""Holds submodules in a list. r"""Holds submodules in a list.