mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add option to specify custom NNAPI serializer (#61025)
Summary: To add serializer for custom ops we can subclass default serializer and update ADDER_MAP Pull Request resolved: https://github.com/pytorch/pytorch/pull/61025 Test Plan: * pytest test/test_nnapi.py::TestNNAPI for current serializer * Custom serializers to be tested with custom ops Imported from OSS Reviewed By: anshuljain1 Differential Revision: D29480745 fbshipit-source-id: 37e3f8de3c97f6c8a486f9879ce11430ea89af34
This commit is contained in:
parent
cbb6ab6d88
commit
a3670ba377
|
|
@ -1,7 +1,8 @@
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.backends._nnapi.serializer import serialize_model
|
from torch.backends._nnapi.serializer import _NnapiSerializer
|
||||||
|
|
||||||
|
|
||||||
class NnapiModule(torch.nn.Module):
|
class NnapiModule(torch.nn.Module):
|
||||||
"""Torch Module that wraps an NNAPI Compilation.
|
"""Torch Module that wraps an NNAPI Compilation.
|
||||||
|
|
@ -75,14 +76,15 @@ class NnapiModule(torch.nn.Module):
|
||||||
raise Exception("Invalid mem_fmt")
|
raise Exception("Invalid mem_fmt")
|
||||||
return outs
|
return outs
|
||||||
|
|
||||||
|
def convert_model_to_nnapi(model, inputs, serializer=None):
|
||||||
def convert_model_to_nnapi(model, inputs):
|
|
||||||
model = torch.jit.freeze(model)
|
model = torch.jit.freeze(model)
|
||||||
|
|
||||||
if isinstance(inputs, torch.Tensor):
|
if isinstance(inputs, torch.Tensor):
|
||||||
inputs = [inputs]
|
inputs = [inputs]
|
||||||
|
|
||||||
ser_model, used_weights, inp_mem_fmts, out_mem_fmts, shape_compute_lines, retval_count = serialize_model(model, inputs)
|
serializer = serializer or _NnapiSerializer(config=None)
|
||||||
|
(ser_model, used_weights, inp_mem_fmts, out_mem_fmts, shape_compute_lines,
|
||||||
|
retval_count) = serializer.serialize_model(model, inputs)
|
||||||
ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
|
ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
|
||||||
|
|
||||||
# We have to create a new class here every time this function is called
|
# We have to create a new class here every time this function is called
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user