[ONNX] Add infra for quantized model export and support quantized mobilenet v3 (#72215)

* Add infrastructure and helper functions to enable future work for other quantized operators and models.
* Add export for quantized operators needed by torchvision mobilenet v3 large.
    * ATen namespace: hardsigmoid, flatten, adaptive_avg_pool, quantize_per_tensor, dequantize.
    * Quantized namespace: conv2d, conv2d_relu, hardswish, add, mul.
* Numerous bug fixes, in unpack_quantized_weight.cpp, symbolic functions, and unit test.

Co-authored-by: BowenBao <bowbaomicrosoft.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73102
This commit is contained in:
BowenBao 2022-02-22 14:55:18 -08:00 committed by PyTorch MergeBot
parent 785ebb9d6d
commit 40de6b80ee
10 changed files with 398 additions and 136 deletions

View File

@ -44,6 +44,8 @@ from torch.onnx.symbolic_helper import _unimplemented
from torch.onnx.utils import unpack_quantized_tensor from torch.onnx.utils import unpack_quantized_tensor
_ORT_PROVIDERS = ["CPUExecutionProvider"]
def flatten_tuples(elem): def flatten_tuples(elem):
tup = [] tup = []
for t in elem: for t in elem:
@ -99,7 +101,7 @@ def convert_to_onnx(model, input=None, opset_version=9, do_constant_folding=True
# suppress ort warnings. # suppress ort warnings.
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
so.log_severity_level = 3 so.log_severity_level = 3
ort_sess = onnxruntime.InferenceSession(f.getvalue(), so) ort_sess = onnxruntime.InferenceSession(f.getvalue(), so, providers=_ORT_PROVIDERS)
return ort_sess return ort_sess
@ -373,7 +375,9 @@ class TestONNXRuntime(unittest.TestCase):
# suppress ort warnings. # suppress ort warnings.
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
ort_sess_opt.log_severity_level = 3 ort_sess_opt.log_severity_level = 3
ort_sess = onnxruntime.InferenceSession(model_file_name, sess_options=ort_sess_opt) ort_sess = onnxruntime.InferenceSession(model_file_name,
sess_options=ort_sess_opt,
providers=_ORT_PROVIDERS)
input_copy = copy.deepcopy(input) input_copy = copy.deepcopy(input)
ort_outs = run_ort(ort_sess, input_copy) ort_outs = run_ort(ort_sess, input_copy)
ort_compare_with_pytorch(ort_outs, output, rtol, atol) ort_compare_with_pytorch(ort_outs, output, rtol, atol)
@ -730,6 +734,48 @@ class TestONNXRuntime(unittest.TestCase):
dynamic_axes={"input_images": {0: "batch_size"}, "output": {0: "batch_size"}}, dynamic_axes={"input_images": {0: "batch_size"}, "output": {0: "batch_size"}},
rtol=1e-3, atol=1e-5) rtol=1e-3, atol=1e-5)
@disableScriptTest()
def test_mobilenet_v3(self):
model = torchvision.models.quantization.mobilenet_v3_large(pretrained=False)
dummy_input = torch.randn(1, 3, 224, 224)
self.run_test(model, (dummy_input,))
@unittest.skip("Fixed in PyTorch master. RuntimeError: Error(s) in loading state_dict for QuantizableMobileNetV3")
@disableScriptTest()
def test_mobilenet_v3_quant(self):
model = torchvision.models.quantization.mobilenet_v3_large(pretrained=True, quantize=True)
from PIL import Image
from torchvision import transforms
data_dir = os.path.join(os.path.dirname(__file__), "assets")
path = os.path.join(data_dir, "grace_hopper_517x606.jpg")
input_image = Image.open(path)
# Based on example from https://pytorch.org/hub/pytorch_vision_resnet/
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image).unsqueeze(0)
# Due to precision error from quantization, check only that the top prediction matches.
class TopPredictor(torch.nn.Module):
def __init__(self, mobilenet):
super().__init__()
self.mobilenet = mobilenet
def forward(self, x):
x = self.mobilenet(x)
_, topk_catid = torch.topk(x[0], 1)
return topk_catid
# Currently, we need convert the model to ScriptModule before export.
# The reason is that PackedParams contains int (not tensor).
# Then it fails when the exporter calls _trace_and_get_graph_from_model().
# TODO: https://msdata.visualstudio.com/Vienna/_workitems/edit/1547858
model = torch.jit.trace(TopPredictor(model), input_tensor)
self.run_test(model, (input_tensor, ))
@disableScriptTest() @disableScriptTest()
def test_word_language_model_RNN_TANH(self): def test_word_language_model_RNN_TANH(self):
self.run_word_language_model("RNN_TANH") self.run_word_language_model("RNN_TANH")
@ -10413,11 +10459,14 @@ class TestONNXRuntime(unittest.TestCase):
loaded_model = onnx.load_from_string(f.getvalue()) loaded_model = onnx.load_from_string(f.getvalue())
self.assertEqual(loaded_model.graph.output[0].type.tensor_type.shape.dim[1].dim_value, 128) self.assertEqual(loaded_model.graph.output[0].type.tensor_type.shape.dim[1].dim_value, 128)
# NOTE: For quantization tests, choose scale and zero point carefully
# such that inputs and outputs do not always overflow/underflow.
# Otherwise test results could be inaccurate.
@skipIfUnsupportedMinOpsetVersion(10) @skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_linear(self): def test_quantized_linear(self):
model = torch.nn.quantized.Linear(1, 2) model = torch.nn.quantized.Linear(4, 8)
input = torch.rand(1, 1) input = torch.randn(4, 4)
input_tensor = torch.quantize_per_tensor(input, 1, 0, torch.quint8) input_tensor = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
# Currently, we need convert the model to ScriptModule before export. # Currently, we need convert the model to ScriptModule before export.
# The reason is that PackedParams contains int (not tensor). # The reason is that PackedParams contains int (not tensor).
# Then it fails when the exporter calls _trace_and_get_graph_from_model(). # Then it fails when the exporter calls _trace_and_get_graph_from_model().
@ -10425,6 +10474,100 @@ class TestONNXRuntime(unittest.TestCase):
self.run_test(torch.jit.trace(model, input_tensor), (input_tensor,)) self.run_test(torch.jit.trace(model, input_tensor), (input_tensor,))
self.run_test(torch.jit.script(model), (input_tensor,)) self.run_test(torch.jit.script(model), (input_tensor,))
@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_conv2d(self):
model = torch.nn.quantized.Conv2d(16, 33, 3, stride=2)
# Manually initialize model weight and bias to random numbers.
# By default all zeros.
q_weight = torch.quantize_per_tensor(torch.randn(33, 16, 3, 3), 0.2, 0, torch.qint8)
bias = torch.randn(33)
model.set_weight_bias(q_weight, bias)
input = torch.randn(3, 16, 32, 32)
q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8)
self.run_test(torch.jit.trace(model, q_input), (q_input,))
@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_adaptive_avg_pool2d(self):
model = torch.nn.AdaptiveAvgPool2d((5, 7))
input = torch.randn(4, 3, 10, 14)
q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8)
self.run_test(torch.jit.trace(model, q_input), (q_input,))
@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_conv2d_relu(self):
model = torch.nn.intrinsic.quantized.ConvReLU2d(16, 33, 3, stride=2)
# Manually initialize model weight and bias to random numbers.
# By default all zeros.
q_weight = torch.quantize_per_tensor(torch.randn(33, 16, 3, 3), 0.2, 0, torch.qint8)
bias = torch.randn(33)
model.set_weight_bias(q_weight, bias)
input = torch.randn(3, 16, 32, 32)
q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8)
self.run_test(torch.jit.trace(model, q_input), (q_input,))
@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_hardswish(self):
model = torch.nn.quantized.Hardswish(1, 0)
input = torch.randn(2, 6)
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
self.run_test(torch.jit.trace(model, q_input), (q_input,))
@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_hardsigmoid(self):
model = torch.nn.Hardsigmoid()
input = torch.randn(2, 6)
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
self.run_test(torch.jit.trace(model, q_input), (q_input,))
@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_flatten(self):
class FlattenModel(torch.nn.Module):
def forward(self, input):
return torch.flatten(input)
x = torch.quantize_per_tensor(torch.randn(1, 2, 3, 4), 1, 0, torch.quint8)
self.run_test(torch.jit.trace(FlattenModel(), x), x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_arithmetic(self):
x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
class ArithmeticModel(torch.nn.Module):
def forward(self, x, y):
o = torch.nn.quantized.QFunctional().add(x, y)
o = torch.nn.quantized.QFunctional().mul(o, x)
return o
self.run_test(torch.jit.trace(ArithmeticModel(), (x, y)), (x, y))
class ArithmeticModel2(torch.nn.Module):
def forward(self, x, y):
o = torch.ops.quantized.add(x, y, 0.4, 100)
o = torch.ops.quantized.mul(o, x, 0.4, 100)
return o
self.run_test(torch.jit.trace(ArithmeticModel2(), (x, y)), (x, y))
@skipIfUnsupportedMinOpsetVersion(10)
def test_quantize_per_tensor(self):
class Module(torch.nn.Module):
def forward(self, x):
return (torch.quantize_per_tensor(x, 0.2, 0, torch.qint8),
torch.quantize_per_tensor(x, 0.2, 128, torch.quint8))
x = torch.randn(4, 6)
self.run_test(torch.jit.trace(Module(), x), x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_dequantize(self):
class Module(torch.nn.Module):
def forward(self, x):
return torch.dequantize(x)
x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 0, torch.qint8)
self.run_test(torch.jit.trace(Module(), x), x)
def make_test(name, base, layer, bidirectional, initial_state, def make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout, script_test_min_opset_version, variable_length, dropout, script_test_min_opset_version,
**extra_kwargs): **extra_kwargs):

View File

@ -891,39 +891,6 @@ class TestUtilityFuns_opset9(_BaseTestCase):
self.assertEqual(next(iter).kind(), "onnx::Constant") self.assertEqual(next(iter).kind(), "onnx::Constant")
self.assertEqual(next(iter).kind(), "aten::cosine_similarity") self.assertEqual(next(iter).kind(), "aten::cosine_similarity")
def test_quantized_fallthrough(self):
# Test Quantized op
class QModule(torch.nn.Module):
def __init__(self):
super(QModule, self).__init__()
self.quant1 = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
res = self.quant1(x)
return self.dequant(res)
model = QModule()
torch.backends.quantized.engine = "qnnpack"
pt_inputs = (torch.randn(1, 2, 3, 4))
model.qconfig = torch.ao.quantization.default_qconfig
q_model = torch.ao.quantization.prepare(model, inplace=False)
q_model = torch.ao.quantization.convert(q_model, inplace=False)
q_model.eval()
graph, _, __ = self._model_to_graph(q_model, pt_inputs,
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
input_names=["pt_inputs"],
dynamic_axes={"pt_inputs": [0, 1, 2, 3]})
iter = graph.nodes()
self.assertEqual(next(iter).kind(), "onnx::Constant")
self.assertEqual(next(iter).kind(), "onnx::Constant")
self.assertEqual(next(iter).kind(), "onnx::Constant")
self.assertEqual(next(iter).kind(), "aten::quantize_per_tensor")
self.assertEqual(next(iter).kind(), "aten::dequantize")
# prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11 # prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11
@skipIfUnsupportedMaxOpsetVersion(10) @skipIfUnsupportedMaxOpsetVersion(10)
def test_prim_fallthrough(self): def test_prim_fallthrough(self):

View File

@ -775,8 +775,10 @@ static void eraseTupleConstruct(Block* block) {
for (auto* input : output_node->inputs()) { for (auto* input : output_node->inputs()) {
block->insertOutput(index + (input_index++), input); block->insertOutput(index + (input_index++), input);
} }
index += input_index;
} else {
index++;
} }
index++;
} }
} }

View File

@ -130,18 +130,21 @@ Node* CreateQuantizedBiasCaffe2(
} }
std::vector<Node*> CreateQuantizedWeights( std::vector<Node*> CreateQuantizedWeights(
std::vector<float> data, int8_t* data,
std::shared_ptr<Graph>& graph, std::shared_ptr<Graph>& graph,
std::vector<int64_t> shapes, const std::vector<int64_t>& shapes,
const std::vector<int64_t>& strides,
float scale, float scale,
int64_t zero_point) { int64_t zero_point) {
Node* const_node_1 = graph->create(prim::Constant); Node* const_node_1 = graph->create(prim::Constant);
auto const_value = auto const_value =
at::from_blob(data.data(), c10::IntArrayRef(shapes), at::kFloat) at::from_blob(
data, c10::IntArrayRef(shapes), c10::IntArrayRef(strides), at::kChar)
.to(at::kCPU); .to(at::kCPU);
auto options = c10::TensorOptions().dtype(at::kFloat).device(at::kCPU); auto options = c10::TensorOptions().dtype(at::kChar).device(at::kCPU);
at::Tensor const_value_copy = at::empty(c10::IntArrayRef(shapes), options); at::Tensor const_value_copy = at::empty_strided(
const_value.copy_(const_value); c10::IntArrayRef(shapes), c10::IntArrayRef(strides), options);
const_value_copy.copy_(const_value);
const_node_1->t_(Symbol::attr("value"), const_value_copy); const_node_1->t_(Symbol::attr("value"), const_value_copy);
Node* const_node_2 = graph->create(prim::Constant); Node* const_node_2 = graph->create(prim::Constant);
@ -150,6 +153,7 @@ std::vector<Node*> CreateQuantizedWeights(
auto const_shape = auto const_shape =
at::from_blob(scale_v.data(), c10::IntArrayRef(scale_shapes), at::kFloat) at::from_blob(scale_v.data(), c10::IntArrayRef(scale_shapes), at::kFloat)
.to(at::kCPU); .to(at::kCPU);
options = c10::TensorOptions().dtype(at::kFloat).device(at::kCPU);
at::Tensor const_shape_copy = at::Tensor const_shape_copy =
at::empty(c10::IntArrayRef(scale_shapes), options); at::empty(c10::IntArrayRef(scale_shapes), options);
const_shape_copy.copy_(const_shape); const_shape_copy.copy_(const_shape);
@ -162,6 +166,7 @@ std::vector<Node*> CreateQuantizedWeights(
at::from_blob( at::from_blob(
zero_point_v.data(), c10::IntArrayRef(zero_shapes), at::kInt) zero_point_v.data(), c10::IntArrayRef(zero_shapes), at::kInt)
.to(at::kCPU); .to(at::kCPU);
options = c10::TensorOptions().dtype(at::kInt).device(at::kCPU);
at::Tensor const_zero_copy = at::Tensor const_zero_copy =
at::empty(c10::IntArrayRef(zero_shapes), options); at::empty(c10::IntArrayRef(zero_shapes), options);
const_zero_copy.copy_(const_zero); const_zero_copy.copy_(const_zero);
@ -402,9 +407,10 @@ void unpackQuantizedWeightsHelper(
std::tie(unpacked_weight, bias) = op.call(packed_weight); std::tie(unpacked_weight, bias) = op.call(packed_weight);
} }
// Permute weights
std::vector<int64_t> wt_sizes = unpacked_weight.sizes().vec(); std::vector<int64_t> wt_sizes = unpacked_weight.sizes().vec();
if (unpacked_weight.ndimension() == 4) { std::vector<int64_t> wt_strides = unpacked_weight.strides().vec();
if (unpacked_weight.ndimension() == 4 && caffe2) {
// Permute weights
unpacked_weight.permute({0, 2, 3, 1}); unpacked_weight.permute({0, 2, 3, 1});
wt_sizes = { wt_sizes = {
unpacked_weight.size(0), unpacked_weight.size(0),
@ -416,13 +422,13 @@ void unpackQuantizedWeightsHelper(
// Remove packed_params // Remove packed_params
qlinear_node->removeInput(1); qlinear_node->removeInput(1);
// Convert from int8 to uint8
int8_t* inp_data = int8_t* inp_data =
reinterpret_cast<int8_t*>(unpacked_weight.data_ptr<c10::qint8>()); reinterpret_cast<int8_t*>(unpacked_weight.data_ptr<c10::qint8>());
const int64_t weight_zp = unpacked_weight.q_zero_point() + 128;
const int64_t wt_numel = unpacked_weight.numel();
if (caffe2) { if (caffe2) {
// Convert from int8 to uint8
const int64_t weight_zp = unpacked_weight.q_zero_point() + 128;
const int64_t wt_numel = unpacked_weight.numel();
// Create caffe2::Int8GivenTensorFill node // Create caffe2::Int8GivenTensorFill node
std::ostringstream os; std::ostringstream os;
for (const auto i : c10::irange(wt_numel)) { for (const auto i : c10::irange(wt_numel)) {
@ -434,27 +440,21 @@ void unpackQuantizedWeightsHelper(
c2_weight->insertBefore(qlinear_node); c2_weight->insertBefore(qlinear_node);
qlinear_node->insertInput(1, c2_weight->output()); qlinear_node->insertInput(1, c2_weight->output());
} else { } else {
std::vector<float> unpacked_weight_values; std::vector<Node*> unpacked_wt = CreateQuantizedWeights(
unpacked_weight_values.reserve(unpacked_weight.numel()); inp_data,
auto unpacked_weight_data =
reinterpret_cast<int8_t*>(unpacked_weight.data_ptr<c10::qint8>());
for (const auto i : c10::irange(unpacked_weight.numel())) {
unpacked_weight_values.push_back(
static_cast<float>(unpacked_weight_data[i]));
}
std::vector<Node*> c2_weight = CreateQuantizedWeights(
unpacked_weight_values,
graph, graph,
wt_sizes, wt_sizes,
wt_strides,
static_cast<float>(unpacked_weight.q_scale()), static_cast<float>(unpacked_weight.q_scale()),
weight_zp); unpacked_weight.q_zero_point());
graph->setInsertPoint(qlinear_node); graph->setInsertPoint(qlinear_node);
c2_weight[0]->insertBefore(qlinear_node); Node* quant_node = graph->create(prim::TupleConstruct);
qlinear_node->insertInput(1, c2_weight[0]->output()); for (auto* n : unpacked_wt) {
c2_weight[1]->insertBefore(qlinear_node); n->insertBefore(qlinear_node);
qlinear_node->insertInput(2, c2_weight[1]->output()); quant_node->addInput(n->output());
c2_weight[2]->insertBefore(qlinear_node); }
qlinear_node->insertInput(3, c2_weight[2]->output()); quant_node->insertBefore(qlinear_node);
qlinear_node->insertInput(1, quant_node->output());
} }
// Add bias // Add bias
@ -507,10 +507,8 @@ void unpackQuantizedWeightsHelper(
CreateQuantizedBias(bias_values, graph, original_bias.sizes().vec()); CreateQuantizedBias(bias_values, graph, original_bias.sizes().vec());
bias->insertBefore(qlinear_node); bias->insertBefore(qlinear_node);
// For quantized_linear inputs, the order is input, weight, bias, .... // For quantized_linear inputs, the order is input, weight, bias, ....
// We unpack weight into 3 values. then it is // Therefore bias is at location 2.
// input, weight_value, weight_scale, weight_zero_point, bias, ... qlinear_node->insertInput(2, bias->output());
// Therefore bias is at location 4.
qlinear_node->insertInput(4, bias->output());
} }
// add conv arguments: stride, padding, dilation, groups // add conv arguments: stride, padding, dilation, groups
@ -520,6 +518,7 @@ void unpackQuantizedWeightsHelper(
conv_ints_args.push_back(stride); conv_ints_args.push_back(stride);
conv_ints_args.push_back(padding); conv_ints_args.push_back(padding);
conv_ints_args.push_back(dilation); conv_ints_args.push_back(dilation);
// skip (input, weight, bias)
const size_t arg_offset = 3; const size_t arg_offset = 3;
for (const auto i : c10::irange(conv_ints_args.size())) { for (const auto i : c10::irange(conv_ints_args.size())) {
Node* ints_node = Node* ints_node =
@ -616,32 +615,35 @@ void UnpackQuantizedWeights(
"quantized::linear_unpack", "quantized::linear_unpack",
QuantizedParamsType::LINEAR, QuantizedParamsType::LINEAR,
caffe2); caffe2);
if (caffe2) { unpackQuantizedWeightsHelper(
unpackQuantizedWeightsHelper( graph,
graph, paramsDict,
paramsDict, qconv2d,
qconv2d, "quantized::conv2d_unpack",
"quantized::conv2d_unpack", QuantizedParamsType::CONV,
QuantizedParamsType::CONV); caffe2);
unpackQuantizedWeightsHelper( unpackQuantizedWeightsHelper(
graph, graph,
paramsDict, paramsDict,
qconv2d_relu, qconv2d_relu,
"quantized::conv2d_unpack", "quantized::conv2d_unpack",
QuantizedParamsType::CONV); QuantizedParamsType::CONV,
unpackQuantizedWeightsHelper( caffe2);
graph, unpackQuantizedWeightsHelper(
paramsDict, graph,
qconv3d, paramsDict,
"quantized::conv3d_unpack", qconv3d,
QuantizedParamsType::CONV); "quantized::conv3d_unpack",
unpackQuantizedWeightsHelper( QuantizedParamsType::CONV,
graph, caffe2);
paramsDict, unpackQuantizedWeightsHelper(
qconv3d_relu, graph,
"quantized::conv3d_unpack", paramsDict,
QuantizedParamsType::CONV); qconv3d_relu,
} else { "quantized::conv3d_unpack",
QuantizedParamsType::CONV,
caffe2);
if (!caffe2) {
UnpackQuantizedTensorInputs(graph); UnpackQuantizedTensorInputs(graph);
} }
GRAPH_DUMP("After UnpackQuantizedWeights: ", graph); GRAPH_DUMP("After UnpackQuantizedWeights: ", graph);

View File

@ -134,7 +134,8 @@ def _unpack_list(list_value):
def _unpack_tuple(tuple_value): def _unpack_tuple(tuple_value):
tuple_node = tuple_value.node() tuple_node = tuple_value.node()
assert tuple_node.kind() == "prim::TupleConstruct" if tuple_node.kind() != "prim::TupleConstruct":
raise RuntimeError("ONNX symbolic expected node type `prim::TupleConstruct`, got `{}`".format(tuple_node))
return list(tuple_node.inputs()) return list(tuple_node.inputs())
# Check if list_value is output from prim::ListConstruct # Check if list_value is output from prim::ListConstruct
@ -192,6 +193,74 @@ def parse_args(*arg_descriptors):
return wrapper return wrapper
return decorator return decorator
def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
"""A decorator which extends support for quantized version of the base operator.
Quantization is detected by examining the arguments that are annotated by
`arg_q_descriptors`.
If quantization is detected, the base operator symbolic function will be wrapped with
argument dequantization and output quantization.
Otherwise, only base symbolic function will be invoked.
For example:
@quantized_args(True, False)
def foo(g, x, y):
return x + y
is equivalent to
def q_foo(g, x, y):
if is_quantized_tensor(x):
x = dequantize(x)
out = foo(g, x, y)
return quantize(out)
else:
return foo(g, x, y)
Args:
arg_q_descriptors: list of bool, where each element represents if the
argument is QTensor for quantized version of this operator.
scale: float default None, quantized output scale. If None, derive from
the first quantized input scale.
zero_point: int default None, quantized output zero point. If None,
derive from the first quantized input zero point.
"""
def decorator(fn):
fn._scale = scale
fn._zero_point = zero_point
@wraps(fn)
def wrapper(g, *args, **kwargs):
_scale = fn._scale
if _scale is not None:
_scale = g.op("Constant", value_t=torch.tensor(_scale))
_zero_point = fn._zero_point
if _zero_point is not None:
_zero_point = g.op("Constant", value_t=torch.tensor(_zero_point))
# some args may be optional, so the length may be smaller
assert len(arg_q_descriptors) >= len(args)
desc_args = tuple(zip(arg_q_descriptors[:len(args)], args))
# Run regular symbolic function if none of the argument is QTensor.
if not any((desc and arg.node().kind() == "prim::TupleConstruct") for desc, arg in desc_args):
return fn(g, *args, **kwargs)
dequantized_args = []
for desc, arg in desc_args:
if desc:
dequantized_arg, scale, zero_point = dequantize_helper(g, arg)
dequantized_args.append(dequantized_arg)
if _scale is None:
_scale = scale
if _zero_point is None:
_zero_point = zero_point
else:
dequantized_args.append(arg)
# TODO: only support single output
output = fn(g, *dequantized_args, **kwargs)
return quantize_helper(g, output, _scale, _zero_point)
return wrapper
return decorator
def _scalar(x): def _scalar(x):
"""Convert a scalar tensor into a Python value.""" """Convert a scalar tensor into a Python value."""
@ -826,6 +895,30 @@ def _handle_reduce_dim_none(g, self, op_name):
return g.op(op_name, self, keepdims_i=1) return g.op(op_name, self, keepdims_i=1)
return g.op(op_name, self, keepdims_i=0) return g.op(op_name, self, keepdims_i=0)
def dequantize_helper(g, qtensor, qdtype=None):
tensor, scale, zero_point = _unpack_tuple(qtensor)
input_qdtype = cast_pytorch_to_onnx[tensor.type().scalarType()]
if qdtype is None:
if input_qdtype is not None:
qdtype = input_qdtype
else:
qdtype = torch.onnx.TensorProtoDataType.UINT8
value = g.op("Cast", tensor, to_i=qdtype)
scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
zero_point = g.op("Cast", zero_point, to_i=qdtype)
return g.op("DequantizeLinear", value, scale, zero_point), scale, zero_point
def quantize_helper(g, tensor, scale, zero_point):
assert scale is not None
if scale.type().scalarType() != "Float":
scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
assert zero_point is not None
if zero_point.type().scalarType() not in ("Byte", "Char"):
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
output = g.op("QuantizeLinear", tensor, scale, zero_point)
return g.op("prim::TupleConstruct", output, scale, zero_point)
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
# ONNX operator version # ONNX operator version
# --------------------------------------------------------------------- # ---------------------------------------------------------------------

View File

@ -10,7 +10,7 @@ import torch.onnx.utils
import torch.onnx.symbolic_helper as sym_help import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _unimplemented from torch.onnx.symbolic_helper import parse_args, _unimplemented
import torch.onnx.symbolic_opset9 import torch.onnx.symbolic_opset9
from torch.onnx.symbolic_opset9 import linear from torch.onnx.symbolic_opset9 import linear, conv2d, add, mul, hardswish, relu
from sys import maxsize from sys import maxsize
@ -322,39 +322,75 @@ def isfinite(g, input):
return __not_(g, __or_(g, inf_node, nan_node)) return __not_(g, __or_(g, inf_node, nan_node))
# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export def quantize_per_tensor(g, input, scale, zero_point, dtype):
dtype = sym_help._get_const(dtype, "i", "dtype")
zero_point = g.op("Cast", zero_point, to_i=sym_help.scalar_type_to_onnx[dtype])
scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
return sym_help.quantize_helper(g, input, scale, zero_point)
def dequantize(g, input):
return sym_help.dequantize_helper(g, input)[0]
class Quantized: class Quantized:
"""
https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were introduced in opset version 10.
"""
domain = "quantized" domain = "quantized"
# DequantizeLinear was added in opset version 10.
@staticmethod @staticmethod
def linear(g, input_original, weight, weight_scale, weight_zero_point, bias, op_scale, op_zero_point): def linear(g, q_input, q_weight, bias, op_scale, op_zero_point):
input_value, input_scale, input_zero_point = sym_help._unpack_tuple(input_original) input, _, _ = sym_help.dequantize_helper(g, q_input)
# From https://pytorch.org/docs/master/generated/torch.nn.quantized.functional.linear.html weight, _, _ = sym_help.dequantize_helper(g, q_weight)
# input (Tensor) Quantized input of type torch.quint8
input_type_dq = torch.onnx.TensorProtoDataType.UINT8
input_value = g.op("Cast", input_value, to_i=input_type_dq)
input_scale = g.op("Cast", input_scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
input_zero_point = g.op("Cast", input_zero_point, to_i=input_type_dq)
input = g.op("DequantizeLinear", input_value, input_scale, input_zero_point)
# weight (Tensor) Quantized weight of type torch.qint8
weight_type_dq = torch.onnx.TensorProtoDataType.INT8
weight = g.op("Cast", weight, to_i=weight_type_dq)
weight_scale = g.op("Cast", weight_scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
weight_zero_point = g.op("Cast", weight_zero_point, to_i=weight_type_dq)
weight = g.op("DequantizeLinear", weight, weight_scale, weight_zero_point)
# bias (Tensor) None or fp32 bias of type torch.float
bias = g.op("Cast", bias, to_i=torch.onnx.TensorProtoDataType.FLOAT)
output = linear(g, input, weight, bias) output = linear(g, input, weight, bias)
if op_scale is None: return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
op_scale = input_scale
elif op_scale.type().scalarType() != "Float":
op_scale = g.op("Cast", op_scale, to_i=sym_help.cast_pytorch_to_onnx["Float"])
if op_zero_point is None: @staticmethod
op_zero_point = input_zero_point def add(g, x, y, op_scale, op_zero_point):
elif op_zero_point.type().scalarType() != "Byte": x, _, _ = sym_help.dequantize_helper(g, x)
op_zero_point = g.op("Cast", op_zero_point, to_i=sym_help.cast_pytorch_to_onnx["Byte"]) y, _, _ = sym_help.dequantize_helper(g, y)
output = g.op("QuantizeLinear", output, op_scale, op_zero_point)
return g.op("prim::TupleConstruct", output, op_scale, op_zero_point) output = add(g, x, y)
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
def mul(g, x, y, op_scale, op_zero_point):
x, _, _ = sym_help.dequantize_helper(g, x)
y, _, _ = sym_help.dequantize_helper(g, y)
output = mul(g, x, y)
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
def hardswish(g, x, op_scale, op_zero_point):
x, _, _ = sym_help.dequantize_helper(g, x)
output = hardswish(g, x)
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
def conv2d_relu(g, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point):
input, _, _ = sym_help.dequantize_helper(g, q_input)
weight, _, _ = sym_help.dequantize_helper(g, q_weight)
output = conv2d(g, input, weight, bias, stride, padding, dilation, groups)
output = relu(g, output)
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
def conv2d(g, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point):
input, _, _ = sym_help.dequantize_helper(g, q_input)
weight, _, _ = sym_help.dequantize_helper(g, q_weight)
output = conv2d(g, input, weight, bias, stride, padding, dilation, groups)
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)

View File

@ -7,7 +7,7 @@ import torch
import torch.onnx.symbolic_helper as sym_help import torch.onnx.symbolic_helper as sym_help
import warnings import warnings
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list, ScalarType from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list, ScalarType, quantized_args
from torch.onnx.symbolic_opset9 import expand, unused, mul from torch.onnx.symbolic_opset9 import expand, unused, mul
from torch.onnx.symbolic_opset9 import linalg_vector_norm as lvn from torch.onnx.symbolic_opset9 import linalg_vector_norm as lvn
from torch.nn.modules.utils import _single, _pair, _triple from torch.nn.modules.utils import _single, _pair, _triple
@ -791,7 +791,7 @@ def narrow(g, input, dim, start, length):
end = g.op("Add", start, length) end = g.op("Add", start, length)
return _slice_helper(g, input, axes=dim, starts=start, ends=end, dynamic_slice=True) return _slice_helper(g, input, axes=dim, starts=start, ends=end, dynamic_slice=True)
@quantized_args(True, False, False)
@parse_args("v", "i", "i") @parse_args("v", "i", "i")
def flatten(g, input, start_dim, end_dim): def flatten(g, input, start_dim, end_dim):
dim = sym_help._get_tensor_rank(input) dim = sym_help._get_tensor_rank(input)

View File

@ -52,3 +52,18 @@ def batch_norm(g, input, weight, bias, running_mean, running_var, training, mome
new_running_mean.setType(running_mean.type()) new_running_mean.setType(running_mean.type())
new_running_var.setType(running_var.type()) new_running_var.setType(running_var.type())
return res return res
class Quantized:
"""
https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
"""
domain = "quantized"
@staticmethod
def hardswish(g, x, op_scale, op_zero_point):
x, _, _ = sym_help.dequantize_helper(g, x)
output = hardswish(g, x)
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)

View File

@ -13,7 +13,7 @@ from functools import partial
from functools import wraps from functools import wraps
import torch.onnx.symbolic_helper as sym_help import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented, ScalarType from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented, ScalarType, quantized_args
from typing import Optional from typing import Optional
from sys import maxsize as maxsize from sys import maxsize as maxsize
@ -949,6 +949,7 @@ avg_pool3d = _avg_pool("avg_pool3d", _triple)
def _adaptive_pool(name, type, tuple_fn, fn=None): def _adaptive_pool(name, type, tuple_fn, fn=None):
@quantized_args(True, False)
def symbolic_fn(g, input, output_size): def symbolic_fn(g, input, output_size):
# _adaptive_pool is supported for cases where output_size is 1 for all dimensions, # _adaptive_pool is supported for cases where output_size is 1 for all dimensions,
# by executing a GlobalPool. # by executing a GlobalPool.
@ -1957,6 +1958,8 @@ def hardswish(g, self):
return g.op("Mul", self, hs) return g.op("Mul", self, hs)
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp
@quantized_args(True, scale=1.0 / 256.0, zero_point=0)
@parse_args("v") @parse_args("v")
def hardsigmoid(g, self): def hardsigmoid(g, self):
# Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid. # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid.
@ -2570,6 +2573,7 @@ def erf(g, input):
return g.op("Erf", input) return g.op("Erf", input)
@quantized_args(True, False, False)
@parse_args("v", "i", "i") @parse_args("v", "i", "i")
def flatten(g, input, start_dim, end_dim): def flatten(g, input, start_dim, end_dim):
dim = sym_help._get_tensor_rank(input) dim = sym_help._get_tensor_rank(input)

View File

@ -131,14 +131,14 @@ class UnsupportedOperatorError(RuntimeError):
def __init__(self, domain, opname, version): def __init__(self, domain, opname, version):
supported_version = get_op_supported_version(opname, domain, version) supported_version = get_op_supported_version(opname, domain, version)
if domain in ["", "aten", "prim", "quantized"]: if domain in ["", "aten", "prim", "quantized"]:
msg = "Exporting the operator " + opname + " to ONNX opset version " + str(version) + " is not supported. " msg = f"Exporting the operator {domain}::{opname} to ONNX opset version {version} is not supported. "
if supported_version is not None: if supported_version is not None:
msg += "Support for this operator was added in version " + str(supported_version) + \ msg += (f"Support for this operator was added in version {supported_version}, "
", try exporting with this version." "try exporting with this version.")
else: else:
msg += "Please feel free to request support or submit a pull request on PyTorch GitHub." msg += "Please feel free to request support or submit a pull request on PyTorch GitHub."
else: else:
msg = ("ONNX export failed on an operator with unrecognized namespace {}::{}. " msg = (f"ONNX export failed on an operator with unrecognized namespace {domain}::{opname}. "
"If you are trying to export a custom operator, make sure you registered " "If you are trying to export a custom operator, make sure you registered "
"it with the right domain and version.".format(domain, opname)) "it with the right domain and version.")
super().__init__(msg) super().__init__(msg)