pytorch/test/jit/test_backends.py
Michael Suo 31621c828d Fix JIT tests when run locally in fbcode (#45776)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45776

Splitting out backend and custom class registration into their own library is
not currently implemented in fbcode, so detect that we are running tests in
fbcode and disable those tests.

Test Plan: buck test mode/no-gpu mode/dev caffe2/test:jit

Reviewed By: smessmer

Differential Revision: D24085871

fbshipit-source-id: 1fcc0547880bc4be59428e2810b6a7f6e50ef798
2020-10-02 17:43:01 -07:00

219 lines
7.0 KiB
Python

from torch.testing._internal.jit_utils import JitTestCase
import os
import sys
import unittest
import torch
import torch._C
from pathlib import Path
from torch.testing._internal.common_utils import (
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
TEST_WITH_ROCM,
skipIfRocm,
)
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
def to_test_backend(module, method_compile_spec):
return torch._C._jit_to_backend("test_backend", module, {"forward": method_compile_spec})
def to_test_backend_multi(module, method_compile_spec):
return torch._C._jit_to_backend("test_backend", module, method_compile_spec)
class BasicModule(torch.nn.Module):
"""
A simple Module used to test to_backend lowering machinery.
"""
def __init__(self):
super().__init__()
def forward(self, x, h):
return self.accum(x, h), self.sub_accum(x, h)
def accum(self, x, h):
return x + h
def sub_accum(self, x, h):
return x - h
class JitBackendTestCase(JitTestCase):
"""
A common base class for JIT backend tests that contains common utility
functions for output comparison and serialization/deserialization.
"""
def setUp(self):
super().setUp()
if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE:
raise unittest.SkipTest("non-portable load_library call used in test")
torch_root = Path(__file__).resolve().parent.parent.parent
p = torch_root / 'build' / 'lib' / 'libjitbackend_test.so'
torch.ops.load_library(str(p))
# Subclasses are expected to set up three variables in their setUp methods:
# module - a regular, Python version of the module being tested
# scripted_module - a scripted version of module
# lowered_modle - a version of module lowered to a backend
def check_function(self, function_name, input):
"""
Check that the function named 'function_name' produces the same output using
Python, regular JIT and the backend for the given 'input'.
"""
# Get handles for Python, JIT and backend methods.
python_method = self.module.__getattribute__(function_name)
jit_method = self.scripted_module.__getattr__(function_name)
backend_method = self.lowered_module.__getattr__(function_name)
# Run methods.
python_output = python_method(input, input)
jit_output = jit_method(input, input)
backend_output = backend_method(input, input)
# The answers returned by Python, JIT and to_backend should all match.
self.assertEqual(python_output, backend_output)
self.assertEqual(jit_output, backend_output)
def save_load(self):
"""
Save and load the lowered module.
"""
self.lowered_module = self.getExportImportCopy(self.lowered_module)
class BasicModuleTest(JitBackendTestCase):
"""
Tests for BasicModule.
"""
def setUp(self):
super().setUp()
# Create Python, JIT and backend versions of BasicModule.
self.module = BasicModule()
self.scripted_module = torch.jit.script(BasicModule())
self.lowered_module = to_test_backend_multi(
self.scripted_module,
{"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
)
def test_execution(self):
# Test execution with backend against Python and JIT.
input = torch.randn(5)
# Test all three module methods.
self.check_function("accum", input)
self.check_function("sub_accum", input)
self.check_function("forward", input)
@skipIfRocm
def test_save_load(self):
# Lowered module should produce the same outputs.
self.test_execution()
# Save the compile spec to compare against the version retrieved after loading.
pre_compile_spec = self.lowered_module.__getattr__("__method_compile_spec")
# Save and load the lowered module.
self.save_load()
# Get the compile spec after loading.
post_compile_spec = self.lowered_module.__getattr__("__method_compile_spec")
# Compile specs should match.
self.assertEqual(pre_compile_spec, post_compile_spec)
# Loaded module should produce the same outputs.
self.test_execution()
class NestedModuleTest(JitBackendTestCase):
"""
Tests for NestedModule that check that a module lowered to a backend can be used
as a submodule.
"""
class NestedModule(torch.nn.Module):
"""
A Module with one submodule that is used to test that lowered Modules
can be used as submodules.
"""
def __init__(self, submodule):
super().__init__()
self.submodule = submodule
def forward(self, x, h):
return self.submodule.forward(x, h)
def setUp(self):
super().setUp()
# Create Python, JIT and backend versions of NestedModule.
# Both modules in self.module are regular Python modules.
self.module = NestedModuleTest.NestedModule(BasicModule())
# Both modules in self.scripted_module are ScriptModules.
self.scripted_module = torch.jit.script(NestedModuleTest.NestedModule(BasicModule()))
lowered_module = to_test_backend_multi(
self.scripted_module, {"forward": {"": ""}}
)
# self.lowered_module is a ScriptModule, but its submodule is a lowered module.
self.lowered_module = torch.jit.script(NestedModuleTest.NestedModule(lowered_module))
def test_execution(self):
# Test execution with backend against Python and JIT.
input = torch.randn(5)
# Test forward.
self.check_function("forward", input)
def test_save_load(self):
# Lowered module should produce the same outputs.
self.test_execution()
# Save and load the lowered module.
self.save_load()
# Loaded module should produce the same outputs.
self.test_execution()
class TestBackends(JitTestCase):
"""
This class wraps and invokes all subclasses of JitBackendTestCase so that each one
does not have to be individually imported in test_jit.py.
"""
def __init__(self, name):
super().__init__(name)
self.basic_module_test = BasicModuleTest(name)
self.nested_module_test = NestedModuleTest(name)
def setUp(self):
super().setUp()
if not TEST_WITH_ROCM:
self.basic_module_test.setUp()
self.nested_module_test.setUp()
@skipIfRocm
def test_execution(self):
self.basic_module_test.test_execution()
self.nested_module_test.test_execution()
@skipIfRocm
def test_save_load(self):
self.basic_module_test.test_save_load()
self.nested_module_test.test_save_load()