Add option to specify custom NNAPI serializer (#61025)

Summary:
To add serializer for custom ops we can subclass default serializer
and update ADDER_MAP

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61025

Test Plan:
* pytest test/test_nnapi.py::TestNNAPI for current serializer
* Custom serializers to be tested with custom ops

Imported from OSS

Reviewed By: anshuljain1

Differential Revision: D29480745

fbshipit-source-id: 37e3f8de3c97f6c8a486f9879ce11430ea89af34
This commit is contained in:
Akshit Khurana 2021-07-09 15:21:35 -07:00 committed by Facebook GitHub Bot
parent cbb6ab6d88
commit a3670ba377

View File

@ -1,7 +1,8 @@
from typing import Optional, List
import torch
from torch.backends._nnapi.serializer import serialize_model
from torch.backends._nnapi.serializer import _NnapiSerializer
class NnapiModule(torch.nn.Module):
"""Torch Module that wraps an NNAPI Compilation.
@ -75,14 +76,15 @@ class NnapiModule(torch.nn.Module):
raise Exception("Invalid mem_fmt")
return outs
def convert_model_to_nnapi(model, inputs):
def convert_model_to_nnapi(model, inputs, serializer=None):
model = torch.jit.freeze(model)
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
ser_model, used_weights, inp_mem_fmts, out_mem_fmts, shape_compute_lines, retval_count = serialize_model(model, inputs)
serializer = serializer or _NnapiSerializer(config=None)
(ser_model, used_weights, inp_mem_fmts, out_mem_fmts, shape_compute_lines,
retval_count) = serializer.serialize_model(model, inputs)
ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
# We have to create a new class here every time this function is called