mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
Fixes https://github.com/pytorch/pytorch/issues/48141
~Mypy is complaining about a missing arg in a function call.~
```bash
torch/backends/_nnapi/serializer.py:806: error: Too few arguments for "_do_add_binary" [call-arg]
Found 1 error in 1 file (checked 1140 source files)
```
9392137dbe/torch/backends/_nnapi/serializer.py (L804-L806)
~dreiss, would you mind take a look when you have some cycles to spare and see what would be the appropriated value for `fuse_code` here? Thanks :)~
Edit: https://github.com/pytorch/pytorch/issues/48925 got merged a couple of days ago. The blocking part is now unblocked, and I just pushed the changes to make mypy happy again. This PR is ready for review.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48142
Reviewed By: ezyang
Differential Revision: D28006249
Pulled By: walterddr
fbshipit-source-id: 5e43eeba7143512a549efaad31541f86718add7c
138 lines
5.0 KiB
Python
138 lines
5.0 KiB
Python
from typing import Optional, List
|
|
|
|
import torch
|
|
from torch.backends._nnapi.serializer import serialize_model
|
|
|
|
class NnapiModule(torch.nn.Module):
|
|
"""Torch Module that wraps an NNAPI Compilation.
|
|
|
|
This module handles preparing the weights, initializing the
|
|
NNAPI TorchBind object, and adjusting the memory formats
|
|
of all inputs and outputs.
|
|
"""
|
|
|
|
# _nnapi.Compilation is defined
|
|
comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined]
|
|
weights: List[torch.Tensor]
|
|
out_templates: List[torch.Tensor]
|
|
|
|
def __init__(
|
|
self,
|
|
shape_compute_module: torch.nn.Module,
|
|
ser_model: torch.Tensor,
|
|
weights: List[torch.Tensor],
|
|
inp_mem_fmts: List[int],
|
|
out_mem_fmts: List[int],
|
|
):
|
|
super().__init__()
|
|
self.shape_compute_module = shape_compute_module
|
|
self.ser_model = ser_model
|
|
self.weights = weights
|
|
self.inp_mem_fmts = inp_mem_fmts
|
|
self.out_mem_fmts = out_mem_fmts
|
|
self.out_templates = []
|
|
self.comp = None
|
|
|
|
@torch.jit.export
|
|
def init(self, args: List[torch.Tensor]):
|
|
assert self.comp is None
|
|
self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator]
|
|
self.weights = [w.contiguous() for w in self.weights]
|
|
comp = torch.classes._nnapi.Compilation()
|
|
comp.init(self.ser_model, self.weights)
|
|
self.comp = comp
|
|
|
|
def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
if self.comp is None:
|
|
self.init(args)
|
|
comp = self.comp
|
|
assert comp is not None
|
|
outs = [torch.empty_like(out) for out in self.out_templates]
|
|
|
|
assert len(args) == len(self.inp_mem_fmts)
|
|
fixed_args = []
|
|
for idx in range(len(args)):
|
|
fmt = self.inp_mem_fmts[idx]
|
|
# These constants match the values in DimOrder in serializer.py
|
|
# TODO: See if it's possible to use those directly.
|
|
if fmt == 0:
|
|
fixed_args.append(args[idx].contiguous())
|
|
elif fmt == 1:
|
|
fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous())
|
|
else:
|
|
raise Exception("Invalid mem_fmt")
|
|
comp.run(fixed_args, outs)
|
|
assert len(outs) == len(self.out_mem_fmts)
|
|
for idx in range(len(self.out_templates)):
|
|
fmt = self.out_mem_fmts[idx]
|
|
# These constants match the values in DimOrder in serializer.py
|
|
# TODO: See if it's possible to use those directly.
|
|
if fmt == 0:
|
|
pass
|
|
elif fmt == 1:
|
|
outs[idx] = outs[idx].permute(0, 3, 1, 2)
|
|
else:
|
|
raise Exception("Invalid mem_fmt")
|
|
return outs
|
|
|
|
|
|
def convert_model_to_nnapi(model, inputs):
|
|
model = torch.jit.freeze(model)
|
|
|
|
if isinstance(inputs, torch.Tensor):
|
|
inputs = [inputs]
|
|
|
|
ser_model, used_weights, inp_mem_fmts, out_mem_fmts, shape_compute_lines, retval_count = serialize_model(model, inputs)
|
|
ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
|
|
|
|
# We have to create a new class here every time this function is called
|
|
# because module.define adds a method to the *class*, not the instance.
|
|
class ShapeComputeModule(torch.nn.Module):
|
|
"""Code-gen-ed module for tensor shape computation
|
|
|
|
module.prepare will mutate ser_model according to the computed operand
|
|
shapes, based on the shapes of args. Returns a list of output templates.
|
|
"""
|
|
pass
|
|
shape_compute_module = torch.jit.script(ShapeComputeModule())
|
|
real_shape_compute_lines = [
|
|
"def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n",
|
|
] + [
|
|
f" {line}\n" for line in shape_compute_lines
|
|
]
|
|
shape_compute_module.define("".join(real_shape_compute_lines))
|
|
|
|
nnapi_model = NnapiModule(
|
|
shape_compute_module,
|
|
ser_model_tensor,
|
|
used_weights,
|
|
inp_mem_fmts,
|
|
out_mem_fmts)
|
|
|
|
class NnapiInterfaceWrapper(torch.nn.Module):
|
|
"""NNAPI list-ifying and de-list-ifying wrapper.
|
|
|
|
NNAPI always expects a list of inputs and provides a list of outputs.
|
|
This module allows us to accept inputs as separate arguments.
|
|
It returns results as either a single tensor or tuple,
|
|
matching the original module.
|
|
"""
|
|
def __init__(self, mod):
|
|
super().__init__()
|
|
self.mod = mod
|
|
|
|
wrapper_model_py = NnapiInterfaceWrapper(nnapi_model)
|
|
wrapper_model = torch.jit.script(wrapper_model_py)
|
|
# TODO: Maybe make these names match the original.
|
|
arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs)))
|
|
if retval_count < 0:
|
|
ret_expr = "retvals[0]"
|
|
else:
|
|
ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count))
|
|
wrapper_model.define(
|
|
f"def forward(self, {arg_list}):\n"
|
|
f" retvals = self.mod([{arg_list}])\n"
|
|
f" return {ret_expr}\n"
|
|
)
|
|
return wrapper_model
|