mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Add infra for quantized model export and support quantized mobilenet v3 (#72215)
* Add infrastructure and helper functions to enable future work for other quantized operators and models.
* Add export for quantized operators needed by torchvision mobilenet v3 large.
* ATen namespace: hardsigmoid, flatten, adaptive_avg_pool, quantize_per_tensor, dequantize.
* Quantized namespace: conv2d, conv2d_relu, hardswish, add, mul.
* Numerous bug fixes, in unpack_quantized_weight.cpp, symbolic functions, and unit test.
Co-authored-by: BowenBao <bowbaomicrosoft.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73102
This commit is contained in:
parent
785ebb9d6d
commit
40de6b80ee
|
|
@ -44,6 +44,8 @@ from torch.onnx.symbolic_helper import _unimplemented
|
|||
from torch.onnx.utils import unpack_quantized_tensor
|
||||
|
||||
|
||||
_ORT_PROVIDERS = ["CPUExecutionProvider"]
|
||||
|
||||
def flatten_tuples(elem):
|
||||
tup = []
|
||||
for t in elem:
|
||||
|
|
@ -99,7 +101,7 @@ def convert_to_onnx(model, input=None, opset_version=9, do_constant_folding=True
|
|||
# suppress ort warnings.
|
||||
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
||||
so.log_severity_level = 3
|
||||
ort_sess = onnxruntime.InferenceSession(f.getvalue(), so)
|
||||
ort_sess = onnxruntime.InferenceSession(f.getvalue(), so, providers=_ORT_PROVIDERS)
|
||||
return ort_sess
|
||||
|
||||
|
||||
|
|
@ -373,7 +375,9 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
# suppress ort warnings.
|
||||
# 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
||||
ort_sess_opt.log_severity_level = 3
|
||||
ort_sess = onnxruntime.InferenceSession(model_file_name, sess_options=ort_sess_opt)
|
||||
ort_sess = onnxruntime.InferenceSession(model_file_name,
|
||||
sess_options=ort_sess_opt,
|
||||
providers=_ORT_PROVIDERS)
|
||||
input_copy = copy.deepcopy(input)
|
||||
ort_outs = run_ort(ort_sess, input_copy)
|
||||
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
|
||||
|
|
@ -730,6 +734,48 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
dynamic_axes={"input_images": {0: "batch_size"}, "output": {0: "batch_size"}},
|
||||
rtol=1e-3, atol=1e-5)
|
||||
|
||||
@disableScriptTest()
|
||||
def test_mobilenet_v3(self):
|
||||
model = torchvision.models.quantization.mobilenet_v3_large(pretrained=False)
|
||||
dummy_input = torch.randn(1, 3, 224, 224)
|
||||
self.run_test(model, (dummy_input,))
|
||||
|
||||
@unittest.skip("Fixed in PyTorch master. RuntimeError: Error(s) in loading state_dict for QuantizableMobileNetV3")
|
||||
@disableScriptTest()
|
||||
def test_mobilenet_v3_quant(self):
|
||||
model = torchvision.models.quantization.mobilenet_v3_large(pretrained=True, quantize=True)
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
data_dir = os.path.join(os.path.dirname(__file__), "assets")
|
||||
path = os.path.join(data_dir, "grace_hopper_517x606.jpg")
|
||||
input_image = Image.open(path)
|
||||
# Based on example from https://pytorch.org/hub/pytorch_vision_resnet/
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
input_tensor = preprocess(input_image).unsqueeze(0)
|
||||
|
||||
# Due to precision error from quantization, check only that the top prediction matches.
|
||||
class TopPredictor(torch.nn.Module):
|
||||
def __init__(self, mobilenet):
|
||||
super().__init__()
|
||||
self.mobilenet = mobilenet
|
||||
|
||||
def forward(self, x):
|
||||
x = self.mobilenet(x)
|
||||
_, topk_catid = torch.topk(x[0], 1)
|
||||
return topk_catid
|
||||
|
||||
# Currently, we need convert the model to ScriptModule before export.
|
||||
# The reason is that PackedParams contains int (not tensor).
|
||||
# Then it fails when the exporter calls _trace_and_get_graph_from_model().
|
||||
# TODO: https://msdata.visualstudio.com/Vienna/_workitems/edit/1547858
|
||||
model = torch.jit.trace(TopPredictor(model), input_tensor)
|
||||
self.run_test(model, (input_tensor, ))
|
||||
|
||||
@disableScriptTest()
|
||||
def test_word_language_model_RNN_TANH(self):
|
||||
self.run_word_language_model("RNN_TANH")
|
||||
|
|
@ -10413,11 +10459,14 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
loaded_model = onnx.load_from_string(f.getvalue())
|
||||
self.assertEqual(loaded_model.graph.output[0].type.tensor_type.shape.dim[1].dim_value, 128)
|
||||
|
||||
# NOTE: For quantization tests, choose scale and zero point carefully
|
||||
# such that inputs and outputs do not always overflow/underflow.
|
||||
# Otherwise test results could be inaccurate.
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_quantized_linear(self):
|
||||
model = torch.nn.quantized.Linear(1, 2)
|
||||
input = torch.rand(1, 1)
|
||||
input_tensor = torch.quantize_per_tensor(input, 1, 0, torch.quint8)
|
||||
model = torch.nn.quantized.Linear(4, 8)
|
||||
input = torch.randn(4, 4)
|
||||
input_tensor = torch.quantize_per_tensor(input, 0.5, 128, torch.quint8)
|
||||
# Currently, we need convert the model to ScriptModule before export.
|
||||
# The reason is that PackedParams contains int (not tensor).
|
||||
# Then it fails when the exporter calls _trace_and_get_graph_from_model().
|
||||
|
|
@ -10425,6 +10474,100 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
self.run_test(torch.jit.trace(model, input_tensor), (input_tensor,))
|
||||
self.run_test(torch.jit.script(model), (input_tensor,))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_quantized_conv2d(self):
|
||||
model = torch.nn.quantized.Conv2d(16, 33, 3, stride=2)
|
||||
# Manually initialize model weight and bias to random numbers.
|
||||
# By default all zeros.
|
||||
q_weight = torch.quantize_per_tensor(torch.randn(33, 16, 3, 3), 0.2, 0, torch.qint8)
|
||||
bias = torch.randn(33)
|
||||
model.set_weight_bias(q_weight, bias)
|
||||
input = torch.randn(3, 16, 32, 32)
|
||||
q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8)
|
||||
self.run_test(torch.jit.trace(model, q_input), (q_input,))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_quantized_adaptive_avg_pool2d(self):
|
||||
model = torch.nn.AdaptiveAvgPool2d((5, 7))
|
||||
input = torch.randn(4, 3, 10, 14)
|
||||
q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8)
|
||||
self.run_test(torch.jit.trace(model, q_input), (q_input,))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_quantized_conv2d_relu(self):
|
||||
model = torch.nn.intrinsic.quantized.ConvReLU2d(16, 33, 3, stride=2)
|
||||
# Manually initialize model weight and bias to random numbers.
|
||||
# By default all zeros.
|
||||
q_weight = torch.quantize_per_tensor(torch.randn(33, 16, 3, 3), 0.2, 0, torch.qint8)
|
||||
bias = torch.randn(33)
|
||||
model.set_weight_bias(q_weight, bias)
|
||||
input = torch.randn(3, 16, 32, 32)
|
||||
q_input = torch.quantize_per_tensor(input, 0.2, 128, torch.quint8)
|
||||
self.run_test(torch.jit.trace(model, q_input), (q_input,))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_quantized_hardswish(self):
|
||||
model = torch.nn.quantized.Hardswish(1, 0)
|
||||
input = torch.randn(2, 6)
|
||||
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
|
||||
self.run_test(torch.jit.trace(model, q_input), (q_input,))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_quantized_hardsigmoid(self):
|
||||
model = torch.nn.Hardsigmoid()
|
||||
input = torch.randn(2, 6)
|
||||
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
|
||||
self.run_test(torch.jit.trace(model, q_input), (q_input,))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_quantized_flatten(self):
|
||||
class FlattenModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return torch.flatten(input)
|
||||
|
||||
x = torch.quantize_per_tensor(torch.randn(1, 2, 3, 4), 1, 0, torch.quint8)
|
||||
self.run_test(torch.jit.trace(FlattenModel(), x), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_quantized_arithmetic(self):
|
||||
x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
|
||||
y = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 128, torch.quint8)
|
||||
|
||||
class ArithmeticModel(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
o = torch.nn.quantized.QFunctional().add(x, y)
|
||||
o = torch.nn.quantized.QFunctional().mul(o, x)
|
||||
return o
|
||||
|
||||
self.run_test(torch.jit.trace(ArithmeticModel(), (x, y)), (x, y))
|
||||
|
||||
class ArithmeticModel2(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
o = torch.ops.quantized.add(x, y, 0.4, 100)
|
||||
o = torch.ops.quantized.mul(o, x, 0.4, 100)
|
||||
return o
|
||||
|
||||
self.run_test(torch.jit.trace(ArithmeticModel2(), (x, y)), (x, y))
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_quantize_per_tensor(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return (torch.quantize_per_tensor(x, 0.2, 0, torch.qint8),
|
||||
torch.quantize_per_tensor(x, 0.2, 128, torch.quint8))
|
||||
|
||||
x = torch.randn(4, 6)
|
||||
self.run_test(torch.jit.trace(Module(), x), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_dequantize(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.dequantize(x)
|
||||
|
||||
x = torch.quantize_per_tensor(torch.randn(3, 4), 0.2, 0, torch.qint8)
|
||||
self.run_test(torch.jit.trace(Module(), x), x)
|
||||
|
||||
def make_test(name, base, layer, bidirectional, initial_state,
|
||||
variable_length, dropout, script_test_min_opset_version,
|
||||
**extra_kwargs):
|
||||
|
|
|
|||
|
|
@ -891,39 +891,6 @@ class TestUtilityFuns_opset9(_BaseTestCase):
|
|||
self.assertEqual(next(iter).kind(), "onnx::Constant")
|
||||
self.assertEqual(next(iter).kind(), "aten::cosine_similarity")
|
||||
|
||||
def test_quantized_fallthrough(self):
|
||||
# Test Quantized op
|
||||
class QModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(QModule, self).__init__()
|
||||
self.quant1 = torch.ao.quantization.QuantStub()
|
||||
self.dequant = torch.ao.quantization.DeQuantStub()
|
||||
|
||||
def forward(self, x):
|
||||
res = self.quant1(x)
|
||||
return self.dequant(res)
|
||||
|
||||
model = QModule()
|
||||
torch.backends.quantized.engine = "qnnpack"
|
||||
pt_inputs = (torch.randn(1, 2, 3, 4))
|
||||
model.qconfig = torch.ao.quantization.default_qconfig
|
||||
q_model = torch.ao.quantization.prepare(model, inplace=False)
|
||||
q_model = torch.ao.quantization.convert(q_model, inplace=False)
|
||||
|
||||
q_model.eval()
|
||||
|
||||
graph, _, __ = self._model_to_graph(q_model, pt_inputs,
|
||||
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
|
||||
input_names=["pt_inputs"],
|
||||
dynamic_axes={"pt_inputs": [0, 1, 2, 3]})
|
||||
|
||||
iter = graph.nodes()
|
||||
self.assertEqual(next(iter).kind(), "onnx::Constant")
|
||||
self.assertEqual(next(iter).kind(), "onnx::Constant")
|
||||
self.assertEqual(next(iter).kind(), "onnx::Constant")
|
||||
self.assertEqual(next(iter).kind(), "aten::quantize_per_tensor")
|
||||
self.assertEqual(next(iter).kind(), "aten::dequantize")
|
||||
|
||||
# prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11
|
||||
@skipIfUnsupportedMaxOpsetVersion(10)
|
||||
def test_prim_fallthrough(self):
|
||||
|
|
|
|||
|
|
@ -775,10 +775,12 @@ static void eraseTupleConstruct(Block* block) {
|
|||
for (auto* input : output_node->inputs()) {
|
||||
block->insertOutput(index + (input_index++), input);
|
||||
}
|
||||
}
|
||||
index += input_index;
|
||||
} else {
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void removeMaxPoolUnusedOutput(Block* b) {
|
||||
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
|
||||
|
|
|
|||
|
|
@ -130,18 +130,21 @@ Node* CreateQuantizedBiasCaffe2(
|
|||
}
|
||||
|
||||
std::vector<Node*> CreateQuantizedWeights(
|
||||
std::vector<float> data,
|
||||
int8_t* data,
|
||||
std::shared_ptr<Graph>& graph,
|
||||
std::vector<int64_t> shapes,
|
||||
const std::vector<int64_t>& shapes,
|
||||
const std::vector<int64_t>& strides,
|
||||
float scale,
|
||||
int64_t zero_point) {
|
||||
Node* const_node_1 = graph->create(prim::Constant);
|
||||
auto const_value =
|
||||
at::from_blob(data.data(), c10::IntArrayRef(shapes), at::kFloat)
|
||||
at::from_blob(
|
||||
data, c10::IntArrayRef(shapes), c10::IntArrayRef(strides), at::kChar)
|
||||
.to(at::kCPU);
|
||||
auto options = c10::TensorOptions().dtype(at::kFloat).device(at::kCPU);
|
||||
at::Tensor const_value_copy = at::empty(c10::IntArrayRef(shapes), options);
|
||||
const_value.copy_(const_value);
|
||||
auto options = c10::TensorOptions().dtype(at::kChar).device(at::kCPU);
|
||||
at::Tensor const_value_copy = at::empty_strided(
|
||||
c10::IntArrayRef(shapes), c10::IntArrayRef(strides), options);
|
||||
const_value_copy.copy_(const_value);
|
||||
const_node_1->t_(Symbol::attr("value"), const_value_copy);
|
||||
|
||||
Node* const_node_2 = graph->create(prim::Constant);
|
||||
|
|
@ -150,6 +153,7 @@ std::vector<Node*> CreateQuantizedWeights(
|
|||
auto const_shape =
|
||||
at::from_blob(scale_v.data(), c10::IntArrayRef(scale_shapes), at::kFloat)
|
||||
.to(at::kCPU);
|
||||
options = c10::TensorOptions().dtype(at::kFloat).device(at::kCPU);
|
||||
at::Tensor const_shape_copy =
|
||||
at::empty(c10::IntArrayRef(scale_shapes), options);
|
||||
const_shape_copy.copy_(const_shape);
|
||||
|
|
@ -162,6 +166,7 @@ std::vector<Node*> CreateQuantizedWeights(
|
|||
at::from_blob(
|
||||
zero_point_v.data(), c10::IntArrayRef(zero_shapes), at::kInt)
|
||||
.to(at::kCPU);
|
||||
options = c10::TensorOptions().dtype(at::kInt).device(at::kCPU);
|
||||
at::Tensor const_zero_copy =
|
||||
at::empty(c10::IntArrayRef(zero_shapes), options);
|
||||
const_zero_copy.copy_(const_zero);
|
||||
|
|
@ -402,9 +407,10 @@ void unpackQuantizedWeightsHelper(
|
|||
std::tie(unpacked_weight, bias) = op.call(packed_weight);
|
||||
}
|
||||
|
||||
// Permute weights
|
||||
std::vector<int64_t> wt_sizes = unpacked_weight.sizes().vec();
|
||||
if (unpacked_weight.ndimension() == 4) {
|
||||
std::vector<int64_t> wt_strides = unpacked_weight.strides().vec();
|
||||
if (unpacked_weight.ndimension() == 4 && caffe2) {
|
||||
// Permute weights
|
||||
unpacked_weight.permute({0, 2, 3, 1});
|
||||
wt_sizes = {
|
||||
unpacked_weight.size(0),
|
||||
|
|
@ -416,13 +422,13 @@ void unpackQuantizedWeightsHelper(
|
|||
// Remove packed_params
|
||||
qlinear_node->removeInput(1);
|
||||
|
||||
// Convert from int8 to uint8
|
||||
int8_t* inp_data =
|
||||
reinterpret_cast<int8_t*>(unpacked_weight.data_ptr<c10::qint8>());
|
||||
const int64_t weight_zp = unpacked_weight.q_zero_point() + 128;
|
||||
const int64_t wt_numel = unpacked_weight.numel();
|
||||
|
||||
if (caffe2) {
|
||||
// Convert from int8 to uint8
|
||||
const int64_t weight_zp = unpacked_weight.q_zero_point() + 128;
|
||||
const int64_t wt_numel = unpacked_weight.numel();
|
||||
// Create caffe2::Int8GivenTensorFill node
|
||||
std::ostringstream os;
|
||||
for (const auto i : c10::irange(wt_numel)) {
|
||||
|
|
@ -434,27 +440,21 @@ void unpackQuantizedWeightsHelper(
|
|||
c2_weight->insertBefore(qlinear_node);
|
||||
qlinear_node->insertInput(1, c2_weight->output());
|
||||
} else {
|
||||
std::vector<float> unpacked_weight_values;
|
||||
unpacked_weight_values.reserve(unpacked_weight.numel());
|
||||
auto unpacked_weight_data =
|
||||
reinterpret_cast<int8_t*>(unpacked_weight.data_ptr<c10::qint8>());
|
||||
for (const auto i : c10::irange(unpacked_weight.numel())) {
|
||||
unpacked_weight_values.push_back(
|
||||
static_cast<float>(unpacked_weight_data[i]));
|
||||
}
|
||||
std::vector<Node*> c2_weight = CreateQuantizedWeights(
|
||||
unpacked_weight_values,
|
||||
std::vector<Node*> unpacked_wt = CreateQuantizedWeights(
|
||||
inp_data,
|
||||
graph,
|
||||
wt_sizes,
|
||||
wt_strides,
|
||||
static_cast<float>(unpacked_weight.q_scale()),
|
||||
weight_zp);
|
||||
unpacked_weight.q_zero_point());
|
||||
graph->setInsertPoint(qlinear_node);
|
||||
c2_weight[0]->insertBefore(qlinear_node);
|
||||
qlinear_node->insertInput(1, c2_weight[0]->output());
|
||||
c2_weight[1]->insertBefore(qlinear_node);
|
||||
qlinear_node->insertInput(2, c2_weight[1]->output());
|
||||
c2_weight[2]->insertBefore(qlinear_node);
|
||||
qlinear_node->insertInput(3, c2_weight[2]->output());
|
||||
Node* quant_node = graph->create(prim::TupleConstruct);
|
||||
for (auto* n : unpacked_wt) {
|
||||
n->insertBefore(qlinear_node);
|
||||
quant_node->addInput(n->output());
|
||||
}
|
||||
quant_node->insertBefore(qlinear_node);
|
||||
qlinear_node->insertInput(1, quant_node->output());
|
||||
}
|
||||
|
||||
// Add bias
|
||||
|
|
@ -507,10 +507,8 @@ void unpackQuantizedWeightsHelper(
|
|||
CreateQuantizedBias(bias_values, graph, original_bias.sizes().vec());
|
||||
bias->insertBefore(qlinear_node);
|
||||
// For quantized_linear inputs, the order is input, weight, bias, ....
|
||||
// We unpack weight into 3 values. then it is
|
||||
// input, weight_value, weight_scale, weight_zero_point, bias, ...
|
||||
// Therefore bias is at location 4.
|
||||
qlinear_node->insertInput(4, bias->output());
|
||||
// Therefore bias is at location 2.
|
||||
qlinear_node->insertInput(2, bias->output());
|
||||
}
|
||||
|
||||
// add conv arguments: stride, padding, dilation, groups
|
||||
|
|
@ -520,6 +518,7 @@ void unpackQuantizedWeightsHelper(
|
|||
conv_ints_args.push_back(stride);
|
||||
conv_ints_args.push_back(padding);
|
||||
conv_ints_args.push_back(dilation);
|
||||
// skip (input, weight, bias)
|
||||
const size_t arg_offset = 3;
|
||||
for (const auto i : c10::irange(conv_ints_args.size())) {
|
||||
Node* ints_node =
|
||||
|
|
@ -616,32 +615,35 @@ void UnpackQuantizedWeights(
|
|||
"quantized::linear_unpack",
|
||||
QuantizedParamsType::LINEAR,
|
||||
caffe2);
|
||||
if (caffe2) {
|
||||
unpackQuantizedWeightsHelper(
|
||||
graph,
|
||||
paramsDict,
|
||||
qconv2d,
|
||||
"quantized::conv2d_unpack",
|
||||
QuantizedParamsType::CONV);
|
||||
QuantizedParamsType::CONV,
|
||||
caffe2);
|
||||
unpackQuantizedWeightsHelper(
|
||||
graph,
|
||||
paramsDict,
|
||||
qconv2d_relu,
|
||||
"quantized::conv2d_unpack",
|
||||
QuantizedParamsType::CONV);
|
||||
QuantizedParamsType::CONV,
|
||||
caffe2);
|
||||
unpackQuantizedWeightsHelper(
|
||||
graph,
|
||||
paramsDict,
|
||||
qconv3d,
|
||||
"quantized::conv3d_unpack",
|
||||
QuantizedParamsType::CONV);
|
||||
QuantizedParamsType::CONV,
|
||||
caffe2);
|
||||
unpackQuantizedWeightsHelper(
|
||||
graph,
|
||||
paramsDict,
|
||||
qconv3d_relu,
|
||||
"quantized::conv3d_unpack",
|
||||
QuantizedParamsType::CONV);
|
||||
} else {
|
||||
QuantizedParamsType::CONV,
|
||||
caffe2);
|
||||
if (!caffe2) {
|
||||
UnpackQuantizedTensorInputs(graph);
|
||||
}
|
||||
GRAPH_DUMP("After UnpackQuantizedWeights: ", graph);
|
||||
|
|
|
|||
|
|
@ -134,7 +134,8 @@ def _unpack_list(list_value):
|
|||
|
||||
def _unpack_tuple(tuple_value):
|
||||
tuple_node = tuple_value.node()
|
||||
assert tuple_node.kind() == "prim::TupleConstruct"
|
||||
if tuple_node.kind() != "prim::TupleConstruct":
|
||||
raise RuntimeError("ONNX symbolic expected node type `prim::TupleConstruct`, got `{}`".format(tuple_node))
|
||||
return list(tuple_node.inputs())
|
||||
|
||||
# Check if list_value is output from prim::ListConstruct
|
||||
|
|
@ -192,6 +193,74 @@ def parse_args(*arg_descriptors):
|
|||
return wrapper
|
||||
return decorator
|
||||
|
||||
def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
|
||||
"""A decorator which extends support for quantized version of the base operator.
|
||||
Quantization is detected by examining the arguments that are annotated by
|
||||
`arg_q_descriptors`.
|
||||
If quantization is detected, the base operator symbolic function will be wrapped with
|
||||
argument dequantization and output quantization.
|
||||
Otherwise, only base symbolic function will be invoked.
|
||||
|
||||
For example:
|
||||
@quantized_args(True, False)
|
||||
def foo(g, x, y):
|
||||
return x + y
|
||||
|
||||
is equivalent to
|
||||
|
||||
def q_foo(g, x, y):
|
||||
if is_quantized_tensor(x):
|
||||
x = dequantize(x)
|
||||
out = foo(g, x, y)
|
||||
return quantize(out)
|
||||
else:
|
||||
return foo(g, x, y)
|
||||
|
||||
Args:
|
||||
arg_q_descriptors: list of bool, where each element represents if the
|
||||
argument is QTensor for quantized version of this operator.
|
||||
scale: float default None, quantized output scale. If None, derive from
|
||||
the first quantized input scale.
|
||||
zero_point: int default None, quantized output zero point. If None,
|
||||
derive from the first quantized input zero point.
|
||||
"""
|
||||
def decorator(fn):
|
||||
fn._scale = scale
|
||||
fn._zero_point = zero_point
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(g, *args, **kwargs):
|
||||
_scale = fn._scale
|
||||
if _scale is not None:
|
||||
_scale = g.op("Constant", value_t=torch.tensor(_scale))
|
||||
_zero_point = fn._zero_point
|
||||
if _zero_point is not None:
|
||||
_zero_point = g.op("Constant", value_t=torch.tensor(_zero_point))
|
||||
|
||||
# some args may be optional, so the length may be smaller
|
||||
assert len(arg_q_descriptors) >= len(args)
|
||||
desc_args = tuple(zip(arg_q_descriptors[:len(args)], args))
|
||||
# Run regular symbolic function if none of the argument is QTensor.
|
||||
if not any((desc and arg.node().kind() == "prim::TupleConstruct") for desc, arg in desc_args):
|
||||
return fn(g, *args, **kwargs)
|
||||
|
||||
dequantized_args = []
|
||||
for desc, arg in desc_args:
|
||||
if desc:
|
||||
dequantized_arg, scale, zero_point = dequantize_helper(g, arg)
|
||||
dequantized_args.append(dequantized_arg)
|
||||
if _scale is None:
|
||||
_scale = scale
|
||||
if _zero_point is None:
|
||||
_zero_point = zero_point
|
||||
else:
|
||||
dequantized_args.append(arg)
|
||||
# TODO: only support single output
|
||||
output = fn(g, *dequantized_args, **kwargs)
|
||||
|
||||
return quantize_helper(g, output, _scale, _zero_point)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
def _scalar(x):
|
||||
"""Convert a scalar tensor into a Python value."""
|
||||
|
|
@ -826,6 +895,30 @@ def _handle_reduce_dim_none(g, self, op_name):
|
|||
return g.op(op_name, self, keepdims_i=1)
|
||||
return g.op(op_name, self, keepdims_i=0)
|
||||
|
||||
def dequantize_helper(g, qtensor, qdtype=None):
|
||||
tensor, scale, zero_point = _unpack_tuple(qtensor)
|
||||
input_qdtype = cast_pytorch_to_onnx[tensor.type().scalarType()]
|
||||
if qdtype is None:
|
||||
if input_qdtype is not None:
|
||||
qdtype = input_qdtype
|
||||
else:
|
||||
qdtype = torch.onnx.TensorProtoDataType.UINT8
|
||||
value = g.op("Cast", tensor, to_i=qdtype)
|
||||
scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||||
zero_point = g.op("Cast", zero_point, to_i=qdtype)
|
||||
return g.op("DequantizeLinear", value, scale, zero_point), scale, zero_point
|
||||
|
||||
def quantize_helper(g, tensor, scale, zero_point):
|
||||
assert scale is not None
|
||||
if scale.type().scalarType() != "Float":
|
||||
scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||||
|
||||
assert zero_point is not None
|
||||
if zero_point.type().scalarType() not in ("Byte", "Char"):
|
||||
zero_point = g.op("Cast", zero_point, to_i=torch.onnx.TensorProtoDataType.UINT8)
|
||||
output = g.op("QuantizeLinear", tensor, scale, zero_point)
|
||||
return g.op("prim::TupleConstruct", output, scale, zero_point)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# ONNX operator version
|
||||
# ---------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import torch.onnx.utils
|
|||
import torch.onnx.symbolic_helper as sym_help
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented
|
||||
import torch.onnx.symbolic_opset9
|
||||
from torch.onnx.symbolic_opset9 import linear
|
||||
from torch.onnx.symbolic_opset9 import linear, conv2d, add, mul, hardswish, relu
|
||||
|
||||
from sys import maxsize
|
||||
|
||||
|
|
@ -322,39 +322,75 @@ def isfinite(g, input):
|
|||
return __not_(g, __or_(g, inf_node, nan_node))
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
|
||||
def quantize_per_tensor(g, input, scale, zero_point, dtype):
|
||||
dtype = sym_help._get_const(dtype, "i", "dtype")
|
||||
zero_point = g.op("Cast", zero_point, to_i=sym_help.scalar_type_to_onnx[dtype])
|
||||
scale = g.op("Cast", scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||||
return sym_help.quantize_helper(g, input, scale, zero_point)
|
||||
|
||||
|
||||
def dequantize(g, input):
|
||||
return sym_help.dequantize_helper(g, input)[0]
|
||||
|
||||
|
||||
class Quantized:
|
||||
"""
|
||||
https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
|
||||
|
||||
Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were introduced in opset version 10.
|
||||
"""
|
||||
domain = "quantized"
|
||||
|
||||
# DequantizeLinear was added in opset version 10.
|
||||
@staticmethod
|
||||
def linear(g, input_original, weight, weight_scale, weight_zero_point, bias, op_scale, op_zero_point):
|
||||
input_value, input_scale, input_zero_point = sym_help._unpack_tuple(input_original)
|
||||
# From https://pytorch.org/docs/master/generated/torch.nn.quantized.functional.linear.html
|
||||
# input (Tensor) – Quantized input of type torch.quint8
|
||||
input_type_dq = torch.onnx.TensorProtoDataType.UINT8
|
||||
input_value = g.op("Cast", input_value, to_i=input_type_dq)
|
||||
input_scale = g.op("Cast", input_scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||||
input_zero_point = g.op("Cast", input_zero_point, to_i=input_type_dq)
|
||||
input = g.op("DequantizeLinear", input_value, input_scale, input_zero_point)
|
||||
# weight (Tensor) – Quantized weight of type torch.qint8
|
||||
weight_type_dq = torch.onnx.TensorProtoDataType.INT8
|
||||
weight = g.op("Cast", weight, to_i=weight_type_dq)
|
||||
weight_scale = g.op("Cast", weight_scale, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||||
weight_zero_point = g.op("Cast", weight_zero_point, to_i=weight_type_dq)
|
||||
weight = g.op("DequantizeLinear", weight, weight_scale, weight_zero_point)
|
||||
# bias (Tensor) – None or fp32 bias of type torch.float
|
||||
bias = g.op("Cast", bias, to_i=torch.onnx.TensorProtoDataType.FLOAT)
|
||||
def linear(g, q_input, q_weight, bias, op_scale, op_zero_point):
|
||||
input, _, _ = sym_help.dequantize_helper(g, q_input)
|
||||
weight, _, _ = sym_help.dequantize_helper(g, q_weight)
|
||||
|
||||
output = linear(g, input, weight, bias)
|
||||
|
||||
if op_scale is None:
|
||||
op_scale = input_scale
|
||||
elif op_scale.type().scalarType() != "Float":
|
||||
op_scale = g.op("Cast", op_scale, to_i=sym_help.cast_pytorch_to_onnx["Float"])
|
||||
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
if op_zero_point is None:
|
||||
op_zero_point = input_zero_point
|
||||
elif op_zero_point.type().scalarType() != "Byte":
|
||||
op_zero_point = g.op("Cast", op_zero_point, to_i=sym_help.cast_pytorch_to_onnx["Byte"])
|
||||
output = g.op("QuantizeLinear", output, op_scale, op_zero_point)
|
||||
return g.op("prim::TupleConstruct", output, op_scale, op_zero_point)
|
||||
@staticmethod
|
||||
def add(g, x, y, op_scale, op_zero_point):
|
||||
x, _, _ = sym_help.dequantize_helper(g, x)
|
||||
y, _, _ = sym_help.dequantize_helper(g, y)
|
||||
|
||||
output = add(g, x, y)
|
||||
|
||||
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
def mul(g, x, y, op_scale, op_zero_point):
|
||||
x, _, _ = sym_help.dequantize_helper(g, x)
|
||||
y, _, _ = sym_help.dequantize_helper(g, y)
|
||||
|
||||
output = mul(g, x, y)
|
||||
|
||||
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
def hardswish(g, x, op_scale, op_zero_point):
|
||||
x, _, _ = sym_help.dequantize_helper(g, x)
|
||||
|
||||
output = hardswish(g, x)
|
||||
|
||||
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
def conv2d_relu(g, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point):
|
||||
input, _, _ = sym_help.dequantize_helper(g, q_input)
|
||||
weight, _, _ = sym_help.dequantize_helper(g, q_weight)
|
||||
|
||||
output = conv2d(g, input, weight, bias, stride, padding, dilation, groups)
|
||||
output = relu(g, output)
|
||||
|
||||
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
def conv2d(g, q_input, q_weight, bias, stride, padding, dilation, groups, op_scale, op_zero_point):
|
||||
input, _, _ = sym_help.dequantize_helper(g, q_input)
|
||||
weight, _, _ = sym_help.dequantize_helper(g, q_weight)
|
||||
|
||||
output = conv2d(g, input, weight, bias, stride, padding, dilation, groups)
|
||||
|
||||
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch
|
|||
import torch.onnx.symbolic_helper as sym_help
|
||||
import warnings
|
||||
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list, ScalarType
|
||||
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list, ScalarType, quantized_args
|
||||
from torch.onnx.symbolic_opset9 import expand, unused, mul
|
||||
from torch.onnx.symbolic_opset9 import linalg_vector_norm as lvn
|
||||
from torch.nn.modules.utils import _single, _pair, _triple
|
||||
|
|
@ -791,7 +791,7 @@ def narrow(g, input, dim, start, length):
|
|||
end = g.op("Add", start, length)
|
||||
return _slice_helper(g, input, axes=dim, starts=start, ends=end, dynamic_slice=True)
|
||||
|
||||
|
||||
@quantized_args(True, False, False)
|
||||
@parse_args("v", "i", "i")
|
||||
def flatten(g, input, start_dim, end_dim):
|
||||
dim = sym_help._get_tensor_rank(input)
|
||||
|
|
|
|||
|
|
@ -52,3 +52,18 @@ def batch_norm(g, input, weight, bias, running_mean, running_var, training, mome
|
|||
new_running_mean.setType(running_mean.type())
|
||||
new_running_var.setType(running_var.type())
|
||||
return res
|
||||
|
||||
|
||||
class Quantized:
|
||||
"""
|
||||
https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
|
||||
"""
|
||||
domain = "quantized"
|
||||
|
||||
@staticmethod
|
||||
def hardswish(g, x, op_scale, op_zero_point):
|
||||
x, _, _ = sym_help.dequantize_helper(g, x)
|
||||
|
||||
output = hardswish(g, x)
|
||||
|
||||
return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from functools import partial
|
|||
from functools import wraps
|
||||
|
||||
import torch.onnx.symbolic_helper as sym_help
|
||||
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented, ScalarType
|
||||
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented, ScalarType, quantized_args
|
||||
|
||||
from typing import Optional
|
||||
from sys import maxsize as maxsize
|
||||
|
|
@ -949,6 +949,7 @@ avg_pool3d = _avg_pool("avg_pool3d", _triple)
|
|||
|
||||
|
||||
def _adaptive_pool(name, type, tuple_fn, fn=None):
|
||||
@quantized_args(True, False)
|
||||
def symbolic_fn(g, input, output_size):
|
||||
# _adaptive_pool is supported for cases where output_size is 1 for all dimensions,
|
||||
# by executing a GlobalPool.
|
||||
|
|
@ -1957,6 +1958,8 @@ def hardswish(g, self):
|
|||
return g.op("Mul", self, hs)
|
||||
|
||||
|
||||
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp
|
||||
@quantized_args(True, scale=1.0 / 256.0, zero_point=0)
|
||||
@parse_args("v")
|
||||
def hardsigmoid(g, self):
|
||||
# Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid.
|
||||
|
|
@ -2570,6 +2573,7 @@ def erf(g, input):
|
|||
return g.op("Erf", input)
|
||||
|
||||
|
||||
@quantized_args(True, False, False)
|
||||
@parse_args("v", "i", "i")
|
||||
def flatten(g, input, start_dim, end_dim):
|
||||
dim = sym_help._get_tensor_rank(input)
|
||||
|
|
|
|||
|
|
@ -131,14 +131,14 @@ class UnsupportedOperatorError(RuntimeError):
|
|||
def __init__(self, domain, opname, version):
|
||||
supported_version = get_op_supported_version(opname, domain, version)
|
||||
if domain in ["", "aten", "prim", "quantized"]:
|
||||
msg = "Exporting the operator " + opname + " to ONNX opset version " + str(version) + " is not supported. "
|
||||
msg = f"Exporting the operator {domain}::{opname} to ONNX opset version {version} is not supported. "
|
||||
if supported_version is not None:
|
||||
msg += "Support for this operator was added in version " + str(supported_version) + \
|
||||
", try exporting with this version."
|
||||
msg += (f"Support for this operator was added in version {supported_version}, "
|
||||
"try exporting with this version.")
|
||||
else:
|
||||
msg += "Please feel free to request support or submit a pull request on PyTorch GitHub."
|
||||
else:
|
||||
msg = ("ONNX export failed on an operator with unrecognized namespace {}::{}. "
|
||||
msg = (f"ONNX export failed on an operator with unrecognized namespace {domain}::{opname}. "
|
||||
"If you are trying to export a custom operator, make sure you registered "
|
||||
"it with the right domain and version.".format(domain, opname))
|
||||
"it with the right domain and version.")
|
||||
super().__init__(msg)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user