mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40284 Adds graph mode handling for inplace hardswish, and test coverage for functional hardswish. Test Plan: ``` python test/test_quantization.py TestQuantizeScriptPTSQOps.test_hardswish ``` Imported from OSS Differential Revision: D22140628 fbshipit-source-id: 55a514f7dc1130d510f69ee4e611d7cb5e08d02e
3362 lines
137 KiB
Python
3362 lines
137 KiB
Python
# -*- coding: utf-8 -*-
|
|
# torch
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.jit
|
|
import torch.jit.quantized
|
|
from torch._C import parse_ir
|
|
|
|
# torch.quantization
|
|
from torch.quantization import (
|
|
QConfig,
|
|
default_dynamic_qconfig,
|
|
default_observer,
|
|
per_channel_dynamic_qconfig,
|
|
default_per_channel_weight_observer,
|
|
default_qconfig,
|
|
get_default_qconfig,
|
|
quantize,
|
|
quantize_dynamic,
|
|
default_weight_observer,
|
|
default_histogram_observer,
|
|
default_eval_fn,
|
|
fuse_modules,
|
|
quantize_jit,
|
|
quantize_dynamic_jit,
|
|
)
|
|
|
|
# torch.quantization.quantize_jit
|
|
from torch.quantization.quantize_jit import (
|
|
convert_jit,
|
|
convert_dynamic_jit,
|
|
fuse_conv_bn_jit,
|
|
prepare_jit,
|
|
prepare_dynamic_jit,
|
|
script_qconfig,
|
|
)
|
|
|
|
# Testing utils
|
|
from torch.testing._internal.common_quantized import (
|
|
override_qengines,
|
|
qengine_is_fbgemm,
|
|
qengine_is_qnnpack,
|
|
)
|
|
|
|
from torch.testing._internal.common_quantization import (
|
|
QuantizationTestCase,
|
|
skipIfNoFBGEMM,
|
|
get_script_module,
|
|
SingleLayerLinearModel,
|
|
SkipQuantModel,
|
|
NestedModel,
|
|
ConvModel,
|
|
default_per_channel_qconfig,
|
|
test_only_eval_fn,
|
|
ConvBnModel,
|
|
)
|
|
# Annotated models
|
|
from torch.testing._internal.common_quantization import (
|
|
AnnotatedSingleLayerLinearModel,
|
|
AnnotatedSkipQuantModel,
|
|
AnnotatedNestedModel,
|
|
AnnotatedConvModel,
|
|
AnnotatedConvBnModel,
|
|
)
|
|
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_utils import suppress_warnings
|
|
from torch.testing._internal.common_utils import TemporaryFileName
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
from torch.testing._internal.jit_utils import attrs_with_prefix
|
|
from torch.testing._internal.jit_utils import get_forward
|
|
from torch.testing._internal.jit_utils import get_forward_graph
|
|
|
|
from torch.jit._recursive import wrap_cpp_module
|
|
|
|
# Standard library
|
|
import io
|
|
import copy
|
|
import itertools
|
|
import unittest
|
|
import numpy as np
|
|
|
|
class TestQuantizeJitPasses(QuantizationTestCase):
|
|
""" Test graph mode quantization passes used by quantize_jit
|
|
"""
|
|
def test_foldbn_trivial(self):
|
|
# Test trivial case
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.conv = torch.nn.Conv2d(1, 20, 5, 1)
|
|
self.bn = torch.nn.BatchNorm2d(num_features=20)
|
|
self.bn.eps = 0.0023
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
# Check that the transformation doesn't change numerics
|
|
for tracing_mode in [True, False]:
|
|
eager = TestModule()
|
|
eager.eval()
|
|
if tracing_mode:
|
|
x = torch.rand(1, 1, 6, 6)
|
|
scripted_or_traced = torch.jit.trace(eager, x)
|
|
else:
|
|
scripted_or_traced = torch.jit.script(eager)
|
|
scripted_or_traced.eval()
|
|
|
|
# Check that in the original script module's forward we have two
|
|
# CallMethod nodes. One of them should be for conv.forward and the other
|
|
# for bn.forward.
|
|
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
|
.run(str(get_forward(scripted_or_traced._c).graph))
|
|
|
|
# Run FoldConvBatchnorm2d pass.
|
|
scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
|
|
|
|
# Check that after the pass one of the CallMethods is gone (supposedly,
|
|
# the bn.forward).
|
|
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
|
|
.run(str(get_forward_graph(scripted_or_traced._c)))
|
|
|
|
# Check that the transformation doesn't change numerics
|
|
x = torch.rand(1, 1, 6, 6)
|
|
self.assertEqual(eager(x), scripted_or_traced(x))
|
|
|
|
def test_foldbn_trivial_nobias(self):
|
|
# Test trivial case
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.conv = torch.nn.Conv2d(1, 20, 5, 1, bias=False)
|
|
self.bn = torch.nn.BatchNorm2d(num_features=20)
|
|
# to make sure new bias is not zero
|
|
self.bn.eps = 0.0027
|
|
self.bn.bias = torch.nn.Parameter(torch.rand([20]))
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
for tracing_mode in [True, False]:
|
|
eager = TestModule()
|
|
eager.eval()
|
|
if tracing_mode:
|
|
x = torch.rand(1, 1, 6, 6)
|
|
scripted_or_traced = torch.jit.trace(eager, x)
|
|
else:
|
|
scripted_or_traced = torch.jit.script(eager)
|
|
scripted_or_traced.eval()
|
|
|
|
# Check that in the original script module's forward we have two
|
|
# CallMethod nodes. One of them should be for conv.forward and the other
|
|
# for bn.forward.
|
|
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
|
.run(str(get_forward_graph(scripted_or_traced._c)))
|
|
|
|
# Run FoldConvBatchnorm2d pass.
|
|
scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
|
|
|
|
# Check that after the pass one of the CallMethods is gone (supposedly,
|
|
# the bn.forward).
|
|
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
|
|
.run(str(get_forward_graph(scripted_or_traced._c)))
|
|
|
|
# Check that the transformation doesn't change numerics
|
|
x = torch.rand(1, 1, 6, 6)
|
|
self.assertEqual(eager(x), scripted_or_traced(x))
|
|
|
|
def test_foldbn_in_submodule(self):
|
|
# Test that we find Conv-BN patterns in submodules
|
|
class SubModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(SubModule, self).__init__()
|
|
self.conv = torch.nn.Conv2d(1, 20, 5, 1)
|
|
self.bn = torch.nn.BatchNorm2d(num_features=20)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.sub = SubModule()
|
|
|
|
def forward(self, x):
|
|
x = self.sub(x)
|
|
return x
|
|
|
|
for tracing_mode in [True, False]:
|
|
eager = TestModule()
|
|
eager.eval()
|
|
if tracing_mode:
|
|
x = torch.rand(1, 1, 10, 10)
|
|
scripted_or_traced = torch.jit.trace(eager, x)
|
|
else:
|
|
scripted_or_traced = torch.jit.script(eager)
|
|
scripted_or_traced.eval()
|
|
|
|
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
|
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
|
|
|
|
scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
|
|
|
|
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
|
|
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
|
|
|
|
x = torch.rand(1, 1, 10, 10)
|
|
self.assertEqual(eager(x), scripted_or_traced(x))
|
|
|
|
def test_foldbn_in_customConv2D(self):
|
|
# Make sure a custom Conv2D class is not folded
|
|
# as we do not know it does.
|
|
class CustomConv2D(torch.nn.Module):
|
|
def __init__(self, a, b, c, d):
|
|
super(CustomConv2D, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return F.relu(x)
|
|
|
|
class SubModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(SubModule, self).__init__()
|
|
self.conv = CustomConv2D(1, 20, 5, 1)
|
|
self.bn = torch.nn.BatchNorm2d(num_features=20)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.sub = SubModule()
|
|
|
|
def forward(self, x):
|
|
x = self.sub(x)
|
|
return x
|
|
|
|
for tracing_mode in [True, False]:
|
|
eager = TestModule()
|
|
eager.eval()
|
|
if tracing_mode:
|
|
x = torch.rand(1, 20, 10, 10)
|
|
scripted_or_traced = torch.jit.trace(eager, x)
|
|
else:
|
|
scripted_or_traced = torch.jit.script(eager)
|
|
scripted_or_traced.eval()
|
|
|
|
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
|
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
|
|
|
|
scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
|
|
|
|
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
|
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
|
|
|
|
x = torch.rand(1, 20, 10, 10)
|
|
self.assertEqual(eager(x), scripted_or_traced(x))
|
|
|
|
def test_foldbn_shared_classtype(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self, bias=False):
|
|
super(TestModule, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(5, 5, 3, bias=bias)
|
|
self.bn1 = torch.nn.BatchNorm2d(num_features=5)
|
|
self.bn1.running_mean.fill_(-0.2)
|
|
self.bn1.bias = torch.nn.Parameter(torch.rand([5]))
|
|
# to make sure new bias is not zero
|
|
self.bn1.eps = 0.0023
|
|
self.conv2 = torch.nn.Conv2d(5, 5, 3, bias=bias)
|
|
self.bn2 = torch.nn.BatchNorm2d(num_features=5)
|
|
self.bn2.eps = 0.0029
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
for tracing_mode in [True, False]:
|
|
for bias in [True, False]:
|
|
eager = TestModule(bias).eval()
|
|
if tracing_mode:
|
|
x = torch.rand(1, 5, 6, 6)
|
|
scripted_or_traced = torch.jit.trace(eager, x).copy()
|
|
else:
|
|
scripted_or_traced = torch.jit.script(eager).copy()
|
|
torch._C._jit_pass_dedup_module_uses(scripted_or_traced ._c)
|
|
folded = fuse_conv_bn_jit(scripted_or_traced)
|
|
x = torch.rand(1, 5, 6, 6)
|
|
self.assertEqual(eager(x), scripted_or_traced(x))
|
|
|
|
def test_foldbn_complex_cases(self):
|
|
# This test case attempt to try combinations of conv2d with bias/nobias
|
|
# as well as BatchNorm with affine/no-affine along with varying the
|
|
# number of layers.
|
|
# this only works when default dtype is double
|
|
torch.set_default_dtype(torch.double)
|
|
|
|
class SubModule(torch.nn.Module):
|
|
def __init__(self, num_blocks, enable_bias, enable_affine):
|
|
super(SubModule, self).__init__()
|
|
layers = []
|
|
for i in range(num_blocks):
|
|
layers.append(torch.nn.Conv2d(20, 20, 5, 1, bias=enable_bias))
|
|
bn_obj = torch.nn.BatchNorm2d(num_features=20, affine=enable_affine)
|
|
if enable_affine:
|
|
bn_obj.weight = torch.nn.Parameter(torch.rand_like(bn_obj.weight))
|
|
bn_obj.bias = torch.nn.Parameter(torch.rand_like(bn_obj.bias))
|
|
bn_obj.running_mean = torch.rand_like(bn_obj.running_mean)
|
|
bn_obj.running_var = torch.rand_like(bn_obj.running_var)
|
|
layers.append(bn_obj)
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.layers(x)
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self, num_blocks, enable_bias, enable_affine):
|
|
super(TestModule, self).__init__()
|
|
self.sub = SubModule(num_blocks, enable_bias, enable_affine)
|
|
|
|
def forward(self, x):
|
|
x = self.sub(x)
|
|
return x
|
|
|
|
bias_affine_options = itertools.product([True, False], [True, False], [True, False], [1, 2])
|
|
for (tracing_mode, enable_bias, enable_bn_affine, num_layers) in bias_affine_options:
|
|
eager = TestModule(num_layers, enable_bias, enable_bn_affine)
|
|
eager.eval()
|
|
|
|
if tracing_mode:
|
|
x = torch.rand(1, 20, 10, 10)
|
|
scripted_or_traced = torch.jit.trace(eager, x)
|
|
else:
|
|
scripted_or_traced = torch.jit.script(eager)
|
|
scripted_or_traced.eval()
|
|
|
|
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", num_layers * 2, exactly=True) \
|
|
.run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
|
|
|
|
scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
|
|
|
|
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", num_layers, exactly=True) \
|
|
.run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
|
|
|
|
x = torch.rand(1, 20, 10, 10)
|
|
self.assertEqual(eager(x), scripted_or_traced(x))
|
|
torch.set_default_dtype(torch.float)
|
|
|
|
def test_fuse_linear(self):
|
|
input_strs = ["""
|
|
graph(%input, %weight, %bias, %4):
|
|
# CHECK-NOT: aten::t
|
|
# CHECK-NOT: aten::addmm
|
|
# CHECK: aten::linear
|
|
%weight_t = aten::t(%weight)
|
|
%res = aten::addmm(%bias, %input, %weight_t, %4, %4)
|
|
return (%res)""", """
|
|
graph(%input, %weight, %bias, %4):
|
|
# CHECK-NOT: aten::t
|
|
# CHECK-NOT: aten::matmul
|
|
# CHECK-NOT: aten::add_
|
|
# CHECK: aten::linear
|
|
%weight_t = aten::t(%weight)
|
|
%output = aten::matmul(%input, %weight_t)
|
|
%res = aten::add_(%output, %bias, %4)
|
|
return (%res)""", """
|
|
graph(%input, %weight):
|
|
# CHECK-NOT: aten::t
|
|
# CHECK-NOT: aten::matmul
|
|
# CHECK: aten::linear
|
|
%weight_t = aten::t(%weight)
|
|
%output = aten::matmul(%input, %weight_t)
|
|
return (%output)"""]
|
|
for input_str in input_strs:
|
|
graph = parse_ir(input_str)
|
|
torch._C._jit_pass_fuse_linear(graph)
|
|
FileCheck().run(input_str, graph)
|
|
|
|
def test_insert_observers(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 5, 3)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
m = torch.jit.script(M())
|
|
qconfig_dict = {'': default_qconfig}
|
|
m = prepare_jit(m, qconfig_dict)
|
|
# for input and output of conv
|
|
assert len(attrs_with_prefix(m, '_observer_')) == 2
|
|
# for weight
|
|
assert len(attrs_with_prefix(m.conv, '_observer_')) == 1
|
|
|
|
def test_insert_observers_child_qconfig(self):
|
|
class Sub(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Sub, self).__init__()
|
|
self.fc = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
return self.fc(x)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 5, 3)
|
|
self.sub = Sub()
|
|
|
|
def forward(self, x):
|
|
return self.sub(self.conv(x))
|
|
|
|
m = torch.jit.script(M())
|
|
qconfig_dict = {'sub.fc': default_qconfig}
|
|
m = prepare_jit(m, qconfig_dict)
|
|
# input and output of sub
|
|
assert len(attrs_with_prefix(m, '_observer_')) == 2
|
|
# not quantized
|
|
assert len(attrs_with_prefix(m.conv, '_observer_')) == 0
|
|
# no observers since we observe in the outer most call site
|
|
assert len(attrs_with_prefix(m.sub, '_observer_')) == 0
|
|
# weight of linear
|
|
assert len(attrs_with_prefix(m.sub.fc, '_observer_')) == 1
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
|
" with instruction set support avx2 or newer.")
|
|
def test_insert_observers_skip_values(self):
|
|
class ConvFunctionalReLU(torch.nn.Module):
|
|
def __init__(self):
|
|
super(ConvFunctionalReLU, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 5, 3)
|
|
|
|
def forward(self, x):
|
|
return F.relu(self.conv(x))
|
|
|
|
class ConvReLUModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(ConvReLUModule, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 5, 3)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.conv(x))
|
|
|
|
class AddReLUModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(AddReLUModule, self).__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
out = self.conv(x)
|
|
out += x
|
|
return self.relu(out)
|
|
|
|
class AddFunctionalReLU(torch.nn.Module):
|
|
def __init__(self):
|
|
super(AddFunctionalReLU, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
out = self.conv(x)
|
|
out += x
|
|
return F.relu(out)
|
|
|
|
def attrs_with_prefix(module, prefix):
|
|
return [x for x, _ in module._modules._c.items()
|
|
if x.startswith(prefix)]
|
|
|
|
qconfig_dict = {'': default_qconfig}
|
|
m = torch.jit.script(ConvFunctionalReLU())
|
|
m = prepare_jit(m, qconfig_dict)
|
|
# observer for weight of conv
|
|
assert len(attrs_with_prefix(m.conv, '_observer_')) == 1
|
|
# observer for input of conv and output of relu
|
|
assert len(attrs_with_prefix(m, '_observer_')) == 2
|
|
|
|
m = torch.jit.script(ConvReLUModule())
|
|
m = prepare_jit(m, qconfig_dict)
|
|
# observer for input of conv and output of relu
|
|
assert len(attrs_with_prefix(m, '_observer_')) == 2
|
|
# observer for weight of conv
|
|
assert len(attrs_with_prefix(m.conv, '_observer_')) == 1
|
|
# observer for output of relu
|
|
assert len(attrs_with_prefix(m.relu, '_observer_')) == 0
|
|
|
|
m = torch.jit.script(AddReLUModule())
|
|
qconfig_dict = {'': default_qconfig}
|
|
m = prepare_jit(m, qconfig_dict)
|
|
assert len(attrs_with_prefix(m, '_observer')) == 3
|
|
assert len(attrs_with_prefix(m.relu, '_observer')) == 0
|
|
FileCheck().check('aten::add_') \
|
|
.check_not('Observer = prim::GetAttr[name="_observer_') \
|
|
.check('ReLU = prim::GetAttr') \
|
|
.run(str(get_forward_graph(m._c)))
|
|
|
|
m = torch.jit.script(AddFunctionalReLU())
|
|
qconfig_dict = {'': default_qconfig}
|
|
m = prepare_jit(m, qconfig_dict)
|
|
assert len(attrs_with_prefix(m, '_observer')) == 3
|
|
FileCheck().check('aten::add_') \
|
|
.check_not('Observer = prim::GetAttr[name="_observer_') \
|
|
.check('CallFunction') \
|
|
.check('Observer = prim::GetAttr[name="_observer_') \
|
|
.run(str(get_forward_graph(m._c)))
|
|
|
|
def test_insert_observers_weight_dtype(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 5, 3)
|
|
|
|
def forward(self, x):
|
|
return F.relu(self.conv(x))
|
|
|
|
m = torch.jit.script(M())
|
|
qconfig_dict = {'': default_qconfig}
|
|
m = prepare_jit(m, qconfig_dict)
|
|
activation_dtypes = set(obs.getattr('dtype') for x, obs in m._modules._c.items()
|
|
if x.startswith('_observer_'))
|
|
weight_dtypes = set(obs.getattr('dtype') for x, obs in m.conv._modules._c.items()
|
|
if x.startswith('_observer_'))
|
|
assert len(activation_dtypes) == 1, 'Expected to have 1 activation dtype'
|
|
assert len(weight_dtypes) == 1, 'Expected to have 1 weight dtype'
|
|
assert list(activation_dtypes)[0] != list(weight_dtypes)[0], 'Expected activation dtype to '
|
|
' be different from wegiht dtype'
|
|
|
|
def test_insert_observers_for_reused_weight(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
|
|
def forward(self, x, y, weight):
|
|
x = F.conv2d(x, weight)
|
|
y = F.conv2d(y, weight)
|
|
return x + y
|
|
|
|
m = torch.jit.script(M()).eval()
|
|
m = prepare_jit(m, {'': default_qconfig})
|
|
# 3 for x, y, weight, one for output of each F.conv2d and one for output of add
|
|
assert len(attrs_with_prefix(m, '_observer')) == 6
|
|
|
|
def test_insert_observers_shared_class_type(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 5, 3).float()
|
|
self.conv2 = torch.nn.Conv2d(3, 5, 3).float()
|
|
|
|
def forward(self, x):
|
|
return self.conv2(self.conv1(x))
|
|
|
|
m = torch.jit.script(M())
|
|
qconfig_dict = {'': default_qconfig}
|
|
m = prepare_jit(m, qconfig_dict)
|
|
# conv1 and conv2 shares the same type, we need to
|
|
# make sure we didn't quantize the type twice
|
|
conv1_observers = attrs_with_prefix(m.conv1, '_observer_')
|
|
conv2_observers = attrs_with_prefix(m.conv2, '_observer_')
|
|
assert len(conv1_observers) == 1, \
|
|
'Expected to have 1 observer submodules'
|
|
assert len(conv2_observers) == 1, \
|
|
'Expected to have 1 observer submodules'
|
|
assert conv1_observers == conv2_observers, \
|
|
'Expect conv1 and conv2 to have same observers since the class type is shared'
|
|
|
|
def test_insert_observers_for_general_ops(self):
|
|
""" Make sure we skip observers for ops that doesn't require
|
|
observation, e.g. flatten
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = torch.flatten(x)
|
|
return x
|
|
|
|
m = torch.jit.script(M())
|
|
qconfig_dict = {'': default_qconfig}
|
|
m = prepare_jit(m, qconfig_dict)
|
|
# input and output of conv
|
|
assert len(attrs_with_prefix(m, '_observer_')) == 2
|
|
FileCheck().check('Observer = prim::GetAttr[name="_observer_') \
|
|
.check('prim::GetAttr[name="conv"]') \
|
|
.check('prim::CallMethod') \
|
|
.check('Observer = prim::GetAttr[name="_observer_') \
|
|
.check('aten::flatten') \
|
|
.check_not('Observer = prim::GetAttr[name="_observer_') \
|
|
.run(m.graph)
|
|
|
|
# TODO: this is too long, split this to test_insert_observers.py and remove
|
|
# insrt_observers prefix
|
|
def test_insert_observers_propagate_observed(self):
|
|
""" Make sure we propagate observed property through general ops
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
|
|
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = torch.flatten(x)
|
|
# we don't want to insert observer for input of self.conv2
|
|
# because output of self.conv1 is already observed
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
m = torch.jit.script(M())
|
|
qconfig_dict = {'': default_qconfig}
|
|
m = prepare_jit(m, qconfig_dict)
|
|
# input and output of conv
|
|
assert len(attrs_with_prefix(m, '_observer_')) == 3
|
|
FileCheck().check('Observer = prim::GetAttr[name="_observer_') \
|
|
.check('prim::GetAttr[name="conv1"]') \
|
|
.check('prim::CallMethod') \
|
|
.check('Observer = prim::GetAttr[name="_observer_') \
|
|
.check('aten::flatten') \
|
|
.check_not('Observer = prim::GetAttr[name="_observer_') \
|
|
.check('prim::GetAttr[name="conv2"]') \
|
|
.check('Observer = prim::GetAttr[name="_observer_') \
|
|
.run(m.graph)
|
|
|
|
def test_insert_observers_propagate_observed_in_submodule(self):
|
|
""" Make sure we propagate observed property through general ops
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
|
|
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
|
|
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.avgpool(x)
|
|
# we don't want to insert observer for input of self.conv2
|
|
# because output of self.conv1 is already observed
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
m = torch.jit.script(M())
|
|
qconfig_dict = {'': default_qconfig}
|
|
m = prepare_jit(m, qconfig_dict)
|
|
# input and output of conv
|
|
assert len(attrs_with_prefix(m, '_observer_')) == 3
|
|
FileCheck().check('Observer = prim::GetAttr[name="_observer_') \
|
|
.check('prim::GetAttr[name="conv1"]') \
|
|
.check('prim::CallMethod') \
|
|
.check('Observer = prim::GetAttr[name="_observer_') \
|
|
.check('prim::CallMethod') \
|
|
.check_not('Observer = prim::GetAttr[name="_observer_') \
|
|
.check('prim::GetAttr[name="conv2"]') \
|
|
.check('Observer = prim::GetAttr[name="_observer_') \
|
|
.run(m.graph)
|
|
|
|
def test_insert_observers_propagate_observed_for_function(self):
|
|
def channel_shuffle(x, groups):
|
|
# type: (torch.Tensor, int) -> torch.Tensor
|
|
batchsize, num_channels, height, width = x.data.size()
|
|
channels_per_group = num_channels // groups
|
|
# reshape
|
|
x = x.view(batchsize, groups,
|
|
channels_per_group, height, width)
|
|
x = torch.transpose(x, 1, 2).contiguous()
|
|
# flatten
|
|
x = x.view(batchsize, -1, height, width)
|
|
return x
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 3, 1).float()
|
|
self.conv2 = torch.nn.Conv2d(3, 3, 1).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = channel_shuffle(x, 1)
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
m = torch.jit.script(M()).eval()
|
|
m = prepare_jit(m, {'': default_qconfig})
|
|
# we want to test that channel_shuffle is going to pass
|
|
# the observed property from the output of conv1 to input of conv2
|
|
# so that we don't insert observers for input of conv2
|
|
assert len(attrs_with_prefix(m, '_observer_',)) == 3
|
|
|
|
def test_insert_observers_for_if(self):
|
|
class QuantProp(torch.nn.Module):
|
|
def __init__(self, use_skip):
|
|
super(QuantProp, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 1).float()
|
|
self.use_skip = use_skip
|
|
|
|
def forward(self, x):
|
|
if self.use_skip:
|
|
x = self.conv(x)
|
|
return torch.reshape(x, x.shape)
|
|
else:
|
|
x = self.conv(x)
|
|
return torch.reshape(x, x.shape)
|
|
|
|
class Res(torch.nn.Module):
|
|
def __init__(self, use_skip):
|
|
super(Res, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 1).float()
|
|
self.use_skip = use_skip
|
|
|
|
def forward(self, x):
|
|
if self.use_skip:
|
|
return self.conv(x)
|
|
else:
|
|
return self.conv(x)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.quant_prop = QuantProp(True)
|
|
self.res = Res(False)
|
|
|
|
def forward(self, x):
|
|
x = self.quant_prop(x)
|
|
x = self.res(x)
|
|
return x
|
|
|
|
data = [torch.rand(1, 3, 10, 10, dtype=torch.float)]
|
|
result = {False : [1, 2, 2], True : [2, 1, 0]}
|
|
for tracing in [True, False]:
|
|
if tracing:
|
|
m = torch.jit.trace(M(), data).eval()
|
|
else:
|
|
m = torch.jit.script(M()).eval()
|
|
m = prepare_jit(m, {'': default_qconfig})
|
|
assert len(attrs_with_prefix(m, '_observer_',)) == result[tracing][0]
|
|
assert len(attrs_with_prefix(m.quant_prop, '_observer_',)) == result[tracing][1]
|
|
assert len(attrs_with_prefix(m.res, '_observer_',)) == result[tracing][2]
|
|
|
|
def test_insert_observers_for_nested_if(self):
|
|
class Res(torch.nn.Module):
|
|
def __init__(self, use_skip):
|
|
super(Res, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 1).float()
|
|
self.cond = use_skip
|
|
self.use_skip = use_skip
|
|
|
|
def forward(self, x):
|
|
if self.use_skip:
|
|
if self.cond:
|
|
return self.conv(x)
|
|
else:
|
|
return self.conv(x)
|
|
else:
|
|
return self.conv(x)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.res1 = Res(True)
|
|
self.res2 = Res(False)
|
|
|
|
def forward(self, x):
|
|
x = self.res1(x)
|
|
x = self.res2(x)
|
|
return x
|
|
|
|
data = torch.rand((1, 3, 10, 10), dtype=torch.float)
|
|
result = {True : 3, False : 1}
|
|
for tracing in [True, False]:
|
|
if tracing:
|
|
m = torch.jit.trace(M(), data).eval()
|
|
else:
|
|
m = torch.jit.script(M()).eval()
|
|
m = prepare_jit(m, {'': default_qconfig})
|
|
assert len(attrs_with_prefix(m, '_observer_')) == result[tracing]
|
|
|
|
def test_insert_observers_for_if_consistent_observation(self):
|
|
""" check quantization for if works as long as
|
|
output of all branches are quantized/observed consistently
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self, cond):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3).float()
|
|
self.cond = cond
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
# x is already observed
|
|
if self.cond:
|
|
x = torch.flatten(x)
|
|
return x
|
|
|
|
class M2(torch.nn.Module):
|
|
def __init__(self, cond):
|
|
super(M2, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
|
|
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
|
|
self.cond = cond
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
if self.cond:
|
|
x = self.conv2(x)
|
|
# x will be observed in the branch
|
|
else:
|
|
x = torch.flatten(x)
|
|
# since output for both branch are quantized
|
|
# the if node is quantized consistently
|
|
return x
|
|
|
|
data = torch.rand((1, 3, 5, 5), dtype=torch.float)
|
|
options = list(itertools.product([True, False], [True, False]))
|
|
for cond, tracing in options:
|
|
if tracing:
|
|
m = torch.jit.trace(M(cond), data)
|
|
else:
|
|
m = torch.jit.script(M(cond))
|
|
m = prepare_jit(m, {'': default_qconfig})
|
|
assert len(attrs_with_prefix(m, '_observer_')) == 2
|
|
|
|
for cond, tracing in options:
|
|
if tracing:
|
|
m = torch.jit.trace(M2(cond), data)
|
|
else:
|
|
m = torch.jit.script(M2(cond))
|
|
m = prepare_jit(m, {'': default_qconfig})
|
|
num_observers = 2 if tracing and not cond else 3
|
|
assert len(attrs_with_prefix(m, '_observer_')) == num_observers
|
|
|
|
def test_insert_quant_dequant(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 5, 3).float()
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
for is_per_channel in [True, False]:
|
|
m = torch.jit.script(M())
|
|
observer = default_per_channel_weight_observer.with_args(ch_axis=1) \
|
|
if is_per_channel else default_observer
|
|
qconfig_dict = {'': QConfig(activation=observer, weight=observer)}
|
|
m = prepare_jit(m, qconfig_dict)
|
|
data = torch.randn(1, 3, 10, 10, dtype=torch.float)
|
|
|
|
m(data)
|
|
m = convert_jit(m, debug=True)
|
|
assert len(m._modules._c.items()) == 1, \
|
|
'Expected to have single submodule of conv'
|
|
# make sure the quantized model is executable
|
|
m(data)
|
|
quant_func = "aten::quantize_per_channel" if is_per_channel \
|
|
else "aten::quantize_per_tensor"
|
|
FileCheck().check_count(quant_func, 3, exactly=True) \
|
|
.run(m.graph)
|
|
|
|
def test_insert_quant_dequant_shared_class_type(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
|
|
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
return self.conv2(self.conv1(x))
|
|
|
|
for is_per_channel in [True, False]:
|
|
m = torch.jit.script(M())
|
|
observer = default_per_channel_weight_observer.with_args(ch_axis=1) \
|
|
if is_per_channel else default_observer
|
|
qconfig = QConfig(activation=observer, weight=observer)
|
|
qconfig_dict = {'': qconfig}
|
|
m = prepare_jit(m, qconfig_dict)
|
|
# observers for input, output and value between conv1/conv2
|
|
assert len(attrs_with_prefix(m, '_observer_')) == 3, \
|
|
'Expected to have 3 obervers'
|
|
# observer for weight
|
|
assert len(attrs_with_prefix(m.conv1, '_observer_')) == 1, \
|
|
'Expected to have 1 obervers'
|
|
# observer for weight
|
|
assert len(attrs_with_prefix(m.conv2, '_observer_')) == 1, \
|
|
'Expected to have 1 obervers'
|
|
|
|
data = torch.randn(1, 3, 10, 10, dtype=torch.float)
|
|
m(data)
|
|
m = convert_jit(m, debug=True)
|
|
m(data)
|
|
assert m.conv1._c._type() == m.conv2._c._type()
|
|
|
|
# check all observers have been removed
|
|
assert len(attrs_with_prefix(m, '_observer_')) == 0, \
|
|
'Expected to have 0 obervers'
|
|
assert len(attrs_with_prefix(m.conv1, '_observer_')) == 0, \
|
|
'Expected to have 0 obervers'
|
|
assert len(attrs_with_prefix(m.conv2, '_observer_')) == 0, \
|
|
'Expected to have 0 obervers'
|
|
|
|
quant_func = "aten::quantize_per_channel" if is_per_channel \
|
|
else "aten::quantize_per_tensor"
|
|
for module in ['conv1', 'conv2']:
|
|
conv = m._c.getattr(module)
|
|
# quantize weight
|
|
FileCheck().check(quant_func) \
|
|
.check_next("aten::dequantize") \
|
|
.check("prim::CallMethod[name=\"_conv_forward\"]") \
|
|
.check("return") \
|
|
.run(get_forward_graph(conv))
|
|
# no quantize node in _conv_forward
|
|
FileCheck().check_not(quant_func) \
|
|
.check("aten::conv2d") \
|
|
.check_not(quant_func) \
|
|
.check("return") \
|
|
.run(conv._get_method('_conv_forward').graph)
|
|
|
|
def test_dedup_module_uses(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.relu(x)
|
|
x -= 0.5
|
|
return self.relu(x)
|
|
|
|
data = torch.randn((2, 2))
|
|
m = torch.jit.script(M())
|
|
ref_res = m(data)
|
|
assert len([x for x, _ in m._modules._c.items()
|
|
if x.startswith('relu')]) == 1, \
|
|
"Expected to have 1 relu modules after dedup module uses"
|
|
torch._C._jit_pass_dedup_module_uses(m._c)
|
|
m = torch.jit._recursive.wrap_cpp_module(m._c)
|
|
res = m(data)
|
|
assert len([x for x, _ in m._modules._c.items()
|
|
if x.startswith('relu')]) == 2, \
|
|
"Expected to have 2 relu modules after dedup module uses"
|
|
self.assertEqual(res, ref_res)
|
|
|
|
def test_replicate_dequantize(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 1).float()
|
|
|
|
def forward(self, x):
|
|
x = torch.dequantize(x)
|
|
r = self.conv(x)
|
|
r += x
|
|
return r
|
|
x = torch.randn([1, 3, 10, 10], dtype=torch.float)
|
|
x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8)
|
|
m = torch.jit.script(M())
|
|
ref_res = m(x)
|
|
FileCheck().check_count("aten::dequantize", 1, exactly=True) \
|
|
.run(m.graph)
|
|
torch._C._jit_pass_replicate_dequantize(m.graph)
|
|
FileCheck().check_count("aten::dequantize", 2, exactly=True) \
|
|
.run(m.graph)
|
|
res = get_forward(m._c)(x)
|
|
self.assertEqual(res, ref_res)
|
|
|
|
def test_replicate_dequantize_in_block(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, cond):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 1).float()
|
|
|
|
self.cond = cond
|
|
|
|
def forward(self, x):
|
|
x = torch.dequantize(x)
|
|
if self.cond:
|
|
x = self.conv(x)
|
|
else:
|
|
x = x + 3
|
|
return x
|
|
|
|
x = torch.randn([1, 3, 10, 10], dtype=torch.float)
|
|
x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8)
|
|
m = torch.jit.script(M(True))
|
|
ref_res = m(x)
|
|
FileCheck().check_count("aten::dequantize", 1, exactly=True) \
|
|
.run(m.graph)
|
|
torch._C._jit_pass_replicate_dequantize(m.graph)
|
|
FileCheck().check_count("aten::dequantize", 2, exactly=True) \
|
|
.run(m.graph)
|
|
# check dequantize is right before CallMethod of conv
|
|
FileCheck().check("aten::dequantize") \
|
|
.check_next("CallMethod") \
|
|
.run(m.graph)
|
|
# check dequantize is right before add
|
|
FileCheck().check("aten::dequantize") \
|
|
.check("aten::dequantize") \
|
|
.check_next("aten::add") \
|
|
.run(m.graph)
|
|
res = get_forward(m._c)(x)
|
|
self.assertEqual(res, ref_res)
|
|
|
|
def test_swap_functional_linear(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
|
|
def forward(self, x, weight, bias):
|
|
x = torch.dequantize(x)
|
|
weight = torch.dequantize(weight)
|
|
x = F.linear(x, weight, bias)
|
|
x = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
|
|
return x
|
|
|
|
x = torch.rand((10, 5), dtype=torch.float)
|
|
x = torch.quantize_per_tensor(x, scale=0.5, zero_point=1, dtype=torch.quint8)
|
|
weight = torch.rand((5, 5), dtype=torch.float)
|
|
weight = torch.quantize_per_tensor(weight, scale=0.5, zero_point=1, dtype=torch.qint8)
|
|
bias = torch.rand((5), dtype=torch.float)
|
|
m = torch.jit.script(M())
|
|
ref_res = m(x, weight, bias)
|
|
FileCheck().check("CallFunction") \
|
|
.run(m.graph)
|
|
torch._C._jit_pass_swap_functional_linear(m.graph)
|
|
FileCheck().check("aten::linear") \
|
|
.check_not("CallFunction") \
|
|
.run(m.graph)
|
|
res = m(x, weight, bias)
|
|
self.assertEqual(res, ref_res)
|
|
|
|
def test_replicate_quantize_for_if(self):
|
|
""" We want to move quantize nodes for output of prim::If
|
|
inside the prim::If blocks so that we can match quantization
|
|
patterns.
|
|
"""
|
|
class Res(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Res, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 1).float()
|
|
self.use_skip = True
|
|
|
|
def forward(self, x, cond):
|
|
# type: (Tensor, bool) -> Tensor
|
|
# to avoid being frozen
|
|
self.use_skip = cond
|
|
if self.use_skip:
|
|
return self.conv(x)
|
|
else:
|
|
return self.conv(x)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.res1 = Res()
|
|
self.res2 = Res()
|
|
|
|
def forward(self, x):
|
|
x = self.res1(x, True)
|
|
x = self.res2(x, False)
|
|
return x
|
|
|
|
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
qconfig_dict = {'': default_qconfig}
|
|
m = torch.jit.script(M()).eval()
|
|
m = quantize_jit(m, qconfig_dict, test_only_eval_fn, [data])
|
|
# make sure patterns in both branches are fused
|
|
FileCheck().check_count("quantized::conv2d(", 4, exactly=True) \
|
|
.run(m.graph)
|
|
|
|
def test_finalize_for_linear(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.fc = torch.nn.Linear(5, 5).float()
|
|
|
|
def forward(self, x):
|
|
return self.fc(x)
|
|
|
|
data = [(torch.rand((1, 5), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
qconfig_dict = {'': default_qconfig}
|
|
model = torch.jit.script(M()).eval()
|
|
model = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data])
|
|
# make sure there is only one quantize_per_tensor for input
|
|
# and linear_prepack is folded
|
|
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
|
|
.check_not("quantized::linear_prepack") \
|
|
.check("quantized::linear") \
|
|
.run(model.graph)
|
|
|
|
def test_finalize_debug(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3).float()
|
|
self.avgpool = torch.nn.AvgPool2d(3)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.avgpool(x)
|
|
return x
|
|
|
|
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
qconfig_dict = {'': default_qconfig}
|
|
model = torch.jit.script(M()).eval()
|
|
model = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data], debug=True)
|
|
FileCheck().check_not("quantized::conv2d") \
|
|
.check("aten::conv2d") \
|
|
.check("aten::avg_pool2d") \
|
|
.check("aten::q_scale") \
|
|
.check_next("aten::q_zero_point") \
|
|
.check_next("prim::dtype") \
|
|
.check_next("aten::quantize_per_tensor") \
|
|
.check("aten::dequantize") \
|
|
.run(model.graph)
|
|
|
|
def test_finalize_no_extra_dequantize(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x.size(0) * x
|
|
|
|
model = torch.jit.script(M()).eval()
|
|
model = quantize_jit(model, {'': default_qconfig}, test_only_eval_fn, [self.img_data])
|
|
FileCheck().check_not("aten::dequantize(") \
|
|
.run(model.graph)
|
|
|
|
def test_module_list(self):
|
|
class SimpleLinearLayer(torch.nn.Module):
|
|
def __init__(self):
|
|
super(SimpleLinearLayer, self).__init__()
|
|
self.fc = torch.nn.Linear(5, 5).float()
|
|
|
|
def forward(self, x):
|
|
return self.fc(x)
|
|
|
|
class ComplexModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(ComplexModel, self).__init__()
|
|
self.layers = torch.nn.ModuleList([SimpleLinearLayer() for i in range(2)])
|
|
|
|
def forward(self, x):
|
|
# type: (torch.Tensor) -> List[torch.Tensor]
|
|
states = []
|
|
for layer in self.layers:
|
|
val = layer(x)
|
|
states.append(val)
|
|
return states
|
|
|
|
data = torch.rand((1, 5), dtype=torch.float)
|
|
qconfig_dict = {'': default_qconfig}
|
|
model = torch.jit.script(ComplexModel()).eval()
|
|
model = prepare_jit(model, qconfig_dict)
|
|
assert len(attrs_with_prefix(model, '_observer')) == 3
|
|
model(data)
|
|
model = convert_jit(model, debug=False)
|
|
FileCheck().check("quantized::linear") \
|
|
.check("quantized::linear") \
|
|
.run(model.graph)
|
|
|
|
def test_conv_trace(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv1d = torch.nn.Conv1d(3, 3, 3).float()
|
|
self.conv2d = torch.nn.Conv2d(3, 3, 3).float()
|
|
self.conv3d = torch.nn.Conv3d(3, 3, 3).float()
|
|
|
|
def forward(self, x, y, z):
|
|
a = self.conv1d(x)
|
|
b = self.conv2d(y)
|
|
c = self.conv3d(z)
|
|
return (a, b, c)
|
|
|
|
qconfig_dict = {'': default_qconfig}
|
|
inputs = (torch.rand((1, 3, 10), dtype=torch.float),
|
|
torch.rand((1, 3, 10, 10), dtype=torch.float),
|
|
torch.rand((1, 3, 10, 10, 10), dtype=torch.float))
|
|
model = torch.jit.trace(M(), inputs).eval()
|
|
m = prepare_jit(model, qconfig_dict)
|
|
FileCheck().check('aten::conv1d') \
|
|
.check_not("aten::_convolution") \
|
|
.run(str(get_forward_graph(m.conv1d._c)))
|
|
FileCheck().check('aten::conv2d') \
|
|
.check_not("aten::_convolution") \
|
|
.run(str(get_forward_graph(m.conv2d._c)))
|
|
FileCheck().check('aten::conv3d') \
|
|
.check_not("aten::_convolution") \
|
|
.run(str(get_forward_graph(m.conv3d._c)))
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
|
" with instruction set support avx2 or newer.")
|
|
def test_replicate_dequant_same_value(self):
|
|
class Mul(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Mul, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x * x
|
|
|
|
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float),
|
|
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
|
|
qconfig_dict = {'': default_qconfig}
|
|
model = torch.jit.script(Mul()).eval()
|
|
m = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data])
|
|
FileCheck().check("quantized::mul(") \
|
|
.check_not("aten::mul") \
|
|
.run(m.graph)
|
|
|
|
class TestQuantizeJitOps(QuantizationTestCase):
|
|
""" Test graph mode post training static quantization works
|
|
for individual ops end to end.
|
|
"""
|
|
@skipIfNoFBGEMM
|
|
def test_linear(self):
|
|
class ModuleLinear(torch.nn.Module):
|
|
def __init__(self, has_relu=False, f_relu=False):
|
|
super(ModuleLinear, self).__init__()
|
|
self.linear = torch.nn.Linear(30, 4).float()
|
|
if has_relu:
|
|
if f_relu:
|
|
self.relu = F.relu
|
|
else:
|
|
self.relu = torch.nn.ReLU()
|
|
else:
|
|
self.relu = torch.nn.Identity()
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.linear(x))
|
|
|
|
class FuncLinear(torch.nn.Module):
|
|
def __init__(self, has_relu=False, f_relu=False):
|
|
super(FuncLinear, self).__init__()
|
|
self.w = torch.randn(4, 30)
|
|
self.b = torch.randn(4)
|
|
if has_relu:
|
|
if f_relu:
|
|
self.relu = F.relu
|
|
else:
|
|
self.relu = torch.nn.ReLU()
|
|
else:
|
|
self.relu = torch.nn.Identity()
|
|
|
|
def forward(self, x):
|
|
return self.relu(F.linear(x, self.w, self.b))
|
|
|
|
data = [(torch.rand((1, 30), dtype=torch.float),
|
|
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
|
|
for model in [ModuleLinear(has_relu=False),
|
|
FuncLinear(has_relu=False)]:
|
|
model = self.checkGraphModeOp(model, data, "quantized::linear",
|
|
tracing=False)
|
|
FileCheck() \
|
|
.check_count("aten::quantize_per_tensor", 1, exactly=True) \
|
|
.run(model.graph)
|
|
FileCheck().check_not("quantized::linear_prepack") \
|
|
.run(model.graph)
|
|
|
|
for f_relu in [True, False]:
|
|
for model in [ModuleLinear(has_relu=True, f_relu=f_relu),
|
|
FuncLinear(has_relu=True, f_relu=f_relu)]:
|
|
model = self.checkGraphModeOp(model, data,
|
|
"quantized::linear_relu",
|
|
tracing=False)
|
|
checker = FileCheck().check_not("aten::linear") \
|
|
.check_not("aten::relu") \
|
|
.check_not("quantized::linear(") \
|
|
.check_not("quantized::relu(") \
|
|
.run(model.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_conv(self):
|
|
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
|
|
|
|
class Conv(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super(Conv, self).__init__()
|
|
self.conv = conv_module[dim](3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
options = itertools.product([1, 2, 3], [True, False])
|
|
for dim, tracing in options:
|
|
model = self.checkGraphModeOp(
|
|
Conv(dim), self.img_data_dict[dim],
|
|
"quantized::conv{}d".format(dim), tracing)
|
|
# make sure there is only one quantize_per_tensor for input
|
|
# and conv2d_prepack is folded
|
|
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
|
|
.run(model.graph)
|
|
|
|
FileCheck().check_not("quantized::conv{}d_prepack".format(dim)) \
|
|
.run(model.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_conv_relu(self):
|
|
"""tests for conv1d_relu/conv2d_relu/conv3d_relu"""
|
|
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
|
|
|
|
class ConvNdRelu(torch.nn.Module):
|
|
def __init__(self, dim, inplace):
|
|
super(ConvNdRelu, self).__init__()
|
|
self.conv = conv_module[dim](3, 3, 3).float()
|
|
self.relu = torch.nn.ReLU(inplace)
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.conv(x))
|
|
|
|
class ConvNdFunctionalRelu(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super(ConvNdFunctionalRelu, self).__init__()
|
|
self.conv = conv_module[dim](3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
return F.relu(self.conv(x))
|
|
|
|
class ConvNdInplaceFunctionalRelu(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super(ConvNdInplaceFunctionalRelu, self).__init__()
|
|
self.conv = conv_module[dim](3, 3, 3).float()
|
|
|
|
def forward(self, x):
|
|
return F.relu(self.conv(x), True)
|
|
|
|
options = itertools.product([1, 2, 3], [True, False])
|
|
for dim, tracing in options:
|
|
for orig_m in [ConvNdRelu(dim, True),
|
|
ConvNdRelu(dim, False),
|
|
ConvNdFunctionalRelu(dim),
|
|
ConvNdInplaceFunctionalRelu(dim)]:
|
|
conv_name = "conv{}d".format(dim)
|
|
m = self.checkGraphModeOp(
|
|
orig_m, self.img_data_dict[dim], "quantized::conv{}d_relu(".format(dim), tracing=tracing)
|
|
|
|
FileCheck().check_not("aten::conv{}d(".format(dim)) \
|
|
.check_not("aten::relu") \
|
|
.check_not("quantized::conv{}d(".format(dim)) \
|
|
.check_not("quantized::relu(") \
|
|
.run(m.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_add_alpha(self):
|
|
""" Test quant fusion for multiple aten::add using same
|
|
constant alpha as the third argument
|
|
"""
|
|
class QuantizedAdd(torch.nn.Module):
|
|
def __init__(self):
|
|
super(QuantizedAdd, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
z = x + y
|
|
w = y + z
|
|
return z + w
|
|
|
|
data = [(torch.randn(1, 2, 5, 5, dtype=torch.float),
|
|
torch.randn(1, 2, 5, 5, dtype=torch.float),
|
|
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
for tracing in [True, False]:
|
|
m = self.checkGraphModeOp(QuantizedAdd(), data, "quantized::add", tracing)
|
|
FileCheck().check_count("quantized::add", 3, exactly=True) \
|
|
.run(m.graph)
|
|
FileCheck().check_not("aten::add") \
|
|
.check_not("aten::add_") \
|
|
.run(m.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_add_relu_alpha(self):
|
|
""" Test quant fusion for multiple aten::add using same
|
|
constant alpha as the third argument in add_relu pattern
|
|
"""
|
|
class AddRelu(torch.nn.Module):
|
|
def __init__(self, inplace):
|
|
super(AddRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.relu = torch.nn.ReLU(inplace)
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x = x + y
|
|
x = self.relu(x)
|
|
x = x + y
|
|
return self.relu(x)
|
|
|
|
class InplaceAddRelu(torch.nn.Module):
|
|
def __init__(self, inplace):
|
|
super(InplaceAddRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.relu = torch.nn.ReLU(inplace)
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x += y
|
|
x = self.relu(x)
|
|
x += y
|
|
return self.relu(x)
|
|
|
|
class AddFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(AddFunctionalRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x = x + y
|
|
x = F.relu(x)
|
|
x = x + y
|
|
return F.relu(x)
|
|
|
|
class InplaceAddFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(InplaceAddFunctionalRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x += y
|
|
x = F.relu(x)
|
|
x += y
|
|
return F.relu(x)
|
|
|
|
class AddInplaceFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(AddInplaceFunctionalRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x = x + y
|
|
x = F.relu(x, True)
|
|
x = x + y
|
|
return F.relu(x, True)
|
|
|
|
class InplaceAddInplaceFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(InplaceAddInplaceFunctionalRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x += y
|
|
x = F.relu(x, True)
|
|
x += y
|
|
return F.relu(x, True)
|
|
|
|
data = [(torch.rand((1, 2, 5, 5), dtype=torch.float),
|
|
torch.rand((1, 2, 5, 5), dtype=torch.float),
|
|
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
for m_orig in [AddRelu(True), AddRelu(False),
|
|
InplaceAddRelu(True), InplaceAddRelu(False),
|
|
AddFunctionalRelu(), InplaceAddFunctionalRelu(),
|
|
AddInplaceFunctionalRelu(), InplaceAddInplaceFunctionalRelu()]:
|
|
for tracing in [True, False]:
|
|
m = self.checkGraphModeOp(m_orig, data, "quantized::add_relu(", tracing=tracing)
|
|
FileCheck().check_count("quantized::add_relu(", 2, exactly=True) \
|
|
.run(m.graph)
|
|
FileCheck().check_not("aten::add(") \
|
|
.check_not("aten::add_(") \
|
|
.check_not("aten::relu(") \
|
|
.check_not("aten::relu_(") \
|
|
.check_not("quantized::add(") \
|
|
.check_not("quantized::relu(") \
|
|
.run(m.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_add(self):
|
|
class QuantizedAdd(torch.nn.Module):
|
|
def __init__(self):
|
|
super(QuantizedAdd, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
return x + y
|
|
|
|
class QuantizedInplaceAdd(torch.nn.Module):
|
|
def __init__(self):
|
|
super(QuantizedInplaceAdd, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x += y
|
|
return x
|
|
|
|
class NonQuantizedAdd(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NonQuantizedAdd, self).__init__()
|
|
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
class NonQuantizedInplaceAdd(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NonQuantizedInplaceAdd, self).__init__()
|
|
|
|
def forward(self, x, y):
|
|
x += y
|
|
return x
|
|
|
|
data = [(torch.randn(1, 2, 5, 5, dtype=torch.float),
|
|
torch.randn(1, 2, 5, 5, dtype=torch.float),
|
|
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
for m, quantized in [(QuantizedAdd(), True),
|
|
(QuantizedInplaceAdd(), True),
|
|
(NonQuantizedAdd(), False),
|
|
(NonQuantizedInplaceAdd(), False)]:
|
|
for tracing in [True, False]:
|
|
op = "quantized::add" if quantized else "aten::add"
|
|
m = self.checkGraphModeOp(m, data, op, tracing)
|
|
# TODO: remove after refactor of checkGraphModeOp
|
|
if quantized:
|
|
FileCheck().check_not("aten::add") \
|
|
.check_not("aten::add_") \
|
|
.run(m.graph)
|
|
else:
|
|
FileCheck().check_not("quantized::add") \
|
|
.run(m.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_add_scalar(self):
|
|
class QuantizedAddScalar(torch.nn.Module):
|
|
def __init__(self):
|
|
super(QuantizedAddScalar, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x + 3
|
|
|
|
class QuantizedInplaceAddScalar(torch.nn.Module):
|
|
def __init__(self):
|
|
super(QuantizedInplaceAddScalar, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x += 3
|
|
return x
|
|
|
|
class NonQuantizedAddScalar(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NonQuantizedAddScalar, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 3
|
|
|
|
class NonQuantizedInplaceAddScalar(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NonQuantizedInplaceAddScalar, self).__init__()
|
|
|
|
def forward(self, x):
|
|
x += 3
|
|
return x
|
|
|
|
data = [(torch.randn(1, 2, 5, 5, dtype=torch.float),
|
|
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
for m, quantized in [(QuantizedAddScalar(), True),
|
|
(QuantizedInplaceAddScalar(), True),
|
|
(NonQuantizedAddScalar(), False),
|
|
(NonQuantizedInplaceAddScalar(), False)]:
|
|
for tracing in [True, False]:
|
|
op = "quantized::add_scalar" if quantized else "aten::add"
|
|
# TODO: fix debug=True numerics
|
|
m = self.checkGraphModeOp(m, data, op, tracing, check=False)
|
|
# TODO: remove after refactor of checkGraphModeOp
|
|
if quantized:
|
|
FileCheck().check_not("aten::add") \
|
|
.check_not("aten::add_") \
|
|
.run(m.graph)
|
|
else:
|
|
FileCheck().check_not("quantized::add_scalar") \
|
|
.run(m.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_add_relu(self):
|
|
class AddRelu(torch.nn.Module):
|
|
def __init__(self, inplace):
|
|
super(AddRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.relu = torch.nn.ReLU(inplace)
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x = x + y
|
|
return self.relu(x)
|
|
|
|
class InplaceAddRelu(torch.nn.Module):
|
|
def __init__(self, inplace):
|
|
super(InplaceAddRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.relu = torch.nn.ReLU(inplace)
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x += y
|
|
return self.relu(x)
|
|
|
|
class AddFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(AddFunctionalRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x = x + y
|
|
return F.relu(x)
|
|
|
|
class InplaceAddFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(InplaceAddFunctionalRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x += y
|
|
return F.relu(x)
|
|
|
|
class AddInplaceFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(AddInplaceFunctionalRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x = x + y
|
|
return F.relu(x, True)
|
|
|
|
class InplaceAddInplaceFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(InplaceAddInplaceFunctionalRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x += y
|
|
return F.relu(x, True)
|
|
|
|
data = [(torch.rand((1, 2, 5, 5), dtype=torch.float),
|
|
torch.rand((1, 2, 5, 5), dtype=torch.float),
|
|
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
for m in [AddRelu(True), AddRelu(False),
|
|
InplaceAddRelu(True), InplaceAddRelu(False),
|
|
AddFunctionalRelu(), InplaceAddFunctionalRelu(),
|
|
AddInplaceFunctionalRelu(), InplaceAddInplaceFunctionalRelu()]:
|
|
for tracing in [True, False]:
|
|
m = self.checkGraphModeOp(m, data, "quantized::add_relu(", tracing)
|
|
FileCheck().check_not("aten::add(") \
|
|
.check_not("aten::add_(") \
|
|
.check_not("aten::relu(") \
|
|
.check_not("aten::relu_(") \
|
|
.check_not("quantized::add(") \
|
|
.check_not("quantized::relu(") \
|
|
.run(m.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_add_scalar_relu(self):
|
|
class AddScalarRelu(torch.nn.Module):
|
|
def __init__(self, inplace):
|
|
super(AddScalarRelu, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.relu = torch.nn.ReLU(inplace)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return self.relu(x + 3)
|
|
|
|
class InplaceAddScalarRelu(torch.nn.Module):
|
|
def __init__(self, inplace):
|
|
super(InplaceAddScalarRelu, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.relu = torch.nn.ReLU(inplace)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x += 3
|
|
return self.relu(x)
|
|
|
|
class AddScalarFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(AddScalarFunctionalRelu, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return F.relu(x + 3)
|
|
|
|
class InplaceAddScalarFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(InplaceAddScalarFunctionalRelu, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x += 3
|
|
return F.relu(x)
|
|
|
|
class AddScalarInplaceFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(AddScalarInplaceFunctionalRelu, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return F.relu(x + 3, True)
|
|
|
|
class InplaceAddScalarInplaceFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(InplaceAddScalarInplaceFunctionalRelu, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x += 3
|
|
return F.relu(x, True)
|
|
|
|
data = [(torch.rand((1, 2, 5, 5), dtype=torch.float),
|
|
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
for m in [AddScalarRelu(True), AddScalarRelu(False),
|
|
InplaceAddScalarRelu(True), InplaceAddScalarRelu(False),
|
|
AddScalarFunctionalRelu(),
|
|
InplaceAddScalarFunctionalRelu(),
|
|
AddScalarInplaceFunctionalRelu(),
|
|
InplaceAddScalarInplaceFunctionalRelu()]:
|
|
for tracing in [True, False]:
|
|
# quantized::add_scalar_relu or quantized::add_scalar_relu_out
|
|
# TODO: split this after refactor of checkGraphModeOp
|
|
# TODO: fix debug=True numerics
|
|
m = self.checkGraphModeOp(m, data, "quantized::add_scalar_relu", tracing, check=False)
|
|
FileCheck().check_not("aten::add(") \
|
|
.check_not("aten::add_(") \
|
|
.check_not("aten::relu(") \
|
|
.check_not("aten::relu_(") \
|
|
.check_not("quantized::add_scalar(") \
|
|
.check_not("quantized::relu(") \
|
|
.run(m.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_cat(self):
|
|
""" quantization of the output of cat will be depend on the
|
|
input of cat. we only quantize the output of cat when its inputs are quantized.
|
|
"""
|
|
class QuantizedCat(torch.nn.Module):
|
|
def __init__(self):
|
|
super(QuantizedCat, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
return torch.cat([x, y], 1)
|
|
|
|
class NonQuantizedCat(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NonQuantizedCat, self).__init__()
|
|
|
|
def forward(self, x, y):
|
|
return torch.cat([x, y], 1)
|
|
|
|
data = [(torch.randn(1, 2, 5, 5, dtype=torch.float),
|
|
torch.randn(1, 2, 5, 5, dtype=torch.float),
|
|
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
for tracing in [True, False]:
|
|
m = self.checkGraphModeOp(QuantizedCat(), data, "quantized::cat", tracing)
|
|
FileCheck().check_not("aten::cat") \
|
|
.run(m.graph)
|
|
|
|
m = self.checkGraphModeOp(NonQuantizedCat(), data, "aten::cat", tracing)
|
|
FileCheck().check_not("quantized::cat") \
|
|
.run(m.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_qbatch_norm(self):
|
|
bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super(M, self).__init__()
|
|
self.bn = bn_module[dim](3).to(torch.float)
|
|
|
|
def forward(self, x):
|
|
return self.bn(x)
|
|
|
|
options = itertools.product([True, False], [2, 3])
|
|
for tracing, dim in options:
|
|
model = self.checkGraphModeOp(M(dim), self.img_data_dict[dim], "quantized::batch_norm", tracing)
|
|
|
|
FileCheck().check_not("aten::batch_norm") \
|
|
.run(model.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_qbatch_norm_relu(self):
|
|
bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
|
|
|
|
class BNRelu(torch.nn.Module):
|
|
def __init__(self, dim, inplace):
|
|
super(BNRelu, self).__init__()
|
|
self.bn = bn_module[dim](3).to(torch.float)
|
|
self.relu = torch.nn.ReLU(inplace=inplace)
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.bn(x))
|
|
|
|
class BNFuncRelu(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super(BNFuncRelu, self).__init__()
|
|
self.bn = bn_module[dim](3).to(torch.float)
|
|
|
|
def forward(self, x):
|
|
return F.relu(self.bn(x), False)
|
|
|
|
class BNFuncInplaceRelu(torch.nn.Module):
|
|
def __init__(self, dim):
|
|
super(BNFuncInplaceRelu, self).__init__()
|
|
self.bn = bn_module[dim](3).to(torch.float)
|
|
|
|
def forward(self, x):
|
|
return F.relu(self.bn(x), True)
|
|
|
|
options = itertools.product([True, False], [2, 3])
|
|
for tracing, dim in options:
|
|
for instance in [BNRelu(dim, True), BNRelu(dim, False),
|
|
BNFuncRelu(dim), BNFuncInplaceRelu(dim)]:
|
|
model = self.checkGraphModeOp(instance, self.img_data_dict[dim],
|
|
"quantized::batch_norm_relu", tracing)
|
|
FileCheck().check_not("aten::batch_norm") \
|
|
.check_not("aten::relu") \
|
|
.check_not("aten::relu_") \
|
|
.run(model.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_mul(self):
|
|
class QuantizedMul(torch.nn.Module):
|
|
def __init__(self):
|
|
super(QuantizedMul, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
return x * y
|
|
|
|
class QuantizedInplaceMul(torch.nn.Module):
|
|
def __init__(self):
|
|
super(QuantizedInplaceMul, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x *= y
|
|
return x
|
|
|
|
class NonQuantizedMul(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NonQuantizedMul, self).__init__()
|
|
|
|
def forward(self, x, y):
|
|
return x * y
|
|
|
|
class NonQuantizedInplaceMul(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NonQuantizedInplaceMul, self).__init__()
|
|
|
|
def forward(self, x, y):
|
|
x *= y
|
|
return x
|
|
|
|
data = [(torch.randn(1, 2, 10, 10, dtype=torch.float),
|
|
torch.randn(1, 2, 10, 10, dtype=torch.float),
|
|
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
for m, quantized in [(QuantizedMul(), True),
|
|
(QuantizedInplaceMul(), True),
|
|
(NonQuantizedMul(), False),
|
|
(NonQuantizedInplaceMul(), False)]:
|
|
for tracing in [True, False]:
|
|
op = "quantized::mul" if quantized else "aten::mul"
|
|
m = self.checkGraphModeOp(m, data, op, tracing)
|
|
# TODO: remove after refactor of checkGraphModeOp
|
|
if quantized:
|
|
FileCheck().check_not("aten::mul") \
|
|
.check_not("aten::mul_") \
|
|
.run(m.graph)
|
|
else:
|
|
FileCheck().check_not("quantized::mul") \
|
|
.run(m.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_mul_scalar(self):
|
|
class QuantizedMulScalar(torch.nn.Module):
|
|
def __init__(self):
|
|
super(QuantizedMulScalar, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return x * 3
|
|
|
|
class QuantizedInplaceMulScalar(torch.nn.Module):
|
|
def __init__(self):
|
|
super(QuantizedInplaceMulScalar, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x *= 3
|
|
return x
|
|
|
|
class NonQuantizedMulScalar(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NonQuantizedMulScalar, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x * 3
|
|
|
|
class NonQuantizedInplaceMulScalar(torch.nn.Module):
|
|
def __init__(self):
|
|
super(NonQuantizedInplaceMulScalar, self).__init__()
|
|
|
|
def forward(self, x):
|
|
x *= 3
|
|
return x
|
|
|
|
data = [(torch.randn(1, 2, 5, 5, dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
for m, quantized in [(QuantizedMulScalar(), True),
|
|
(QuantizedInplaceMulScalar(), True),
|
|
(NonQuantizedMulScalar(), False),
|
|
(NonQuantizedInplaceMulScalar(), False)]:
|
|
for tracing in [True, False]:
|
|
op = "quantized::mul_scalar" if quantized else "aten::mul"
|
|
# TODO: fix debug=True numerics
|
|
m = self.checkGraphModeOp(m, data, op, tracing, check=False)
|
|
# TODO: remove after refactor of checkGraphModeOp
|
|
if quantized:
|
|
FileCheck().check_not("aten::mul") \
|
|
.check_not("aten::mul_") \
|
|
.run(m.graph)
|
|
else:
|
|
FileCheck().check_not("quantized::mul_scalar") \
|
|
.run(m.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_mul_relu(self):
|
|
class MulRelu(torch.nn.Module):
|
|
def __init__(self, inplace):
|
|
super(MulRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.relu = torch.nn.ReLU(inplace)
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x = x * y
|
|
return self.relu(x)
|
|
|
|
class InplaceMulRelu(torch.nn.Module):
|
|
def __init__(self, inplace):
|
|
super(InplaceMulRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.relu = torch.nn.ReLU(inplace)
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x *= y
|
|
return self.relu(x)
|
|
|
|
class MulFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MulFunctionalRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x = x * y
|
|
return F.relu(x)
|
|
|
|
class InplaceMulFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(InplaceMulFunctionalRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x *= y
|
|
return F.relu(x)
|
|
|
|
class MulInplaceFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MulInplaceFunctionalRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x = x * y
|
|
return F.relu(x, True)
|
|
|
|
class InplaceMulInplaceFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(InplaceMulInplaceFunctionalRelu, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x, y):
|
|
x = self.conv1(x)
|
|
y = self.conv2(y)
|
|
x *= y
|
|
return F.relu(x, True)
|
|
|
|
data = [(torch.rand((1, 2, 5, 5), dtype=torch.float),
|
|
torch.rand((1, 2, 5, 5), dtype=torch.float),
|
|
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
for m in [MulRelu(True), MulRelu(False),
|
|
InplaceMulRelu(True), InplaceMulRelu(False),
|
|
MulFunctionalRelu(), InplaceMulFunctionalRelu(),
|
|
MulInplaceFunctionalRelu(), InplaceMulInplaceFunctionalRelu()]:
|
|
for tracing in [True, False]:
|
|
m = self.checkGraphModeOp(m, data, "quantized::mul_relu(", tracing)
|
|
FileCheck().check_not("aten::mul(") \
|
|
.check_not("aten::mul_(") \
|
|
.check_not("aten::relu(") \
|
|
.check_not("aten::relu_(") \
|
|
.check_not("quantized::mul(") \
|
|
.check_not("quantized::relu(") \
|
|
.run(m.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_quantized_mul_scalar_relu(self):
|
|
class MulScalarRelu(torch.nn.Module):
|
|
def __init__(self, inplace):
|
|
super(MulScalarRelu, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.relu = torch.nn.ReLU(inplace)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return self.relu(x * 3)
|
|
|
|
class InplaceMulScalarRelu(torch.nn.Module):
|
|
def __init__(self, inplace):
|
|
super(InplaceMulScalarRelu, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.relu = torch.nn.ReLU(inplace)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x *= 3
|
|
return self.relu(x)
|
|
|
|
class MulScalarFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MulScalarFunctionalRelu, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return F.relu(x * 3)
|
|
|
|
class InplaceMulScalarFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(InplaceMulScalarFunctionalRelu, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x *= 3
|
|
return F.relu(x)
|
|
|
|
class MulScalarInplaceFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MulScalarInplaceFunctionalRelu, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
return F.relu(x * 3, True)
|
|
|
|
class InplaceMulScalarInplaceFunctionalRelu(torch.nn.Module):
|
|
def __init__(self):
|
|
super(InplaceMulScalarInplaceFunctionalRelu, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x *= 3
|
|
return F.relu(x, True)
|
|
|
|
data = [(torch.randn(1, 2, 5, 5, dtype=torch.float),
|
|
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
for m in [MulScalarRelu(True), MulScalarRelu(False),
|
|
InplaceMulScalarRelu(True), InplaceMulScalarRelu(False),
|
|
MulScalarFunctionalRelu(),
|
|
InplaceMulScalarFunctionalRelu(),
|
|
MulScalarInplaceFunctionalRelu(),
|
|
InplaceMulScalarInplaceFunctionalRelu()]:
|
|
for tracing in [True, False]:
|
|
# quantized::mul_scalar_relu or quantized::mul_scalar_relu_out
|
|
# TODO: fix debug=True numerics
|
|
m = self.checkGraphModeOp(m, data, "quantized::mul_scalar_relu", tracing, check=False)
|
|
FileCheck().check_not("aten::mul(") \
|
|
.check_not("aten::mul_(") \
|
|
.check_not("aten::relu(") \
|
|
.check_not("aten::relu_(") \
|
|
.check_not("quantized::mul_scalar(") \
|
|
.check_not("quantized::relu(") \
|
|
.run(m.graph)
|
|
|
|
def test_hardswish(self):
|
|
class FunctionalHardswish(torch.nn.Module):
|
|
def __init__(self, inplace):
|
|
super(FunctionalHardswish, self).__init__()
|
|
self.inplace = inplace
|
|
|
|
def forward(self, input):
|
|
return torch.nn.functional.hardswish(input, inplace=self.inplace)
|
|
|
|
modules = [torch.nn.Hardswish(), FunctionalHardswish(True),
|
|
FunctionalHardswish(False)]
|
|
|
|
for test_case in itertools.product([True, False], modules):
|
|
tracing, m = test_case
|
|
m = self.checkGraphModeOp(
|
|
m, self.img_data, "quantized::hardswish", tracing)
|
|
FileCheck().check_not("aten::hardswish") \
|
|
.check_not("aten::hardswish_") \
|
|
.run(m.graph)
|
|
|
|
def test_elu(self):
|
|
class FunctionalELU(torch.nn.Module):
|
|
def __init__(self, inplace=False):
|
|
super(FunctionalELU, self).__init__()
|
|
self.inplace = inplace
|
|
|
|
def forward(self, input):
|
|
return torch.nn.functional.elu(input, inplace=self.inplace)
|
|
|
|
modules = [torch.nn.ELU, FunctionalELU]
|
|
for test_case in itertools.product([True, False], [True, False], modules):
|
|
tracing, inplace, mod_class = test_case
|
|
m = mod_class(inplace=inplace)
|
|
m = self.checkGraphModeOp(m, self.img_data, "quantized::elu", tracing)
|
|
FileCheck().check_not("aten::elu") \
|
|
.check_not("aten::elu_") \
|
|
.run(m.graph)
|
|
|
|
def test_layer_norm(self):
|
|
data = [(torch.rand((1, 2, 5, 5), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
layer_norm = torch.nn.LayerNorm([2, 5, 5])
|
|
for tracing in [True, False]:
|
|
m = self.checkGraphModeOp(layer_norm, data, "quantized::layer_norm", tracing)
|
|
FileCheck().check_not("aten::layer_norm") \
|
|
.run(m.graph)
|
|
|
|
def test_group_norm(self):
|
|
data = [(torch.rand((1, 4, 5, 5), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
group_norm = torch.nn.GroupNorm(2, 4)
|
|
for tracing in [True, False]:
|
|
m = self.checkGraphModeOp(group_norm, data, "quantized::group_norm", tracing)
|
|
FileCheck().check_not("aten::group_norm") \
|
|
.run(m.graph)
|
|
|
|
def test_instance_norm(self):
|
|
data_1d = [(torch.rand((1, 4, 5), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
data_2d = [(torch.rand((1, 4, 5, 1), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
data_3d = [(torch.rand((1, 4, 5, 1, 1), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
data = {1 : data_1d, 2 : data_2d, 3 : data_3d}
|
|
instance_norm_modules = {1 : torch.nn.InstanceNorm1d,
|
|
2 : torch.nn.InstanceNorm2d,
|
|
3 : torch.nn.InstanceNorm3d}
|
|
|
|
options = itertools.product([1, 2, 3], [True, False])
|
|
for dim, tracing in options:
|
|
instance_norm = instance_norm_modules[dim](4)
|
|
m = self.checkGraphModeOp(
|
|
instance_norm, data[dim], "quantized::instance_norm", tracing)
|
|
FileCheck().check_not("aten::instance_norm") \
|
|
.run(m.graph)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_clamp(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(2, 2, 2).float()
|
|
self.relu6 = torch.nn.ReLU6()
|
|
self.relu6_ = torch.nn.ReLU6(True)
|
|
self.hardtanh = torch.nn.Hardtanh()
|
|
self.hardtanh_ = torch.nn.Hardtanh(inplace=True)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.relu6(x)
|
|
self.relu6_(x)
|
|
x = F.relu6(x)
|
|
x = torch.clamp(x, -3, 3)
|
|
x = x.clamp(-2.5, 2.5)
|
|
# x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready
|
|
x = self.hardtanh(x)
|
|
self.hardtanh_(x)
|
|
x = F.hardtanh(x)
|
|
F.hardtanh_(x)
|
|
return x
|
|
|
|
data = [(torch.rand((1, 2, 5, 5), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
|
options = itertools.product(["aten::clamp", "aten::hardtanh", "aten::hardtanh_"], [True, False])
|
|
for op, tracing in options:
|
|
m = self.checkGraphModeOp(M(), data, op, tracing)
|
|
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
|
|
.run(m.graph)
|
|
|
|
FileCheck().check_count("aten::dequantize", 1, exactly=True) \
|
|
.run(m.graph)
|
|
|
|
def test_general_shape_ops(self):
|
|
""" A test that checks dequantize will be swapped for
|
|
all supported general shape ops like aten::flatten
|
|
without actually checking for execution of these ops
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
|
|
self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
|
|
self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
|
|
self.dropout = torch.nn.Dropout()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.maxpool1d(x)
|
|
x = self.maxpool2d(x)
|
|
x = self.maxpool3d(x)
|
|
x = torch.flatten(x)
|
|
x = torch.max(x)
|
|
x = torch.min(x)
|
|
x = x.reshape([-1])
|
|
x = x.resize_(1, 1, x.numel())
|
|
x = x.view(-1)
|
|
# prim::ListConstruct
|
|
xs = [x, x]
|
|
# prim::ListUnpack
|
|
x, y = xs
|
|
# prim::TupleConstruct
|
|
xs = (x, x)
|
|
# prim::TupleUnpack
|
|
x, y = xs
|
|
x = x.transpose(1, 2)
|
|
x = x.contiguous()
|
|
x, y = torch.chunk(x, 2)
|
|
x = F.dropout(x)
|
|
x = self.dropout(x)
|
|
x, _ = torch.sort(x)
|
|
x = x.permute(0, 2, 3, 1)
|
|
x = torch.repeat_interleave(x, 3, 1)
|
|
x = self.relu(x)
|
|
x = F.relu(x)
|
|
x.relu_()
|
|
x = x.squeeze(0)
|
|
x.squeeze_(0)
|
|
x = torch.squeeze(x, 0)
|
|
x = x.unsqueeze(0)
|
|
x.unsqueeze_(0)
|
|
x = torch.unsqueeze(x, 0)
|
|
x = x.detach()
|
|
x.detach_()
|
|
x = x.repeat(4, 2)
|
|
y = []
|
|
y.append(x)
|
|
x, _ = y
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
data = torch.rand(1, 3, 10, 10)
|
|
# This model is not executable since we just put all ops
|
|
# in the same forward, therefore we only test scripting
|
|
m = torch.jit.script(M())
|
|
qconfig = script_qconfig(default_qconfig)
|
|
# dummy data to suppress warning
|
|
get_forward(qconfig.activation)(data)
|
|
get_forward(qconfig.weight)(data)
|
|
|
|
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(
|
|
m._c, 'forward', {'': qconfig}, inplace=False))
|
|
m = convert_jit(m)
|
|
# This checks that the dequantize from the output of first conv
|
|
# is being propagated to the end, so that we don't insert extra
|
|
# observers and also successfully fused two quantized::conv2d
|
|
# patterns
|
|
# one quantize_per_tensor for input
|
|
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
|
|
.check_count("quantized::conv2d", 2, exactly=True) \
|
|
.check("aten::dequantize") \
|
|
.run(m.graph)
|
|
|
|
def test_general_value_ops(self):
|
|
""" A test that checks correct patterns are produced for
|
|
all supported general value ops like aten::avg_pool2d \
|
|
without actually checking for execution of these ops
|
|
"""
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 3)
|
|
self.avg_pool1d = torch.nn.AvgPool1d(3)
|
|
self.avg_pool2d = torch.nn.AvgPool2d(3)
|
|
self.avg_pool3d = torch.nn.AvgPool3d(3)
|
|
self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
|
|
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
|
|
self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
|
|
self.leaky_relu = torch.nn.LeakyReLU()
|
|
self.hardsigmoid = torch.nn.Hardsigmoid()
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
self.tanh = torch.nn.Tanh()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.avg_pool1d(x)
|
|
x = self.avg_pool2d(x)
|
|
x = self.avg_pool3d(x)
|
|
x = self.adaptive_avg_pool1d(x)
|
|
x = self.adaptive_avg_pool2d(x)
|
|
x = self.adaptive_avg_pool3d(x)
|
|
x = F.avg_pool1d(x, 3)
|
|
x = F.avg_pool2d(x, 3)
|
|
x = F.avg_pool3d(x, 3)
|
|
x = F.adaptive_avg_pool1d(x, (1))
|
|
x = F.adaptive_avg_pool2d(x, (1, 1))
|
|
x = F.adaptive_avg_pool3d(x, (1, 1, 1))
|
|
x = torch.mean(x)
|
|
x = torch.mean(x, [2, 3], False)
|
|
x = x.mean()
|
|
x = x.mean([2, 3], True)
|
|
# interpolate node will introduce 3 quantize_per_tensor ops
|
|
x = F.interpolate(x, 4, mode='nearest') # interpolate node
|
|
x = F.upsample(x, (32, 32)) # interpolate node
|
|
x = F.upsample_nearest(x, (32, 32)) # interpolate node
|
|
x = F.interpolate(x, 4, mode='linear') # common node
|
|
x = F.upsample_bilinear(x, (32, 32)) # common node
|
|
x = self.leaky_relu(x)
|
|
x = F.leaky_relu(x)
|
|
x.leaky_relu_()
|
|
x = self.hardsigmoid(x)
|
|
x = F.hardsigmoid(x)
|
|
x.hardsigmoid_()
|
|
x = self.sigmoid(x)
|
|
x = torch.sigmoid(x)
|
|
# F.sigmoid is deprecated
|
|
x = x.sigmoid()
|
|
x.sigmoid_()
|
|
x = self.tanh(x)
|
|
# F.tanh is deprecated
|
|
x = torch.tanh(x)
|
|
x = x.tanh()
|
|
x.tanh_()
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
# This model is not executable since we just put all ops
|
|
# in the same forward, therefore we only test scripting
|
|
m = torch.jit.script(M())
|
|
qconfig = script_qconfig(default_qconfig)
|
|
# dummy data to suppress warning
|
|
data = torch.rand(1, 3, 10, 10)
|
|
get_forward(qconfig.activation)(data)
|
|
get_forward(qconfig.weight)(data)
|
|
|
|
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(
|
|
m._c, 'forward', {'': qconfig}, inplace=False))
|
|
# Checking the model before fianlize contain unfused patterns
|
|
# that numerically matches the model after quantize by checking
|
|
# number of aten::quantize_per_tensor functions
|
|
# conv has 3 quantize_per_tensor for activations and 1 for weight
|
|
# and for N general value op between conv we should have
|
|
|
|
# N + 1 quantize_per_tensor between these ops
|
|
m1 = convert_jit(m, debug=True)
|
|
# NB: This Needs to be updated when we add more ops to test
|
|
# mapping from number of quant for the op to the number of these ops
|
|
# for example, for `3` in the key means for this type of op
|
|
# we'll have 3 quantize_per_tensor
|
|
num_op_by_num_quant = {1: 32, 2: 2, 3: 3}
|
|
num_quantize_per_tensor = 1 # for output
|
|
for num_quant, num_op in num_op_by_num_quant.items():
|
|
num_quantize_per_tensor += num_op * num_quant
|
|
FileCheck().check_count("aten::quantize_per_tensor(", num_quantize_per_tensor, exactly=True) \
|
|
.run(m1.graph)
|
|
|
|
# This checks that the dequantize from the output of first conv
|
|
# is being propagated to the end, so that we don't insert extra
|
|
# observers and also successfully fused two quantized::conv2d
|
|
# patterns
|
|
# one quantize_per_tensor for input
|
|
m2 = convert_jit(m, debug=False)
|
|
FileCheck().check_count("aten::quantize_per_tensor(", 1, exactly=True) \
|
|
.run(m2.graph)
|
|
FileCheck().check_count("quantized::conv2d(", 2, exactly=True) \
|
|
.check("aten::dequantize(") \
|
|
.run(m2.graph)
|
|
|
|
class TestQuantizeDynamicJitPasses(QuantizationTestCase):
|
|
def test_prepare_dynamic(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.fc = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
return self.fc(x)
|
|
|
|
m = torch.jit.script(M())
|
|
m = prepare_dynamic_jit(m, {'': default_dynamic_qconfig})
|
|
# for input of FC for dynamic quant
|
|
assert len(attrs_with_prefix(m, '_observer_')) == 1
|
|
# for weight
|
|
assert len(attrs_with_prefix(m.fc, '_observer_')) == 1
|
|
FileCheck().check('DynamicQuantObserver = prim::GetAttr[name="_observer_') \
|
|
.check('prim::GetAttr[name="fc"]') \
|
|
.check('prim::CallMethod') \
|
|
.check_not('Observer = prim::GetAttr[name="_observer_') \
|
|
.run(m.graph)
|
|
|
|
|
|
def test_prepare_dynamic_child_qconfig(self):
|
|
class Sub(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Sub, self).__init__()
|
|
self.fc = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
return self.fc(x)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 5, 3)
|
|
self.sub = Sub()
|
|
|
|
def forward(self, x):
|
|
return self.sub(self.conv(x))
|
|
|
|
m = torch.jit.script(M())
|
|
# only quantize child module.
|
|
m = prepare_dynamic_jit(m, {'sub.fc': default_dynamic_qconfig})
|
|
|
|
# input of sub for dynamic quant
|
|
assert len(attrs_with_prefix(m, '_observer_')) == 1
|
|
# not quantized
|
|
assert len(attrs_with_prefix(m.conv, '_observer_')) == 0
|
|
# no observers since we observe in the outer most call site
|
|
assert len(attrs_with_prefix(m.sub, '_observer_')) == 0
|
|
# weight of linear
|
|
assert len(attrs_with_prefix(m.sub.fc, '_observer_')) == 1
|
|
FileCheck().check('prim::GetAttr[name="sub') \
|
|
.check('prim::CallMethod') \
|
|
.check('DynamicQuantObserver = prim::GetAttr[name="_observer_') \
|
|
.check('prim::CallMethod') \
|
|
.check_not('Observer = prim::GetAttr[name="_observer_') \
|
|
.run(m.graph)
|
|
|
|
def test_insert_quant_dequant_linear_dynamic(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.fc1 = torch.nn.Linear(5, 5).float()
|
|
self.fc2 = torch.nn.Linear(5, 5).float()
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
return self.fc2(x)
|
|
for is_per_channel in [True, False]:
|
|
m = torch.jit.script(M())
|
|
qconfig = per_channel_dynamic_qconfig if is_per_channel is True else default_dynamic_qconfig
|
|
m = quantize_dynamic_jit(m, {'': qconfig}, debug=True)
|
|
assert len(m._modules._c.items()) == 2, \
|
|
'Expected to have two submodule of linear'
|
|
|
|
wt_quant_func = "aten::quantize_per_channel" if is_per_channel \
|
|
else "aten::quantize_per_tensor"
|
|
act_quant_func = "aten::quantize_per_tensor"
|
|
# quantizing activations
|
|
FileCheck().check("aten::_choose_qparams_per_tensor") \
|
|
.check_next(act_quant_func) \
|
|
.check_next("aten::dequantize") \
|
|
.check("aten::_choose_qparams_per_tensor") \
|
|
.check_next(act_quant_func) \
|
|
.check_next("aten::dequantize") \
|
|
.check(wt_quant_func) \
|
|
.check_next("aten::dequantize") \
|
|
.check_not(wt_quant_func) \
|
|
.check("return") \
|
|
.run(m.graph)
|
|
|
|
@override_qengines
|
|
def test_dynamic_multi_op(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
|
|
|
|
def forward(self, x):
|
|
x = x + 5
|
|
return self.fc1(x)
|
|
|
|
x = torch.randn(5, 5)
|
|
for tracing in [True, False]:
|
|
model = self.checkGraphModeOp(M(), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True)
|
|
# add op is not dynamically quantized.
|
|
FileCheck().check("aten::add") \
|
|
.run(model.graph)
|
|
|
|
@override_qengines
|
|
def test_dynamic_quant_multi_uses(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.fc = torch.nn.Linear(5, 5).float()
|
|
|
|
def forward(self, x):
|
|
size1 = x.size()
|
|
size2 = x.size()
|
|
return self.fc(x), size1, size2
|
|
|
|
x = torch.randn(5, 5)
|
|
for tracing in [True, False]:
|
|
model = self.checkGraphModeOp(M(), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True)
|
|
FileCheck().check_not("aten::_choose_qparams_per_tensor") \
|
|
.run(model.graph)
|
|
|
|
@override_qengines
|
|
def test_dynamic_shared_weights(self):
|
|
class myMod(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.linear = nn.Linear(5, 5)
|
|
self.linear.weight = weight
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class DynamicModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(DynamicModel, self).__init__()
|
|
self.weight = torch.nn.Parameter(torch.ones(5, 5))
|
|
self.mod1 = myMod(self.weight)
|
|
|
|
def forward(self, x):
|
|
y = self.mod1(x)
|
|
z = torch.nn.functional.linear(y, self.weight)
|
|
return z
|
|
|
|
model = torch.jit.script(DynamicModel()).eval()
|
|
data = torch.randn(5, 5, dtype=torch.float)
|
|
quant_ops = ['mod1', '']
|
|
counts = [1, 2]
|
|
for op, count in zip(quant_ops, counts):
|
|
qconfig_dict = {op: default_dynamic_qconfig}
|
|
m1 = quantize_dynamic_jit(model, qconfig_dict)
|
|
out_graph = m1(data)
|
|
|
|
FileCheck().check_count("quantized::linear_dynamic(", count, exactly=True) \
|
|
.check_not("aten::_choose_qparams_per_tensor") \
|
|
.run(m1.graph)
|
|
|
|
# Explicitly call forward on model before convert
|
|
m2 = prepare_dynamic_jit(model, qconfig_dict)
|
|
m2(data)
|
|
m2 = convert_dynamic_jit(m2, debug=False)
|
|
out_ref = m2(data)
|
|
self.assertEqual(out_graph, out_ref)
|
|
|
|
@override_qengines
|
|
def test_dynamic_with_if(self):
|
|
class Res(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Res, self).__init__()
|
|
self.weight = torch.nn.Parameter(torch.ones(5, 5))
|
|
|
|
def forward(self, x, cond):
|
|
# type: (Tensor, bool) -> Tensor
|
|
if cond:
|
|
return torch.nn.functional.linear(x, self.weight)
|
|
else:
|
|
return torch.nn.functional.linear(x, self.weight)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.res1 = Res()
|
|
self.res2 = Res()
|
|
|
|
def forward(self, x):
|
|
x = self.res1(x, True)
|
|
x = self.res2(x, False)
|
|
return x
|
|
|
|
model = torch.jit.script(M()).eval()
|
|
data = torch.randn(5, 5, dtype=torch.float)
|
|
qconfig_dict = {'': default_dynamic_qconfig}
|
|
for tracing in [True, False]:
|
|
m1 = self.checkGraphModeOp(M(), data, "quantized::linear_dynamic", tracing=tracing, dynamic=True)
|
|
FileCheck().check_count("quantized::linear_dynamic(", 2, exactly=True) \
|
|
.check_not("aten::_choose_qparams_per_tensor") \
|
|
.run(m1.graph)
|
|
|
|
# Check to make sure weight observers run correctly
|
|
ref_qparams = []
|
|
qconfig = script_qconfig(default_dynamic_qconfig)
|
|
wt_module = wrap_cpp_module(qconfig.weight)
|
|
for wt in [model.res1.weight, model.res2.weight]:
|
|
wt_module(wt)
|
|
qparams = wt_module.calculate_qparams()
|
|
ref_qparams.append((qparams[0].item(), qparams[1].item()))
|
|
|
|
m2 = quantize_dynamic_jit(model, qconfig_dict, debug=True)
|
|
graph_params = []
|
|
for x, obs in m2._modules._c.items():
|
|
if x == 'res1':
|
|
graph_params.append((obs.getattr('6_scale_0'), obs.getattr('6_zero_point_0')))
|
|
elif x == 'res2':
|
|
graph_params.append((obs.getattr('10_scale_0'), obs.getattr('10_zero_point_0')))
|
|
self.assertEqual(ref_qparams, graph_params)
|
|
|
|
def test_dynamic_weight_observer(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.fc = torch.nn.Linear(5, 5).float()
|
|
self.fc2 = torch.nn.Linear(5, 5).float()
|
|
|
|
def forward(self, x):
|
|
x = self.fc(x)
|
|
return self.fc2(x)
|
|
|
|
qconfig_dict = {'': default_dynamic_qconfig}
|
|
eager_model = M().eval()
|
|
x = torch.rand(5, 5)
|
|
for tracing in [True, False]:
|
|
model = get_script_module(eager_model, tracing, x)
|
|
qconfig = script_qconfig(default_dynamic_qconfig)
|
|
ref_qparams = []
|
|
wt_module = wrap_cpp_module(qconfig.weight)
|
|
for wt in [model.fc.weight, model.fc2.weight]:
|
|
wt_module(wt)
|
|
qparams = wt_module.calculate_qparams()
|
|
ref_qparams.append((qparams[0].item(), qparams[1].item()))
|
|
model = quantize_dynamic_jit(model, qconfig_dict, debug=True)
|
|
graph_params = []
|
|
for x, obs in model._modules._c.items():
|
|
if tracing:
|
|
graph_params.append((obs.getattr('4_scale_0'), obs.getattr('4_zero_point_0')))
|
|
else:
|
|
graph_params.append((obs.getattr('3_scale_0'), obs.getattr('3_zero_point_0')))
|
|
self.assertEqual(ref_qparams, graph_params)
|
|
|
|
class TestQuantizeDynamicJitOps(QuantizationTestCase):
|
|
""" Test graph mode post training dynamic quantization works
|
|
for individual ops end to end.
|
|
"""
|
|
@override_qengines
|
|
def test_quantized_linear_dynamic(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.fc = torch.nn.Linear(5, 5).float()
|
|
|
|
def forward(self, x):
|
|
return self.fc(x)
|
|
|
|
x = torch.rand(5, 5)
|
|
for tracing in [True, False]:
|
|
model = self.checkGraphModeOp(M(), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True)
|
|
|
|
class TestQuantizeJitJit(JitTestCase):
|
|
def _test_lower_graph_impl(self, model, data):
|
|
model.qconfig = torch.quantization.default_qconfig
|
|
model = torch.quantization.prepare(model)
|
|
model = torch.quantization.convert(model)
|
|
|
|
outputs = model(data)
|
|
input_names = ["x"]
|
|
|
|
def export_to_onnx(model, input, input_names):
|
|
outputs = model(input)
|
|
|
|
traced = torch.jit.trace(model, input)
|
|
buf = io.BytesIO()
|
|
torch.jit.save(traced, buf)
|
|
buf.seek(0)
|
|
|
|
model = torch.jit.load(buf)
|
|
f = io.BytesIO()
|
|
torch.onnx.export(model, input, f, input_names=input_names, example_outputs=outputs,
|
|
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
|
|
onnx_model = export_to_onnx(model, data, input_names)
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
|
|
' with instruction set support avx2 or newer.')
|
|
def test_lower_graph_linear(self):
|
|
model = torch.quantization.QuantWrapper(torch.nn.Linear(5, 10, bias=True)).to(dtype=torch.float)
|
|
data_numpy = np.random.rand(1, 2, 5).astype(np.float32)
|
|
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
|
|
self._test_lower_graph_impl(model, data)
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
|
|
' with instruction set support avx2 or newer.')
|
|
def test_lower_graph_conv2d(self):
|
|
model = torch.quantization.QuantWrapper(torch.nn.Conv2d(3, 5, 2, bias=True)).to(dtype=torch.float)
|
|
data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32)
|
|
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
|
|
self._test_lower_graph_impl(model, data)
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
|
|
' with instruction set support avx2 or newer.')
|
|
@unittest.skip("onnx opset9 does not support quantize_per_tensor and caffe2 \
|
|
does not support conv3d")
|
|
def test_lower_graph_conv3d(self):
|
|
model = torch.quantization.QuantWrapper(torch.nn.Conv3d(3, 5, 2, bias=True)).to(dtype=torch.float)
|
|
data_numpy = np.random.rand(1, 3, 6, 6, 6).astype(np.float32)
|
|
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
|
|
self._test_lower_graph_impl(model, data)
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
|
|
' with instruction set support avx2 or newer.')
|
|
def test_rnn_cell_quantized(self):
|
|
d_in, d_hid = 2, 2
|
|
|
|
for cell in [
|
|
torch.nn.LSTMCell(d_in, d_hid).float(),
|
|
torch.nn.GRUCell(d_in, d_hid).float(),
|
|
torch.nn.RNNCell(d_in, d_hid).float(),
|
|
]:
|
|
if isinstance(cell, torch.nn.LSTMCell):
|
|
num_chunks = 4
|
|
elif isinstance(cell, torch.nn.GRUCell):
|
|
num_chunks = 3
|
|
elif isinstance(cell, torch.nn.RNNCell):
|
|
num_chunks = 1
|
|
|
|
# Replace parameter values s.t. the range of values is exactly
|
|
# 255, thus we will have 0 quantization error in the quantized
|
|
# GEMM call. This i s for testing purposes.
|
|
#
|
|
# Note that the current implementation does not support
|
|
# accumulation values outside of the range representable by a
|
|
# 16 bit integer, instead resulting in a saturated value. We
|
|
# must take care that in our test we do not end up with a dot
|
|
# product that overflows the int16 range, e.g.
|
|
# (255*127+255*127) = 64770. So, we hardcode the test values
|
|
# here and ensure a mix of signedness.
|
|
vals = [[100, -155],
|
|
[100, -155],
|
|
[-155, 100],
|
|
[-155, 100],
|
|
[100, -155],
|
|
[-155, 100],
|
|
[-155, 100],
|
|
[100, -155]]
|
|
vals = vals[:d_hid * num_chunks]
|
|
cell.weight_ih = torch.nn.Parameter(
|
|
torch.tensor(vals, dtype=torch.float),
|
|
requires_grad=False)
|
|
cell.weight_hh = torch.nn.Parameter(
|
|
torch.tensor(vals, dtype=torch.float),
|
|
requires_grad=False)
|
|
|
|
ref = copy.deepcopy(cell)
|
|
|
|
cell = torch.jit.quantized.quantize_rnn_cell_modules(cell)
|
|
x = torch.tensor([[100, -155],
|
|
[-155, 100],
|
|
[100, -155]], dtype=torch.float)
|
|
h0_vals = [[-155, 100],
|
|
[-155, 155],
|
|
[100, -155]]
|
|
hx = torch.tensor(h0_vals, dtype=torch.float)
|
|
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
|
|
cx = torch.tensor(h0_vals, dtype=torch.float)
|
|
hiddens = (hx, cx)
|
|
else:
|
|
hiddens = hx
|
|
|
|
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
|
|
class ScriptWrapper(torch.jit.ScriptModule):
|
|
def __init__(self, cell):
|
|
super(ScriptWrapper, self).__init__()
|
|
self.cell = cell
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x, hiddens):
|
|
# type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
|
|
return self.cell(x, hiddens)
|
|
else:
|
|
|
|
class ScriptWrapper(torch.jit.ScriptModule):
|
|
def __init__(self, cell):
|
|
super(ScriptWrapper, self).__init__()
|
|
self.cell = cell
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x, hiddens):
|
|
# type: (torch.Tensor, torch.Tensor) -> torch.Tensor
|
|
return self.cell(x, hiddens)
|
|
|
|
cell = ScriptWrapper(cell)
|
|
outs = cell(x, hiddens)
|
|
cell = self.getExportImportCopyWithPacking(cell)
|
|
|
|
outs = cell(x, hiddens)
|
|
ref_outs = ref(x, hiddens)
|
|
|
|
self.assertEqual(len(outs), len(ref_outs))
|
|
for out, ref_out in zip(outs, ref_outs):
|
|
torch.testing.assert_allclose(out, ref_out)
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
|
'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
|
|
' with instruction set support avx2 or newer.')
|
|
def test_rnn_quantized(self):
|
|
d_in, d_hid = 2, 2
|
|
|
|
for cell in [
|
|
torch.nn.LSTM(d_in, d_hid).float(),
|
|
torch.nn.GRU(d_in, d_hid).float(),
|
|
]:
|
|
|
|
# Replace parameter values s.t. the range of values is exactly
|
|
# 255, thus we will have 0 quantization error in the quantized
|
|
# GEMM call. This i s for testing purposes.
|
|
#
|
|
# Note that the current implementation does not support
|
|
# accumulation values outside of the range representable by a
|
|
# 16 bit integer, instead resulting in a saturated value. We
|
|
# must take care that in our test we do not end up with a dot
|
|
# product that overflows the int16 range, e.g.
|
|
# (255*127+255*127) = 64770. So, we hardcode the test values
|
|
# here and ensure a mix of signedness.
|
|
vals = [[100, -155],
|
|
[100, -155],
|
|
[-155, 100],
|
|
[-155, 100],
|
|
[100, -155],
|
|
[-155, 100],
|
|
[-155, 100],
|
|
[100, -155]]
|
|
if isinstance(cell, torch.nn.LSTM):
|
|
num_chunks = 4
|
|
elif isinstance(cell, torch.nn.GRU):
|
|
num_chunks = 3
|
|
vals = vals[:d_hid * num_chunks]
|
|
cell.weight_ih_l0 = torch.nn.Parameter(
|
|
torch.tensor(vals, dtype=torch.float),
|
|
requires_grad=False)
|
|
cell.weight_hh_l0 = torch.nn.Parameter(
|
|
torch.tensor(vals, dtype=torch.float),
|
|
requires_grad=False)
|
|
|
|
ref = copy.deepcopy(cell)
|
|
cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8)
|
|
cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16)
|
|
|
|
niter = 10
|
|
x = torch.tensor([[100, -155],
|
|
[-155, 100],
|
|
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
|
|
h0_vals = [[-155, 100],
|
|
[-155, 155],
|
|
[100, -155]]
|
|
hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
|
|
cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
|
|
|
|
if isinstance(ref, torch.nn.LSTM):
|
|
hiddens = (hx, cx)
|
|
elif isinstance(ref, torch.nn.GRU):
|
|
hiddens = hx
|
|
|
|
ref_out, ref_hid = ref(x, hiddens)
|
|
|
|
# Compare int8 quantized to unquantized
|
|
output_int8, final_hiddens_int8 = cell_int8(x, hiddens)
|
|
|
|
torch.testing.assert_allclose(output_int8, ref_out)
|
|
for out, ref in zip(final_hiddens_int8, ref_hid):
|
|
torch.testing.assert_allclose(out, ref)
|
|
|
|
# Compare fp16 quantized to unquantized
|
|
output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens)
|
|
|
|
torch.testing.assert_allclose(output_fp16, ref_out)
|
|
for out, ref in zip(final_hiddens_fp16, ref_hid):
|
|
torch.testing.assert_allclose(out, ref)
|
|
|
|
def compare_quantized_unquantized(ScriptWrapper, cell):
|
|
wrapper = ScriptWrapper(cell)
|
|
|
|
# Compare quantize scripted module to unquantized
|
|
script_out, script_hid = wrapper(x, hiddens)
|
|
torch.testing.assert_allclose(script_out, ref_out)
|
|
for out, ref in zip(script_hid, ref_hid):
|
|
torch.testing.assert_allclose(out, ref)
|
|
|
|
# Compare export/import to unquantized
|
|
export_import_wrapper = self.getExportImportCopyWithPacking(wrapper)
|
|
ei_out, ei_hid = export_import_wrapper(x, hiddens)
|
|
torch.testing.assert_allclose(ei_out, ref_out)
|
|
for out, ref in zip(ei_hid, ref_hid):
|
|
torch.testing.assert_allclose(out, ref)
|
|
|
|
if isinstance(cell, torch.jit.quantized.QuantizedGRU):
|
|
class ScriptWrapper(torch.jit.ScriptModule):
|
|
def __init__(self, cell):
|
|
super(ScriptWrapper, self).__init__()
|
|
self.cell = cell
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x, hiddens):
|
|
# type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
|
|
return self.cell(x, hiddens)
|
|
|
|
compare_quantized_unquantized(ScriptWrapper, cell)
|
|
elif isinstance(cell, torch.jit.quantized.QuantizedLSTM):
|
|
for cell in [cell_int8, cell_fp16]:
|
|
class ScriptWrapper(torch.jit.ScriptModule):
|
|
def __init__(self, cell):
|
|
super(ScriptWrapper, self).__init__()
|
|
self.cell = cell
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x, hiddens):
|
|
# type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor])
|
|
# -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
return self.cell(x, hiddens)
|
|
compare_quantized_unquantized(ScriptWrapper, cell)
|
|
|
|
if 'fbgemm' in torch.backends.quantized.supported_engines:
|
|
# Suppression: using deprecated quant api
|
|
@suppress_warnings
|
|
def test_quantization_modules(self):
|
|
K1, N1 = 2, 2
|
|
|
|
class FooBar(torch.nn.Module):
|
|
def __init__(self):
|
|
super(FooBar, self).__init__()
|
|
self.linear1 = torch.nn.Linear(K1, N1).float()
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
return x
|
|
|
|
fb = FooBar()
|
|
fb.linear1.weight = torch.nn.Parameter(
|
|
torch.tensor([[-150, 100], [100, -150]], dtype=torch.float), requires_grad=False)
|
|
fb.linear1.bias = torch.nn.Parameter(torch.zeros_like(fb.linear1.bias), requires_grad=False)
|
|
|
|
x = (torch.rand(1, K1).float() - 0.5) / 10.0
|
|
value = torch.tensor([[100, -150]], dtype=torch.float)
|
|
|
|
y_ref = fb(value)
|
|
|
|
fb_int8 = torch.jit.quantized.quantize_linear_modules(fb)
|
|
traced_int8 = torch.jit.trace(fb_int8, (x,))
|
|
fb_int8 = self.getExportImportCopyWithPacking(traced_int8)
|
|
y_int8 = fb_int8(value)
|
|
|
|
fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16)
|
|
traced_fp16 = torch.jit.trace(fb_fp16, (x,))
|
|
fb_fp16 = self.getExportImportCopyWithPacking(traced_fp16)
|
|
y_fp16 = fb_fp16(value)
|
|
|
|
torch.testing.assert_allclose(y_int8, y_ref, rtol=0.0001, atol=1e-3)
|
|
torch.testing.assert_allclose(y_fp16, y_ref, rtol=0.0001, atol=1e-3)
|
|
|
|
def _test_pickle_checkpoint_qtensor(self, device):
|
|
with TemporaryFileName() as fname:
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['fname']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.fname = fname
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
torch.save((x, y), self.fname)
|
|
return y
|
|
|
|
q = torch.quantize_per_tensor(
|
|
torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device)
|
|
qc = torch.quantize_per_channel(
|
|
torch.rand(2, 3, dtype=torch.float),
|
|
scales=torch.tensor([0.1, 0.5, 0.01]),
|
|
zero_points=torch.tensor([10, 0, 20]),
|
|
axis=1, dtype=torch.quint8).to(device)
|
|
m = M()
|
|
m(q, qc)
|
|
with open(fname, "rb") as handle:
|
|
loaded_q, loaded_qc = torch.load(fname)
|
|
self.assertEqual(loaded_q, q)
|
|
self.assertEqual(loaded_qc, qc)
|
|
|
|
def test_pickle_checkpoint_qtensor(self):
|
|
self._test_pickle_checkpoint_qtensor('cpu')
|
|
|
|
def test_serialize_qtensor(self):
|
|
class SimpleQTensor(torch.jit.ScriptModule):
|
|
def __init__(self, per_channel):
|
|
super(SimpleQTensor, self).__init__()
|
|
x = torch.rand(5, 5).float()
|
|
if not per_channel:
|
|
x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8)
|
|
else:
|
|
s = torch.rand(5, dtype=torch.float64) + 0.1
|
|
zp = torch.randint(5, 15, (5,))
|
|
x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8)
|
|
self.register_buffer('x', x_q)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
return self.x
|
|
|
|
for per_channel in [False, True]:
|
|
model = SimpleQTensor(per_channel)
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(model, buffer)
|
|
buffer.seek(0)
|
|
model_loaded = torch.jit.load(buffer)
|
|
self.assertEqual(model_loaded(), model())
|
|
|
|
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, "requires FBGEMM")
|
|
def test_erase_class_tensor_shapes(self):
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self, in_features, out_features):
|
|
super(Linear, self).__init__()
|
|
qweight = torch._empty_affine_quantized(
|
|
[out_features, in_features], scale=1, zero_point=0,
|
|
dtype=torch.qint8)
|
|
self._packed_weight = torch.ops.quantized.linear_prepack(qweight)
|
|
|
|
@torch.jit.export
|
|
def __getstate__(self):
|
|
return (torch.ops.quantized.linear_unpack(self._packed_weight)[0], self.training)
|
|
|
|
def forward(self):
|
|
return self._packed_weight
|
|
|
|
@torch.jit.export
|
|
def __setstate__(self, state):
|
|
self._packed_weight = torch.ops.quantized.linear_prepack(state[0])
|
|
self.training = state[1]
|
|
|
|
@property
|
|
def weight(self):
|
|
return torch.ops.quantized.linear_unpack(self._packed_weight)[0]
|
|
|
|
@weight.setter
|
|
def weight(self, w):
|
|
self._packed_weight = torch.ops.quantized.linear_prepack(w)
|
|
|
|
with torch.jit._disable_emit_hooks():
|
|
x = torch.jit.script(Linear(10, 10))
|
|
torch._C._jit_pass_erase_shape_information(x.graph)
|
|
|
|
|
|
class TestQuantizeJit(QuantizationTestCase):
|
|
@override_qengines
|
|
def test_single_linear(self):
|
|
r"""Compare the result of quantizing single linear layer in
|
|
eager mode and graph mode
|
|
"""
|
|
# eager mode
|
|
annotated_linear_model = AnnotatedSingleLayerLinearModel(torch.backends.quantized.engine).eval()
|
|
linear_model = SingleLayerLinearModel().eval()
|
|
# copy the weight from eager mode so that we can
|
|
# compare the result of the two quantized models later
|
|
linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach())
|
|
linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach())
|
|
model_eager = quantize(annotated_linear_model, test_only_eval_fn, self.calib_data)
|
|
|
|
qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
|
|
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
|
|
model_script = torch.jit.script(linear_model)
|
|
result_eager = model_eager(self.calib_data[0][0])
|
|
for model_under_test in [model_traced, model_script]:
|
|
model_quantized = quantize_jit(
|
|
model_under_test,
|
|
qconfig_dict,
|
|
test_only_eval_fn,
|
|
[self.calib_data],
|
|
inplace=False)
|
|
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_observer_with_ignored_function(self):
|
|
r"""Test observers with ignored function and make sure it works in
|
|
graph mode
|
|
"""
|
|
# eager mode
|
|
annotated_linear_model = AnnotatedSingleLayerLinearModel('fbgemm').eval()
|
|
for qconfig in [
|
|
QConfig(
|
|
activation=default_observer,
|
|
weight=default_weight_observer),
|
|
QConfig(
|
|
activation=default_histogram_observer,
|
|
weight=default_weight_observer),
|
|
QConfig(
|
|
activation=default_observer,
|
|
weight=default_per_channel_weight_observer),
|
|
]:
|
|
annotated_linear_model.qconfig = qconfig
|
|
linear_model = SingleLayerLinearModel().eval()
|
|
# copy the weight from eager mode so that we can
|
|
# compare the result of the two quantized models later
|
|
linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach())
|
|
linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach())
|
|
model_eager = quantize(annotated_linear_model, test_only_eval_fn,
|
|
self.calib_data)
|
|
|
|
qconfig_dict = {'': qconfig}
|
|
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
|
|
model_script = torch.jit.script(linear_model)
|
|
result_eager = model_eager(self.calib_data[0][0])
|
|
for model_under_test in [model_traced, model_script]:
|
|
model_quantized = quantize_jit(
|
|
model_under_test,
|
|
qconfig_dict,
|
|
test_only_eval_fn,
|
|
[self.calib_data],
|
|
inplace=False)
|
|
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
|
|
|
|
@override_qengines
|
|
def test_conv(self):
|
|
r"""Compare the result of quantizing conv layer in
|
|
eager mode and graph mode
|
|
"""
|
|
# eager mode
|
|
annotated_conv_model = AnnotatedConvModel(torch.backends.quantized.engine).eval()
|
|
conv_model = ConvModel().eval()
|
|
# copy the weight from eager mode so that we can
|
|
# compare the result of the two quantized models later
|
|
conv_model.conv.weight = torch.nn.Parameter(annotated_conv_model.conv.weight.detach())
|
|
model_eager = quantize(annotated_conv_model, default_eval_fn, self.img_data)
|
|
qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
|
|
model_traced = torch.jit.trace(conv_model, self.img_data[0][0])
|
|
model_script = torch.jit.script(conv_model)
|
|
result_eager = model_eager(self.img_data[0][0])
|
|
for model_under_test in [model_traced, model_script]:
|
|
model_quantized = quantize_jit(
|
|
model_under_test,
|
|
qconfig_dict,
|
|
default_eval_fn,
|
|
[self.img_data],
|
|
inplace=False)
|
|
self.assertEqual(model_quantized(self.img_data[0][0]), result_eager)
|
|
|
|
@override_qengines
|
|
def test_conv_bn(self):
|
|
r"""Compare the result of quantizing conv + bn layer in
|
|
eager mode and graph mode
|
|
"""
|
|
# eager mode
|
|
conv_model = AnnotatedConvBnModel().eval()
|
|
conv_model_to_script = ConvBnModel().eval()
|
|
# copy the weight from eager mode so that we can
|
|
# compare the result of the two quantized models later
|
|
conv_model_to_script.conv.weight = torch.nn.Parameter(conv_model.conv.weight.detach())
|
|
fuse_modules(conv_model, ['conv', 'bn'], inplace=True)
|
|
model_eager = quantize(conv_model, default_eval_fn,
|
|
self.img_data)
|
|
qconfig_dict = {
|
|
'': default_qconfig
|
|
}
|
|
model_script = quantize_jit(
|
|
torch.jit.script(conv_model_to_script),
|
|
qconfig_dict,
|
|
default_eval_fn,
|
|
[self.img_data],
|
|
inplace=False)
|
|
result_eager = model_eager(self.img_data[0][0])
|
|
result_script = model_script(self.img_data[0][0])
|
|
self.assertEqual(result_eager, result_script)
|
|
|
|
@override_qengines
|
|
def test_nested(self):
|
|
# Eager mode
|
|
eager_model = AnnotatedNestedModel(torch.backends.quantized.engine).eval()
|
|
|
|
# Graph mode
|
|
script_model = NestedModel().eval()
|
|
# Copy weights for eager_model
|
|
script_model.sub1.fc.weight = torch.nn.Parameter(eager_model.sub1.fc.weight.detach())
|
|
script_model.sub1.fc.bias = torch.nn.Parameter(eager_model.sub1.fc.bias.detach())
|
|
script_model.sub2.fc1.weight = torch.nn.Parameter(eager_model.sub2.fc1.module.weight.detach())
|
|
script_model.sub2.fc1.bias = torch.nn.Parameter(eager_model.sub2.fc1.module.bias.detach())
|
|
script_model.sub2.fc2.weight = torch.nn.Parameter(eager_model.sub2.fc2.weight.detach())
|
|
script_model.sub2.fc2.bias = torch.nn.Parameter(eager_model.sub2.fc2.bias.detach())
|
|
script_model.fc3.weight = torch.nn.Parameter(eager_model.fc3.module.weight.detach())
|
|
script_model.fc3.bias = torch.nn.Parameter(eager_model.fc3.module.bias.detach())
|
|
|
|
model_eager = quantize(eager_model, test_only_eval_fn, self.calib_data)
|
|
qconfig_dict = {
|
|
'sub2.fc1': default_per_channel_qconfig if qengine_is_fbgemm() else default_qconfig,
|
|
'fc3': default_qconfig
|
|
}
|
|
model_traced = torch.jit.trace(script_model, self.calib_data[0][0])
|
|
model_script = torch.jit.script(script_model)
|
|
result_eager = model_eager(self.calib_data[0][0])
|
|
for model_under_test in [model_traced, model_script]:
|
|
model_quantized = quantize_jit(
|
|
model_under_test,
|
|
qconfig_dict,
|
|
test_only_eval_fn,
|
|
[self.calib_data],
|
|
inplace=False)
|
|
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
|
|
|
|
@override_qengines
|
|
def test_skip_quant(self):
|
|
""" Test None qconfig
|
|
"""
|
|
# Eager mode
|
|
eager_model = AnnotatedSkipQuantModel(torch.backends.quantized.engine).eval()
|
|
|
|
# Graph mode
|
|
script_model = SkipQuantModel().eval()
|
|
# Copy weights for eager_model
|
|
script_model.sub.fc1.weight = torch.nn.Parameter(eager_model.sub.module.fc1.weight.detach())
|
|
script_model.sub.fc1.bias = torch.nn.Parameter(eager_model.sub.module.fc1.bias.detach())
|
|
script_model.sub.fc2.weight = torch.nn.Parameter(eager_model.sub.module.fc2.weight.detach())
|
|
script_model.sub.fc2.bias = torch.nn.Parameter(eager_model.sub.module.fc2.bias.detach())
|
|
script_model.fc.weight = torch.nn.Parameter(eager_model.fc.weight.detach())
|
|
script_model.fc.bias = torch.nn.Parameter(eager_model.fc.bias.detach())
|
|
|
|
eager_model.fuse_modules()
|
|
|
|
model_eager = quantize(eager_model, test_only_eval_fn, self.calib_data)
|
|
qconfig_dict = {
|
|
'': get_default_qconfig(torch.backends.quantized.engine),
|
|
'fc': None
|
|
}
|
|
model_traced = torch.jit.trace(script_model, self.calib_data[0][0])
|
|
model_script = torch.jit.script(script_model)
|
|
result_eager = model_eager(self.calib_data[0][0])
|
|
for model_under_test in [model_traced, model_script]:
|
|
model_quantized = quantize_jit(
|
|
model_under_test,
|
|
qconfig_dict,
|
|
test_only_eval_fn,
|
|
[self.calib_data],
|
|
inplace=False)
|
|
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
|
|
|
|
@override_qengines
|
|
def test_single_linear_dynamic(self):
|
|
r"""Compare the result of dynamic quantization of single linear layer in
|
|
eager mode and graph mode.
|
|
"""
|
|
if qengine_is_qnnpack():
|
|
# eager mode
|
|
annotated_linear_model = AnnotatedSingleLayerLinearModel('qnnpack').eval()
|
|
linear_model = SingleLayerLinearModel().eval()
|
|
# copy the weight from eager mode so that we can
|
|
# compare the result of the two quantized models later
|
|
linear_model.fc1.weight = torch.nn.Parameter(annotated_linear_model.fc1.module.weight.detach())
|
|
linear_model.fc1.bias = torch.nn.Parameter(annotated_linear_model.fc1.module.bias.detach())
|
|
qconfig_dict = {'': default_dynamic_qconfig}
|
|
model_eager = quantize_dynamic(annotated_linear_model, qconfig_dict)
|
|
|
|
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
|
|
model_script = torch.jit.script(linear_model)
|
|
result_eager = model_eager(self.calib_data[0][0])
|
|
|
|
for model_under_test in [model_traced, model_script]:
|
|
model_quantized = quantize_dynamic_jit(
|
|
model_under_test,
|
|
qconfig_dict)
|
|
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
|
|
|
|
# Check to make sure choose_qparams->quant->dequant->linear is numerically
|
|
# equivalent to the final quantized model.
|
|
model_fake_quantized = quantize_dynamic_jit(
|
|
model_under_test,
|
|
qconfig_dict,
|
|
debug=True)
|
|
self.assertEqual(model_fake_quantized(self.calib_data[0][0]), result_eager)
|