pytorch/test/jit/test_backend_nnapi.py
Amy He 046272f3e5 [6/N] Nnapi Backend Delegate: Comprehensive OSS Tests (#61782)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61782

This PR depends on https://github.com/pytorch/pytorch/pull/61787

### Summary:
Added more comprehensive tests for Android NNAPI delegate.
Previously, there was only one basic test for lowering a PReLU module with the NNAPI delegate. Now, more tests are inherited from `test_nnapi.py`, the file for testing NNAPI conversion and execution without the delegate.

**test_backend_nnapi.py**
Test file for Android NNAPI delegate.
- `TestNnapiBackend` class inherits tests from `test_nnapi.py` and overrides the model conversion to use the delegate API.
- Includes an extra test for passing input arguments as Tensors and Tensor Lists.
- Has extra set up for loading the NNAPI delegate library and changing the default dtype from float64 to float32 (dtype is typically float32 by default, but not in delegate backend unit tests)

**test_nnapi.py**
Test file for Android NNAPI without the delegate.
- Some code was refactored to allow override of only the NNAPI conversion call.
- An extra function was added to allow the NNAPI delegate unit test to turn off the model execution step. Once the NNAPI delegate's execution implementation is complete, this may no longer be necessary.

### Test Plan:
I ran `python test/test_jit.py TestNnapiBackend` and `python test/test_nnapi.py` to run both test files.

Test Plan: Imported from OSS

Reviewed By: raziel, iseeyuan

Differential Revision: D29772005

fbshipit-source-id: 5d14067a4f6081835699b87a2ece5bd6bed00c6b
2021-07-23 17:04:07 -07:00

73 lines
2.6 KiB
Python

import os
import sys
import unittest
import torch
import torch._C
from pathlib import Path
from test_nnapi import TestNNAPI
from torch.testing._internal.common_utils import TEST_WITH_ASAN
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
"""
Unit Tests for Nnapi backend with delegate
Inherits most tests from TestNNAPI, which loads Android NNAPI models
without the delegate API.
"""
# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
# Second skip is because ASAN is currently causing an error.
# It is still unclear how to resolve this. T95764916
torch_root = Path(__file__).resolve().parent.parent.parent
lib_path = torch_root / 'build' / 'lib' / 'libnnapi_backend.so'
@unittest.skipIf(not os.path.exists(lib_path),
"Skipping the test as libnnapi_backend.so was not found")
@unittest.skipIf(TEST_WITH_ASAN, "Unresolved bug with ASAN")
class TestNnapiBackend(TestNNAPI):
def setUp(self):
super().setUp()
# Save default dtype
module = torch.nn.PReLU()
self.default_dtype = module.weight.dtype
# Change dtype to float32 (since a different unit test changed dtype to float64,
# which is not supported by the Android NNAPI delegate)
# Float32 should typically be the default in other files.
torch.set_default_dtype(torch.float32)
# Load nnapi delegate library
torch.ops.load_library(str(lib_path))
# Disable execution tests, only test lowering modules
# TODO: Re-enable execution tests after the Nnapi delegate is complete
super().set_can_run_nnapi(False)
# Override
def call_lowering_to_nnapi(self, traced_module, args):
compile_spec = {"forward": {"inputs": args}}
return torch._C._jit_to_backend("nnapi", traced_module, compile_spec)
def test_tensor_input(self):
# Lower a simple module
args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
module = torch.nn.PReLU()
traced = torch.jit.trace(module, args)
# Argument input is a single Tensor
self.call_lowering_to_nnapi(traced, args)
# Argument input is a Tensor in a list
self.call_lowering_to_nnapi(traced, [args])
def tearDown(self):
# Change dtype back to default (Otherwise, other unit tests will complain)
torch.set_default_dtype(self.default_dtype)