mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add option to specify custom NNAPI serializer (#61025)
Summary: To add serializer for custom ops we can subclass default serializer and update ADDER_MAP Pull Request resolved: https://github.com/pytorch/pytorch/pull/61025 Test Plan: * pytest test/test_nnapi.py::TestNNAPI for current serializer * Custom serializers to be tested with custom ops Imported from OSS Reviewed By: anshuljain1 Differential Revision: D29480745 fbshipit-source-id: 37e3f8de3c97f6c8a486f9879ce11430ea89af34
This commit is contained in:
parent
cbb6ab6d88
commit
a3670ba377
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
from torch.backends._nnapi.serializer import serialize_model
|
||||
from torch.backends._nnapi.serializer import _NnapiSerializer
|
||||
|
||||
|
||||
class NnapiModule(torch.nn.Module):
|
||||
"""Torch Module that wraps an NNAPI Compilation.
|
||||
|
|
@ -75,14 +76,15 @@ class NnapiModule(torch.nn.Module):
|
|||
raise Exception("Invalid mem_fmt")
|
||||
return outs
|
||||
|
||||
|
||||
def convert_model_to_nnapi(model, inputs):
|
||||
def convert_model_to_nnapi(model, inputs, serializer=None):
|
||||
model = torch.jit.freeze(model)
|
||||
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = [inputs]
|
||||
|
||||
ser_model, used_weights, inp_mem_fmts, out_mem_fmts, shape_compute_lines, retval_count = serialize_model(model, inputs)
|
||||
serializer = serializer or _NnapiSerializer(config=None)
|
||||
(ser_model, used_weights, inp_mem_fmts, out_mem_fmts, shape_compute_lines,
|
||||
retval_count) = serializer.serialize_model(model, inputs)
|
||||
ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
|
||||
|
||||
# We have to create a new class here every time this function is called
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user