mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nn] implement extend method to sequential class (#81179)
Follows #71329 cc @kshitij12345 :) Pull Request resolved: https://github.com/pytorch/pytorch/pull/81179 Approved by: https://github.com/albanD
This commit is contained in:
parent
0f164d342f
commit
2c0b11b43b
|
|
@ -1600,6 +1600,19 @@ class TestNN(NNTestCase):
|
|||
self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4))
|
||||
self.assertEqual(nn.Sequential(l1).append(l2).append(l4), nn.Sequential(l1, l2, l4))
|
||||
|
||||
def test_Sequential_extend(self):
|
||||
l1 = nn.Linear(10, 20)
|
||||
l2 = nn.Linear(20, 30)
|
||||
l3 = nn.Linear(30, 40)
|
||||
l4 = nn.Linear(40, 50)
|
||||
n1 = nn.Sequential(l1, l2)
|
||||
n2 = nn.Sequential(l3, l4)
|
||||
n3 = nn.Sequential(l1, l2)
|
||||
for l in n2:
|
||||
n1.append(l)
|
||||
n3.extend(n2)
|
||||
self.assertEqual(n3, n1)
|
||||
|
||||
def test_ModuleList(self):
|
||||
modules = [nn.ReLU(), nn.Linear(5, 5)]
|
||||
module_list = nn.ModuleList(modules)
|
||||
|
|
|
|||
|
|
@ -163,6 +163,11 @@ class Sequential(Module):
|
|||
self.add_module(str(len(self)), module)
|
||||
return self
|
||||
|
||||
def extend(self, sequential) -> 'Sequential':
|
||||
for layer in sequential:
|
||||
self.append(layer)
|
||||
return self
|
||||
|
||||
|
||||
class ModuleList(Module):
|
||||
r"""Holds submodules in a list.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user