mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61782 This PR depends on https://github.com/pytorch/pytorch/pull/61787 ### Summary: Added more comprehensive tests for Android NNAPI delegate. Previously, there was only one basic test for lowering a PReLU module with the NNAPI delegate. Now, more tests are inherited from `test_nnapi.py`, the file for testing NNAPI conversion and execution without the delegate. **test_backend_nnapi.py** Test file for Android NNAPI delegate. - `TestNnapiBackend` class inherits tests from `test_nnapi.py` and overrides the model conversion to use the delegate API. - Includes an extra test for passing input arguments as Tensors and Tensor Lists. - Has extra set up for loading the NNAPI delegate library and changing the default dtype from float64 to float32 (dtype is typically float32 by default, but not in delegate backend unit tests) **test_nnapi.py** Test file for Android NNAPI without the delegate. - Some code was refactored to allow override of only the NNAPI conversion call. - An extra function was added to allow the NNAPI delegate unit test to turn off the model execution step. Once the NNAPI delegate's execution implementation is complete, this may no longer be necessary. ### Test Plan: I ran `python test/test_jit.py TestNnapiBackend` and `python test/test_nnapi.py` to run both test files. Test Plan: Imported from OSS Reviewed By: raziel, iseeyuan Differential Revision: D29772005 fbshipit-source-id: 5d14067a4f6081835699b87a2ece5bd6bed00c6b
73 lines
2.6 KiB
Python
73 lines
2.6 KiB
Python
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._C
|
|
from pathlib import Path
|
|
from test_nnapi import TestNNAPI
|
|
from torch.testing._internal.common_utils import TEST_WITH_ASAN
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
"""
|
|
Unit Tests for Nnapi backend with delegate
|
|
Inherits most tests from TestNNAPI, which loads Android NNAPI models
|
|
without the delegate API.
|
|
"""
|
|
# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
|
|
# Second skip is because ASAN is currently causing an error.
|
|
# It is still unclear how to resolve this. T95764916
|
|
torch_root = Path(__file__).resolve().parent.parent.parent
|
|
lib_path = torch_root / 'build' / 'lib' / 'libnnapi_backend.so'
|
|
@unittest.skipIf(not os.path.exists(lib_path),
|
|
"Skipping the test as libnnapi_backend.so was not found")
|
|
@unittest.skipIf(TEST_WITH_ASAN, "Unresolved bug with ASAN")
|
|
class TestNnapiBackend(TestNNAPI):
|
|
def setUp(self):
|
|
super().setUp()
|
|
|
|
# Save default dtype
|
|
module = torch.nn.PReLU()
|
|
self.default_dtype = module.weight.dtype
|
|
# Change dtype to float32 (since a different unit test changed dtype to float64,
|
|
# which is not supported by the Android NNAPI delegate)
|
|
# Float32 should typically be the default in other files.
|
|
torch.set_default_dtype(torch.float32)
|
|
|
|
# Load nnapi delegate library
|
|
torch.ops.load_library(str(lib_path))
|
|
|
|
# Disable execution tests, only test lowering modules
|
|
# TODO: Re-enable execution tests after the Nnapi delegate is complete
|
|
super().set_can_run_nnapi(False)
|
|
|
|
# Override
|
|
def call_lowering_to_nnapi(self, traced_module, args):
|
|
compile_spec = {"forward": {"inputs": args}}
|
|
return torch._C._jit_to_backend("nnapi", traced_module, compile_spec)
|
|
|
|
def test_tensor_input(self):
|
|
# Lower a simple module
|
|
args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
|
|
module = torch.nn.PReLU()
|
|
traced = torch.jit.trace(module, args)
|
|
|
|
# Argument input is a single Tensor
|
|
self.call_lowering_to_nnapi(traced, args)
|
|
# Argument input is a Tensor in a list
|
|
self.call_lowering_to_nnapi(traced, [args])
|
|
|
|
def tearDown(self):
|
|
# Change dtype back to default (Otherwise, other unit tests will complain)
|
|
torch.set_default_dtype(self.default_dtype)
|