pytorch/test/jit/test_module_interface.py
Zino Benaissa 4d80c8c648 Fix inlining interface call in fork subgraph (#43790)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43790

Interface calls were not handled properly when they are used in fork
subgraph. This PR fixes this issue.

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D23402039

Pulled By: bzinodev

fbshipit-source-id: 41adc5ee7d942250e732e243ab30e356d78d9bf7
2020-09-23 11:17:19 -07:00

675 lines
22 KiB
Python

# flake8: noqa
# TODO: enable linting check for this file
from typing import List
import torch
import torch.nn as nn
import os
import sys
from torch.testing._internal.jit_utils import JitTestCase
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class OrigModule(nn.Module):
def __init__(self):
super(OrigModule, self).__init__()
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
return inp1 + inp2 + 1
def two(self, input):
# type: (Tensor) -> Tensor
return input + 2
def forward(self, input):
# type: (Tensor) -> Tensor
return input + self.one(input, input) + 1
class NewModule(nn.Module):
def __init__(self):
super(NewModule, self).__init__()
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
return inp1 * inp2 + 1
def forward(self, input):
# type: (Tensor) -> Tensor
return self.one(input, input + 1)
class TestModuleInterface(JitTestCase):
def test_not_submodule_interface_call(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
class TestNotModuleInterfaceCall(nn.Module):
proxy_mod : ModuleInterface
def __init__(self):
super(TestNotModuleInterfaceCall, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod.two(input)
with self.assertRaisesRegex(RuntimeError, "Tried to access nonexistent attribute"):
torch.jit.script(TestNotModuleInterfaceCall())
def test_module_interface(self):
global OneTwoModule, OneTwoClass
@torch.jit.interface
class OneTwoModule(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
pass
def two(self, x):
# type: (Tensor) -> Tensor
pass
def forward(self, x):
# type: (Tensor) -> Tensor
pass
@torch.jit.interface
class OneTwoClass(object):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
pass
def two(self, x):
# type: (Tensor) -> Tensor
pass
class FooMod(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
return x + y
def two(self, x):
# type: (Tensor) -> Tensor
return 2 * x
def forward(self, x):
# type: (Tensor) -> Tensor
return self.one(self.two(x), x)
class BarMod(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
return x * y
def two(self, x):
# type: (Tensor) -> Tensor
return 2 / x
def forward(self, x):
# type: (Tensor) -> Tensor
return self.two(self.one(x, x))
@torch.jit.export
def forward2(self, x):
# type: (Tensor) -> Tensor
return self.two(self.one(x, x)) + 1
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
return mod_list[0].forward(x) + mod_list[1].forward(x)
def use_class_interface(mod_list, x):
# type: (List[OneTwoClass], Tensor) -> Tensor
return mod_list[0].two(x) + mod_list[1].one(x, x)
scripted_foo_mod = torch.jit.script(FooMod())
scripted_bar_mod = torch.jit.script(BarMod())
self.checkScript(use_module_interface,
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),))
self.checkScript(use_class_interface,
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),))
def call_module_interface_on_other_method(mod_interface, x):
# type: (OneTwoModule, Tensor) -> Tensor
return mod_interface.forward2(x)
# ensure error out when we call the module on the method other than the interface specified.
with self.assertRaisesRegex(RuntimeError, "Tried to access nonexistent attribute or method"):
self.checkScript(call_module_interface_on_other_method, (scripted_bar_mod, torch.rand(3, 4),))
def test_module_interface_subtype(self):
global OneTwoModule
@torch.jit.interface
class OneTwoModule(nn.Module):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
pass
def two(self, x):
# type: (Tensor) -> Tensor
pass
def forward(self, x):
# type: (Tensor) -> Tensor
pass
@torch.jit.script
def as_module_interface(x):
# type: (OneTwoModule) -> OneTwoModule
return x
@torch.jit.script
class Foo(object):
def one(self, x, y):
# type: (Tensor, Tensor) -> Tensor
return x + y
def two(self, x):
# type: (Tensor) -> Tensor
return 2 * x
def forward(self, x):
# type: (Tensor) -> Tensor
return self.one(self.two(x), x)
# check class object is not a subtype of module interface
with self.assertRaisesRegex(RuntimeError, "ScriptModule class can be subtype of module interface"):
as_module_interface(Foo())
class WrongMod(nn.Module):
def two(self, x):
# type: (int) -> int
return 2 * x
def forward(self, x):
# type: (Tensor) -> Tensor
return x + torch.randn(3, self.two(3))
scripted_wrong_mod = torch.jit.script(WrongMod())
# wrong module that is not compatible with module interface
with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
as_module_interface(scripted_wrong_mod)
def test_module_interface_inheritance(self):
with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"):
@torch.jit.interface
class InheritMod(nn.ReLU):
def three(self, x):
# type: (Tensor) -> Tensor
return 3 * x
def test_module_swap(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
def forward(self, input):
# type: (Tensor) -> Tensor
pass
class TestModule(nn.Module):
proxy_mod : ModuleInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod.forward(input)
scripted_mod = torch.jit.script(TestModule())
input = torch.randn(3, 4)
self.assertEqual(scripted_mod(input), 3 * input + 2)
# module swap with module that have the same interface
scripted_mod.proxy_mod = torch.jit.script(NewModule())
self.assertEqual(scripted_mod(input), input * (input + 1) + 1)
# module swap with non-scripted module should throw error
with self.assertRaisesRegex(RuntimeError, "a ScriptModule with non-scripted module"):
scripted_mod.proxy_mod = NewModule()
def test_module_swap_wrong_module(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
def forward(self, input):
# type: (Tensor) -> Tensor
pass
class NewModuleWrong(nn.Module):
def __init__(self):
super(NewModuleWrong, self).__init__()
def forward(self, input):
# type: (int) -> int
return input + 1
class TestModule(nn.Module):
proxy_mod : ModuleInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod.forward(input)
scripted_mod = torch.jit.script(TestModule())
# module swap with in-compatible interface
with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
scripted_mod.proxy_mod = torch.jit.script(NewModuleWrong())
def test_module_swap_no_lazy_compile(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
def forward(self, input):
# type: (Tensor) -> Tensor
pass
class TestModule(nn.Module):
proxy_mod : ModuleInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod.forward(input)
class NewModuleMethodNotLazyCompile(nn.Module):
def __init__(self):
super(NewModuleMethodNotLazyCompile, self).__init__()
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
return inp1 * inp2 + 1
def forward(self, input):
# type: (Tensor) -> Tensor
return input + 1
scripted_mod = torch.jit.script(TestModule())
# module swap with module that have the same interface, but the method not get
# lazily compiled from forward, user need to export it explicitly for swap to work
with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodNotLazyCompile())
class NewModuleMethodManualExport(nn.Module):
def __init__(self):
super(NewModuleMethodManualExport, self).__init__()
@torch.jit.export
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
return inp1 * inp2 + 1
def forward(self, input):
# type: (Tensor) -> Tensor
return input + 1
scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport())
input = torch.randn(3, 4)
self.assertEqual(scripted_mod(input), input + 1)
def test_module_swap_no_module_interface(self):
# test module swapping with no module interface
class TestNoModuleInterface(nn.Module):
def __init__(self):
super(TestNoModuleInterface, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod(input)
scripted_no_module_interface = torch.jit.script(TestNoModuleInterface())
# proxy mod is swapped with the new ScriptModule that share the same JIT type, should succeed.
scripted_no_module_interface.proxy_mod = torch.jit.script(OrigModule())
# proxy_mod is neither a module interface or have the same JIT type, should fail
with self.assertRaisesRegex(RuntimeError,
"Expected a value of type '__torch__.jit.test_module_interface.OrigModule' " +
"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule'"):
scripted_no_module_interface.proxy_mod = torch.jit.script(NewModule())
def test_script_module_as_interface_swap(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
def forward(self, input):
# type: (Tensor) -> Tensor
pass
class OrigScriptModule(torch.jit.ScriptModule):
def __init__(self):
super(OrigScriptModule, self).__init__()
@torch.jit.script_method
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
return inp1 + inp2 + 1
@torch.jit.script_method
def forward(self, input):
# type: (Tensor) -> Tensor
return input + self.one(input, input) + 1
class NewScriptModule(torch.jit.ScriptModule):
def __init__(self):
super(NewScriptModule, self).__init__()
@torch.jit.script_method
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
return inp1 * inp2 + 1
@torch.jit.script_method
def forward(self, input):
# type: (Tensor) -> Tensor
return self.one(input, input + 1)
class TestNNModuleWithScriptModule(nn.Module):
proxy_mod : ModuleInterface
def __init__(self):
super(TestNNModuleWithScriptModule, self).__init__()
self.proxy_mod = OrigScriptModule()
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod.forward(input)
input = torch.randn(3, 4)
scripted_mod = torch.jit.script(TestNNModuleWithScriptModule())
self.assertEqual(scripted_mod(input), 3 * input + 2)
scripted_mod.proxy_mod = NewScriptModule()
self.assertEqual(scripted_mod(input), input * (input + 1) + 1)
# The call to forward of proxy_mod cannot be inlined. Making sure
# Freezing is throwing an error for now.
def test_freeze_module_with_interface(self):
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.b = 20
def forward(self, x):
return self.b
class OrigMod(torch.nn.Module):
def __init__(self):
super(OrigMod, self).__init__()
self.a = 0
def forward(self, x):
return self.a
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x):
# type: (Tensor) -> int
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigMod()
self.sub = SubModule() # folded
def forward(self, x):
return self.proxy_mod(x) + self.sub(x)
m = torch.jit.script(TestModule())
m.eval()
mf = torch._C._freeze_module(m._c)
# Assume interface has no aliasing
mf = torch._C._freeze_module(m._c, freezeInterfaces = True)
input = torch.tensor([1])
out_s = m.forward(input)
out_f = mf.forward(input)
self.assertEqual(out_s, out_f)
def test_freeze_module_with_setattr_in_interface(self):
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.b = 20
def forward(self, x):
self.b += 2;
return self.b
@torch.jit.export
def getb(self, x):
return self.b
class OrigMod(torch.nn.Module):
def __init__(self):
super(OrigMod, self).__init__()
self.a = 0
def forward(self, x):
return self.a
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x):
# type: (Tensor) -> int
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigMod()
self.sub = SubModule()
def forward(self, x):
return self.proxy_mod(x) + self.sub.getb(x)
m = torch.jit.script(TestModule())
m.proxy_mod = m.sub
m.eval()
with self.assertRaisesRegex(RuntimeError, "failed to freeze interface attribute 'proxy_mod'"):
mf = torch._C._freeze_module(m._c, freezeInterfaces = True)
def test_freeze_module_with_inplace_mutation_in_interface(self):
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.b = torch.tensor([1.5])
def forward(self, x):
self.b[0] += 2;
return self.b
@torch.jit.export
def getb(self, x):
return self.b
class OrigMod(torch.nn.Module):
def __init__(self):
super(OrigMod, self).__init__()
self.a = torch.tensor([0.5])
def forward(self, x):
return self.a
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x):
# type: (Tensor) -> Tensor
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigMod()
self.sub = SubModule()
def forward(self, x):
y = self.proxy_mod(x);
z= self.sub.getb(x)
return y[0] + z[0]
m = torch.jit.script(TestModule())
m.proxy_mod = m.sub
m.sub.b = m.proxy_mod.b
m.eval()
with self.assertRaisesRegex(RuntimeError, "failed to freeze interface attribute 'proxy_mod'"):
mf = torch._C._freeze_module(m._c, freezeInterfaces = True)
def test_freeze_module_with_mutated_interface(self):
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.b = torch.tensor([1.5])
def forward(self, x):
return self.b
@torch.jit.export
def getb(self, x):
return self.b
class OrigMod(torch.nn.Module):
def __init__(self):
super(OrigMod, self).__init__()
self.a = torch.tensor([0.5])
def forward(self, x):
return self.a
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x):
# type: (Tensor) -> Tensor
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigMod()
self.sub = SubModule()
def forward(self, x):
self.proxy_mod = self.sub
y = self.proxy_mod(x);
z= self.sub.getb(x)
return y[0] + z[0]
m = torch.jit.script(TestModule())
m.eval()
with self.assertRaisesRegex(RuntimeError, "failed to freeze interface attribute 'proxy_mod'"):
mf = torch._C._freeze_module(m._c, freezeInterfaces = True)
def test_freeze_module_with_interface_and_fork(self):
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.b = torch.tensor([1.5])
def forward(self, x):
self.b[0] += 3.2
return self.b
class OrigMod(torch.nn.Module):
def __init__(self):
super(OrigMod, self).__init__()
self.a = torch.tensor([0.5])
def forward(self, x):
return self.a
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x):
# type: (Tensor) -> Tensor
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigMod()
self.sub = SubModule()
def forward(self, x):
y = self.proxy_mod(x);
z= self.sub(x)
return y + z
class MainModule(torch.nn.Module):
def __init__(self):
super(MainModule, self).__init__()
self.test= TestModule();
def forward(self, x):
fut = torch.jit._fork(self.test.forward, x)
y = self.test(x)
z = torch.jit._wait(fut)
return y + z
m = torch.jit.script(MainModule())
m.eval()
mf = torch._C._freeze_module(m._c, freezeInterfaces = True)
def test_module_apis_interface(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
class TestModule(nn.Module):
proxy_mod : ModuleInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
return input * 2
@torch.jit.export
def method(self, input):
for module in self.modules():
input = module(input)
return input
with self.assertRaisesRegex(Exception, "Could not compile"):
scripted_mod = torch.jit.script(TestModule())