mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Follow up to https://github.com/pytorch/pytorch/pull/41946/, to suggest enumerating a module as an alternative if a user tries indexing into a modulelist/sequential with a non-integer literal Pull Request resolved: https://github.com/pytorch/pytorch/pull/43361 Reviewed By: mrshenli Differential Revision: D23602388 Pulled By: eellison fbshipit-source-id: 51fa28d5bc45720529b3d45e92d367ee6c9e3316
431 lines
15 KiB
Python
431 lines
15 KiB
Python
import os
|
|
import sys
|
|
|
|
from typing import List
|
|
from collections import OrderedDict
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead.")
|
|
|
|
class TestModuleContainers(JitTestCase):
|
|
def test_sequential_intermediary_types(self):
|
|
class A(torch.nn.Module):
|
|
def __init__(self):
|
|
super(A, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 3
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self):
|
|
super(B, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return {"1": x}
|
|
|
|
class C(torch.nn.Module):
|
|
def __init__(self):
|
|
super(C, self).__init__()
|
|
self.foo = torch.nn.Sequential(A(), B())
|
|
|
|
def forward(self, x):
|
|
return self.foo(x)
|
|
|
|
self.checkModule(C(), (torch.tensor(1),))
|
|
|
|
def test_moduledict(self):
|
|
class Inner(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 10
|
|
|
|
class Inner2(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x * 2
|
|
|
|
class Inner3(torch.nn.Module):
|
|
def forward(self, x):
|
|
return (x - 4) * 3
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
modules = OrderedDict([
|
|
('one', Inner()),
|
|
('two', Inner2()),
|
|
('three', Inner3()),
|
|
])
|
|
self.moduledict = nn.ModuleDict(modules)
|
|
|
|
def forward(self, x, skip_name):
|
|
# type: (Tensor, str)
|
|
names = torch.jit.annotate(List[str], [])
|
|
values = []
|
|
for name in self.moduledict:
|
|
names.append(name)
|
|
|
|
for name, mod in self.moduledict.items():
|
|
if name != skip_name:
|
|
names.append(name)
|
|
x = mod(x)
|
|
values.append(x)
|
|
|
|
for mod in self.moduledict.values():
|
|
x = mod(x)
|
|
values.append(x)
|
|
|
|
for key in self.moduledict.keys():
|
|
names.append(key)
|
|
|
|
return x, names
|
|
|
|
class M2(M):
|
|
def __init__(self):
|
|
super(M2, self).__init__()
|
|
|
|
def forward(self, x, skip_name):
|
|
# type: (Tensor, str)
|
|
names = torch.jit.annotate(List[str], [])
|
|
values = []
|
|
x2 = x
|
|
iter = 0
|
|
for name in self.moduledict:
|
|
names.append(name)
|
|
|
|
for i, (name, mod) in enumerate(self.moduledict.items()):
|
|
iter += i
|
|
if name != skip_name:
|
|
names.append(name)
|
|
x = mod(x)
|
|
values.append(x)
|
|
|
|
for i, mod in enumerate(self.moduledict.values()):
|
|
iter += i
|
|
x = mod(x)
|
|
values.append(x)
|
|
|
|
for i, key in enumerate(self.moduledict.keys()):
|
|
iter += i
|
|
names.append(key)
|
|
|
|
for mod, mod in zip(self.moduledict.values(), self.moduledict.values()):
|
|
iter += i
|
|
x2 = mod(mod(x2))
|
|
|
|
return x, x2, names, iter
|
|
|
|
|
|
for name in ["", "one", "two", "three"]:
|
|
inp = torch.tensor(1)
|
|
self.checkModule(M(), (inp, name))
|
|
self.checkModule(M2(), (inp, name))
|
|
|
|
def test_custom_container_forward(self):
|
|
class Inner(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 10
|
|
|
|
class CustomSequential(nn.Sequential):
|
|
def __init__(self):
|
|
super(CustomSequential, self).__init__(
|
|
nn.ReLU(), Inner())
|
|
|
|
def forward(self, x):
|
|
x = x + 3
|
|
for mod in self:
|
|
x = mod(x)
|
|
return x - 5
|
|
|
|
self.checkModule(CustomSequential(), (torch.tensor(.5),))
|
|
|
|
class CustomModuleList(nn.ModuleList):
|
|
def __init__(self):
|
|
super(CustomModuleList, self).__init__(
|
|
[nn.ReLU(), Inner()])
|
|
|
|
def forward(self, x):
|
|
x = x + 3
|
|
for mod in self:
|
|
x = mod(x)
|
|
return x - 5
|
|
|
|
self.checkModule(CustomModuleList(), (torch.tensor(.5),))
|
|
|
|
class CustomModuleDict(nn.ModuleDict):
|
|
def __init__(self):
|
|
super(CustomModuleDict, self).__init__(
|
|
OrderedDict([
|
|
('one', Inner()),
|
|
('two', nn.ReLU()),
|
|
('three', Inner()),
|
|
]))
|
|
|
|
def forward(self, x):
|
|
x = x + 3
|
|
names = torch.jit.annotate(List[str], [])
|
|
for name, mod in self.items():
|
|
x = mod(x)
|
|
names.append(name)
|
|
return names, x - 5
|
|
|
|
self.checkModule(CustomModuleDict(), (torch.tensor(.5),))
|
|
|
|
def test_script_module_list_sequential(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self, mod_list):
|
|
super(M, self).__init__()
|
|
self.mods = mod_list
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, v):
|
|
for m in self.mods:
|
|
v = m(v)
|
|
return v
|
|
|
|
with torch.jit.optimized_execution(False):
|
|
m = M(nn.Sequential(nn.ReLU()))
|
|
self.assertExportImportModule(m, (torch.randn(2, 2),))
|
|
|
|
def test_script_modulelist_index(self):
|
|
class Sub(torch.nn.Module):
|
|
def __init__(self, i):
|
|
super(Sub, self).__init__()
|
|
self.i = i
|
|
|
|
def forward(self, thing):
|
|
return thing - self.i
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.mods = nn.ModuleList([Sub(i) for i in range(10)])
|
|
|
|
def forward(self, v):
|
|
v = self.mods[4].forward(v)
|
|
v = self.mods[-1].forward(v)
|
|
v = self.mods[-9].forward(v)
|
|
return v
|
|
|
|
x = torch.tensor(1)
|
|
self.checkModule(M(), (x,))
|
|
|
|
class MForward(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MForward, self).__init__()
|
|
self.mods = nn.ModuleList([Sub(i) for i in range(10)])
|
|
|
|
def forward(self, v):
|
|
v = self.mods[4](v)
|
|
v = self.mods[-1](v)
|
|
v = self.mods[-9](v)
|
|
return v
|
|
|
|
self.checkModule(MForward(), (torch.tensor(1),))
|
|
|
|
class M2(M):
|
|
def __init__(self):
|
|
super(M2, self).__init__()
|
|
|
|
def forward(self, v):
|
|
return self.mods[-11].forward(v)
|
|
|
|
with self.assertRaisesRegex(Exception, "Index -11 out of range"):
|
|
torch.jit.script(M2())
|
|
|
|
|
|
class M2(M):
|
|
def __init__(self):
|
|
super(M2, self).__init__()
|
|
|
|
def forward(self, v):
|
|
return self.mods[-11].forward(v)
|
|
|
|
with self.assertRaisesRegex(Exception, "Index -11 out of range"):
|
|
torch.jit.script(M2())
|
|
|
|
class M3(M):
|
|
def __init__(self):
|
|
super(M3, self).__init__()
|
|
|
|
def forward(self, v):
|
|
i = 3
|
|
return self.mods[i].forward(v)
|
|
|
|
with self.assertRaisesRegex(Exception, "Enumeration is supported"):
|
|
torch.jit.script(M3())
|
|
|
|
def test_module_interface_special_methods(self):
|
|
class CustomModuleInterface(torch.nn.Module):
|
|
def __init__(self):
|
|
super(CustomModuleInterface, self).__init__()
|
|
|
|
class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
|
|
def __init__(self, modules=None):
|
|
CustomModuleInterface.__init__(self)
|
|
torch.nn.ModuleList.__init__(self, modules)
|
|
|
|
class CustomSequential(CustomModuleInterface, torch.nn.Sequential):
|
|
def __init__(self, modules=None):
|
|
CustomModuleInterface.__init__(self)
|
|
torch.nn.Sequential.__init__(self, modules)
|
|
|
|
class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
|
|
def __init__(self, modules=None):
|
|
CustomModuleInterface.__init__(self)
|
|
torch.nn.ModuleDict.__init__(self, modules)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
# work around aliasing issue for 'is' operator by scripting ReLU up front
|
|
self.submod = torch.jit.script(torch.nn.ReLU())
|
|
self.modulelist = CustomModuleList([self.submod])
|
|
self.sequential = CustomSequential(self.submod)
|
|
self.moduledict = CustomModuleDict({"submod": self.submod})
|
|
|
|
def forward(self, inputs):
|
|
assert self.modulelist[0] is self.submod, "__getitem__ failing for ModuleList"
|
|
assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
|
|
for module in self.modulelist:
|
|
assert module is self.submod, "__iter__ failing for ModuleList"
|
|
|
|
assert self.sequential[0] is self.submod, "__getitem__ failing for Sequential"
|
|
assert len(self.sequential) == 1, "__len__ failing for Sequential"
|
|
for module in self.sequential:
|
|
assert module is self.submod, "__iter__ failing for Sequential"
|
|
|
|
assert self.moduledict["submod"] is self.submod, "__getitem__ failing for ModuleDict"
|
|
assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
|
|
|
|
# note: unable to index moduledict with a string variable currently
|
|
i = 0
|
|
for key in self.moduledict:
|
|
i += 1
|
|
assert i == len(self.moduledict), "iteration failing for ModuleDict"
|
|
|
|
assert "submod" in self.moduledict, "__contains__ fails for ModuleDict"
|
|
|
|
for key in self.moduledict.keys():
|
|
assert key == "submod", "keys() fails for ModuleDict"
|
|
|
|
for item in self.moduledict.items():
|
|
assert item[0] == "submod", "items() fails for ModuleDict"
|
|
assert item[1] is self.submod, "items() fails for ModuleDict"
|
|
|
|
for value in self.moduledict.values():
|
|
assert value is self.submod, "values() fails for ModuleDict"
|
|
|
|
return inputs
|
|
|
|
m = MyModule()
|
|
self.checkModule(m, [torch.randn(2, 2)])
|
|
|
|
def test_special_method_with_override(self):
|
|
class CustomModuleInterface(torch.nn.Module):
|
|
def __init__(self):
|
|
super(CustomModuleInterface, self).__init__()
|
|
|
|
class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
|
|
def __init__(self, modules=None):
|
|
CustomModuleInterface.__init__(self)
|
|
torch.nn.ModuleList.__init__(self, modules)
|
|
|
|
def __len__(self):
|
|
# this is arbitrary, just to check that the overridden py __len__ from
|
|
# CustomModuleList takes precedence over the automatically generated
|
|
# __len__ added by the jit compiler
|
|
return 2
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
# work around aliasing issue for 'is' operator by scripting ReLU up front
|
|
self.submod = torch.jit.script(torch.nn.ReLU())
|
|
self.modulelist = CustomModuleList([self.submod])
|
|
|
|
def forward(self, inputs):
|
|
assert len(self.modulelist) == 2, "__len__ failing for ModuleList"
|
|
return inputs
|
|
|
|
m = MyModule()
|
|
self.checkModule(m, [torch.randn(2, 2)])
|
|
mm = torch.jit.script(m)
|
|
|
|
def test_moduledict_getitem(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.relu = torch.jit.script(torch.nn.ReLU())
|
|
self.tanh = torch.jit.script(torch.nn.Tanh())
|
|
self.moduledict = torch.nn.ModuleDict({"relu": self.relu,
|
|
"tanh": self.tanh})
|
|
|
|
def forward(self, input):
|
|
assert self.moduledict['relu'] is self.relu
|
|
assert self.moduledict['tanh'] is self.tanh
|
|
return input
|
|
|
|
m = MyModule()
|
|
self.checkModule(m, [torch.randn(2, 2)])
|
|
|
|
def test_moduledict_keyerror(self):
|
|
class BadModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(BadModule, self).__init__()
|
|
self.moduledict = torch.nn.ModuleDict({"foo": None,
|
|
"bar": None})
|
|
|
|
def forward(self, input):
|
|
assert self.moduledict['blah'] == "blah", "this is a keyerror"
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Key Error, blah"):
|
|
b = BadModule()
|
|
torch.jit.script(b)
|
|
|
|
class AnotherBadModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(AnotherBadModule, self).__init__()
|
|
self.moduledict = torch.nn.ModuleDict({"foo": None,
|
|
"bar": None})
|
|
|
|
def forward(self, input):
|
|
idx = 'blah'
|
|
assert self.moduledict[idx] == "blah", "this is a string literal error"
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Unable to extract string literal index. "
|
|
"ModuleDict indexing is only supported with string literals."):
|
|
b = AnotherBadModule()
|
|
torch.jit.script(b)
|
|
|
|
def test_empty_dict_override_contains(self):
|
|
class CustomModuleInterface(torch.nn.Module):
|
|
def __init__(self):
|
|
super(CustomModuleInterface, self).__init__()
|
|
|
|
class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
|
|
def __init__(self, modules=None):
|
|
CustomModuleInterface.__init__(self)
|
|
torch.nn.ModuleDict.__init__(self, modules)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
# work around aliasing issue for 'is' operator by scripting ReLU up front
|
|
self.submod = torch.jit.script(torch.nn.ReLU())
|
|
self.moduledict = CustomModuleDict()
|
|
|
|
def forward(self, inputs):
|
|
assert "submod" not in self.moduledict, "__contains__ fails for ModuleDict"
|
|
return inputs
|
|
|
|
m = MyModule()
|
|
self.checkModule(m, [torch.randn(2, 2)])
|