diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 461123ce23d..bce33270687 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -44,6 +44,8 @@ from torch.onnx.symbolic_helper import _unimplemented from torch.onnx.utils import unpack_quantized_tensor +_ORT_PROVIDERS = ["CPUExecutionProvider"] + def flatten_tuples(elem): tup = [] for t in elem: @@ -99,7 +101,7 @@ def convert_to_onnx(model, input=None, opset_version=9, do_constant_folding=True # suppress ort warnings. # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. so.log_severity_level = 3 - ort_sess = onnxruntime.InferenceSession(f.getvalue(), so) + ort_sess = onnxruntime.InferenceSession(f.getvalue(), so, providers=_ORT_PROVIDERS) return ort_sess @@ -373,7 +375,9 @@ class TestONNXRuntime(unittest.TestCase): # suppress ort warnings. # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. ort_sess_opt.log_severity_level = 3 - ort_sess = onnxruntime.InferenceSession(model_file_name, sess_options=ort_sess_opt) + ort_sess = onnxruntime.InferenceSession(model_file_name, + sess_options=ort_sess_opt, + providers=_ORT_PROVIDERS) input_copy = copy.deepcopy(input) ort_outs = run_ort(ort_sess, input_copy) ort_compare_with_pytorch(ort_outs, output, rtol, atol) @@ -730,6 +734,48 @@ class TestONNXRuntime(unittest.TestCase): dynamic_axes={"input_images": {0: "batch_size"}, "output": {0: "batch_size"}}, rtol=1e-3, atol=1e-5) + @disableScriptTest() + def test_mobilenet_v3(self): + model = torchvision.models.quantization.mobilenet_v3_large(pretrained=False) + dummy_input = torch.randn(1, 3, 224, 224) + self.run_test(model, (dummy_input,)) + + @unittest.skip("Fixed in PyTorch master. RuntimeError: Error(s) in loading state_dict for QuantizableMobileNetV3") + @disableScriptTest() + def test_mobilenet_v3_quant(self): + model = torchvision.models.quantization.mobilenet_v3_large(pretrained=True, quantize=True) + from PIL import Image + from torchvision import transforms + data_dir = os.path.join(os.path.dirname(__file__), "assets") + path = os.path.join(data_dir, "grace_hopper_517x606.jpg") + input_image = Image.open(path) + # Based on example from https://pytorch.org/hub/pytorch_vision_resnet/ + preprocess = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + input_tensor = preprocess(input_image).unsqueeze(0) + + # Due to precision error from quantization, check only that the top prediction matches. + class TopPredictor(torch.nn.Module): + def __init__(self, mobilenet): + super().__init__() + self.mobilenet = mobilenet + + def forward(self, x): + x = self.mobilenet(x) + _, topk_catid = torch.topk(x[0], 1) + return topk_catid + + # Currently, we need convert the model to ScriptModule before export. + # The reason is that PackedParams contains int (not tensor). + # Then it fails when the exporter calls _trace_and_get_graph_from_model(). + # TODO: https://msdata.visualstudio.com/Vienna/_workitems/edit/1547858 + model = torch.jit.trace(TopPredictor(model), input_tensor) + self.run_test(model, (input_tensor, )) + @disableScriptTest() def test_word_language_model_RNN_TANH(self): self.run_word_language_model("RNN_TANH") @@ -10413,11 +10459,14 @@ class TestONNXRuntime(unittest.TestCase): loaded_model = onnx.load_from_string(f.getvalue()) self.assertEqual(loaded_model.graph.output[0].type.tensor_type.shape.dim[1].dim_value, 128) + # NOTE: For quantization tests, choose scale and zero point carefully + # such that inputs and outputs do not always overflow/underflow. + # Otherwise test results could be inaccurate. @skipIfUnsupportedMinOpsetVersion(10) def test_quantized_linear(self): - model = torch.nn.quantized.Linear(1, 2) - input = torch.rand(1, 1) - input_tensor = torch.quantize_per_tensor(input, 1, 0, torch.quint8) + model = torch.nn.quantized.Linear(4, 8) + input = torch.randn(4, 4) + input_tensor = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8) # Currently, we need convert the model to ScriptModule before export. # The reason is that PackedParams contains int (not tensor). # Then it fails when the exporter calls _trace_and_get_graph_from_model(). @@ -10425,6 +10474,100 @@ class TestONNXRuntime(unittest.TestCase): self.run_test(torch.jit.trace(model, input_tensor), (input_tensor,)) self.run_test(torch.jit.script(model), (input_tensor,)) + @skipIfUnsupportedMinOpsetVersion(10) + def test_quantized_conv2d(self): + model = torch.nn.quantized.Conv2d(16, 33, 3, stride=2) + # Manually initialize model weight and bias to random numbers. + # By default all zeros. + q_weight = torch.quantize_per_tensor(torch.randn(33, 16, 3, 3), 0.2, 0, torch.qint8) + bias = torch.randn(33) + model.set_weight_bias(q_weight, bias) + input = torch.randn(3, 16, 32, 32) + q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8) + self.run_test(torch.jit.trace(model, q_input), (q_input,)) + + @skipIfUnsupportedMinOpsetVersion(10) + def test_quantized_adaptive_avg_pool2d(self): + model = torch.nn.AdaptiveAvgPool2d((5, 7)) + input = torch.randn(4, 3, 10, 14) + q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8) + self.run_test(torch.jit.trace(model, q_input), (q_input,)) + + @skipIfUnsupportedMinOpsetVersion(10) + def test_quantized_conv2d_relu(self): + model = torch.nn.intrinsic.quantized.ConvReLU2d(16, 33, 3, stride=2) + # Manually initialize model weight and bias to random numbers. + # By default all zeros. + q_weight = torch.quantize_per_tensor(torch.randn(33, 16, 3, 3), 0.2, 0, torch.qint8) + bias = torch.randn(33) + model.set_weight_bias(q_weight, bias) + input = torch.randn(3, 16, 32, 32) + q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8) + self.run_test(torch.jit.trace(model, q_input), (q_input,)) + + @skipIfUnsupportedMinOpsetVersion(10) + def test_quantized_hardswish(self): + model = torch.nn.quantized.Hardswish(1, 0) + input = torch.randn(2, 6) + q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8) + self.run_test(torch.jit.trace(model, q_input), (q_input,)) + + @skipIfUnsupportedMinOpsetVersion(10) + def test_quantized_hardsigmoid(self): + model = torch.nn.Hardsigmoid() + input = torch.randn(2, 6) + q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8) + self.run_test(torch.jit.trace(model, q_input), (q_input,)) + + @skipIfUnsupportedMinOpsetVersion(10) + def test_quantized_flatten(self): + class FlattenModel(torch.nn.Module): + def forward(self, input): + return torch.flatten(input) + + x = torch.quantize_per_tensor(torch.randn(1, 2, 3, 4), 1, 0, torch.quint8) + self.run_test(torch.jit.trace(FlattenModel(), x), x) + + @skipIfUnsupportedMinOpsetVersion(10) + def test_quantized_arithmetic(self): + x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8) + y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8) + + class ArithmeticModel(torch.nn.Module): + def forward(self, x, y): + o = torch.nn.quantized.QFunctional().add(x, y) + o = torch.nn.quantized.QFunctional().mul(o, x) + return o + + self.run_test(torch.jit.trace(ArithmeticModel(), (x, y)), (x, y)) + + class ArithmeticModel2(torch.nn.Module): + def forward(self, x, y): + o = torch.ops.quantized.add(x, y, 0.4, 100) + o = torch.ops.quantized.mul(o, x, 0.4, 100) + return o + + self.run_test(torch.jit.trace(ArithmeticModel2(), (x, y)), (x, y)) + + @skipIfUnsupportedMinOpsetVersion(10) + def test_quantize_per_tensor(self): + class Module(torch.nn.Module): + def forward(self, x): + return (torch.quantize_per_tensor(x, 0.2, 0, torch.qint8), + torch.quantize_per_tensor(x, 0.2, 128, torch.quint8)) + + x = torch.randn(4, 6) + self.run_test(torch.jit.trace(Module(), x), x) + + @skipIfUnsupportedMinOpsetVersion(10) + def test_dequantize(self): + class Module(torch.nn.Module): + def forward(self, x): + return torch.dequantize(x) + + x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 0, torch.qint8) + self.run_test(torch.jit.trace(Module(), x), x) + def make_test(name, base, layer, bidirectional, initial_state, variable_length, dropout, script_test_min_opset_version, **extra_kwargs): diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 433a9d2cc75..03f950ce3e4 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -891,39 +891,6 @@ class TestUtilityFuns_opset9(_BaseTestCase): self.assertEqual(next(iter).kind(), "onnx::Constant") self.assertEqual(next(iter).kind(), "aten::cosine_similarity") - def test_quantized_fallthrough(self): - # Test Quantized op - class QModule(torch.nn.Module): - def __init__(self): - super(QModule, self).__init__() - self.quant1 = torch.ao.quantization.QuantStub() - self.dequant = torch.ao.quantization.DeQuantStub() - - def forward(self, x): - res = self.quant1(x) - return self.dequant(res) - - model = QModule() - torch.backends.quantized.engine = "qnnpack" - pt_inputs = (torch.randn(1, 2, 3, 4)) - model.qconfig = torch.ao.quantization.default_qconfig - q_model = torch.ao.quantization.prepare(model, inplace=False) - q_model = torch.ao.quantization.convert(q_model, inplace=False) - - q_model.eval() - - graph, _, __ = self._model_to_graph(q_model, pt_inputs, - operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, - input_names=["pt_inputs"], - dynamic_axes={"pt_inputs": [0, 1, 2, 3]}) - - iter = graph.nodes() - self.assertEqual(next(iter).kind(), "onnx::Constant") - self.assertEqual(next(iter).kind(), "onnx::Constant") - self.assertEqual(next(iter).kind(), "onnx::Constant") - self.assertEqual(next(iter).kind(), "aten::quantize_per_tensor") - self.assertEqual(next(iter).kind(), "aten::dequantize") - # prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11 @skipIfUnsupportedMaxOpsetVersion(10) def test_prim_fallthrough(self): diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index e5114d9f1e8..3bc0998972a 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -775,8 +775,10 @@ static void eraseTupleConstruct(Block* block) { for (auto* input : output_node->inputs()) { block->insertOutput(index + (input_index++), input); } + index += input_index; + } else { + index++; } - index++; } } diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 7c237cc8046..e385563921f 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -130,18 +130,21 @@ Node* CreateQuantizedBiasCaffe2( } std::vector CreateQuantizedWeights( - std::vector data, + int8_t* data, std::shared_ptr& graph, - std::vector shapes, + const std::vector& shapes, + const std::vector& strides, float scale, int64_t zero_point) { Node* const_node_1 = graph->create(prim::Constant); auto const_value = - at::from_blob(data.data(), c10::IntArrayRef(shapes), at::kFloat) + at::from_blob( + data, c10::IntArrayRef(shapes), c10::IntArrayRef(strides), at::kChar) .to(at::kCPU); - auto options = c10::TensorOptions().dtype(at::kFloat).device(at::kCPU); - at::Tensor const_value_copy = at::empty(c10::IntArrayRef(shapes), options); - const_value.copy_(const_value); + auto options = c10::TensorOptions().dtype(at::kChar).device(at::kCPU); + at::Tensor const_value_copy = at::empty_strided( + c10::IntArrayRef(shapes), c10::IntArrayRef(strides), options); + const_value_copy.copy_(const_value); const_node_1->t_(Symbol::attr("value"), const_value_copy); Node* const_node_2 = graph->create(prim::Constant); @@ -150,6 +153,7 @@ std::vector CreateQuantizedWeights( auto const_shape = at::from_blob(scale_v.data(), c10::IntArrayRef(scale_shapes), at::kFloat) .to(at::kCPU); + options = c10::TensorOptions().dtype(at::kFloat).device(at::kCPU); at::Tensor const_shape_copy = at::empty(c10::IntArrayRef(scale_shapes), options); const_shape_copy.copy_(const_shape); @@ -162,6 +166,7 @@ std::vector CreateQuantizedWeights( at::from_blob( zero_point_v.data(), c10::IntArrayRef(zero_shapes), at::kInt) .to(at::kCPU); + options = c10::TensorOptions().dtype(at::kInt).device(at::kCPU); at::Tensor const_zero_copy = at::empty(c10::IntArrayRef(zero_shapes), options); const_zero_copy.copy_(const_zero); @@ -402,9 +407,10 @@ void unpackQuantizedWeightsHelper( std::tie(unpacked_weight, bias) = op.call(packed_weight); } - // Permute weights std::vector wt_sizes = unpacked_weight.sizes().vec(); - if (unpacked_weight.ndimension() == 4) { + std::vector wt_strides = unpacked_weight.strides().vec(); + if (unpacked_weight.ndimension() == 4 && caffe2) { + // Permute weights unpacked_weight.permute({0, 2, 3, 1}); wt_sizes = { unpacked_weight.size(0), @@ -416,13 +422,13 @@ void unpackQuantizedWeightsHelper( // Remove packed_params qlinear_node->removeInput(1); - // Convert from int8 to uint8 int8_t* inp_data = reinterpret_cast(unpacked_weight.data_ptr()); - const int64_t weight_zp = unpacked_weight.q_zero_point() + 128; - const int64_t wt_numel = unpacked_weight.numel(); if (caffe2) { + // Convert from int8 to uint8 + const int64_t weight_zp = unpacked_weight.q_zero_point() + 128; + const int64_t wt_numel = unpacked_weight.numel(); // Create caffe2::Int8GivenTensorFill node std::ostringstream os; for (const auto i : c10::irange(wt_numel)) { @@ -434,27 +440,21 @@ void unpackQuantizedWeightsHelper( c2_weight->insertBefore(qlinear_node); qlinear_node->insertInput(1, c2_weight->output()); } else { - std::vector unpacked_weight_values; - unpacked_weight_values.reserve(unpacked_weight.numel()); - auto unpacked_weight_data = - reinterpret_cast(unpacked_weight.data_ptr()); - for (const auto i : c10::irange(unpacked_weight.numel())) { - unpacked_weight_values.push_back( - static_cast(unpacked_weight_data[i])); - } - std::vector c2_weight = CreateQuantizedWeights( - unpacked_weight_values, + std::vector unpacked_wt = CreateQuantizedWeights( + inp_data, graph, wt_sizes, + wt_strides, static_cast(unpacked_weight.q_scale()), - weight_zp); + unpacked_weight.q_zero_point()); graph->setInsertPoint(qlinear_node); - c2_weight[0]->insertBefore(qlinear_node); - qlinear_node->insertInput(1, c2_weight[0]->output()); - c2_weight[1]->insertBefore(qlinear_node); - qlinear_node->insertInput(2, c2_weight[1]->output()); - c2_weight[2]->insertBefore(qlinear_node); - qlinear_node->insertInput(3, c2_weight[2]->output()); + Node* quant_node = graph->create(prim::TupleConstruct); + for (auto* n : unpacked_wt) { + n->insertBefore(qlinear_node); + quant_node->addInput(n->output()); + } + quant_node->insertBefore(qlinear_node); + qlinear_node->insertInput(1, quant_node->output()); } // Add bias @@ -507,10 +507,8 @@ void unpackQuantizedWeightsHelper( CreateQuantizedBias(bias_values, graph, original_bias.sizes().vec()); bias->insertBefore(qlinear_node); // For quantized_linear inputs, the order is input, weight, bias, .... - // We unpack weight into 3 values. then it is - // input, weight_value, weight_scale, weight_zero_point, bias, ... - // Therefore bias is at location 4. - qlinear_node->insertInput(4, bias->output()); + // Therefore bias is at location 2. + qlinear_node->insertInput(2, bias->output()); } // add conv arguments: stride, padding, dilation, groups @@ -520,6 +518,7 @@ void unpackQuantizedWeightsHelper( conv_ints_args.push_back(stride); conv_ints_args.push_back(padding); conv_ints_args.push_back(dilation); + // skip (input, weight, bias) const size_t arg_offset = 3; for (const auto i : c10::irange(conv_ints_args.size())) { Node* ints_node = @@ -616,32 +615,35 @@ void UnpackQuantizedWeights( "quantized::linear_unpack", QuantizedParamsType::LINEAR, caffe2); - if (caffe2) { - unpackQuantizedWeightsHelper( - graph, - paramsDict, - qconv2d, - "quantized::conv2d_unpack", - QuantizedParamsType::CONV); - unpackQuantizedWeightsHelper( - graph, - paramsDict, - qconv2d_relu, - "quantized::conv2d_unpack", - QuantizedParamsType::CONV); - unpackQuantizedWeightsHelper( - graph, - paramsDict, - qconv3d, - "quantized::conv3d_unpack", - QuantizedParamsType::CONV); - unpackQuantizedWeightsHelper( - graph, - paramsDict, - qconv3d_relu, - "quantized::conv3d_unpack", - QuantizedParamsType::CONV); - } else { + unpackQuantizedWeightsHelper( + graph, + paramsDict, + qconv2d, + "quantized::conv2d_unpack", + QuantizedParamsType::CONV, + caffe2); + unpackQuantizedWeightsHelper( + graph, + paramsDict, + qconv2d_relu, + "quantized::conv2d_unpack", + QuantizedParamsType::CONV, + caffe2); + unpackQuantizedWeightsHelper( + graph, + paramsDict, + qconv3d, + "quantized::conv3d_unpack", + QuantizedParamsType::CONV, + caffe2); + unpackQuantizedWeightsHelper( + graph, + paramsDict, + qconv3d_relu, + "quantized::conv3d_unpack", + QuantizedParamsType::CONV, + caffe2); + if (!caffe2) { UnpackQuantizedTensorInputs(graph); } GRAPH_DUMP("After UnpackQuantizedWeights: ", graph); diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 9adc8397aae..0a2c07a5510 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -134,7 +134,8 @@ def _unpack_list(list_value): def _unpack_tuple(tuple_value): tuple_node = tuple_value.node() - assert tuple_node.kind() == "prim::TupleConstruct" + if tuple_node.kind() != "prim::TupleConstruct": + raise RuntimeError("ONNX symbolic expected node type `prim::TupleConstruct`, got `{}`".format(tuple_node)) return list(tuple_node.inputs()) # Check if list_value is output from prim::ListConstruct @@ -192,6 +193,74 @@ def parse_args(*arg_descriptors): return wrapper return decorator +def quantized_args(*arg_q_descriptors, scale=None, zero_point=None): + """A decorator which extends support for quantized version of the base operator. + Quantization is detected by examining the arguments that are annotated by + `arg_q_descriptors`. + If quantization is detected, the base operator symbolic function will be wrapped with + argument dequantization and output quantization. + Otherwise, only base symbolic function will be invoked. + + For example: + @quantized_args(True, False) + def foo(g, x, y): + return x + y + + is equivalent to + + def q_foo(g, x, y): + if is_quantized_tensor(x): + x = dequantize(x) + out = foo(g, x, y) + return quantize(out) + else: + return foo(g, x, y) + + Args: + arg_q_descriptors: list of bool, where each element represents if the + argument is QTensor for quantized version of this operator. + scale: float default None, quantized output scale. If None, derive from + the first quantized input scale. + zero_point: int default None, quantized output zero point. If None, + derive from the first quantized input zero point. + """ + def decorator(fn): + fn._scale = scale + fn._zero_point = zero_point + + @wraps(fn) + def wrapper(g, *args, **kwargs): + _scale = fn._scale + if _scale is not None: + _scale = g.op("Constant", value_t=torch.tensor(_scale)) + _zero_point = fn._zero_point + if _zero_point is not None: + _zero_point = g.op("Constant", value_t=torch.tensor(_zero_point)) + + # some args may be optional, so the length may be smaller + assert len(arg_q_descriptors) >= len(args) + desc_args = tuple(zip(arg_q_descriptors[:len(args)], args)) + # Run regular symbolic function if none of the argument is QTensor. + if not any((desc and arg.node().kind() == "prim::TupleConstruct") for desc, arg in desc_args): + return fn(g, *args, **kwargs) + + dequantized_args = [] + for desc, arg in desc_args: + if desc: + dequantized_arg, scale, zero_point = dequantize_helper(g, arg) + dequantized_args.append(dequantized_arg) + if _scale is None: + _scale = scale + if _zero_point is None: + _zero_point = zero_point + else: + dequantized_args.append(arg) + # TODO: only support single output + output = fn(g, *dequantized_args, **kwargs) + + return quantize_helper(g, output, _scale, _zero_point) + return wrapper + return decorator def _scalar(x): """Convert a scalar tensor into a Python value.""" @@ -826,6 +895,30 @@ def _handle_reduce_dim_none(g, self, op_name): return g.op(op_name, self, keepdims_i=1) return g.op(op_name, self, keepdims_i=0) +def dequantize_helper(g, qtensor, qdtype=None): + tensor, scale, zero_point = _unpack_tuple(qtensor) + input_qdtype = cast_pytorch_to_onnx[tensor.type().scalarType()] + if qdtype is None: + if input_qdtype is not None: + qdtype = input_qdtype + else: + qdtype = torch.onnx.TensorProtoDataType.UINT8 + value = g.op("Cast", tensor, to_i=qdtype) + scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT) + zero_point = g.op("Cast", zero_point, to_i=qdtype) + return g.op("DequantizeLinear", value, scale, zero_point), scale, zero_point + +def quantize_helper(g, tensor, scale, zero_point): + assert scale is not None + if scale.type().scalarType() != "Float": + scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT) + + assert zero_point is not None + if zero_point.type().scalarType() not in ("Byte", "Char"): + zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8) + output = g.op("QuantizeLinear", tensor, scale, zero_point) + return g.op("prim::TupleConstruct", output, scale, zero_point) + # --------------------------------------------------------------------- # ONNX operator version # --------------------------------------------------------------------- diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 0c15875eb58..a08d012a952 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -10,7 +10,7 @@ import torch.onnx.utils import torch.onnx.symbolic_helper as sym_help from torch.onnx.symbolic_helper import parse_args, _unimplemented import torch.onnx.symbolic_opset9 -from torch.onnx.symbolic_opset9 import linear +from torch.onnx.symbolic_opset9 import linear, conv2d, add, mul, hardswish, relu from sys import maxsize @@ -322,39 +322,75 @@ def isfinite(g, input): return __not_(g, __or_(g, inf_node, nan_node)) -# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export +def quantize_per_tensor(g, input, scale, zero_point, dtype): + dtype = sym_help._get_const(dtype, "i", "dtype") + zero_point = g.op("Cast", zero_point, to_i=sym_help.scalar_type_to_onnx[dtype]) + scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT) + return sym_help.quantize_helper(g, input, scale, zero_point) + + +def dequantize(g, input): + return sym_help.dequantize_helper(g, input)[0] + + class Quantized: + """ + https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export + + Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were introduced in opset version 10. + """ domain = "quantized" - # DequantizeLinear was added in opset version 10. @staticmethod - def linear(g, input_original, weight, weight_scale, weight_zero_point, bias, op_scale, op_zero_point): - input_value, input_scale, input_zero_point = sym_help._unpack_tuple(input_original) - # From https://pytorch.org/docs/master/generated/torch.nn.quantized.functional.linear.html - # input (Tensor) – Quantized input of type torch.quint8 - input_type_dq = torch.onnx.TensorProtoDataType.UINT8 - input_value = g.op("Cast", input_value, to_i=input_type_dq) - input_scale = g.op("Cast", input_scale, to_i=torch.onnx.TensorProtoDataType.FLOAT) - input_zero_point = g.op("Cast", input_zero_point, to_i=input_type_dq) - input = g.op("DequantizeLinear", input_value, input_scale, input_zero_point) - # weight (Tensor) – Quantized weight of type torch.qint8 - weight_type_dq = torch.onnx.TensorProtoDataType.INT8 - weight = g.op("Cast", weight, to_i=weight_type_dq) - weight_scale = g.op("Cast", weight_scale, to_i=torch.onnx.TensorProtoDataType.FLOAT) - weight_zero_point = g.op("Cast", weight_zero_point, to_i=weight_type_dq) - weight = g.op("DequantizeLinear", weight, weight_scale, weight_zero_point) - # bias (Tensor) – None or fp32 bias of type torch.float - bias = g.op("Cast", bias, to_i=torch.onnx.TensorProtoDataType.FLOAT) + def linear(g, q_input, q_weight, bias, op_scale, op_zero_point): + input, _, _ = sym_help.dequantize_helper(g, q_input) + weight, _, _ = sym_help.dequantize_helper(g, q_weight) + output = linear(g, input, weight, bias) - if op_scale is None: - op_scale = input_scale - elif op_scale.type().scalarType() != "Float": - op_scale = g.op("Cast", op_scale, to_i=sym_help.cast_pytorch_to_onnx["Float"]) + return sym_help.quantize_helper(g, output, op_scale, op_zero_point) - if op_zero_point is None: - op_zero_point = input_zero_point - elif op_zero_point.type().scalarType() != "Byte": - op_zero_point = g.op("Cast", op_zero_point, to_i=sym_help.cast_pytorch_to_onnx["Byte"]) - output = g.op("QuantizeLinear", output, op_scale, op_zero_point) - return g.op("prim::TupleConstruct", output, op_scale, op_zero_point) + @staticmethod + def add(g, x, y, op_scale, op_zero_point): + x, _, _ = sym_help.dequantize_helper(g, x) + y, _, _ = sym_help.dequantize_helper(g, y) + + output = add(g, x, y) + + return sym_help.quantize_helper(g, output, op_scale, op_zero_point) + + @staticmethod + def mul(g, x, y, op_scale, op_zero_point): + x, _, _ = sym_help.dequantize_helper(g, x) + y, _, _ = sym_help.dequantize_helper(g, y) + + output = mul(g, x, y) + + return sym_help.quantize_helper(g, output, op_scale, op_zero_point) + + @staticmethod + def hardswish(g, x, op_scale, op_zero_point): + x, _, _ = sym_help.dequantize_helper(g, x) + + output = hardswish(g, x) + + return sym_help.quantize_helper(g, output, op_scale, op_zero_point) + + @staticmethod + def conv2d_relu(g, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point): + input, _, _ = sym_help.dequantize_helper(g, q_input) + weight, _, _ = sym_help.dequantize_helper(g, q_weight) + + output = conv2d(g, input, weight, bias, stride, padding, dilation, groups) + output = relu(g, output) + + return sym_help.quantize_helper(g, output, op_scale, op_zero_point) + + @staticmethod + def conv2d(g, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point): + input, _, _ = sym_help.dequantize_helper(g, q_input) + weight, _, _ = sym_help.dequantize_helper(g, q_weight) + + output = conv2d(g, input, weight, bias, stride, padding, dilation, groups) + + return sym_help.quantize_helper(g, output, op_scale, op_zero_point) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 1d85090cec7..7e9471528d7 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -7,7 +7,7 @@ import torch import torch.onnx.symbolic_helper as sym_help import warnings -from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list, ScalarType +from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list, ScalarType, quantized_args from torch.onnx.symbolic_opset9 import expand, unused, mul from torch.onnx.symbolic_opset9 import linalg_vector_norm as lvn from torch.nn.modules.utils import _single, _pair, _triple @@ -791,7 +791,7 @@ def narrow(g, input, dim, start, length): end = g.op("Add", start, length) return _slice_helper(g, input, axes=dim, starts=start, ends=end, dynamic_slice=True) - +@quantized_args(True, False, False) @parse_args("v", "i", "i") def flatten(g, input, start_dim, end_dim): dim = sym_help._get_tensor_rank(input) diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index d4775b553da..f1ef80f1221 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -52,3 +52,18 @@ def batch_norm(g, input, weight, bias, running_mean, running_var, training, mome new_running_mean.setType(running_mean.type()) new_running_var.setType(running_var.type()) return res + + +class Quantized: + """ + https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export + """ + domain = "quantized" + + @staticmethod + def hardswish(g, x, op_scale, op_zero_point): + x, _, _ = sym_help.dequantize_helper(g, x) + + output = hardswish(g, x) + + return sym_help.quantize_helper(g, output, op_scale, op_zero_point) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 6c05a8c56a4..83780de9c7b 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -13,7 +13,7 @@ from functools import partial from functools import wraps import torch.onnx.symbolic_helper as sym_help -from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented, ScalarType +from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented, ScalarType, quantized_args from typing import Optional from sys import maxsize as maxsize @@ -949,6 +949,7 @@ avg_pool3d = _avg_pool("avg_pool3d", _triple) def _adaptive_pool(name, type, tuple_fn, fn=None): + @quantized_args(True, False) def symbolic_fn(g, input, output_size): # _adaptive_pool is supported for cases where output_size is 1 for all dimensions, # by executing a GlobalPool. @@ -1957,6 +1958,8 @@ def hardswish(g, self): return g.op("Mul", self, hs) +# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp +@quantized_args(True, scale=1.0 / 256.0, zero_point=0) @parse_args("v") def hardsigmoid(g, self): # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid. @@ -2570,6 +2573,7 @@ def erf(g, input): return g.op("Erf", input) +@quantized_args(True, False, False) @parse_args("v", "i", "i") def flatten(g, input, start_dim, end_dim): dim = sym_help._get_tensor_rank(input) diff --git a/torch/onnx/symbolic_registry.py b/torch/onnx/symbolic_registry.py index 2e577222a89..b10da3c83e1 100644 --- a/torch/onnx/symbolic_registry.py +++ b/torch/onnx/symbolic_registry.py @@ -131,14 +131,14 @@ class UnsupportedOperatorError(RuntimeError): def __init__(self, domain, opname, version): supported_version = get_op_supported_version(opname, domain, version) if domain in ["", "aten", "prim", "quantized"]: - msg = "Exporting the operator " + opname + " to ONNX opset version " + str(version) + " is not supported. " + msg = f"Exporting the operator {domain}::{opname} to ONNX opset version {version} is not supported. " if supported_version is not None: - msg += "Support for this operator was added in version " + str(supported_version) + \ - ", try exporting with this version." + msg += (f"Support for this operator was added in version {supported_version}, " + "try exporting with this version.") else: msg += "Please feel free to request support or submit a pull request on PyTorch GitHub." else: - msg = ("ONNX export failed on an operator with unrecognized namespace {}::{}. " + msg = (f"ONNX export failed on an operator with unrecognized namespace {domain}::{opname}. " "If you are trying to export a custom operator, make sure you registered " - "it with the right domain and version.".format(domain, opname)) + "it with the right domain and version.") super().__init__(msg)