mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Cleaning up onnx module imports to prepare for updating `__init__`.
- Simplify importing the `_C` and `_C._onnx` name spaces
- Remove alias of the symbolic_helper module in imports
- Remove any module level function imports. Import modules instead
- Alias `symbilic_opsetx` as `opsetx`
- Fix some docstrings
Requires:
- https://github.com/pytorch/pytorch/pull/77448
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77423
Approved by: https://github.com/BowenBao
319 lines
9.8 KiB
Python
319 lines
9.8 KiB
Python
import importlib
|
|
from inspect import getmembers, isfunction
|
|
|
|
from torch.onnx import symbolic_helper
|
|
from torch.onnx import symbolic_opset9 as opset9
|
|
from torch.onnx import symbolic_registry
|
|
|
|
|
|
def register_quantized_ops(domain: str, version: int):
|
|
# Register all the non-quantized ops
|
|
symbolic_registry.register_version("", version)
|
|
# Register all quantized ops
|
|
module = importlib.import_module("torch.onnx.symbolic_caffe2")
|
|
symbolic_registry._symbolic_versions["caffe2"] = module
|
|
quant_version_ops = getmembers(symbolic_registry._symbolic_versions["caffe2"])
|
|
for op in quant_version_ops:
|
|
if isfunction(op[1]) and not symbolic_registry.is_registered_op(
|
|
op[0], domain, version
|
|
):
|
|
aten_q_ops = [
|
|
"relu",
|
|
"_empty_affine_quantized",
|
|
"dequantize",
|
|
"quantize_per_tensor",
|
|
"upsample_nearest2d",
|
|
"avg_pool2d",
|
|
"reshape",
|
|
"slice",
|
|
"cat",
|
|
"max_pool2d",
|
|
"sigmoid",
|
|
]
|
|
if op[0] in aten_q_ops:
|
|
symbolic_registry.register_op(op[0], op[1], "", version)
|
|
symbolic_registry.register_op(op[0], op[1], domain, version)
|
|
|
|
|
|
def _permute_helper(g, input, axes):
|
|
quant_args = {
|
|
"axes_i": axes,
|
|
"Y_scale_f": input.node()["Y_scale"],
|
|
"Y_zero_point_i": input.node()["Y_zero_point"],
|
|
}
|
|
output = g.op("_caffe2::Int8Transpose", input, **quant_args)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
def nchw2nhwc(g, input):
|
|
axes = [0, 2, 3, 1]
|
|
return _permute_helper(g, input, axes)
|
|
|
|
|
|
def nhwc2nchw(g, input):
|
|
axes = [0, 3, 1, 2]
|
|
return _permute_helper(g, input, axes)
|
|
|
|
|
|
def linear_prepack(g, weight, bias):
|
|
# Mapping to a dummy caffe2 prepack node.
|
|
# During the onnx -> c2 conversion we can look up original weight and bias
|
|
# from this node
|
|
output = g.op("_caffe2::WeightPrepack", weight, bias)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "v", "f", "i")
|
|
def linear(g, input, weight, bias, scale, zero_point):
|
|
kwargs = {
|
|
"Y_scale_f": scale,
|
|
"Y_zero_point_i": zero_point,
|
|
}
|
|
output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
def conv_prepack(g, input, weight, bias, stride, padding, dilation, groups):
|
|
# Mapping to a dummy caffe2 prepack node.
|
|
# During the onnx -> c2 conversion we can look up original weight and bias
|
|
# from this node
|
|
output = g.op("_caffe2::WeightPrepack", input, weight, bias)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i")
|
|
def conv2d(
|
|
g, input, weight, bias, stride, padding, dilation, groups, scale, zero_point
|
|
):
|
|
kernel_size = weight.node()["shape"][1:3]
|
|
kwargs = {
|
|
"strides_i": stride,
|
|
"pads_i": padding + padding,
|
|
"dilations_i": dilation,
|
|
"group_i": groups,
|
|
"kernels_i": kernel_size,
|
|
"order_s": "NHWC",
|
|
"Y_scale_f": scale,
|
|
"Y_zero_point_i": zero_point,
|
|
}
|
|
output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i")
|
|
def conv2d_relu(
|
|
g, input, weight, bias, stride, padding, dilation, groups, scale, zero_point
|
|
):
|
|
kernel_size = weight.node()["shape"][1:3]
|
|
kwargs = {
|
|
"strides_i": stride,
|
|
"pads_i": padding + padding,
|
|
"dilations_i": dilation,
|
|
"group_i": groups,
|
|
"kernels_i": kernel_size,
|
|
"order_s": "NHWC",
|
|
"Y_scale_f": scale,
|
|
"Y_zero_point_i": zero_point,
|
|
}
|
|
output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "f", "i")
|
|
def add(g, input_a, input_b, scale, zero_point):
|
|
kwargs = {
|
|
"Y_scale_f": scale,
|
|
"Y_zero_point_i": zero_point,
|
|
}
|
|
output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
@symbolic_helper.parse_args("v")
|
|
def relu(g, input):
|
|
if input not in symbolic_helper._quantized_ops:
|
|
return opset9.relu(g, input)
|
|
kwargs = {
|
|
"Y_scale_f": input.node()["Y_scale"],
|
|
"Y_zero_point_i": input.node()["Y_zero_point"],
|
|
}
|
|
output = g.op("_caffe2::Int8Relu", input, **kwargs)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
@symbolic_helper.parse_args("v", "f", "i", "t")
|
|
def quantize_per_tensor(g, input, scale, zero_point, dtype):
|
|
kwargs = {
|
|
"Y_scale_f": scale,
|
|
"Y_zero_point_i": zero_point,
|
|
}
|
|
output = g.op("_caffe2::Int8Quantize", input, **kwargs)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
@symbolic_helper.parse_args("v")
|
|
def dequantize(g, input):
|
|
return g.op("_caffe2::Int8Dequantize", input)
|
|
|
|
|
|
@symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t")
|
|
def _empty_affine_quantized(
|
|
g, input, shape, scale, zero_point, dtype, pin_memory, memory_format, layout
|
|
):
|
|
return input
|
|
|
|
|
|
def upsample_nearest2d(
|
|
g, input, output_size, align_corners=None, scales_h=None, scales_w=None
|
|
):
|
|
if input not in symbolic_helper._quantized_ops:
|
|
return opset9.upsample_nearest2d(g, input, output_size, align_corners)
|
|
|
|
output_size = symbolic_helper._parse_arg(output_size, "is")
|
|
kwargs = {
|
|
"output_size_i": output_size,
|
|
"Y_scale_f": input.node()["Y_scale"],
|
|
"Y_zero_point_i": input.node()["Y_zero_point"],
|
|
}
|
|
input = nchw2nhwc(g, input)
|
|
output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs)
|
|
output = nhwc2nchw(g, output)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
|
|
def max_pool2d(g, input, kernel_size, stride, padding, dilation, ceil_mode):
|
|
if input not in symbolic_helper._quantized_ops:
|
|
return opset9.max_pool2d(
|
|
g, input, kernel_size, stride, padding, dilation, ceil_mode
|
|
)
|
|
kwargs = {
|
|
"strides_i": stride,
|
|
"pads_i": padding + padding,
|
|
"kernel_i": kernel_size[0],
|
|
"order_s": "NHWC",
|
|
"Y_scale_f": input.node()["Y_scale"],
|
|
"Y_zero_point_i": input.node()["Y_zero_point"],
|
|
}
|
|
input = nchw2nhwc(g, input)
|
|
output = g.op("_caffe2::Int8MaxPool", input, **kwargs)
|
|
output = nhwc2nchw(g, output)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
|
|
def avg_pool2d(
|
|
g,
|
|
input,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
ceil_mode,
|
|
count_include_pad,
|
|
divisor_override=None,
|
|
):
|
|
if input not in symbolic_helper._quantized_ops:
|
|
return opset9.avg_pool2d(
|
|
g,
|
|
input,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
ceil_mode,
|
|
count_include_pad,
|
|
divisor_override,
|
|
)
|
|
kwargs = {
|
|
"strides_i": stride,
|
|
"pads_i": padding + padding,
|
|
"kernel_i": kernel_size[0],
|
|
"order_s": "NHWC",
|
|
"Y_scale_f": input.node()["Y_scale"],
|
|
"Y_zero_point_i": input.node()["Y_zero_point"],
|
|
}
|
|
input = nchw2nhwc(g, input)
|
|
output = g.op("_caffe2::Int8AveragePool", input, **kwargs)
|
|
output = nhwc2nchw(g, output)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
def reshape(g, input, shape):
|
|
if input not in symbolic_helper._quantized_ops:
|
|
return opset9.reshape(g, input, shape)
|
|
|
|
kwargs = {
|
|
"Y_scale_f": input.node()["Y_scale"],
|
|
"Y_zero_point_i": input.node()["Y_zero_point"],
|
|
}
|
|
output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
@symbolic_helper.parse_args("v", "v", "v", "v", "i")
|
|
def slice(g, input, dim, start, end, step):
|
|
if input not in symbolic_helper._quantized_ops:
|
|
return opset9.slice(g, input, dim, start, end, step)
|
|
|
|
if step != 1:
|
|
raise RuntimeError("ONNX quantized slice export only works for step 1.")
|
|
start = symbolic_helper._parse_arg(start, "i")
|
|
end = symbolic_helper._parse_arg(end, "i")
|
|
dim = symbolic_helper._parse_arg(dim, "i")
|
|
|
|
kwargs = {
|
|
"start_idx_i": start,
|
|
"end_idx_i": end,
|
|
"dim_i": dim,
|
|
"Y_scale_f": input.node()["Y_scale"],
|
|
"Y_zero_point_i": input.node()["Y_zero_point"],
|
|
}
|
|
output = g.op("_caffe2::Int8Slice", input, **kwargs)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
def cat(g, tensor_list, dim, scale=None, zero_point=None):
|
|
tensors = symbolic_helper._unpack_list(tensor_list)
|
|
input = tensors[0]
|
|
if input not in symbolic_helper._quantized_ops:
|
|
return opset9.cat(g, tensor_list, dim)
|
|
|
|
dim = symbolic_helper._parse_arg(dim, "i")
|
|
kwargs = {
|
|
"Y_scale_f": tensors[0].node()["Y_scale"],
|
|
"Y_zero_point_i": tensors[0].node()["Y_zero_point"],
|
|
}
|
|
output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|
|
|
|
|
|
@symbolic_helper.parse_args("v")
|
|
def sigmoid(g, input):
|
|
if input not in symbolic_helper._quantized_ops:
|
|
return opset9.sigmoid(g, input)
|
|
# Caffe2 expects the output scale to be 1/2^8
|
|
# and output zero_point to be 0 (quint8 type)
|
|
out_scale = 1.0 / 256
|
|
zero_point = 0
|
|
kwargs = {
|
|
"Y_scale_f": out_scale,
|
|
"Y_zero_point_i": zero_point,
|
|
}
|
|
output = g.op("_caffe2::Int8Sigmoid", input, **kwargs)
|
|
symbolic_helper._quantized_ops.add(output)
|
|
return output
|