diff --git a/test/test_nn.py b/test/test_nn.py index 464c44b753f..1439c1a0408 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1600,6 +1600,19 @@ class TestNN(NNTestCase): self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4)) self.assertEqual(nn.Sequential(l1).append(l2).append(l4), nn.Sequential(l1, l2, l4)) + def test_Sequential_extend(self): + l1 = nn.Linear(10, 20) + l2 = nn.Linear(20, 30) + l3 = nn.Linear(30, 40) + l4 = nn.Linear(40, 50) + n1 = nn.Sequential(l1, l2) + n2 = nn.Sequential(l3, l4) + n3 = nn.Sequential(l1, l2) + for l in n2: + n1.append(l) + n3.extend(n2) + self.assertEqual(n3, n1) + def test_ModuleList(self): modules = [nn.ReLU(), nn.Linear(5, 5)] module_list = nn.ModuleList(modules) diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index bd49c235aa5..96a823f888b 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -163,6 +163,11 @@ class Sequential(Module): self.add_module(str(len(self)), module) return self + def extend(self, sequential) -> 'Sequential': + for layer in sequential: + self.append(layer) + return self + class ModuleList(Module): r"""Holds submodules in a list.