pytorch/torch/backends/_nnapi/serializer.py
David Reiss da7a27b847 [NNAPI] Initial flexible size support (#54701)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54701

We need NNAPI models to support inputs (and, by extension, intermediate
values and outputs) whose shape is only determined at load time.  For
example, a vision models input shape might be dependent on the aspect
ratio of the device camera.  While NNAPI has full support for variable
shapes (by setting components of the operand shape to 0), the guidance
we have received is that vendor-provided drivers for real hardware are
not able to support this efficiently.  Therefore, we take a hybrid
approach where shapes are calculated at model load time to
semi-dynamically construct our NNAPI model.  While this doesn't let us
have truly dynamic input shapes, it does allow us to ensure that the
vendor driver only sees fixed shapes, so we get maximum performance.

In this initial commit, only PReLU supports dynamic shapes.  Additional
operators will be converted in separate diffs.

- In order to convert a flexible-shape model, the user supplies inputs
  with shapes containing dimensions of size 0 for the flexible
  dimensions.
- During conversion, we generate code to compute the shapes of all
  intermediates and outputs as a function of the input shapes.
- We no longer run the input model to produce the output templates.
  Instead, we generate code to return properly-sized templates, given
  the input shapes.
- All of this generated code goes into a "ShapeComputeModule" that is
  used by the NnapiModule during initialization.
- The ShapeComputeModule mutates the serialized model to fill in the
  computed sizes for each operand.  This requires us to change the dtype
  for the serialized model to int32, but this should be fine because
  everything in it is already 4-byte aligned.
- NnapiInitWrapper no longer exists.  Instead, initialization is
  performed on the first run, based on the real arguments.  We plan to
  provide an API for doing eager initialization.
- Unit test updated to allow separate arguments to be given for trace,
  conversion, and inference.  A flexible-shape test case was added for
  PReLU.

Test Plan: Unit test

Reviewed By: axitkhurana

Differential Revision: D27536796

Pulled By: dreiss

fbshipit-source-id: 105585f247987b1e6ec6946a6fe44401237cb0a0
2021-04-06 13:49:43 -07:00

1675 lines
60 KiB
Python

import enum
import struct
import array
import logging
from typing import (
Tuple,
NamedTuple,
)
import torch
# TODO: Add type annotations
# TODO: Check tensor types for ops
LOG = logging.getLogger("nnapi_serialize")
class NNAPI_OperandCode(object):
FLOAT32 = 0
INT32 = 1
UINT32 = 2
TENSOR_FLOAT32 = 3
TENSOR_INT32 = 4
TENSOR_QUANT8_ASYMM = 5
BOOL = 6
TENSOR_QUANT16_SYMM = 7
TENSOR_FLOAT16 = 8
TENSOR_BOOL8 = 9
FLOAT16 = 10
TENSOR_QUANT8_SYMM_PER_CHANNEL = 11
TENSOR_QUANT16_ASYMM = 12
class NNAPI_OperationCode(object):
ADD = 0
AVERAGE_POOL_2D = 1
CONCATENATION = 2
CONV_2D = 3
DEPTHWISE_CONV_2D = 4
DEPTH_TO_SPACE = 5
DEQUANTIZE = 6
EMBEDDING_LOOKUP = 7
FLOOR = 8
FULLY_CONNECTED = 9
HASHTABLE_LOOKUP = 10
L2_NORMALIZATION = 11
L2_POOL_2D = 12
LOCAL_RESPONSE_NORMALIZATION = 13
LOGISTIC = 14
LSH_PROJECTION = 15
LSTM = 16
MAX_POOL_2D = 17
MUL = 18
RELU = 19
RELU1 = 20
RELU6 = 21
RESHAPE = 22
RESIZE_BILINEAR = 23
RNN = 24
SOFTMAX = 25
SPACE_TO_DEPTH = 26
SVDF = 27
TANH = 28
BATCH_TO_SPACE_ND = 29
DIV = 30
MEAN = 31
PAD = 32
SPACE_TO_BATCH_ND = 33
SQUEEZE = 34
STRIDED_SLICE = 35
SUB = 36
TRANSPOSE = 37
ABS = 38
ARGMAX = 39
ARGMIN = 40
AXIS_ALIGNED_BBOX_TRANSFORM = 41
BIDIRECTIONAL_SEQUENCE_LSTM = 42
BIDIRECTIONAL_SEQUENCE_RNN = 43
BOX_WITH_NMS_LIMIT = 44
CAST = 45
CHANNEL_SHUFFLE = 46
DETECTION_POSTPROCESSING = 47
EQUAL = 48
EXP = 49
EXPAND_DIMS = 50
GATHER = 51
GENERATE_PROPOSALS = 52
GREATER = 53
GREATER_EQUAL = 54
GROUPED_CONV_2D = 55
HEATMAP_MAX_KEYPOINT = 56
INSTANCE_NORMALIZATION = 57
LESS = 58
LESS_EQUAL = 59
LOG = 60
LOGICAL_AND = 61
LOGICAL_NOT = 62
LOGICAL_OR = 63
LOG_SOFTMAX = 64
MAXIMUM = 65
MINIMUM = 66
NEG = 67
NOT_EQUAL = 68
PAD_V2 = 69
POW = 70
PRELU = 71
QUANTIZE = 72
QUANTIZED_16BIT_LSTM = 73
RANDOM_MULTINOMIAL = 74
REDUCE_ALL = 75
REDUCE_ANY = 76
REDUCE_MAX = 77
REDUCE_MIN = 78
REDUCE_PROD = 79
REDUCE_SUM = 80
ROI_ALIGN = 81
ROI_POOLING = 82
RSQRT = 83
SELECT = 84
SIN = 85
SLICE = 86
SPLIT = 87
SQRT = 88
TILE = 89
TOPK_V2 = 90
TRANSPOSE_CONV_2D = 91
UNIDIRECTIONAL_SEQUENCE_LSTM = 92
UNIDIRECTIONAL_SEQUENCE_RNN = 93
RESIZE_NEAREST_NEIGHBOR = 94
class NNAPI_FuseCode(object):
FUSED_NONE = 0
FUSED_RELU = 1
FUSED_RELU1 = 2
FUSED_RELU6 = 3
class OperandValueSourceType(object):
IMMEDIATE = 0
NUMBERED_BUFFER = 2
NUMBERED_MEMORY = 3
# Scalar types that appear explicitly in models.
# These must be kept in sync with
# AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS.
# TODO: Expose these directly to Python to avoid maintaining this list.
class TorchScalarTypes(enum.Enum):
QUINT8 = 13
def approx_equal(lhs, rhs, tolerance=1e-6):
return abs(lhs - rhs) <= tolerance * min(lhs, rhs)
def tensor_size(op_type, dims):
ITEM_SIZES = {
NNAPI_OperandCode.TENSOR_FLOAT32: 4,
NNAPI_OperandCode.TENSOR_INT32: 4,
NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1,
NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2,
}
size = ITEM_SIZES[op_type]
for d in dims:
size *= d
return size
def change_element(tup, index, value):
ls = list(tup)
ls[index] = value
return tuple(ls)
class ConvPoolArgs2d(NamedTuple):
"""Configuration arguments for a convolution."""
kernel_h: int
kernel_w: int
stride_h: int
stride_w: int
pad_t: int
pad_b: int
pad_l: int
pad_r: int
dilation_h: int
dilation_w: int
group: int
class DimOrder(enum.Enum):
PRESUMED_CONTIGUOUS = 0
CHANNELS_LAST = 1
SCALAR_OR_VECTOR = 2
UNKNOWN_CONSTANT = 999
class Operand(NamedTuple):
"""Represenation of an NNAPI operand."""
# NNAPI operand type. One of NNAPI_OperandCode.
# TODO: Make this an enum.
op_type: int
# This is always the PyTorch shape, which is NCHW for feature maps.
# The actual NNAPI operand might have a transposed shape.
shape: Tuple[int, ...]
# Specifies how the shape of the operand that we define in NNAPI
# relates to the shape we track above.
# - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match
# the shape of the PyTorch tensor.
# - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and
# the NNAPI operand will be represented explicitly as NHWC.
dim_order: DimOrder
# Quantization params
scale: float
zero_point: int
def use_nchw(self):
if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS:
return True
if self.dim_order is DimOrder.CHANNELS_LAST:
return False
raise Exception("Unknown dim order")
def broadcast_shapes(shape1, shape2):
assert len(shape1) > 0
assert len(shape2) > 0
s1 = list(shape1)
s2 = list(shape2)
# TODO: Support non-equal-rank broadcast where semantics match.
# This can be tricky for NHWC tensors because dimension orders
# don't match between PT and NNAPI, even though semantics match.
if len(s1) > len(s2):
# s2 = [1] * (len(s1) - len(s2)) + s2
raise Exception("Non-equal-rank broadcast is not supported yet.")
if len(s2) > len(s1):
# s3 = [1] * (len(s2) - len(s1)) + s1
raise Exception("Non-equal-rank broadcast is not supported yet.")
ret = []
for d1, d2 in zip(s1, s2):
if d1 == 1:
ret.append(d2)
elif d2 == 1:
ret.append(d1)
elif d1 == d2:
ret.append(d1)
else:
raise Exception("Cannot broadcast shapes: {} and {}".format(shape1, shape2))
return tuple(ret)
def get_conv_pool_shape(image_shape, args, out_ch, transpose):
batch, in_c, in_h, in_w = image_shape
# TODO: Handle dilation
if args.dilation_h != 1 or args.dilation_w != 1:
raise Exception("Dilation not supported yet.")
if transpose:
out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b
out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l
else:
out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1
out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1
# Handle variable-sized tensors.
if in_h == 0:
out_h = 0
if in_w == 0:
out_w = 0
out_shape = (batch, out_ch, out_h, out_w)
return out_shape
def fix_shape(shape, dim_order):
# Return the actual shape that an operand should have in NNAPI,
# given a PyTorch shape and dimension order. This is where we
# convert from PyTorch's "always NCHW" shape to explicit NHWC.
if dim_order is DimOrder.PRESUMED_CONTIGUOUS:
return shape
if dim_order is DimOrder.CHANNELS_LAST:
return tuple([shape[0]] + list(shape[2:]) + [shape[1]])
if dim_order is DimOrder.SCALAR_OR_VECTOR:
assert len(shape) == 0 or len(shape) == 1
return shape
if dim_order is DimOrder.UNKNOWN_CONSTANT:
# XXX think this through
return shape
raise Exception(f"Bad dim_order: {dim_order!r}.")
def reverse_map_dim(dim_order, d):
# Return the original PyTorch dimension position for a given dimension.
# d should be the dimension that NNAPI will see.
# reverse_map_dim(PRESUMED_CONTIGUOUS, x) == x
# reverse_map_dim(CHANNELS_LAST, 3) == 1
if dim_order is DimOrder.PRESUMED_CONTIGUOUS:
return d
assert dim_order is DimOrder.CHANNELS_LAST
return [0, 2, 3, 1][d]
def flex_name(op_id, dim):
# Return the local variable name for the computed flexible size
# for a given op and dimension.
return f"s_{op_id}_{dim}"
class _NnapiSerializer(object):
def __init__(self, config):
self.operands = []
self.values = []
self.operations = []
self.value_data = []
self.operation_args = []
self.inputs = []
self.outputs = []
self.flexible_shape_computation_lines = []
self.modules = {}
self.constants = {}
self.tensor_sequences = {}
self.jitval_operand_map = {}
self.cached_immediates = {}
self.used_weights = []
self.weight_offset = 0
if config is None:
config = {}
def get_next_operand_id(self):
return len(self.operands)
# Add a tensor operand corresponding to a JIT Value.
# Returns the NNAPI operand ID. Can be looked up later with
# get_tensor_operand_by_jitval.
def add_tensor_operand(self, jitval, oper):
assert isinstance(oper, Operand)
if jitval in self.jitval_operand_map:
raise Exception("Duplicate tensor: %r" % jitval)
operand_id = self.get_next_operand_id()
self.operands.append(oper)
self.jitval_operand_map[jitval] = operand_id
return operand_id
# Add a tensor operand that does not correspond to a JIT Value.
# Useful for cases where multiple NNAPI operands are required
# to implement one JIT IR node. Returns the NNAPI operand ID.
def add_anonymous_tensor_operand(self, oper):
assert isinstance(oper, Operand)
operand_id = self.get_next_operand_id()
self.operands.append(oper)
return operand_id
@staticmethod
def torch_tensor_to_operand(tensor, dim_order):
dtype = str(tensor.dtype).replace("torch.", "")
scale = 0.0
zero_point = 0
if dtype == "float32":
op_type = NNAPI_OperandCode.TENSOR_FLOAT32
elif dtype == "quint8":
op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
scale = tensor.q_scale()
zero_point = tensor.q_zero_point()
elif dtype == "qint32":
op_type = NNAPI_OperandCode.TENSOR_INT32
scale = tensor.q_scale()
zero_point = tensor.q_zero_point()
assert zero_point == 0
else:
raise Exception(f"Can't handle input with dtype '{tensor.dtype}'")
return Operand(
shape=tuple(tensor.shape),
op_type=op_type,
dim_order=dim_order,
scale=scale,
zero_point=zero_point,
)
def add_tensor_operand_for_input(self, arg_idx, jitval, tensor):
dim_order = (
DimOrder.CHANNELS_LAST if getattr(tensor, "nnapi_nhwc", False)
else DimOrder.PRESUMED_CONTIGUOUS)
toper = self.torch_tensor_to_operand(tensor, dim_order)
operand_id = self.add_tensor_operand(jitval, toper)
self.inputs.append(operand_id)
for dim, size in enumerate(tensor.shape):
if size == 0:
self.compute_operand_shape(operand_id, dim, f"args[{arg_idx}].shape[{dim}]")
return operand_id
def add_tensor_operand_for_weight(self, tensor):
toper = self.torch_tensor_to_operand(tensor, DimOrder.UNKNOWN_CONSTANT)
operand_id = len(self.operands)
self.operands.append(toper)
tsize = tensor_size(toper.op_type, toper.shape)
psize = ((tsize - 1) | 0x3) + 1
self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER))
buf_num = len(self.used_weights)
offset = 0
self.value_data.append(struct.pack(
"iii",
buf_num,
offset,
tsize))
self.used_weights.append(tensor)
return operand_id
def add_immediate_operand(self, code, value, dims):
assert isinstance(dims, tuple)
cache_key = (code, value)
if cache_key not in self.cached_immediates:
operand_id = len(self.operands)
self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0))
self.values.append((operand_id, OperandValueSourceType.IMMEDIATE))
self.value_data.append(value)
self.cached_immediates[cache_key] = operand_id
return self.cached_immediates[cache_key]
def add_immediate_int_scalar(self, value):
return self.add_immediate_operand(
NNAPI_OperandCode.INT32,
struct.pack("i", value),
())
def add_immediate_float_scalar(self, value):
return self.add_immediate_operand(
NNAPI_OperandCode.FLOAT32,
struct.pack("f", value),
())
def add_immediate_bool_scalar(self, value):
return self.add_immediate_operand(
NNAPI_OperandCode.BOOL,
b"\x01" if value else b"\x00",
())
def add_immediate_int_vector(self, value):
return self.add_immediate_operand(
NNAPI_OperandCode.TENSOR_INT32,
array.array("i", value).tobytes(),
(len(value),))
def get_tensor_operand_by_jitval(self, jitval):
operand_id = self.jitval_operand_map[jitval]
return (operand_id, self.operands[operand_id])
def get_tensor_operand_by_jitval_fixed_size(self, jitval):
op_id, oper = self.get_tensor_operand_by_jitval(jitval)
for s in oper.shape:
if s <= 0:
# TODO: Improve this error message, possibly after converting
# many callsites to support flexible size.
raise Exception("Flexible size is not supported for this operand.")
return op_id, oper
def get_tensor_operand_or_constant(self, jitval):
operand_id = self.jitval_operand_map.get(jitval)
if operand_id is None:
_, value = self.get_constant_value(jitval, "TensorType")
operand_id = self.add_tensor_operand_for_weight(value)
return (operand_id, self.operands[operand_id])
def get_tensor_operand_for_weight(self, jitval):
_, value = self.get_constant_value(jitval, "TensorType")
operand_id = self.add_tensor_operand_for_weight(value)
return (operand_id, self.operands[operand_id])
def add_operation(self, opcode, inputs, outputs):
self.operations.append((opcode, len(inputs), len(outputs)))
self.operation_args.extend(inputs + outputs)
def add_tensor_sequence(self, jitval, values):
assert jitval not in self.tensor_sequences
self.tensor_sequences[jitval] = values
def add_constant_value(self, jitval, ctype, value):
assert jitval not in self.constants
self.constants[jitval] = (ctype, value)
def get_constant_value(self, jitval, typekind=None):
record = self.constants.get(jitval)
if record is None:
raise Exception(f"Could not find constant value for '{jitval!r}'.")
ctype, _ = record
if typekind is not None and ctype.kind() != typekind:
raise Exception(
f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'")
return record
@staticmethod
def operand_to_template_torchscript(op_id, oper):
"""Return a TorchScript expression to build a template for a given operand."""
shape_parts = ["("]
for d, s in enumerate(oper.shape):
if s > 0:
# Fixed shape dimension: just add the value.
shape_parts.append(str(s))
else:
# Flexible shape dimension: it should have been computed in a variable.
shape_parts.append(flex_name(op_id, d))
shape_parts.append(",")
shape_parts.append(")")
shape_code = "".join(shape_parts)
if oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
return f"torch.zeros({shape_code}, dtype=torch.float32)"
elif oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
return (
f"torch.quantize_per_tensor("
f"torch.zeros(1), scale={oper.scale}, zero_point={oper.zero_point}, dtype=torch.quint8)"
f".expand({shape_code}).contiguous()"
)
raise Exception(f"Unsupported output operand type: {oper.op_type}")
def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim):
self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim))
def compute_operand_shape(self, op_id, dim, expr):
self.flexible_shape_computation_lines.append(f"{flex_name(op_id, dim)} = {expr}")
def transpose_to_nhwc(self, in_id, oper):
if oper.shape[2:] != (1, 1):
raise Exception("Automatic transpose only supported for H,W == 1,1")
out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST)
inputs = [None] * 2
inputs[0] = in_id
inputs[1] = self.add_immediate_int_vector([0, 2, 3, 1])
outputs = [None] * 1
outputs[0] = self.add_anonymous_tensor_operand(out_oper)
self.add_operation(NNAPI_OperationCode.TRANSPOSE, inputs, outputs)
return outputs[0], out_oper
# Transpose inputs as necessary to allow broadcasting.
def transpose_for_broadcast(self, in0_id, in0_oper, in1_id, in1_oper):
if in0_oper.dim_order == in1_oper.dim_order:
return in0_id, in0_oper, in1_id, in1_oper
# Assume NHWC is preferred if there is a mismatch.
orders = (in0_oper.dim_order, in1_oper.dim_order)
if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST):
return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper)
if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS):
return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper)
raise Exception(
"Automatic transpose not supported for dim_orders: %r, %r" %
(in0_oper.dim_order, in1_oper.dim_order))
def get_size_arg(self, jitval):
ctype, value = self.get_constant_value(jitval)
if ctype.kind() == "ListType":
assert ctype.getElementType().kind() == "IntType"
return value
raise Exception(f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'")
def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config):
pc = [i.item() for i in packed_config]
assert pc[0] == 2
strides = [pc[1], pc[2]]
paddings = [pc[3], pc[4]]
dilations = [pc[5], pc[6]]
output_padding = [pc[7], pc[8]]
group_num = pc[9]
transpose = pc[10]
assert len(pc) == 11
assert output_padding == [0, 0]
assert transpose == 0
return self.get_conv_pool_args_2d_common(kernel_size, strides, paddings, dilations, group_num)
def get_conv_pool_args_2d_from_jit(self, kernel_size, stride, padding, dilation, group=None):
strides = self.get_size_arg(stride)
paddings = self.get_size_arg(padding)
dilations = self.get_size_arg(dilation)
if group is not None:
_, group_num = self.get_constant_value(group, "IntType")
else:
group_num = None
return self.get_conv_pool_args_2d_common(kernel_size, strides, paddings, dilations, group_num)
def get_conv_pool_args_2d_common(self, kernel_size, strides, paddings, dilations, group_num):
kernels = list(kernel_size)
assert len(kernels) == 2
assert len(strides) == 2
assert len(paddings) == 2
assert len(dilations) == 2
# NNAPI uses 4 values for padding.
ph, pw = paddings
real_paddings = [ph, ph, pw, pw]
return ConvPoolArgs2d(*(kernels + strides + real_paddings + dilations + [group_num]))
def serialize_model(self, model, inputs):
self.add_immediate_bool_scalar(False)
self.add_immediate_bool_scalar(True)
inp_dim_orders = []
out_dim_orders = []
self_jitval = next(model.graph.inputs())
self.add_constant_value(self_jitval, self_jitval.type(), model)
for arg_idx, (input_value, input_tensor) in enumerate(zip(list(model.graph.inputs())[1:], inputs)):
op_id = self.add_tensor_operand_for_input(arg_idx, input_value, input_tensor)
inp_dim_orders.append(self.operands[op_id].dim_order.value)
for idx, node in enumerate(model.graph.nodes()):
LOG.debug("Processing node #%d: %r", idx, node)
self.add_node(node)
retn = model.graph.return_node()
assert retn.inputsSize() == 1
assert retn.outputsSize() == 0
retn_input = retn.inputsAt(0)
template_return_lines = ["return ["]
if retn_input.type().kind() == "TensorType":
return_values = [retn_input]
retval_count = -1
elif retn_input.type().kind() == "TupleType":
return_values = self.tensor_sequences[retn_input]
retval_count = len(return_values)
else:
raise Exception(f"Unsupported return type: {retn_input.type()}")
for v in return_values:
op_id = self.jitval_operand_map[v]
self.outputs.append(op_id)
out_dim_orders.append(self.operands[op_id].dim_order.value)
template_return_lines.append(self.operand_to_template_torchscript(op_id, self.operands[op_id]) + ",")
template_return_lines.append("]")
model = []
version = 1
header = struct.pack(
"iiiiii",
version,
len(self.operands),
len(self.values),
len(self.operations),
len(self.inputs),
len(self.outputs),
)
model.append(header)
serialized_values, serialized_value_data = self.serialize_values()
model.extend(struct.pack("iifi", t, len(d), s, z) for (t, d, _m, s, z) in self.operands)
model.extend(serialized_values)
model.extend(struct.pack("iii", *x) for x in self.operations)
# Compact the model so we can get its length so far.
model = [b"".join(model)]
model_offset = len(model[0])
# Model offset is the index into the model (in 32-bit words, not bytes)
# of the next dimension we're about to serialize. If it's 0,
# generate code to mutate it before passing to NNAPI.
assert model_offset % 4 == 0
model_offset = int(model_offset / 4)
for (op_id, (_, dims, dim_order, _, _)) in enumerate(self.operands):
shape = fix_shape(dims, dim_order)
for d, s in enumerate(shape):
if s == 0:
pt_d = reverse_map_dim(dim_order, d)
self.flexible_shape_computation_lines.append(
f"ser_model[{model_offset}] = {flex_name(op_id, pt_d)}")
model_offset += 1
model.append(self.serialize_ints(shape))
model.extend(serialized_value_data)
model.append(self.serialize_ints(self.operation_args))
model.append(self.serialize_ints(self.inputs))
model.append(self.serialize_ints(self.outputs))
self.flexible_shape_computation_lines.extend(template_return_lines)
return (
array.array("i", b"".join(model)),
self.used_weights,
inp_dim_orders,
out_dim_orders,
self.flexible_shape_computation_lines,
retval_count,
)
def serialize_values(self):
serialized_values = []
serialized_value_data = []
assert len(self.values) == len(self.value_data)
for ((op_index, source_type), data) in zip(self.values, self.value_data):
source_length = len(data)
# Pad with 0 bytes out to a multiple of 4 for alignment.
physical_length = ((source_length - 1) | 0x3) + 1
padded_data = data + (b"\0" * (physical_length - source_length))
serialized_values.append(struct.pack("iii", op_index, source_type, source_length))
serialized_value_data.append(padded_data)
return serialized_values, serialized_value_data
@staticmethod
def serialize_ints(ints):
return array.array("i", ints).tobytes()
ADDER_MAP = {
"prim::GetAttr": lambda self, node:
self.add_getattr(node),
"prim::Constant": lambda self, node:
self.add_constant_node(node),
"prim::ListConstruct": lambda self, node:
self.add_list_construct(node),
"prim::TupleConstruct": lambda self, node:
self.add_tuple_construct(node),
"aten::unsqueeze": lambda self, node:
self.add_unsqueeze(node),
"aten::reshape": lambda self, node:
self.add_reshape(node),
"aten::size": lambda self, node:
self.add_size(node),
"aten::cat": lambda self, node:
self.add_cat(node),
"aten::mean": lambda self, node:
self.add_mean(node),
"aten::quantize_per_tensor": lambda self, node:
self.add_quantize(node),
"aten::dequantize": lambda self, node:
self.add_dequantize(node),
"aten::add": lambda self, node:
self.add_add_sub_op(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE),
"aten::sub": lambda self, node:
self.add_add_sub_op(node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE),
"aten::mul": lambda self, node:
self.add_pointwise_simple_binary_broadcast_op(node, NNAPI_OperationCode.MUL, NNAPI_FuseCode.FUSED_NONE),
"aten::relu": lambda self, node:
self.add_pointwise_simple_unary_op(node, NNAPI_OperationCode.RELU),
"aten::sigmoid": lambda self, node:
self.add_pointwise_simple_unary_op(node, NNAPI_OperationCode.LOGISTIC),
"aten::hardtanh": lambda self, node:
self.add_hardtanh(node),
"aten::max_pool2d": lambda self, node:
self.add_pool2d_node(node, NNAPI_OperationCode.MAX_POOL_2D),
"aten::adaptive_avg_pool2d": lambda self, node:
self.add_adaptive_avg_pool2d(node),
"aten::upsample_nearest2d": lambda self, node:
self.add_upsample_nearest2d(node),
"aten::prelu": lambda self, node:
self.add_prelu_op(node),
"aten::addmm": lambda self, node:
self.add_addmm(node),
"aten::linear": lambda self, node:
self.add_linear(node),
"aten::_convolution": lambda self, node:
self.add_conv_underscore(node),
"aten::conv2d": lambda self, node:
self.add_conv2d(node),
"quantized::linear": lambda self, node:
self.add_qlinear(node),
"quantized::conv2d": lambda self, node:
self.add_qconv2d(node, NNAPI_FuseCode.FUSED_NONE),
"quantized::conv2d_relu": lambda self, node:
self.add_qconv2d(node, NNAPI_FuseCode.FUSED_RELU),
"quantized::add": lambda self, node:
self.add_qadd(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE),
"quantized::add_relu": lambda self, node:
self.add_qadd(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU),
}
def add_node(self, node):
adder = self.ADDER_MAP.get(node.kind())
if not adder:
raise Exception("Unsupported node kind (%r) in node %r" % (node.kind(), node))
adder(self, node)
def add_getattr(self, node):
assert node.inputsSize() == 1
assert node.outputsSize() == 1
obj_ctype, obj = self.get_constant_value(node.inputsAt(0))
assert str(obj_ctype).startswith("__torch__.")
name = node.s("name")
value = getattr(obj, name)
output = node.outputsAt(0)
ctype = output.type()
self.add_constant_value(output, ctype, value)
def add_constant_node(self, node):
assert node.inputsSize() == 0
assert node.outputsSize() == 1
output = node.outputsAt(0)
ctype = output.type()
value = output.toIValue()
self.add_constant_value(output, ctype, value)
def add_list_construct(self, node):
assert node.outputsSize() == 1
output = node.outputsAt(0)
ctype = output.type()
const_vals = []
tensors = []
for inp in node.inputs():
if const_vals is not None and inp in self.constants:
_, val = self.get_constant_value(inp)
const_vals.append(val)
else:
const_vals = None
if tensors is not None and inp.type().kind() == "TensorType":
tensors.append(inp)
else:
tensros = None
if const_vals is not None:
# NOTE: Now that TorchScript supports list constants,
# this code path might not be used anymore.
self.add_constant_value(output, ctype, const_vals)
if tensors is not None:
self.add_tensor_sequence(output, tensors)
if const_vals is None and tensors is None:
raise Exception(
"Unable to handle ListConstruct node."
" Neither all constants nor all tensors. %r" % node)
def add_tuple_construct(self, node):
assert node.outputsSize() == 1
output = node.outputsAt(0)
values = []
for inp in node.inputs():
values.append(inp)
self.add_tensor_sequence(output, values)
def add_unsqueeze(self, node):
assert node.inputsSize() == 2
assert node.outputsSize() == 1
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
_, dim = self.get_constant_value(node.inputsAt(1), "IntType")
assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS
real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1
out_shape_list = list(in_oper.shape)
out_shape_list.insert(real_dim, 1)
out_shape = tuple(out_shape_list)
out_oper = in_oper._replace(shape=out_shape)
inputs = [None] * 2
inputs[0] = in_id
inputs[1] = self.add_immediate_int_scalar(dim)
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs)
def add_reshape(self, node):
assert node.inputsSize() == 2
assert node.outputsSize() == 1
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
shape_ctype, shape = self.get_constant_value(node.inputsAt(1))
assert shape_ctype.kind() == "ListType"
assert shape_ctype.getElementType().kind() == "IntType"
is_trivial_reshape = len(shape) == 2 and shape[1] == -1
if in_oper.dim_order != DimOrder.PRESUMED_CONTIGUOUS and not is_trivial_reshape:
raise Exception(
"Currently, reshape is only supported on NHWC tensors if the target size is [X, -1].")
# Bit of a hack here. Use a real tensor to infer the output shape.
out_shape = torch.zeros(1).expand(in_oper.shape).reshape(shape).shape
out_oper = in_oper._replace(shape=out_shape, dim_order=DimOrder.PRESUMED_CONTIGUOUS)
inputs = [None] * 2
inputs[0] = in_id
inputs[1] = self.add_immediate_int_vector(shape)
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
self.add_operation(NNAPI_OperationCode.RESHAPE, inputs, outputs)
def add_size(self, node):
assert node.inputsSize() == 2
assert node.outputsSize() == 1
_, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
_, value = self.constants[node.inputsAt(1)]
res = in_oper.shape[value]
output = node.outputsAt(0)
self.add_constant_value(output, output.type(), res)
def add_cat(self, node):
assert node.inputsSize() == 2
assert node.outputsSize() == 1
tensors = self.tensor_sequences[node.inputsAt(0)]
_, dim = self.get_constant_value(node.inputsAt(1), "IntType")
assert len(tensors) > 0
in_ids = []
out_oper = None
out_dim_size = 0
for inp in tensors:
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(inp)
if out_oper is None:
out_shape = change_element(in_oper.shape, dim, -1)
out_oper = in_oper._replace(shape=out_shape)
assert in_oper.op_type == out_oper.op_type
assert in_oper.dim_order == out_oper.dim_order
assert change_element(in_oper.shape, dim, -1) == change_element(out_oper.shape, dim, -1)
# TODO: Possibly check scale and zero point.
in_ids.append(in_id)
# TODO: Possibly support variable-sized inputs.
out_dim_size += in_oper.shape[dim]
out_oper = out_oper._replace(shape=change_element(out_oper.shape, dim, out_dim_size))
if in_oper.dim_order == DimOrder.CHANNELS_LAST:
assert len(out_oper.shape) == 4
nnapi_dim = [0, 3, 1, 2][dim]
else:
nnapi_dim = dim
inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)]
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs)
def add_mean(self, node):
assert node.inputsSize() == 4
assert node.outputsSize() == 1
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
dim_ctype, dim = self.get_constant_value(node.inputsAt(1))
assert dim_ctype.kind() == "ListType"
assert dim_ctype.getElementType().kind() == "IntType"
_, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType")
# Expect None for dtype
self.get_constant_value(node.inputsAt(3), "NoneType")
if in_oper.dim_order == DimOrder.CHANNELS_LAST:
assert len(in_oper.shape) == 4
nnapi_dim = [[0, 3, 1, 2][d] for d in dim]
else:
nnapi_dim = dim
collapsed_dims = set()
for d in dim:
if d < 0:
d += len(in_oper.shape)
collapsed_dims.add(d)
if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim:
assert collapsed_dims.issuperset({2, 3})
out_dim_order = DimOrder.PRESUMED_CONTIGUOUS
else:
out_dim_order = in_oper.dim_order
out_shape = []
for i, s in enumerate(in_oper.shape):
if i not in collapsed_dims:
out_shape.append(s)
elif keep_dim:
out_shape.append(1)
out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order)
inputs = [None] * 3
inputs[0] = in_id
inputs[1] = self.add_immediate_int_vector(nnapi_dim)
inputs[2] = self.add_immediate_int_scalar(keep_dim)
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs)
def add_quantize(self, node):
assert node.inputsSize() == 4
assert node.outputsSize() == 1
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
if in_oper.dim_order != DimOrder.CHANNELS_LAST:
raise Exception(
"Most hardware backends prefer NHWC quantized tensors. "
"Try setting `t.nnapi_nhwc = True` on your tensor inputs. ")
_, scale = self.get_constant_value(node.inputsAt(1), "FloatType")
_, zero_point = self.get_constant_value(node.inputsAt(2), "IntType")
_, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType")
if scalar_type != TorchScalarTypes.QUINT8.value:
raise Exception(
"PyTorch NNAPI export only supports quantized tensors "
"with the quint8 dtype.")
op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
out_oper = in_oper._replace(
op_type=op_type,
scale=scale,
zero_point=zero_point,
)
inputs = [None] * 1
inputs[0] = in_id
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs)
def add_dequantize(self, node):
assert node.inputsSize() == 1
assert node.outputsSize() == 1
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
out_oper = in_oper._replace(
op_type=NNAPI_OperandCode.TENSOR_FLOAT32,
scale=0.0,
zero_point=0,
)
inputs = [None] * 1
inputs[0] = in_id
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs)
def add_pointwise_simple_unary_op(self, node, opcode):
assert node.inputsSize() == 1
assert node.outputsSize() == 1
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
inputs = [None] * 1
inputs[0] = in_id
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper)
self.add_operation(opcode, inputs, outputs)
def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None):
"""Helper for pointwise binary broadcast ops with superfluous extra args"""
assert node.outputsSize() == 1
assert node.inputsAt(0).type().kind() == "TensorType"
assert node.inputsAt(1).type().kind() == "TensorType"
# TODO: Should support constant as either operand.
in0_id, in0_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
in1_id, in1_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(1))
assert in0_oper.op_type == in1_oper.op_type
in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast(
in0_id, in0_oper, in1_id, in1_oper)
# NOTE: PyTorch and NNAPI have the same broadcast semantics.
out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape)
out_oper = in0_oper._replace(shape=out_shape)
if qparams is not None:
scale, zp = qparams
out_oper = out_oper._replace(scale=scale, zero_point=zp)
inputs = [None] * 3
inputs[0] = in0_id
inputs[1] = in1_id
inputs[2] = self.add_immediate_int_scalar(fuse_code)
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
self.add_operation(opcode, inputs, outputs)
def add_pointwise_simple_binary_broadcast_op(self, node, opcode, fuse_code):
assert node.inputsSize() == 2
self._do_add_binary(node, opcode, fuse_code)
def add_add_sub_op(self, node, opcode, fuse_code):
assert node.inputsSize() == 3
_, alpha = self.get_constant_value(node.inputsAt(2), "IntType")
if alpha != 1:
raise Exception("NNAPI does not support add/sub with alpha.")
self._do_add_binary(node, opcode, fuse_code)
def add_qadd(self, node, opcode, fuse_code):
assert node.inputsSize() == 4
_, scale = self.get_constant_value(node.inputsAt(2), "FloatType")
_, zero_point = self.get_constant_value(node.inputsAt(3), "IntType")
self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point))
def add_hardtanh(self, node):
assert node.inputsSize() == 3
assert node.outputsSize() == 1
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
_, min_val = self.get_constant_value(node.inputsAt(1), "FloatType")
_, max_val = self.get_constant_value(node.inputsAt(2), "FloatType")
op_map = {
(-1, 1): NNAPI_OperationCode.RELU1,
( 0, 6): NNAPI_OperationCode.RELU6, # noqa: E201
}
opcode = op_map.get((min_val, max_val))
if opcode is None:
raise Exception("NNAPI only supports hardtanh with args (-1, 1) or (0, 6).")
inputs = [None] * 1
inputs[0] = in_id
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper)
self.add_operation(opcode, inputs, outputs)
def add_prelu_op(self, node):
assert node.inputsSize() == 2
assert node.outputsSize() == 1
assert node.inputsAt(0).type().kind() == "TensorType"
assert node.inputsAt(1).type().kind() == "TensorType"
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1))
assert len(w_oper.shape) == 1
assert w_oper.shape[0] > 0
if w_oper.shape[0] > 1:
if in_oper.use_nchw():
# TODO: Support this by adding trailing 1 dims.
raise Exception("Per-channel PReLU only supports channels_last right now.")
out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
for dim, size in enumerate(in_oper.shape):
if size > 0:
pass
elif dim <= 1:
raise Exception("PReLU requires fixed size for dim 0 and dim 1.")
else:
self.forward_operand_shape(out_id, dim, in_id, dim)
inputs = [None] * 2
inputs[0] = in_id
inputs[1] = w_id
outputs = [None] * 1
outputs[0] = out_id
self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs)
def add_pool2d_node(self, node, opcode):
assert node.inputsSize() == 6
assert node.outputsSize() == 1
image, kernel, stride, padding, dilation, ceil_mode = node.inputs()
stride = stride or kernel
# TODO: Validate ceil_mode semantics.
args = self.get_conv_pool_args_2d_from_jit(self.get_size_arg(kernel), stride, padding, dilation)
if args.dilation_h != 1 or args.dilation_w != 1:
raise Exception("NNAPI does not support dilated pooling.")
image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image)
assert len(image_oper.shape) == 4
out_shape = get_conv_pool_shape(image_oper.shape, args, image_oper.shape[1], False)
use_nchw = image_oper.use_nchw()
inputs = [None] * 11
inputs[0] = image_id
inputs[1] = self.add_immediate_int_scalar(args.pad_l)
inputs[2] = self.add_immediate_int_scalar(args.pad_r)
inputs[3] = self.add_immediate_int_scalar(args.pad_t)
inputs[4] = self.add_immediate_int_scalar(args.pad_b)
inputs[5] = self.add_immediate_int_scalar(args.stride_w)
inputs[6] = self.add_immediate_int_scalar(args.stride_h)
inputs[7] = self.add_immediate_int_scalar(args.kernel_w)
inputs[8] = self.add_immediate_int_scalar(args.kernel_h)
inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
inputs[10] = self.add_immediate_bool_scalar(use_nchw)
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape))
self.add_operation(opcode, inputs, outputs)
def add_adaptive_avg_pool2d(self, node):
assert node.inputsSize() == 2
assert node.outputsSize() == 1
image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
assert len(image_oper.shape) == 4
size_ctype, size_arg = self.get_constant_value(node.inputsAt(1))
assert size_ctype.kind() == "ListType"
assert size_ctype.getElementType().kind() == "IntType"
if size_arg != [1, 1]:
raise Exception("NNAPI only supports adaptive_avg_pool2d with output size (1, 1).")
out_shape = image_oper.shape[0:2] + tuple(size_arg)
use_nchw = image_oper.use_nchw()
inputs = [None] * 11
inputs[0] = image_id
inputs[1] = self.add_immediate_int_scalar(0)
inputs[2] = self.add_immediate_int_scalar(0)
inputs[3] = self.add_immediate_int_scalar(0)
inputs[4] = self.add_immediate_int_scalar(0)
inputs[5] = self.add_immediate_int_scalar(1)
inputs[6] = self.add_immediate_int_scalar(1)
inputs[7] = self.add_immediate_int_scalar(image_oper.shape[3])
inputs[8] = self.add_immediate_int_scalar(image_oper.shape[2])
inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
inputs[10] = self.add_immediate_bool_scalar(use_nchw)
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape))
self.add_operation(NNAPI_OperationCode.AVERAGE_POOL_2D, inputs, outputs)
def add_upsample_nearest2d(self, node):
assert node.inputsSize() == 3
assert node.outputsSize() == 1
image, size_jit, scale_jit = node.inputs()
size_ctype, size_arg = self.get_constant_value(size_jit)
scale_ctype, scale_arg = self.get_constant_value(scale_jit)
image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(image)
assert len(image_oper.shape) == 4
if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType":
raise Exception("Size and scale cannot both be non-None.")
elif size_ctype.kind() != "NoneType":
assert size_ctype.kind() == "ListType"
assert size_ctype.getElementType().kind() == "IntType"
assert scale_ctype.kind() == "NoneType"
assert scale_arg is None
assert isinstance(size_arg, list)
assert size_arg
assert all(isinstance(val, int) for val in size_arg)
if len(size_arg) == 1:
size_arg = size_arg * 2
assert len(size_arg) == 2
out_h = size_arg[0]
out_w = size_arg[1]
arg_h = self.add_immediate_int_scalar(out_h)
arg_w = self.add_immediate_int_scalar(out_w)
elif scale_ctype.kind() != "NoneType":
assert scale_ctype.kind() == "ListType"
assert scale_ctype.getElementType().kind() == "FloatType"
assert size_ctype.kind() == "NoneType"
assert size_arg is None
assert isinstance(scale_arg, list)
assert scale_arg
assert all(isinstance(val, float) for val in scale_arg)
if len(scale_arg) == 1:
scale_arg = scale_arg * 2
assert len(scale_arg) == 2
out_h = int(scale_arg[0] * image_oper.shape[2])
out_w = int(scale_arg[1] * image_oper.shape[3])
arg_h = self.add_immediate_float_scalar(scale_arg[0])
arg_w = self.add_immediate_float_scalar(scale_arg[1])
else:
raise Exception("Size and scale cannot both be None.")
out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w)
use_nchw = image_oper.use_nchw()
inputs = [None] * 4
inputs[0] = image_id
inputs[1] = arg_w
inputs[2] = arg_h
inputs[3] = self.add_immediate_bool_scalar(use_nchw)
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape))
self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs)
def add_addmm(self, node):
assert node.inputsSize() == 5
assert node.outputsSize() == 1
jit_bias, jit_input, jit_weight, jit_beta, jit_alpha = node.inputs()
for jitval in (jit_beta, jit_alpha):
scale_ctype, scale_value = self.get_constant_value(jitval)
assert scale_ctype.kind() in ("IntType", "FloatType")
if scale_value != 1:
raise Exception("NNAPI Fully-Connected does not support alpha and beta.")
self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias)
def add_linear(self, node):
assert node.inputsSize() == 3
assert node.outputsSize() == 1
jit_input, jit_weight, jit_bias = node.inputs()
self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias)
def add_addmm_or_linear(self, node, transpose_weight, jit_input, jit_weight, jit_bias):
input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias)
assert len(input_oper.shape) == 2
assert len(bias_oper.shape) == 1
# TODO: Transform at load time to share weights with CPU model.
_, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
assert len(weight_tensor.shape) == 2
if transpose_weight:
nnapi_weight_tensor = weight_tensor.t().contiguous()
else:
nnapi_weight_tensor = weight_tensor.contiguous()
weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
weight_oper = self.operands[weight_id]
out_shape = (input_oper.shape[0], weight_oper.shape[0])
inputs = [None] * 4
inputs[0] = input_id
inputs[1] = weight_id
inputs[2] = bias_id
inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), input_oper._replace(shape=out_shape))
self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
def add_qlinear(self, node):
assert node.inputsSize() == 4
assert node.outputsSize() == 1
(
jit_input,
jit_packed_weight,
jit_scale,
jit_zero_point,
) = node.inputs()
input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input)
# TODO: Support automatic reshape
assert len(input_oper.shape) == 2
_, out_scale = self.get_constant_value(jit_scale, "FloatType")
_, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
assert weight_ctype.name() == "LinearPackedParamsBase"
raw_weight, raw_bias = packed_weight.__getstate__()[0]
assert raw_bias is not None
assert len(raw_weight.shape) == 2
assert len(raw_bias.shape) == 1
assert raw_bias.shape[0] == raw_weight.shape[0]
assert raw_weight.shape[1] == input_oper.shape[1]
assert raw_weight.qscheme() == torch.per_tensor_affine
if raw_weight.dtype == torch.quint8:
unsigned_weight = raw_weight
else:
assert raw_weight.dtype == torch.qint8
unsigned_weight = torch._make_per_tensor_quantized_tensor(
(raw_weight.int_repr().int() + 128).to(torch.uint8),
scale=raw_weight.q_scale(),
zero_point=raw_weight.q_zero_point() + 128)
weight_scale = unsigned_weight.q_scale()
bias_scale = input_oper.scale * weight_scale
int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
bias_id = self.add_tensor_operand_for_weight(int_bias)
multiplier = input_oper.scale * weight_scale / out_scale
assert multiplier > 0
if multiplier >= 1:
raise Exception(
"Quantized convolution multiplier is greater than 1. "
"This is supported by NNAPI, but not by most hardware backends. "
"Try training a model without quantization-aware training. ")
# TODO: Transform at load time to share weights with CPU model.
nnapi_weight_tensor = unsigned_weight.contiguous()
weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
weight_oper = self.operands[weight_id]
out_shape = (input_oper.shape[0], weight_oper.shape[0])
out_oper = input_oper._replace(
shape=out_shape,
scale=out_scale,
zero_point=out_zero_point,
)
inputs = [None] * 4
inputs[0] = input_id
inputs[1] = weight_id
inputs[2] = bias_id
inputs[3] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE)
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs)
def get_optional_bias(self, jit_bias, weight_tensor):
ctype, value = self.get_constant_value(jit_bias)
if ctype.kind() == "NoneType":
nnapi_bias_tensor = torch.zeros(weight_tensor.size()[0], dtype=weight_tensor.dtype)
bias_id = self.add_tensor_operand_for_weight(nnapi_bias_tensor)
bias_oper = self.operands[bias_id]
return bias_id, bias_oper
else:
return self.get_tensor_operand_for_weight(jit_bias)
def add_conv2d(self, node):
assert node.inputsSize() == 7
assert node.outputsSize() == 1
(
jit_image,
jit_weight,
jit_bias,
jit_stride,
jit_pad,
jit_dilation,
jit_groups,
) = node.inputs()
_, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor)
args = self.get_conv_pool_args_2d_from_jit(
weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups)
return self.add_conv2d_common(
node.outputsAt(0),
0.0,
0,
jit_image,
weight_tensor,
bias_id,
args,
False, # transpose
NNAPI_FuseCode.FUSED_NONE,
)
def add_conv_underscore(self, node):
assert node.inputsSize() == 13
assert node.outputsSize() == 1
(
jit_image,
jit_weight,
jit_bias,
jit_stride,
jit_pad,
jit_dilation,
jit_transpose,
_,
jit_groups,
_,
_,
_,
_,
) = node.inputs()
# XXX check jit_transpose
_, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor)
args = self.get_conv_pool_args_2d_from_jit(
weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups)
return self.add_conv2d_common(
node.outputsAt(0),
0.0,
0,
jit_image,
weight_tensor,
bias_id,
args,
False, # transpose
NNAPI_FuseCode.FUSED_NONE,
)
def add_qconv2d(self, node, fuse_code):
assert node.inputsSize() == 4
assert node.outputsSize() == 1
(
jit_image,
jit_packed_weight,
jit_scale,
jit_zero_point,
) = node.inputs()
_, out_scale = self.get_constant_value(jit_scale, "FloatType")
_, out_zero_point = self.get_constant_value(jit_zero_point, "IntType")
weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight)
assert weight_ctype.name() == "Conv2dPackedParamsBase"
(
pack_version,
tensors,
opt_tensors,
) = packed_weight.__getstate__()[0]
assert pack_version == "2"
packed_config, raw_weight = tensors
raw_bias, = opt_tensors
assert raw_bias is not None
args = self.get_conv_pool_args_2d_from_pack(raw_weight.shape[2:4], packed_config)
assert raw_weight.qscheme() == torch.per_tensor_affine
if raw_weight.dtype == torch.quint8:
unsigned_weight = raw_weight
else:
assert raw_weight.dtype == torch.qint8
unsigned_weight = torch._make_per_tensor_quantized_tensor(
(raw_weight.int_repr().int() + 128).to(torch.uint8),
scale=raw_weight.q_scale(),
zero_point=raw_weight.q_zero_point() + 128)
weight_scale = unsigned_weight.q_scale()
_, image_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_image)
bias_scale = image_oper.scale * weight_scale
int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32)
bias_id = self.add_tensor_operand_for_weight(int_bias)
multiplier = image_oper.scale * weight_scale / out_scale
assert multiplier > 0
if multiplier >= 1:
raise Exception(
"Quantized convolution multiplier is greater than 1. "
"This is supported by NNAPI, but not by most hardware backends. "
"Try training a model without quantization-aware training. ")
return self.add_conv2d_common(
node.outputsAt(0),
out_scale,
out_zero_point,
jit_image,
unsigned_weight,
bias_id,
args,
False, # transpose
fuse_code,
)
def add_conv2d_common(
self,
jit_out,
out_scale,
out_zero_point,
jit_image,
weight_tensor,
bias_id,
args,
transpose,
fuse_code):
image_id, image_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_image)
in_c = image_oper.shape[1]
if args.group == 1:
# Full convolution
depthwise = False
weight_permutation = (0, 2, 3, 1)
elif args.group == in_c:
# Depthwise convolution
depthwise = True
weight_permutation = (1, 2, 3, 0)
else:
raise Exception("Group convolution not supported yet.")
# TODO: Transform at load time to share weights with CPU model.
nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous()
weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
weight_oper = self.operands[weight_id]
bias_oper = self.operands[bias_id]
if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32:
assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32
elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM
assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32
assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale)
assert bias_oper.zero_point == 0
else:
raise Exception(
"Unsupported input type for conv2d: {}"
.format(image_oper.op_type))
assert len(image_oper.shape) == 4
assert len(weight_oper.shape) == 4
assert len(bias_oper.shape) == 1
if depthwise:
# Depthwise convolution
one, kern_h, kern_w, out_c = weight_oper.shape
assert one == 1
assert out_c % in_c == 0
channel_multiplier = out_c // in_c
assert channel_multiplier == 1 # Don't support multiplier
assert out_c == in_c
else:
# Full convolution
kern_nf, kern_h, kern_w, kern_d = weight_oper.shape
out_c = kern_nf
assert kern_d == in_c
assert out_c == bias_oper.shape[0]
out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose)
out_oper = image_oper._replace(
shape=out_shape,
scale=out_scale,
zero_point=out_zero_point,
)
use_nchw = image_oper.use_nchw()
if depthwise:
num_args = 12
opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D
else:
num_args = 11
if transpose:
opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D
else:
opcode = NNAPI_OperationCode.CONV_2D
inputs = [None] * num_args
inputs[0] = image_id
inputs[1] = weight_id
inputs[2] = bias_id
inputs[3] = self.add_immediate_int_scalar(args.pad_l)
inputs[4] = self.add_immediate_int_scalar(args.pad_r)
inputs[5] = self.add_immediate_int_scalar(args.pad_t)
inputs[6] = self.add_immediate_int_scalar(args.pad_b)
inputs[7] = self.add_immediate_int_scalar(args.stride_w)
inputs[8] = self.add_immediate_int_scalar(args.stride_h)
if depthwise:
inputs[9] = self.add_immediate_int_scalar(1)
inputs[10] = self.add_immediate_int_scalar(fuse_code)
inputs[11] = self.add_immediate_bool_scalar(use_nchw)
else:
inputs[9] = self.add_immediate_int_scalar(fuse_code)
inputs[10] = self.add_immediate_bool_scalar(use_nchw)
outputs = [None] * 1
outputs[0] = self.add_tensor_operand(jit_out, out_oper)
self.add_operation(opcode, inputs, outputs)
def serialize_model(module, inputs, config=None):
return _NnapiSerializer(config).serialize_model(module, inputs)