mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Easy] enable PYFMT for torch/quantization/eager (#150761)
All modifications are done through tools, the detailed commands are as follows: ```bash lintrunner -a --take "PYFMT" --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150761 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
91b090c912
commit
8895c290f4
|
|
@ -1165,14 +1165,6 @@ exclude_patterns = [
|
||||||
'test/quantization/core/test_utils.py',
|
'test/quantization/core/test_utils.py',
|
||||||
'test/quantization/core/test_workflow_module.py',
|
'test/quantization/core/test_workflow_module.py',
|
||||||
'test/quantization/core/test_workflow_ops.py',
|
'test/quantization/core/test_workflow_ops.py',
|
||||||
'test/quantization/eager/__init__.py',
|
|
||||||
'test/quantization/eager/test_bias_correction_eager.py',
|
|
||||||
'test/quantization/eager/test_equalize_eager.py',
|
|
||||||
'test/quantization/eager/test_fuse_eager.py',
|
|
||||||
'test/quantization/eager/test_model_numerics.py',
|
|
||||||
'test/quantization/eager/test_numeric_suite_eager.py',
|
|
||||||
'test/quantization/eager/test_quantize_eager_ptq.py',
|
|
||||||
'test/quantization/eager/test_quantize_eager_qat.py',
|
|
||||||
'test/quantization/fx/__init__.py',
|
'test/quantization/fx/__init__.py',
|
||||||
'test/quantization/fx/test_equalize_fx.py',
|
'test/quantization/fx/test_equalize_fx.py',
|
||||||
'test/quantization/fx/test_model_report_fx.py',
|
'test/quantization/fx/test_model_report_fx.py',
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,23 @@
|
||||||
# Owner(s): ["oncall: quantization"]
|
# Owner(s): ["oncall: quantization"]
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
|
||||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
|
||||||
|
|
||||||
from torch.ao.quantization import default_qconfig
|
|
||||||
from torch.ao.quantization import QuantWrapper
|
|
||||||
import torch.ao.ns._numeric_suite as ns
|
import torch.ao.ns._numeric_suite as ns
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.ao.quantization import default_qconfig, QuantWrapper
|
||||||
from torch.ao.quantization._correct_bias import (
|
from torch.ao.quantization._correct_bias import (
|
||||||
_supported_modules,
|
_supported_modules,
|
||||||
_supported_modules_quantized,
|
_supported_modules_quantized,
|
||||||
bias_correction,
|
bias_correction,
|
||||||
get_module,
|
get_module,
|
||||||
get_param,
|
get_param,
|
||||||
parent_child_names
|
parent_child_names,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_quantization import (
|
||||||
|
QuantizationTestCase,
|
||||||
|
skipIfNoFBGEMM,
|
||||||
)
|
)
|
||||||
|
|
||||||
import copy
|
|
||||||
|
|
||||||
|
|
||||||
class TestBiasCorrectionEager(QuantizationTestCase):
|
class TestBiasCorrectionEager(QuantizationTestCase):
|
||||||
|
|
@ -28,9 +27,9 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
||||||
return 20 * torch.log10(Ps / Pn)
|
return 20 * torch.log10(Ps / Pn)
|
||||||
|
|
||||||
def correct_artificial_bias_quantize(self, float_model, img_data):
|
def correct_artificial_bias_quantize(self, float_model, img_data):
|
||||||
''' Adding artificial bias and testing if bias persists after bias
|
"""Adding artificial bias and testing if bias persists after bias
|
||||||
correction. This test case changes the bias of a quantized submodule
|
correction. This test case changes the bias of a quantized submodule
|
||||||
'''
|
"""
|
||||||
artificial_model = copy.deepcopy(float_model)
|
artificial_model = copy.deepcopy(float_model)
|
||||||
artificial_model.qconfig = default_qconfig
|
artificial_model.qconfig = default_qconfig
|
||||||
torch.ao.quantization.prepare(artificial_model, inplace=True)
|
torch.ao.quantization.prepare(artificial_model, inplace=True)
|
||||||
|
|
@ -41,12 +40,17 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
||||||
# manually changing bias
|
# manually changing bias
|
||||||
for name, submodule in artificial_model.named_modules():
|
for name, submodule in artificial_model.named_modules():
|
||||||
if type(submodule) in _supported_modules:
|
if type(submodule) in _supported_modules:
|
||||||
x = get_param(submodule, 'bias')
|
x = get_param(submodule, "bias")
|
||||||
weight = get_param(submodule, 'weight')
|
weight = get_param(submodule, "weight")
|
||||||
if x is not None:
|
if x is not None:
|
||||||
submodule.set_weight_bias(weight, x.data * 3)
|
submodule.set_weight_bias(weight, x.data * 3)
|
||||||
|
|
||||||
bias_correction(float_model, artificial_model, img_data, target_modules=_supported_modules_quantized)
|
bias_correction(
|
||||||
|
float_model,
|
||||||
|
artificial_model,
|
||||||
|
img_data,
|
||||||
|
target_modules=_supported_modules_quantized,
|
||||||
|
)
|
||||||
|
|
||||||
# Trims off the shadow module,
|
# Trims off the shadow module,
|
||||||
for name, submodule in artificial_model.named_modules():
|
for name, submodule in artificial_model.named_modules():
|
||||||
|
|
@ -58,11 +62,13 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
||||||
for name, artificial_submodule in artificial_model.named_modules():
|
for name, artificial_submodule in artificial_model.named_modules():
|
||||||
if type(artificial_submodule) in _supported_modules_quantized:
|
if type(artificial_submodule) in _supported_modules_quantized:
|
||||||
submodule = get_module(float_model, name)
|
submodule = get_module(float_model, name)
|
||||||
float_bias = get_param(submodule, 'bias')
|
float_bias = get_param(submodule, "bias")
|
||||||
artificial_bias = get_param(artificial_submodule, 'bias')
|
artificial_bias = get_param(artificial_submodule, "bias")
|
||||||
|
|
||||||
self.assertTrue(self.compute_sqnr(float_bias, artificial_bias) > 30,
|
self.assertTrue(
|
||||||
"Correcting quantized bias produced too much noise, sqnr score too low")
|
self.compute_sqnr(float_bias, artificial_bias) > 30,
|
||||||
|
"Correcting quantized bias produced too much noise, sqnr score too low",
|
||||||
|
)
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
def test_linear_chain(self):
|
def test_linear_chain(self):
|
||||||
|
|
@ -78,9 +84,15 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
||||||
x = self.linear2(x)
|
x = self.linear2(x)
|
||||||
x = self.linear3(x)
|
x = self.linear3(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
float_model = QuantWrapper(LinearChain())
|
float_model = QuantWrapper(LinearChain())
|
||||||
img_data = [(torch.rand(10, 3, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long))
|
img_data = [
|
||||||
for _ in range(50)]
|
(
|
||||||
|
torch.rand(10, 3, dtype=torch.float),
|
||||||
|
torch.randint(0, 1, (2,), dtype=torch.long),
|
||||||
|
)
|
||||||
|
for _ in range(50)
|
||||||
|
]
|
||||||
self.correct_artificial_bias_quantize(float_model, img_data)
|
self.correct_artificial_bias_quantize(float_model, img_data)
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
|
|
@ -97,7 +109,13 @@ class TestBiasCorrectionEager(QuantizationTestCase):
|
||||||
x = self.conv2d2(x)
|
x = self.conv2d2(x)
|
||||||
x = self.conv2d3(x)
|
x = self.conv2d3(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
float_model = QuantWrapper(ConvChain())
|
float_model = QuantWrapper(ConvChain())
|
||||||
img_data = [(torch.rand(10, 3, 125, 125, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long))
|
img_data = [
|
||||||
for _ in range(50)]
|
(
|
||||||
|
torch.rand(10, 3, 125, 125, dtype=torch.float),
|
||||||
|
torch.randint(0, 1, (2,), dtype=torch.long),
|
||||||
|
)
|
||||||
|
for _ in range(50)
|
||||||
|
]
|
||||||
self.correct_artificial_bias_quantize(float_model, img_data)
|
self.correct_artificial_bias_quantize(float_model, img_data)
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,19 @@
|
||||||
# Owner(s): ["oncall: quantization"]
|
# Owner(s): ["oncall: quantization"]
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
|
||||||
from torch.ao.quantization.fuse_modules import fuse_modules
|
|
||||||
|
|
||||||
import torch.ao.quantization._equalize as _equalize
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.ao.quantization._equalize as _equalize
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.ao.quantization.fuse_modules import fuse_modules
|
||||||
|
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||||
|
|
||||||
|
|
||||||
class TestEqualizeEager(QuantizationTestCase):
|
class TestEqualizeEager(QuantizationTestCase):
|
||||||
def checkChannelsEqualized(self, tensor1, tensor2, output_axis, input_axis):
|
def checkChannelsEqualized(self, tensor1, tensor2, output_axis, input_axis):
|
||||||
''' Checks the channel ranges of tensor1, tensor2 are the same,
|
"""Checks the channel ranges of tensor1, tensor2 are the same,
|
||||||
which is an indication that equalization has been applied correctly
|
which is an indication that equalization has been applied correctly
|
||||||
'''
|
"""
|
||||||
output_channel_tensor1 = _equalize.channel_range(tensor1, output_axis)
|
output_channel_tensor1 = _equalize.channel_range(tensor1, output_axis)
|
||||||
input_channel_tensor2 = _equalize.channel_range(tensor2, input_axis)
|
input_channel_tensor2 = _equalize.channel_range(tensor2, input_axis)
|
||||||
|
|
||||||
|
|
@ -23,18 +22,17 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||||
self.assertEqual(output_channel_tensor1, input_channel_tensor2)
|
self.assertEqual(output_channel_tensor1, input_channel_tensor2)
|
||||||
|
|
||||||
def getModule(self, model, name):
|
def getModule(self, model, name):
|
||||||
''' Given the name is a submodule to a model, return the submodule
|
"""Given the name is a submodule to a model, return the submodule"""
|
||||||
'''
|
|
||||||
curr = model
|
curr = model
|
||||||
name = name.split('.')
|
name = name.split(".")
|
||||||
for subname in name:
|
for subname in name:
|
||||||
curr = curr._modules[subname]
|
curr = curr._modules[subname]
|
||||||
return curr
|
return curr
|
||||||
|
|
||||||
def test_cross_layer_equalization(self):
|
def test_cross_layer_equalization(self):
|
||||||
''' applies _equalize.cross_layer_equalization on two modules and checks
|
"""applies _equalize.cross_layer_equalization on two modules and checks
|
||||||
to make sure channels ranges are equivalent
|
to make sure channels ranges are equivalent
|
||||||
'''
|
"""
|
||||||
module1 = nn.Conv2d(3, 4, 2)
|
module1 = nn.Conv2d(3, 4, 2)
|
||||||
module2 = nn.Linear(4, 4)
|
module2 = nn.Linear(4, 4)
|
||||||
|
|
||||||
|
|
@ -45,13 +43,18 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||||
|
|
||||||
mod_tensor1, mod_tensor2 = module1.weight, module2.weight
|
mod_tensor1, mod_tensor2 = module1.weight, module2.weight
|
||||||
|
|
||||||
self.checkChannelsEqualized(mod_tensor1, mod_tensor2, module1_output_channel_axis, module2_input_channel_axis)
|
self.checkChannelsEqualized(
|
||||||
|
mod_tensor1,
|
||||||
|
mod_tensor2,
|
||||||
|
module1_output_channel_axis,
|
||||||
|
module2_input_channel_axis,
|
||||||
|
)
|
||||||
|
|
||||||
def test_converged(self):
|
def test_converged(self):
|
||||||
''' Sanity checks on _equalize.converged working
|
"""Sanity checks on _equalize.converged working
|
||||||
identical modules should return true
|
identical modules should return true
|
||||||
modules with high difference in weights should return false
|
modules with high difference in weights should return false
|
||||||
'''
|
"""
|
||||||
module1 = nn.Linear(3, 3)
|
module1 = nn.Linear(3, 3)
|
||||||
module2 = nn.Linear(3, 3)
|
module2 = nn.Linear(3, 3)
|
||||||
|
|
||||||
|
|
@ -59,18 +62,19 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||||
module2.weight = nn.parameter.Parameter(torch.zeros(module1.weight.size()))
|
module2.weight = nn.parameter.Parameter(torch.zeros(module1.weight.size()))
|
||||||
|
|
||||||
# input is a dictionary
|
# input is a dictionary
|
||||||
dictionary_1 = {'linear1': module1}
|
dictionary_1 = {"linear1": module1}
|
||||||
dictionary_2 = {'linear1': module2}
|
dictionary_2 = {"linear1": module2}
|
||||||
self.assertTrue(_equalize.converged(dictionary_1, dictionary_1, 1e-6))
|
self.assertTrue(_equalize.converged(dictionary_1, dictionary_1, 1e-6))
|
||||||
self.assertFalse(_equalize.converged(dictionary_1, dictionary_2, 1e-6))
|
self.assertFalse(_equalize.converged(dictionary_1, dictionary_2, 1e-6))
|
||||||
|
|
||||||
def test_equalize(self):
|
def test_equalize(self):
|
||||||
''' First checks to see if _equalize.equalize can handle multiple
|
"""First checks to see if _equalize.equalize can handle multiple
|
||||||
pair modules as input
|
pair modules as input
|
||||||
then checks correctness of the function by ensuring the equalized
|
then checks correctness of the function by ensuring the equalized
|
||||||
and unequalized versions of the model yield the same output
|
and unequalized versions of the model yield the same output
|
||||||
given the same input
|
given the same input
|
||||||
'''
|
"""
|
||||||
|
|
||||||
class ChainModule(nn.Module):
|
class ChainModule(nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -83,13 +87,16 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||||
x = self.linear2(x)
|
x = self.linear2(x)
|
||||||
x = self.linear3(x)
|
x = self.linear3(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
chain1 = ChainModule()
|
chain1 = ChainModule()
|
||||||
chain2 = copy.deepcopy(chain1)
|
chain2 = copy.deepcopy(chain1)
|
||||||
|
|
||||||
_equalize.equalize(chain1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6)
|
_equalize.equalize(
|
||||||
linear1 = self.getModule(chain1, 'linear1')
|
chain1, [["linear1", "linear2"], ["linear2", "linear3"]], 1e-6
|
||||||
linear2 = self.getModule(chain1, 'linear2')
|
)
|
||||||
linear3 = self.getModule(chain1, 'linear3')
|
linear1 = self.getModule(chain1, "linear1")
|
||||||
|
linear2 = self.getModule(chain1, "linear2")
|
||||||
|
linear3 = self.getModule(chain1, "linear3")
|
||||||
|
|
||||||
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
|
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
|
||||||
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
|
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
|
||||||
|
|
@ -98,7 +105,7 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||||
self.assertEqual(chain1(input), chain2(input))
|
self.assertEqual(chain1(input), chain2(input))
|
||||||
|
|
||||||
def test_equalize_fused_convrelu(self):
|
def test_equalize_fused_convrelu(self):
|
||||||
''' Checks to see if eager mode equalization supports fused
|
"""Checks to see if eager mode equalization supports fused
|
||||||
ConvReLU2d models
|
ConvReLU2d models
|
||||||
|
|
||||||
A model with 3 ConvReLU2d is constructed. Next, the conv2d and relu
|
A model with 3 ConvReLU2d is constructed. Next, the conv2d and relu
|
||||||
|
|
@ -106,7 +113,8 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||||
equalization applied. Finally, we ensure that the channels have been
|
equalization applied. Finally, we ensure that the channels have been
|
||||||
equalized and that the equalized and unequalized versions of the model
|
equalized and that the equalized and unequalized versions of the model
|
||||||
yield the same output given the same input
|
yield the same output given the same input
|
||||||
'''
|
"""
|
||||||
|
|
||||||
class M(nn.Module):
|
class M(nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -128,13 +136,15 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||||
|
|
||||||
model = M()
|
model = M()
|
||||||
|
|
||||||
fused_model1 = fuse_modules(model, [['conv1', 'relu1'], ['conv2', 'relu2'], ['conv3', 'relu3']])
|
fused_model1 = fuse_modules(
|
||||||
|
model, [["conv1", "relu1"], ["conv2", "relu2"], ["conv3", "relu3"]]
|
||||||
|
)
|
||||||
fused_model2 = copy.deepcopy(fused_model1)
|
fused_model2 = copy.deepcopy(fused_model1)
|
||||||
|
|
||||||
_equalize.equalize(fused_model1, [['conv1', 'conv2'], ['conv2', 'conv3']], 1e-6)
|
_equalize.equalize(fused_model1, [["conv1", "conv2"], ["conv2", "conv3"]], 1e-6)
|
||||||
conv1 = self.getModule(fused_model1, 'conv1')[0]
|
conv1 = self.getModule(fused_model1, "conv1")[0]
|
||||||
conv2 = self.getModule(fused_model1, 'conv2')[0]
|
conv2 = self.getModule(fused_model1, "conv2")[0]
|
||||||
conv3 = self.getModule(fused_model1, 'conv3')[0]
|
conv3 = self.getModule(fused_model1, "conv3")[0]
|
||||||
|
|
||||||
self.checkChannelsEqualized(conv1.weight, conv2.weight, 0, 1)
|
self.checkChannelsEqualized(conv1.weight, conv2.weight, 0, 1)
|
||||||
self.checkChannelsEqualized(conv2.weight, conv3.weight, 0, 1)
|
self.checkChannelsEqualized(conv2.weight, conv3.weight, 0, 1)
|
||||||
|
|
@ -144,7 +154,7 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||||
self.assertEqual(fused_model1(input), model(input))
|
self.assertEqual(fused_model1(input), model(input))
|
||||||
|
|
||||||
def test_equalize_fused_linearrelu(self):
|
def test_equalize_fused_linearrelu(self):
|
||||||
''' Checks to see if eager mode equalization supports fused
|
"""Checks to see if eager mode equalization supports fused
|
||||||
LinearReLU models
|
LinearReLU models
|
||||||
|
|
||||||
A model with 3 LinearReLU is constructed. Next, the linear and relu
|
A model with 3 LinearReLU is constructed. Next, the linear and relu
|
||||||
|
|
@ -152,7 +162,8 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||||
equalization applied. Finally, we ensure that the channels have been
|
equalization applied. Finally, we ensure that the channels have been
|
||||||
equalized and that the equalized and unequalized versions of the model
|
equalized and that the equalized and unequalized versions of the model
|
||||||
yield the same output given the same input
|
yield the same output given the same input
|
||||||
'''
|
"""
|
||||||
|
|
||||||
class M(nn.Module):
|
class M(nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -174,13 +185,17 @@ class TestEqualizeEager(QuantizationTestCase):
|
||||||
|
|
||||||
model = M()
|
model = M()
|
||||||
|
|
||||||
fused_model1 = fuse_modules(model, [['linear1', 'relu1'], ['linear2', 'relu2'], ['linear3', 'relu3']])
|
fused_model1 = fuse_modules(
|
||||||
|
model, [["linear1", "relu1"], ["linear2", "relu2"], ["linear3", "relu3"]]
|
||||||
|
)
|
||||||
fused_model2 = copy.deepcopy(fused_model1)
|
fused_model2 = copy.deepcopy(fused_model1)
|
||||||
|
|
||||||
_equalize.equalize(fused_model1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6)
|
_equalize.equalize(
|
||||||
linear1 = self.getModule(fused_model1, 'linear1')[0]
|
fused_model1, [["linear1", "linear2"], ["linear2", "linear3"]], 1e-6
|
||||||
linear2 = self.getModule(fused_model1, 'linear2')[0]
|
)
|
||||||
linear3 = self.getModule(fused_model1, 'linear3')[0]
|
linear1 = self.getModule(fused_model1, "linear1")[0]
|
||||||
|
linear2 = self.getModule(fused_model1, "linear2")[0]
|
||||||
|
linear3 = self.getModule(fused_model1, "linear3")[0]
|
||||||
|
|
||||||
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
|
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
|
||||||
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
|
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
|
||||||
|
|
|
||||||
|
|
@ -3,37 +3,35 @@
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.ao.nn.quantized as nnq
|
|
||||||
import torch.ao.nn.intrinsic as nni
|
import torch.ao.nn.intrinsic as nni
|
||||||
import torch.ao.nn.intrinsic.quantized as nniq
|
|
||||||
import torch.ao.nn.intrinsic.qat as nniqat
|
import torch.ao.nn.intrinsic.qat as nniqat
|
||||||
|
import torch.ao.nn.intrinsic.quantized as nniq
|
||||||
|
import torch.ao.nn.quantized as nnq
|
||||||
|
import torch.nn as nn
|
||||||
from torch.ao.quantization import (
|
from torch.ao.quantization import (
|
||||||
quantize,
|
|
||||||
prepare,
|
|
||||||
convert,
|
convert,
|
||||||
prepare_qat,
|
default_qat_qconfig,
|
||||||
quantize_qat,
|
default_qconfig,
|
||||||
fuse_modules,
|
fuse_modules,
|
||||||
fuse_modules_qat,
|
fuse_modules_qat,
|
||||||
|
prepare,
|
||||||
|
prepare_qat,
|
||||||
QConfig,
|
QConfig,
|
||||||
default_qconfig,
|
quantize,
|
||||||
default_qat_qconfig,
|
quantize_qat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.testing._internal.common_quantization import (
|
from torch.testing._internal.common_quantization import (
|
||||||
QuantizationTestCase,
|
|
||||||
ModelForFusion,
|
|
||||||
ModelWithSequentialFusion,
|
|
||||||
ModelForLinearBNFusion,
|
|
||||||
ModelForFusionWithBias,
|
|
||||||
ModelForConvTransposeBNFusion,
|
ModelForConvTransposeBNFusion,
|
||||||
|
ModelForFusion,
|
||||||
|
ModelForFusionWithBias,
|
||||||
|
ModelForLinearBNFusion,
|
||||||
|
ModelWithSequentialFusion,
|
||||||
|
QuantizationTestCase,
|
||||||
SingleLayerLinearModel,
|
SingleLayerLinearModel,
|
||||||
|
skipIfNoFBGEMM,
|
||||||
test_only_eval_fn,
|
test_only_eval_fn,
|
||||||
test_only_train_fn,
|
test_only_train_fn,
|
||||||
skipIfNoFBGEMM,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.testing._internal.common_quantized import (
|
from torch.testing._internal.common_quantized import (
|
||||||
override_quantized_engine,
|
override_quantized_engine,
|
||||||
supported_qengines,
|
supported_qengines,
|
||||||
|
|
@ -45,23 +43,38 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
def test_fuse_module_train(self):
|
def test_fuse_module_train(self):
|
||||||
model = ModelForFusion(default_qat_qconfig).train()
|
model = ModelForFusion(default_qat_qconfig).train()
|
||||||
# Test step by step fusion
|
# Test step by step fusion
|
||||||
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
|
model = fuse_modules_qat(model, ["conv1", "bn1", "relu1"])
|
||||||
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
|
model = fuse_modules_qat(model, ["sub1.conv", "sub1.bn"])
|
||||||
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
|
self.assertEqual(
|
||||||
msg="Fused Conv + BN + Relu first layer")
|
type(model.conv1),
|
||||||
self.assertEqual(type(model.bn1), torch.nn.Identity,
|
nni.ConvBnReLU2d,
|
||||||
msg="Fused Conv + BN + Relu (skipped BN)")
|
msg="Fused Conv + BN + Relu first layer",
|
||||||
self.assertEqual(type(model.relu1), torch.nn.Identity,
|
)
|
||||||
msg="Fused Conv + BN + Relu (skipped Relu)")
|
self.assertEqual(
|
||||||
|
type(model.bn1),
|
||||||
|
torch.nn.Identity,
|
||||||
|
msg="Fused Conv + BN + Relu (skipped BN)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.relu1),
|
||||||
|
torch.nn.Identity,
|
||||||
|
msg="Fused Conv + BN + Relu (skipped Relu)",
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(type(model.sub1.conv), nni.ConvBn2d,
|
self.assertEqual(
|
||||||
msg="Fused submodule Conv + BN")
|
type(model.sub1.conv), nni.ConvBn2d, msg="Fused submodule Conv + BN"
|
||||||
self.assertEqual(type(model.sub1.bn), torch.nn.Identity,
|
)
|
||||||
msg="Fused submodule Conv + BN (skipped BN)")
|
self.assertEqual(
|
||||||
self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d,
|
type(model.sub1.bn),
|
||||||
msg="Non-fused submodule Conv")
|
torch.nn.Identity,
|
||||||
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
|
msg="Fused submodule Conv + BN (skipped BN)",
|
||||||
msg="Non-fused submodule ReLU")
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.sub2.conv), torch.nn.Conv2d, msg="Non-fused submodule Conv"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU"
|
||||||
|
)
|
||||||
model = prepare_qat(model)
|
model = prepare_qat(model)
|
||||||
self.checkObservers(model)
|
self.checkObservers(model)
|
||||||
|
|
||||||
|
|
@ -89,69 +102,121 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
test_only_eval_fn(model, self.img_data_1d)
|
test_only_eval_fn(model, self.img_data_1d)
|
||||||
self.checkNoQconfig(model)
|
self.checkNoQconfig(model)
|
||||||
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'",
|
||||||
|
):
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
model = ModelForFusion(default_qat_qconfig).train()
|
model = ModelForFusion(default_qat_qconfig).train()
|
||||||
model = fuse_modules_qat(
|
model = fuse_modules_qat(
|
||||||
model,
|
model, [["conv1", "bn1", "relu1"], ["sub1.conv", "sub1.bn"]]
|
||||||
[['conv1', 'bn1', 'relu1'],
|
)
|
||||||
['sub1.conv', 'sub1.bn']])
|
|
||||||
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
|
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
|
||||||
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'",
|
||||||
|
):
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
|
|
||||||
def test_fuse_module_eval(self):
|
def test_fuse_module_eval(self):
|
||||||
model = ModelForFusion(default_qconfig)
|
model = ModelForFusion(default_qconfig)
|
||||||
model.eval()
|
model.eval()
|
||||||
model = fuse_modules(
|
model = fuse_modules(
|
||||||
model,
|
model,
|
||||||
[['conv3', 'bn3', 'relu4'],
|
[
|
||||||
['conv1', 'bn1', 'relu1'],
|
["conv3", "bn3", "relu4"],
|
||||||
['conv2', 'relu2'],
|
["conv1", "bn1", "relu1"],
|
||||||
['bn2', 'relu3'],
|
["conv2", "relu2"],
|
||||||
['sub1.conv', 'sub1.bn']])
|
["bn2", "relu3"],
|
||||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
["sub1.conv", "sub1.bn"],
|
||||||
msg="Fused Conv + BN + Relu first layer (BN is folded)")
|
],
|
||||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
)
|
||||||
msg="Fused Conv + BN + Relu (Conv + folded BN only)")
|
self.assertEqual(
|
||||||
self.assertEqual(type(model.conv1[1]), nn.ReLU,
|
type(model.conv1),
|
||||||
msg="Fused Conv + BN + Relu second layer (Relu only)")
|
nni.ConvReLU2d,
|
||||||
self.assertEqual(type(model.bn1), nn.Identity,
|
msg="Fused Conv + BN + Relu first layer (BN is folded)",
|
||||||
msg="Fused Conv + BN + Relu second layer (Skipped BN)")
|
)
|
||||||
self.assertEqual(type(model.relu1), nn.Identity,
|
self.assertEqual(
|
||||||
msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
|
type(model.conv1[0]),
|
||||||
self.assertEqual(type(model.conv2), nni.ConvReLU3d,
|
nn.Conv2d,
|
||||||
msg="Fused Conv + BN + Relu first layer (BN is folded)")
|
msg="Fused Conv + BN + Relu (Conv + folded BN only)",
|
||||||
self.assertEqual(type(model.bn2), nni.BNReLU3d,
|
)
|
||||||
msg="Fused BN + Relu first layer (Relu is folded))")
|
self.assertEqual(
|
||||||
self.assertEqual(type(model.relu3), nn.Identity,
|
type(model.conv1[1]),
|
||||||
msg="Fused BN + Relu second layer (Skipped Relu)")
|
nn.ReLU,
|
||||||
self.assertEqual(type(model.conv2[0]), nn.Conv3d,
|
msg="Fused Conv + BN + Relu second layer (Relu only)",
|
||||||
msg="Fused Conv + BN + Relu (Conv + folded BN only)")
|
)
|
||||||
self.assertEqual(type(model.conv2[1]), nn.ReLU,
|
self.assertEqual(
|
||||||
msg="Fused Conv + BN + Relu second layer (Relu only)")
|
type(model.bn1),
|
||||||
self.assertEqual(type(model.relu2), nn.Identity,
|
nn.Identity,
|
||||||
msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
|
msg="Fused Conv + BN + Relu second layer (Skipped BN)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.relu1),
|
||||||
|
nn.Identity,
|
||||||
|
msg="Fused Conv + BN + Relu second layer (Skipped Relu)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.conv2),
|
||||||
|
nni.ConvReLU3d,
|
||||||
|
msg="Fused Conv + BN + Relu first layer (BN is folded)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.bn2),
|
||||||
|
nni.BNReLU3d,
|
||||||
|
msg="Fused BN + Relu first layer (Relu is folded))",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.relu3),
|
||||||
|
nn.Identity,
|
||||||
|
msg="Fused BN + Relu second layer (Skipped Relu)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.conv2[0]),
|
||||||
|
nn.Conv3d,
|
||||||
|
msg="Fused Conv + BN + Relu (Conv + folded BN only)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.conv2[1]),
|
||||||
|
nn.ReLU,
|
||||||
|
msg="Fused Conv + BN + Relu second layer (Relu only)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.relu2),
|
||||||
|
nn.Identity,
|
||||||
|
msg="Fused Conv + BN + Relu second layer (Skipped Relu)",
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(type(model.conv3), nni.ConvReLU1d,
|
self.assertEqual(
|
||||||
msg="Fused Conv + Relu for Conv1d (folded BN)")
|
type(model.conv3),
|
||||||
self.assertEqual(type(model.conv3[0]), nn.Conv1d,
|
nni.ConvReLU1d,
|
||||||
msg="Fused Conv + Relu for Conv1d ")
|
msg="Fused Conv + Relu for Conv1d (folded BN)",
|
||||||
self.assertEqual(type(model.conv3[1]), nn.ReLU,
|
)
|
||||||
msg="Fused Conv + Relu for Conv1d")
|
self.assertEqual(
|
||||||
self.assertEqual(type(model.bn3), nn.Identity,
|
type(model.conv3[0]), nn.Conv1d, msg="Fused Conv + Relu for Conv1d "
|
||||||
msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)")
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.conv3[1]), nn.ReLU, msg="Fused Conv + Relu for Conv1d"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.bn3),
|
||||||
|
nn.Identity,
|
||||||
|
msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)",
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(type(model.sub1.conv), nn.Conv2d,
|
self.assertEqual(
|
||||||
msg="Fused submodule Conv + folded BN")
|
type(model.sub1.conv), nn.Conv2d, msg="Fused submodule Conv + folded BN"
|
||||||
self.assertEqual(type(model.sub1.bn), nn.Identity,
|
)
|
||||||
msg="Fused submodule (skipped BN)")
|
self.assertEqual(
|
||||||
self.assertEqual(type(model.sub2.conv), nn.Conv2d,
|
type(model.sub1.bn), nn.Identity, msg="Fused submodule (skipped BN)"
|
||||||
msg="Non-fused submodule Conv")
|
)
|
||||||
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
|
self.assertEqual(
|
||||||
msg="Non-fused submodule ReLU")
|
type(model.sub2.conv), nn.Conv2d, msg="Non-fused submodule Conv"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU"
|
||||||
|
)
|
||||||
|
|
||||||
model = prepare(model)
|
model = prepare(model)
|
||||||
self.checkObservers(model)
|
self.checkObservers(model)
|
||||||
|
|
@ -176,11 +241,14 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
model = ModelForFusion(default_qconfig).eval()
|
model = ModelForFusion(default_qconfig).eval()
|
||||||
model = fuse_modules(
|
model = fuse_modules(
|
||||||
model,
|
model,
|
||||||
[['conv1', 'bn1', 'relu1'],
|
[
|
||||||
['conv2', 'relu2'],
|
["conv1", "bn1", "relu1"],
|
||||||
['bn2', 'relu3'],
|
["conv2", "relu2"],
|
||||||
['sub1.conv', 'sub1.bn'],
|
["bn2", "relu3"],
|
||||||
['conv3', 'bn3', 'relu4']])
|
["sub1.conv", "sub1.bn"],
|
||||||
|
["conv3", "bn3", "relu4"],
|
||||||
|
],
|
||||||
|
)
|
||||||
model = quantize(model, test_only_eval_fn, [self.img_data_1d])
|
model = quantize(model, test_only_eval_fn, [self.img_data_1d])
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
|
|
@ -190,27 +258,46 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
model = ModelWithSequentialFusion().train()
|
model = ModelWithSequentialFusion().train()
|
||||||
model.to(torch.float)
|
model.to(torch.float)
|
||||||
fuse_modules_qat(
|
fuse_modules_qat(
|
||||||
model, [['conv1', 'relu1'] ,
|
model,
|
||||||
['features.0.0', 'features.0.1', 'features.0.2'],
|
[
|
||||||
['features.1.0', 'features.1.1', 'features.1.2'],
|
["conv1", "relu1"],
|
||||||
['features.2.0', 'features.2.1', 'features.2.2'],
|
["features.0.0", "features.0.1", "features.0.2"],
|
||||||
['classifier.0', 'classifier.1']],
|
["features.1.0", "features.1.1", "features.1.2"],
|
||||||
inplace=True)
|
["features.2.0", "features.2.1", "features.2.2"],
|
||||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
["classifier.0", "classifier.1"],
|
||||||
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
],
|
||||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
inplace=True,
|
||||||
msg="Fused Conv + Relu: Conv2d")
|
)
|
||||||
self.assertEqual(type(model.conv1[1]), nn.ReLU,
|
self.assertEqual(
|
||||||
msg="Fused Conv + Relu: Relu")
|
type(model.conv1),
|
||||||
self.assertEqual(type(model.relu1), nn.Identity,
|
nni.ConvReLU2d,
|
||||||
msg="Fused Conv + Relu: Identity")
|
msg="Fused Conv + Relu: nni.ConvReLU2d",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity"
|
||||||
|
)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d,
|
self.assertEqual(
|
||||||
msg="Fused submodule Conv + folded BN")
|
type(model.features[i][0]),
|
||||||
self.assertEqual(type(model.features[i][1]), nn.Identity,
|
nni.ConvBnReLU2d,
|
||||||
msg="Fused submodule (skipped BN)")
|
msg="Fused submodule Conv + folded BN",
|
||||||
self.assertEqual(type(model.features[i][2]), nn.Identity,
|
)
|
||||||
msg="Non-fused submodule Conv")
|
self.assertEqual(
|
||||||
|
type(model.features[i][1]),
|
||||||
|
nn.Identity,
|
||||||
|
msg="Fused submodule (skipped BN)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.features[i][2]),
|
||||||
|
nn.Identity,
|
||||||
|
msg="Non-fused submodule Conv",
|
||||||
|
)
|
||||||
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
|
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
|
||||||
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
||||||
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
||||||
|
|
@ -218,17 +305,26 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
self.checkObservers(model)
|
self.checkObservers(model)
|
||||||
model(self.img_data_2d[0][0])
|
model(self.img_data_2d[0][0])
|
||||||
|
|
||||||
|
|
||||||
def checkQAT(model):
|
def checkQAT(model):
|
||||||
self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
|
self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
|
||||||
self.assertEqual(type(model.relu1), nn.Identity)
|
self.assertEqual(type(model.relu1), nn.Identity)
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d,
|
self.assertEqual(
|
||||||
msg="Fused submodule Conv + folded BN")
|
type(model.features[i][0]),
|
||||||
self.assertEqual(type(model.features[i][1]), nn.Identity,
|
nniqat.ConvBnReLU2d,
|
||||||
msg="Fused submodule (skipped BN)")
|
msg="Fused submodule Conv + folded BN",
|
||||||
self.assertEqual(type(model.features[i][2]), nn.Identity,
|
)
|
||||||
msg="Non-fused submodule Conv")
|
self.assertEqual(
|
||||||
|
type(model.features[i][1]),
|
||||||
|
nn.Identity,
|
||||||
|
msg="Fused submodule (skipped BN)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.features[i][2]),
|
||||||
|
nn.Identity,
|
||||||
|
msg="Non-fused submodule Conv",
|
||||||
|
)
|
||||||
self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
|
self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
|
||||||
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
||||||
|
|
||||||
|
|
@ -245,27 +341,45 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
model.to(torch.float)
|
model.to(torch.float)
|
||||||
fuse_modules(
|
fuse_modules(
|
||||||
model,
|
model,
|
||||||
[['conv1', 'relu1'],
|
[
|
||||||
['features.0.0', 'features.0.1', 'features.0.2'],
|
["conv1", "relu1"],
|
||||||
['features.1.0', 'features.1.1', 'features.1.2'],
|
["features.0.0", "features.0.1", "features.0.2"],
|
||||||
['features.2.0', 'features.2.1', 'features.2.2'],
|
["features.1.0", "features.1.1", "features.1.2"],
|
||||||
['classifier.0', 'classifier.1']],
|
["features.2.0", "features.2.1", "features.2.2"],
|
||||||
inplace=True)
|
["classifier.0", "classifier.1"],
|
||||||
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
],
|
||||||
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
inplace=True,
|
||||||
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
)
|
||||||
msg="Fused Conv + Relu: Conv2d")
|
self.assertEqual(
|
||||||
self.assertEqual(type(model.conv1[1]), nn.ReLU,
|
type(model.conv1),
|
||||||
msg="Fused Conv + Relu: Relu")
|
nni.ConvReLU2d,
|
||||||
self.assertEqual(type(model.relu1), nn.Identity,
|
msg="Fused Conv + Relu: nni.ConvReLU2d",
|
||||||
msg="Fused Conv + Relu: Identity")
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu"
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity"
|
||||||
|
)
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d,
|
self.assertEqual(
|
||||||
msg="Fused submodule Conv + folded BN")
|
type(model.features[i][0]),
|
||||||
self.assertEqual(type(model.features[i][1]), nn.Identity,
|
nni.ConvReLU2d,
|
||||||
msg="Fused submodule (skipped BN)")
|
msg="Fused submodule Conv + folded BN",
|
||||||
self.assertEqual(type(model.features[i][2]), nn.Identity,
|
)
|
||||||
msg="Non-fused submodule Conv")
|
self.assertEqual(
|
||||||
|
type(model.features[i][1]),
|
||||||
|
nn.Identity,
|
||||||
|
msg="Fused submodule (skipped BN)",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
type(model.features[i][2]),
|
||||||
|
nn.Identity,
|
||||||
|
msg="Non-fused submodule Conv",
|
||||||
|
)
|
||||||
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
|
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
|
||||||
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
||||||
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
||||||
|
|
@ -297,12 +411,12 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
out_ref = model_ref(self.img_data_2d[0][0])
|
out_ref = model_ref(self.img_data_2d[0][0])
|
||||||
|
|
||||||
# fused model
|
# fused model
|
||||||
model_orig.qconfig = QConfig(activation=torch.nn.Identity,
|
model_orig.qconfig = QConfig(
|
||||||
weight=torch.nn.Identity)
|
activation=torch.nn.Identity, weight=torch.nn.Identity
|
||||||
|
)
|
||||||
model = fuse_modules_qat(
|
model = fuse_modules_qat(
|
||||||
model_orig,
|
model_orig, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]]
|
||||||
[["conv1", "bn1", "relu1"],
|
)
|
||||||
["conv2", "bn2"]])
|
|
||||||
prep_model = prepare_qat(model, inplace=False)
|
prep_model = prepare_qat(model, inplace=False)
|
||||||
# output with fusion but no observers.
|
# output with fusion but no observers.
|
||||||
out_fused = prep_model(self.img_data_2d[0][0])
|
out_fused = prep_model(self.img_data_2d[0][0])
|
||||||
|
|
@ -332,7 +446,6 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
|
|
||||||
checkQAT(model)
|
checkQAT(model)
|
||||||
|
|
||||||
|
|
||||||
def test_fusion_linear_bn_eval(self):
|
def test_fusion_linear_bn_eval(self):
|
||||||
model = ModelForLinearBNFusion().train()
|
model = ModelForLinearBNFusion().train()
|
||||||
inp1 = torch.randn(8, 20)
|
inp1 = torch.randn(8, 20)
|
||||||
|
|
@ -357,7 +470,9 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
model.eval()
|
model.eval()
|
||||||
golden = model(inp2)
|
golden = model(inp2)
|
||||||
|
|
||||||
model = fuse_modules(model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]])
|
model = fuse_modules(
|
||||||
|
model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]]
|
||||||
|
)
|
||||||
self.assertEqual(type(model.bn1), nn.Identity)
|
self.assertEqual(type(model.bn1), nn.Identity)
|
||||||
self.assertEqual(type(model.bn2), nn.Identity)
|
self.assertEqual(type(model.bn2), nn.Identity)
|
||||||
self.assertEqual(type(model.bn3), nn.Identity)
|
self.assertEqual(type(model.bn3), nn.Identity)
|
||||||
|
|
@ -384,50 +499,68 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
model = ModelForFusion(default_qat_qconfig).train()
|
model = ModelForFusion(default_qat_qconfig).train()
|
||||||
|
|
||||||
counter = {
|
counter = {
|
||||||
'pre_forwards': 0,
|
"pre_forwards": 0,
|
||||||
'forwards': 0,
|
"forwards": 0,
|
||||||
}
|
}
|
||||||
fused = False
|
fused = False
|
||||||
|
|
||||||
def fw_pre_hook(fused_module_class, h_module, input):
|
def fw_pre_hook(fused_module_class, h_module, input):
|
||||||
if fused:
|
if fused:
|
||||||
self.assertEqual(type(h_module), fused_module_class,
|
self.assertEqual(
|
||||||
"After fusion owner of the first module's forward pre hook is not a fused module")
|
type(h_module),
|
||||||
counter['pre_forwards'] += 1
|
fused_module_class,
|
||||||
|
"After fusion owner of the first module's forward pre hook is not a fused module",
|
||||||
|
)
|
||||||
|
counter["pre_forwards"] += 1
|
||||||
|
|
||||||
def fw_hook(fused_module_class, h_module, input, output):
|
def fw_hook(fused_module_class, h_module, input, output):
|
||||||
if fused:
|
if fused:
|
||||||
self.assertEqual(type(h_module), fused_module_class,
|
self.assertEqual(
|
||||||
"After fusion owner of the last module's forward hook is not a fused module")
|
type(h_module),
|
||||||
counter['forwards'] += 1
|
fused_module_class,
|
||||||
|
"After fusion owner of the last module's forward hook is not a fused module",
|
||||||
|
)
|
||||||
|
counter["forwards"] += 1
|
||||||
|
|
||||||
# Registering two pre and two post forward hooks, thus expecting counter increment by two each inference
|
# Registering two pre and two post forward hooks, thus expecting counter increment by two each inference
|
||||||
model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args))
|
model.conv1.register_forward_pre_hook(
|
||||||
model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args))
|
lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args)
|
||||||
model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args))
|
)
|
||||||
|
model.sub1.conv.register_forward_pre_hook(
|
||||||
|
lambda *args: fw_pre_hook(nni.ConvBn2d, *args)
|
||||||
|
)
|
||||||
|
model.relu1.register_forward_hook(
|
||||||
|
lambda *args: fw_hook(nni.ConvBnReLU2d, *args)
|
||||||
|
)
|
||||||
model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args))
|
model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args))
|
||||||
|
|
||||||
test_only_eval_fn(model, self.img_data_1d)
|
test_only_eval_fn(model, self.img_data_1d)
|
||||||
self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
|
self.assertEqual(counter["pre_forwards"], 2 * len(self.img_data_1d))
|
||||||
self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))
|
self.assertEqual(counter["forwards"], 2 * len(self.img_data_1d))
|
||||||
|
|
||||||
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
|
model = fuse_modules_qat(model, ["conv1", "bn1", "relu1"])
|
||||||
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
|
model = fuse_modules_qat(model, ["sub1.conv", "sub1.bn"])
|
||||||
|
|
||||||
fused = True
|
fused = True
|
||||||
before_fusion_pre_count = counter['pre_forwards']
|
before_fusion_pre_count = counter["pre_forwards"]
|
||||||
before_fusion_post_count = counter['forwards']
|
before_fusion_post_count = counter["forwards"]
|
||||||
test_only_eval_fn(model, self.img_data_1d)
|
test_only_eval_fn(model, self.img_data_1d)
|
||||||
self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d))
|
self.assertEqual(
|
||||||
self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d))
|
counter["pre_forwards"] - before_fusion_pre_count, 2 * len(self.img_data_1d)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
counter["forwards"] - before_fusion_post_count, 2 * len(self.img_data_1d)
|
||||||
|
)
|
||||||
|
|
||||||
def test_fuse_modules_with_nested_hooks(self):
|
def test_fuse_modules_with_nested_hooks(self):
|
||||||
r"""Test case that checks whether a nested module with sub-sub modules registered with hooks
|
r"""Test case that checks whether a nested module with sub-sub modules registered with hooks
|
||||||
can be safely fused. Safeguard for issues similar to https://github.com/pytorch/pytorch/issues/105063
|
can be safely fused. Safeguard for issues similar to https://github.com/pytorch/pytorch/issues/105063
|
||||||
in the future.
|
in the future.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def myhook(*x):
|
def myhook(*x):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
for qengine in supported_qengines:
|
for qengine in supported_qengines:
|
||||||
with override_quantized_engine(qengine):
|
with override_quantized_engine(qengine):
|
||||||
model = ModelWithSequentialFusion().eval()
|
model = ModelWithSequentialFusion().eval()
|
||||||
|
|
@ -435,28 +568,32 @@ class TestFuseEager(QuantizationTestCase):
|
||||||
for sub_model in model.modules():
|
for sub_model in model.modules():
|
||||||
if isinstance(sub_model, nn.Sequential):
|
if isinstance(sub_model, nn.Sequential):
|
||||||
for layer in sub_model:
|
for layer in sub_model:
|
||||||
if hasattr(layer, 'register_forward_hook'):
|
if hasattr(layer, "register_forward_hook"):
|
||||||
layer.register_forward_hook(myhook)
|
layer.register_forward_hook(myhook)
|
||||||
|
|
||||||
fuse_modules(model, [['features.0.0', 'features.0.1', 'features.0.2']], inplace=True)
|
fuse_modules(
|
||||||
|
model,
|
||||||
|
[["features.0.0", "features.0.1", "features.0.2"]],
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
type(model.features[0][0]),
|
type(model.features[0][0]),
|
||||||
nni.ConvReLU2d,
|
nni.ConvReLU2d,
|
||||||
msg="Fused submodule Conv + folded BN"
|
msg="Fused submodule Conv + folded BN",
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
type(model.features[0][1]),
|
type(model.features[0][1]),
|
||||||
nn.Identity,
|
nn.Identity,
|
||||||
msg="Fused submodule (skipped BN)"
|
msg="Fused submodule (skipped BN)",
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
type(model.features[0][2]),
|
type(model.features[0][2]),
|
||||||
nn.Identity,
|
nn.Identity,
|
||||||
msg="Non-fused submodule Conv"
|
msg="Non-fused submodule Conv",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"This test file is not meant to be run directly, use:\n\n"
|
"This test file is not meant to be run directly, use:\n\n"
|
||||||
"\tpython test/test_quantization.py TESTNAME\n\n"
|
"\tpython test/test_quantization.py TESTNAME\n\n"
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
# Owner(s): ["oncall: quantization"]
|
# Owner(s): ["oncall: quantization"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch.testing._internal.common_quantization import (
|
from torch.testing._internal.common_quantization import (
|
||||||
QuantizationTestCase,
|
|
||||||
ModelMultipleOps,
|
ModelMultipleOps,
|
||||||
ModelMultipleOpsNoAvgPool,
|
ModelMultipleOpsNoAvgPool,
|
||||||
|
QuantizationTestCase,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_quantized import (
|
from torch.testing._internal.common_quantized import (
|
||||||
override_quantized_engine,
|
override_quantized_engine,
|
||||||
supported_qengines,
|
supported_qengines,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestModelNumericsEager(QuantizationTestCase):
|
class TestModelNumericsEager(QuantizationTestCase):
|
||||||
def test_float_quant_compare_per_tensor(self):
|
def test_float_quant_compare_per_tensor(self):
|
||||||
for qengine in supported_qengines:
|
for qengine in supported_qengines:
|
||||||
|
|
@ -25,16 +25,24 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||||
qModel = torch.ao.quantization.QuantWrapper(my_model)
|
qModel = torch.ao.quantization.QuantWrapper(my_model)
|
||||||
qModel.eval()
|
qModel.eval()
|
||||||
qModel.qconfig = torch.ao.quantization.default_qconfig
|
qModel.qconfig = torch.ao.quantization.default_qconfig
|
||||||
torch.ao.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
torch.ao.quantization.fuse_modules(
|
||||||
|
qModel.module, [["conv1", "bn1", "relu1"]], inplace=True
|
||||||
|
)
|
||||||
torch.ao.quantization.prepare(qModel, inplace=True)
|
torch.ao.quantization.prepare(qModel, inplace=True)
|
||||||
qModel(calib_data)
|
qModel(calib_data)
|
||||||
torch.ao.quantization.convert(qModel, inplace=True)
|
torch.ao.quantization.convert(qModel, inplace=True)
|
||||||
out_q = qModel(eval_data)
|
out_q = qModel(eval_data)
|
||||||
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q))
|
SQNRdB = 20 * torch.log10(
|
||||||
|
torch.norm(out_ref) / torch.norm(out_ref - out_q)
|
||||||
|
)
|
||||||
# Quantized model output should be close to floating point model output numerically
|
# Quantized model output should be close to floating point model output numerically
|
||||||
# Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired
|
# Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired
|
||||||
# output
|
# output
|
||||||
self.assertGreater(SQNRdB, 30, msg='Quantized model numerics diverge from float, expect SQNR > 30 dB')
|
self.assertGreater(
|
||||||
|
SQNRdB,
|
||||||
|
30,
|
||||||
|
msg="Quantized model numerics diverge from float, expect SQNR > 30 dB",
|
||||||
|
)
|
||||||
|
|
||||||
def test_float_quant_compare_per_channel(self):
|
def test_float_quant_compare_per_channel(self):
|
||||||
# Test for per-channel Quant
|
# Test for per-channel Quant
|
||||||
|
|
@ -47,7 +55,9 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||||
q_model = torch.ao.quantization.QuantWrapper(my_model)
|
q_model = torch.ao.quantization.QuantWrapper(my_model)
|
||||||
q_model.eval()
|
q_model.eval()
|
||||||
q_model.qconfig = torch.ao.quantization.default_per_channel_qconfig
|
q_model.qconfig = torch.ao.quantization.default_per_channel_qconfig
|
||||||
torch.ao.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
torch.ao.quantization.fuse_modules(
|
||||||
|
q_model.module, [["conv1", "bn1", "relu1"]], inplace=True
|
||||||
|
)
|
||||||
torch.ao.quantization.prepare(q_model)
|
torch.ao.quantization.prepare(q_model)
|
||||||
q_model(calib_data)
|
q_model(calib_data)
|
||||||
torch.ao.quantization.convert(q_model)
|
torch.ao.quantization.convert(q_model)
|
||||||
|
|
@ -55,7 +65,11 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||||
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q))
|
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q))
|
||||||
# Quantized model output should be close to floating point model output numerically
|
# Quantized model output should be close to floating point model output numerically
|
||||||
# Setting target SQNR to be 35 dB
|
# Setting target SQNR to be 35 dB
|
||||||
self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB')
|
self.assertGreater(
|
||||||
|
SQNRdB,
|
||||||
|
35,
|
||||||
|
msg="Quantized model numerics diverge from float, expect SQNR > 35 dB",
|
||||||
|
)
|
||||||
|
|
||||||
def test_fake_quant_true_quant_compare(self):
|
def test_fake_quant_true_quant_compare(self):
|
||||||
for qengine in supported_qengines:
|
for qengine in supported_qengines:
|
||||||
|
|
@ -69,7 +83,9 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||||
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
||||||
fq_model.train()
|
fq_model.train()
|
||||||
fq_model.qconfig = torch.ao.quantization.default_qat_qconfig
|
fq_model.qconfig = torch.ao.quantization.default_qat_qconfig
|
||||||
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
torch.ao.quantization.fuse_modules_qat(
|
||||||
|
fq_model.module, [["conv1", "bn1", "relu1"]], inplace=True
|
||||||
|
)
|
||||||
torch.ao.quantization.prepare_qat(fq_model)
|
torch.ao.quantization.prepare_qat(fq_model)
|
||||||
fq_model.eval()
|
fq_model.eval()
|
||||||
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
||||||
|
|
@ -78,14 +94,26 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||||
fq_model.apply(torch.ao.quantization.enable_fake_quant)
|
fq_model.apply(torch.ao.quantization.enable_fake_quant)
|
||||||
fq_model.apply(torch.ao.quantization.disable_observer)
|
fq_model.apply(torch.ao.quantization.disable_observer)
|
||||||
out_fq = fq_model(eval_data)
|
out_fq = fq_model(eval_data)
|
||||||
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
|
SQNRdB = 20 * torch.log10(
|
||||||
|
torch.norm(out_ref) / torch.norm(out_ref - out_fq)
|
||||||
|
)
|
||||||
# Quantized model output should be close to floating point model output numerically
|
# Quantized model output should be close to floating point model output numerically
|
||||||
# Setting target SQNR to be 35 dB
|
# Setting target SQNR to be 35 dB
|
||||||
self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB')
|
self.assertGreater(
|
||||||
|
SQNRdB,
|
||||||
|
35,
|
||||||
|
msg="Quantized model numerics diverge from float, expect SQNR > 35 dB",
|
||||||
|
)
|
||||||
torch.ao.quantization.convert(fq_model)
|
torch.ao.quantization.convert(fq_model)
|
||||||
out_q = fq_model(eval_data)
|
out_q = fq_model(eval_data)
|
||||||
SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10))
|
SQNRdB = 20 * torch.log10(
|
||||||
self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB')
|
torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10)
|
||||||
|
)
|
||||||
|
self.assertGreater(
|
||||||
|
SQNRdB,
|
||||||
|
60,
|
||||||
|
msg="Fake quant and true quant numerics diverge, expect SQNR > 60 dB",
|
||||||
|
)
|
||||||
|
|
||||||
# Test to compare weight only quantized model numerics and
|
# Test to compare weight only quantized model numerics and
|
||||||
# activation only quantized model numerics with float
|
# activation only quantized model numerics with float
|
||||||
|
|
@ -95,8 +123,10 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||||
torch.manual_seed(67)
|
torch.manual_seed(67)
|
||||||
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
|
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
|
||||||
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
|
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
|
||||||
qconfigset = {torch.ao.quantization.default_weight_only_qconfig,
|
qconfigset = {
|
||||||
torch.ao.quantization.default_activation_only_qconfig}
|
torch.ao.quantization.default_weight_only_qconfig,
|
||||||
|
torch.ao.quantization.default_activation_only_qconfig,
|
||||||
|
}
|
||||||
SQNRTarget = [35, 45]
|
SQNRTarget = [35, 45]
|
||||||
for idx, qconfig in enumerate(qconfigset):
|
for idx, qconfig in enumerate(qconfigset):
|
||||||
my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
|
my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
|
||||||
|
|
@ -105,7 +135,9 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||||
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
fq_model = torch.ao.quantization.QuantWrapper(my_model)
|
||||||
fq_model.train()
|
fq_model.train()
|
||||||
fq_model.qconfig = qconfig
|
fq_model.qconfig = qconfig
|
||||||
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True)
|
torch.ao.quantization.fuse_modules_qat(
|
||||||
|
fq_model.module, [["conv1", "bn1", "relu1"]], inplace=True
|
||||||
|
)
|
||||||
torch.ao.quantization.prepare_qat(fq_model)
|
torch.ao.quantization.prepare_qat(fq_model)
|
||||||
fq_model.eval()
|
fq_model.eval()
|
||||||
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
fq_model.apply(torch.ao.quantization.disable_fake_quant)
|
||||||
|
|
@ -114,11 +146,19 @@ class TestModelNumericsEager(QuantizationTestCase):
|
||||||
fq_model.apply(torch.ao.quantization.enable_fake_quant)
|
fq_model.apply(torch.ao.quantization.enable_fake_quant)
|
||||||
fq_model.apply(torch.ao.quantization.disable_observer)
|
fq_model.apply(torch.ao.quantization.disable_observer)
|
||||||
out_fq = fq_model(eval_data)
|
out_fq = fq_model(eval_data)
|
||||||
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq))
|
SQNRdB = 20 * torch.log10(
|
||||||
self.assertGreater(SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float')
|
torch.norm(out_ref) / torch.norm(out_ref - out_fq)
|
||||||
|
)
|
||||||
|
self.assertGreater(
|
||||||
|
SQNRdB,
|
||||||
|
SQNRTarget[idx],
|
||||||
|
msg="Quantized model numerics diverge from float",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
raise RuntimeError(
|
||||||
"\tpython test/test_quantization.py TESTNAME\n\n"
|
"This test file is not meant to be run directly, use:\n\n"
|
||||||
"instead.")
|
"\tpython test/test_quantization.py TESTNAME\n\n"
|
||||||
|
"instead."
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -2,43 +2,45 @@
|
||||||
# ruff: noqa: F841
|
# ruff: noqa: F841
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.ao.nn.quantized as nnq
|
import torch.ao.nn.quantized as nnq
|
||||||
from torch.ao.quantization import (
|
import torch.nn as nn
|
||||||
DeQuantStub,
|
|
||||||
QuantStub,
|
|
||||||
convert,
|
|
||||||
default_qconfig,
|
|
||||||
prepare,
|
|
||||||
quantize,
|
|
||||||
quantize_dynamic,
|
|
||||||
)
|
|
||||||
from torch.ao.ns._numeric_suite import (
|
from torch.ao.ns._numeric_suite import (
|
||||||
OutputLogger,
|
|
||||||
Shadow,
|
|
||||||
ShadowLogger,
|
|
||||||
compare_model_outputs,
|
compare_model_outputs,
|
||||||
compare_model_stub,
|
compare_model_stub,
|
||||||
compare_weights,
|
compare_weights,
|
||||||
prepare_model_outputs,
|
|
||||||
get_matching_activations,
|
get_matching_activations,
|
||||||
|
OutputLogger,
|
||||||
|
prepare_model_outputs,
|
||||||
|
Shadow,
|
||||||
|
ShadowLogger,
|
||||||
|
)
|
||||||
|
from torch.ao.quantization import (
|
||||||
|
convert,
|
||||||
|
default_qconfig,
|
||||||
|
DeQuantStub,
|
||||||
|
prepare,
|
||||||
|
quantize,
|
||||||
|
quantize_dynamic,
|
||||||
|
QuantStub,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_quantization import (
|
from torch.testing._internal.common_quantization import (
|
||||||
AnnotatedConvBnReLUModel,
|
AnnotatedConvBnReLUModel,
|
||||||
AnnotatedConvModel,
|
AnnotatedConvModel,
|
||||||
AnnotatedConvTransposeModel,
|
AnnotatedConvTransposeModel,
|
||||||
AnnotatedSingleLayerLinearModel,
|
AnnotatedSingleLayerLinearModel,
|
||||||
LSTMwithHiddenDynamicModel,
|
|
||||||
AnnotatedTwoLayerLinearModel,
|
AnnotatedTwoLayerLinearModel,
|
||||||
|
LSTMwithHiddenDynamicModel,
|
||||||
QuantizationTestCase,
|
QuantizationTestCase,
|
||||||
SingleLayerLinearDynamicModel,
|
SingleLayerLinearDynamicModel,
|
||||||
test_only_eval_fn,
|
|
||||||
skip_if_no_torchvision,
|
skip_if_no_torchvision,
|
||||||
|
test_only_eval_fn,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_quantized import override_qengines
|
from torch.testing._internal.common_quantized import override_qengines
|
||||||
from torch.testing._internal.common_utils import IS_ARM64
|
from torch.testing._internal.common_utils import IS_ARM64
|
||||||
|
|
||||||
|
|
||||||
class SubModule(torch.nn.Module):
|
class SubModule(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -200,10 +202,18 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||||
for i, val in enumerate(v["quantized"]):
|
for i, val in enumerate(v["quantized"]):
|
||||||
self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
|
self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
|
||||||
|
|
||||||
model_list = [AnnotatedConvModel(qengine),
|
model_list = [
|
||||||
AnnotatedConvTransposeModel("qnnpack"), # ConvT cannot use per channel weights
|
AnnotatedConvModel(qengine),
|
||||||
AnnotatedConvBnReLUModel(qengine)]
|
AnnotatedConvTransposeModel(
|
||||||
module_swap_list = [nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d, nn.ConvTranspose2d]
|
"qnnpack"
|
||||||
|
), # ConvT cannot use per channel weights
|
||||||
|
AnnotatedConvBnReLUModel(qengine),
|
||||||
|
]
|
||||||
|
module_swap_list = [
|
||||||
|
nn.Conv2d,
|
||||||
|
nn.intrinsic.modules.fused.ConvReLU2d,
|
||||||
|
nn.ConvTranspose2d,
|
||||||
|
]
|
||||||
for model in model_list:
|
for model in model_list:
|
||||||
model.eval()
|
model.eval()
|
||||||
if hasattr(model, "fuse_model"):
|
if hasattr(model, "fuse_model"):
|
||||||
|
|
@ -279,7 +289,6 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||||
self.assertTrue(isinstance(q_model.mod1, Shadow))
|
self.assertTrue(isinstance(q_model.mod1, Shadow))
|
||||||
self.assertFalse(isinstance(q_model.conv, Shadow))
|
self.assertFalse(isinstance(q_model.conv, Shadow))
|
||||||
|
|
||||||
|
|
||||||
@override_qengines
|
@override_qengines
|
||||||
def test_compare_model_stub_functional_static(self):
|
def test_compare_model_stub_functional_static(self):
|
||||||
r"""Compare the output of static quantized functional layer and its float shadow module"""
|
r"""Compare the output of static quantized functional layer and its float shadow module"""
|
||||||
|
|
@ -486,7 +495,9 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||||
for i, val in enumerate(v["quantized"]):
|
for i, val in enumerate(v["quantized"]):
|
||||||
self.assertTrue(len(v["float"][i]) == len(v["quantized"][i]))
|
self.assertTrue(len(v["float"][i]) == len(v["quantized"][i]))
|
||||||
if i == 0:
|
if i == 0:
|
||||||
self.assertTrue(v["float"][i][0].shape == v["quantized"][i][0].shape)
|
self.assertTrue(
|
||||||
|
v["float"][i][0].shape == v["quantized"][i][0].shape
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
v["float"][i][0].shape == v["quantized"][i][0].shape
|
v["float"][i][0].shape == v["quantized"][i][0].shape
|
||||||
|
|
@ -540,12 +551,23 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||||
|
|
||||||
@skip_if_no_torchvision
|
@skip_if_no_torchvision
|
||||||
def _test_vision_model(self, float_model):
|
def _test_vision_model(self, float_model):
|
||||||
float_model.to('cpu')
|
float_model.to("cpu")
|
||||||
float_model.eval()
|
float_model.eval()
|
||||||
float_model.fuse_model()
|
float_model.fuse_model()
|
||||||
float_model.qconfig = torch.ao.quantization.default_qconfig
|
float_model.qconfig = torch.ao.quantization.default_qconfig
|
||||||
img_data = [(torch.rand(2, 3, 224, 224, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)]
|
img_data = [
|
||||||
qmodel = quantize(float_model, torch.ao.quantization.default_eval_fn, [img_data], inplace=False)
|
(
|
||||||
|
torch.rand(2, 3, 224, 224, dtype=torch.float),
|
||||||
|
torch.randint(0, 1, (2,), dtype=torch.long),
|
||||||
|
)
|
||||||
|
for _ in range(2)
|
||||||
|
]
|
||||||
|
qmodel = quantize(
|
||||||
|
float_model,
|
||||||
|
torch.ao.quantization.default_eval_fn,
|
||||||
|
[img_data],
|
||||||
|
inplace=False,
|
||||||
|
)
|
||||||
|
|
||||||
wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
|
wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
|
||||||
|
|
||||||
|
|
@ -560,9 +582,11 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||||
# 'quantized', containing the activations of floating point and quantized model at matching locations.
|
# 'quantized', containing the activations of floating point and quantized model at matching locations.
|
||||||
act_compare_dict = compare_model_outputs(float_model, qmodel, data)
|
act_compare_dict = compare_model_outputs(float_model, qmodel, data)
|
||||||
|
|
||||||
|
|
||||||
for key in act_compare_dict:
|
for key in act_compare_dict:
|
||||||
compute_error(act_compare_dict[key]['float'][0], act_compare_dict[key]['quantized'][0].dequantize())
|
compute_error(
|
||||||
|
act_compare_dict[key]["float"][0],
|
||||||
|
act_compare_dict[key]["quantized"][0].dequantize(),
|
||||||
|
)
|
||||||
|
|
||||||
prepare_model_outputs(float_model, qmodel)
|
prepare_model_outputs(float_model, qmodel)
|
||||||
|
|
||||||
|
|
@ -579,10 +603,12 @@ class TestNumericSuiteEager(QuantizationTestCase):
|
||||||
@unittest.skipIf(IS_ARM64, "Not working on arm right now")
|
@unittest.skipIf(IS_ARM64, "Not working on arm right now")
|
||||||
def test_mobilenet_v2(self):
|
def test_mobilenet_v2(self):
|
||||||
from torchvision.models.quantization import mobilenet_v2
|
from torchvision.models.quantization import mobilenet_v2
|
||||||
|
|
||||||
self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False))
|
self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False))
|
||||||
|
|
||||||
@skip_if_no_torchvision
|
@skip_if_no_torchvision
|
||||||
@unittest.skipIf(IS_ARM64, "Not working on arm right now")
|
@unittest.skipIf(IS_ARM64, "Not working on arm right now")
|
||||||
def test_mobilenet_v3(self):
|
def test_mobilenet_v3(self):
|
||||||
from torchvision.models.quantization import mobilenet_v3_large
|
from torchvision.models.quantization import mobilenet_v3_large
|
||||||
|
|
||||||
self._test_vision_model(mobilenet_v3_large(pretrained=True, quantize=False))
|
self._test_vision_model(mobilenet_v3_large(pretrained=True, quantize=False))
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -3,6 +3,8 @@
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
from hypothesis import given, strategies as st
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.ao.nn.intrinsic.qat as nniqat
|
import torch.ao.nn.intrinsic.qat as nniqat
|
||||||
import torch.ao.nn.qat as nnqat
|
import torch.ao.nn.qat as nnqat
|
||||||
|
|
@ -12,8 +14,6 @@ import torch.ao.nn.quantized.dynamic as nnqd
|
||||||
import torch.backends.mkldnn
|
import torch.backends.mkldnn
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.testing._internal.hypothesis_utils as hu
|
import torch.testing._internal.hypothesis_utils as hu
|
||||||
|
|
||||||
from hypothesis import given, strategies as st
|
|
||||||
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
|
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
|
||||||
from torch.ao.quantization import (
|
from torch.ao.quantization import (
|
||||||
convert,
|
convert,
|
||||||
|
|
@ -50,42 +50,63 @@ from torch.testing._internal.common_quantization import (
|
||||||
test_only_train_fn,
|
test_only_train_fn,
|
||||||
TwoLayerLinearModel,
|
TwoLayerLinearModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.testing._internal.common_quantized import (
|
from torch.testing._internal.common_quantized import (
|
||||||
override_qengines,
|
override_qengines,
|
||||||
override_quantized_engine,
|
override_quantized_engine,
|
||||||
supported_qengines,
|
supported_qengines,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import skipIfNoXNNPACK
|
from torch.testing._internal.common_utils import skipIfNoXNNPACK
|
||||||
|
|
||||||
|
|
||||||
hu.assert_deadline_disabled()
|
hu.assert_deadline_disabled()
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
|
|
||||||
class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||||
"""
|
"""
|
||||||
Conv-BN fusion implemented with explicit folding. Useful
|
Conv-BN fusion implemented with explicit folding. Useful
|
||||||
to verify numerical equivalency with non-folded version.
|
to verify numerical equivalency with non-folded version.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
# ConvNd args
|
def __init__(
|
||||||
in_channels, out_channels, kernel_size, stride,
|
self,
|
||||||
padding, dilation, transposed, output_padding,
|
# ConvNd args
|
||||||
groups,
|
in_channels,
|
||||||
bias,
|
out_channels,
|
||||||
padding_mode,
|
kernel_size,
|
||||||
# BatchNormNd args
|
stride,
|
||||||
# num_features: out_channels
|
padding,
|
||||||
eps=1e-05, momentum=0.1,
|
dilation,
|
||||||
# affine: True
|
transposed,
|
||||||
# track_running_stats: True
|
output_padding,
|
||||||
# Args for this module
|
groups,
|
||||||
freeze_bn=False,
|
bias,
|
||||||
qconfig=None):
|
padding_mode,
|
||||||
nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
|
# BatchNormNd args
|
||||||
stride, padding, dilation, transposed,
|
# num_features: out_channels
|
||||||
output_padding, groups, False, padding_mode)
|
eps=1e-05,
|
||||||
assert qconfig, 'qconfig must be provided for QAT module'
|
momentum=0.1,
|
||||||
|
# affine: True
|
||||||
|
# track_running_stats: True
|
||||||
|
# Args for this module
|
||||||
|
freeze_bn=False,
|
||||||
|
qconfig=None,
|
||||||
|
):
|
||||||
|
nn.modules.conv._ConvNd.__init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
transposed,
|
||||||
|
output_padding,
|
||||||
|
groups,
|
||||||
|
False,
|
||||||
|
padding_mode,
|
||||||
|
)
|
||||||
|
assert qconfig, "qconfig must be provided for QAT module"
|
||||||
self.qconfig = qconfig
|
self.qconfig = qconfig
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.momentum = momentum
|
self.momentum = momentum
|
||||||
|
|
@ -103,7 +124,7 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = nn.Parameter(torch.empty(out_channels))
|
self.bias = nn.Parameter(torch.empty(out_channels))
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter("bias", None)
|
||||||
self.reset_bn_parameters()
|
self.reset_bn_parameters()
|
||||||
|
|
||||||
def reset_running_stats(self):
|
def reset_running_stats(self):
|
||||||
|
|
@ -123,7 +144,7 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
super().reset_parameters()
|
super().reset_parameters()
|
||||||
# A hack to avoid resetting on undefined parameters
|
# A hack to avoid resetting on undefined parameters
|
||||||
if hasattr(self, 'gamma'):
|
if hasattr(self, "gamma"):
|
||||||
self.reset_bn_parameters()
|
self.reset_bn_parameters()
|
||||||
|
|
||||||
def update_bn_stats(self):
|
def update_bn_stats(self):
|
||||||
|
|
@ -161,33 +182,50 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
|
zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
|
||||||
else:
|
else:
|
||||||
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype)
|
zero_bias = torch.zeros(
|
||||||
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias)
|
self.out_channels, device=scaled_weight.device, dtype=input.dtype
|
||||||
|
)
|
||||||
|
conv = self._conv_forward(
|
||||||
|
input, self.weight_fake_quant(scaled_weight), zero_bias
|
||||||
|
)
|
||||||
|
|
||||||
if self.training and not self.freeze_bn:
|
if self.training and not self.freeze_bn:
|
||||||
# recovering original conv to get original batch_mean and batch_var
|
# recovering original conv to get original batch_mean and batch_var
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1])
|
conv_orig = conv / scale_factor.reshape(
|
||||||
|
[1, -1, 1, 1]
|
||||||
|
) + self.bias.reshape([1, -1, 1, 1])
|
||||||
else:
|
else:
|
||||||
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
|
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
|
||||||
batch_mean = torch.mean(conv_orig, dim=[0, 2, 3])
|
batch_mean = torch.mean(conv_orig, dim=[0, 2, 3])
|
||||||
batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False)
|
batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False)
|
||||||
n = float(conv_orig.numel() / conv_orig.size()[1])
|
n = float(conv_orig.numel() / conv_orig.size()[1])
|
||||||
unbiased_batch_var = batch_var * (n / (n - 1))
|
unbiased_batch_var = batch_var * (n / (n - 1))
|
||||||
batch_rstd = torch.ones_like(batch_var, memory_format=torch.contiguous_format) / torch.sqrt(batch_var + self.eps)
|
batch_rstd = torch.ones_like(
|
||||||
|
batch_var, memory_format=torch.contiguous_format
|
||||||
|
) / torch.sqrt(batch_var + self.eps)
|
||||||
|
|
||||||
conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + \
|
conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + (
|
||||||
(self.beta - self.gamma * batch_rstd * batch_mean).reshape([1, -1, 1, 1])
|
self.beta - self.gamma * batch_rstd * batch_mean
|
||||||
self.running_mean = exponential_average_factor * batch_mean.detach() + \
|
).reshape([1, -1, 1, 1])
|
||||||
(1 - exponential_average_factor) * self.running_mean
|
self.running_mean = (
|
||||||
self.running_var = exponential_average_factor * unbiased_batch_var.detach() + \
|
exponential_average_factor * batch_mean.detach()
|
||||||
(1 - exponential_average_factor) * self.running_var
|
+ (1 - exponential_average_factor) * self.running_mean
|
||||||
|
)
|
||||||
|
self.running_var = (
|
||||||
|
exponential_average_factor * unbiased_batch_var.detach()
|
||||||
|
+ (1 - exponential_average_factor) * self.running_var
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if self.bias is None:
|
if self.bias is None:
|
||||||
conv = conv + (self.beta - self.gamma * self.running_mean /
|
conv = conv + (
|
||||||
running_std).reshape([1, -1, 1, 1])
|
self.beta - self.gamma * self.running_mean / running_std
|
||||||
|
).reshape([1, -1, 1, 1])
|
||||||
else:
|
else:
|
||||||
conv = conv + (self.gamma * (self.bias - self.running_mean) / running_std + self.beta).reshape([1, -1, 1, 1])
|
conv = conv + (
|
||||||
|
self.gamma * (self.bias - self.running_mean) / running_std
|
||||||
|
+ self.beta
|
||||||
|
).reshape([1, -1, 1, 1])
|
||||||
return conv
|
return conv
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
|
|
@ -200,23 +238,37 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod, qconfig=None):
|
def from_float(cls, mod, qconfig=None):
|
||||||
r"""Create a qat module from a float module or qparams_dict
|
r"""Create a qat module from a float module or qparams_dict
|
||||||
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
||||||
or directly from user
|
or directly from user
|
||||||
"""
|
"""
|
||||||
assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \
|
assert type(mod) == cls._FLOAT_MODULE, (
|
||||||
cls._FLOAT_MODULE.__name__
|
"qat."
|
||||||
|
+ cls.__name__
|
||||||
|
+ ".from_float only works for "
|
||||||
|
+ cls._FLOAT_MODULE.__name__
|
||||||
|
)
|
||||||
if not qconfig:
|
if not qconfig:
|
||||||
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
assert hasattr(
|
||||||
assert mod.qconfig, 'Input float module must have a valid qconfig'
|
mod, "qconfig"
|
||||||
|
), "Input float module must have qconfig defined"
|
||||||
|
assert mod.qconfig, "Input float module must have a valid qconfig"
|
||||||
qconfig = mod.qconfig
|
qconfig = mod.qconfig
|
||||||
conv, bn = mod[0], mod[1]
|
conv, bn = mod[0], mod[1]
|
||||||
qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size,
|
qat_convbn = cls(
|
||||||
conv.stride, conv.padding, conv.dilation,
|
conv.in_channels,
|
||||||
conv.groups, conv.bias is not None,
|
conv.out_channels,
|
||||||
conv.padding_mode,
|
conv.kernel_size,
|
||||||
bn.eps, bn.momentum,
|
conv.stride,
|
||||||
False,
|
conv.padding,
|
||||||
qconfig)
|
conv.dilation,
|
||||||
|
conv.groups,
|
||||||
|
conv.bias is not None,
|
||||||
|
conv.padding_mode,
|
||||||
|
bn.eps,
|
||||||
|
bn.momentum,
|
||||||
|
False,
|
||||||
|
qconfig,
|
||||||
|
)
|
||||||
qat_convbn.weight = conv.weight
|
qat_convbn.weight = conv.weight
|
||||||
qat_convbn.bias = conv.bias
|
qat_convbn.bias = conv.bias
|
||||||
qat_convbn.gamma = bn.weight
|
qat_convbn.gamma = bn.weight
|
||||||
|
|
@ -226,41 +278,69 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
|
||||||
qat_convbn.num_batches_tracked = bn.num_batches_tracked
|
qat_convbn.num_batches_tracked = bn.num_batches_tracked
|
||||||
return qat_convbn
|
return qat_convbn
|
||||||
|
|
||||||
|
|
||||||
class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d):
|
class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d):
|
||||||
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvBn2d
|
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvBn2d
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
# ConvNd args
|
self,
|
||||||
in_channels, out_channels, kernel_size, stride=1,
|
# ConvNd args
|
||||||
padding=0, dilation=1, groups=1,
|
in_channels,
|
||||||
bias=None,
|
out_channels,
|
||||||
padding_mode='zeros',
|
kernel_size,
|
||||||
# BatchNorm2d args
|
stride=1,
|
||||||
# num_features: out_channels
|
padding=0,
|
||||||
eps=1e-05, momentum=0.1,
|
dilation=1,
|
||||||
# affine: True
|
groups=1,
|
||||||
# track_running_stats: True
|
bias=None,
|
||||||
# Args for this module
|
padding_mode="zeros",
|
||||||
freeze_bn=False,
|
# BatchNorm2d args
|
||||||
qconfig=None):
|
# num_features: out_channels
|
||||||
|
eps=1e-05,
|
||||||
|
momentum=0.1,
|
||||||
|
# affine: True
|
||||||
|
# track_running_stats: True
|
||||||
|
# Args for this module
|
||||||
|
freeze_bn=False,
|
||||||
|
qconfig=None,
|
||||||
|
):
|
||||||
kernel_size = _pair(kernel_size)
|
kernel_size = _pair(kernel_size)
|
||||||
stride = _pair(stride)
|
stride = _pair(stride)
|
||||||
padding = _pair(padding)
|
padding = _pair(padding)
|
||||||
dilation = _pair(dilation)
|
dilation = _pair(dilation)
|
||||||
_ReferenceConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride,
|
_ReferenceConvBnNd.__init__(
|
||||||
padding, dilation, False, _pair(0), groups, bias, padding_mode,
|
self,
|
||||||
eps, momentum, freeze_bn, qconfig)
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
False,
|
||||||
|
_pair(0),
|
||||||
|
groups,
|
||||||
|
bias,
|
||||||
|
padding_mode,
|
||||||
|
eps,
|
||||||
|
momentum,
|
||||||
|
freeze_bn,
|
||||||
|
qconfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizeEagerQAT(QuantizationTestCase):
|
class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
self.embed_linear_data_train = [[torch.randint(0, 10, (12, 12), dtype=torch.long),
|
self.embed_linear_data_train = [
|
||||||
torch.randn((12, 1), dtype=torch.float)]
|
[
|
||||||
for _ in range(2)]
|
torch.randint(0, 10, (12, 12), dtype=torch.long),
|
||||||
|
torch.randn((12, 1), dtype=torch.float),
|
||||||
|
]
|
||||||
|
for _ in range(2)
|
||||||
|
]
|
||||||
self.embed_data = [[torch.randint(0, 10, (12, 1))]]
|
self.embed_data = [[torch.randint(0, 10, (12, 1))]]
|
||||||
|
|
||||||
|
|
||||||
def test_manual(self):
|
def test_manual(self):
|
||||||
for qengine in supported_qengines:
|
for qengine in supported_qengines:
|
||||||
with override_quantized_engine(qengine):
|
with override_quantized_engine(qengine):
|
||||||
|
|
@ -279,8 +359,9 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
|
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn,
|
model = quantize_qat(
|
||||||
[self.train_data])
|
ManualLinearQATModel(qengine), test_only_train_fn, [self.train_data]
|
||||||
|
)
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
def test_dropout(self):
|
def test_dropout(self):
|
||||||
|
|
@ -301,8 +382,11 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
|
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
model = quantize_qat(ManualDropoutQATModel(qengine), test_only_train_fn,
|
model = quantize_qat(
|
||||||
[self.train_data])
|
ManualDropoutQATModel(qengine),
|
||||||
|
test_only_train_fn,
|
||||||
|
[self.train_data],
|
||||||
|
)
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
def test_eval_only_fake_quant(self):
|
def test_eval_only_fake_quant(self):
|
||||||
|
|
@ -342,7 +426,9 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
model = ManualConvLinearQATModel()
|
model = ManualConvLinearQATModel()
|
||||||
model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
|
model = quantize_qat(
|
||||||
|
model, test_only_train_fn, [self.img_data_2d_train]
|
||||||
|
)
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
@skipIfNoXNNPACK
|
@skipIfNoXNNPACK
|
||||||
|
|
@ -351,7 +437,7 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
Supported only with qengine=qnnpack, which uses symmetric
|
Supported only with qengine=qnnpack, which uses symmetric
|
||||||
kernels from xnnpack library."""
|
kernels from xnnpack library."""
|
||||||
for qengine in supported_qengines:
|
for qengine in supported_qengines:
|
||||||
if qengine != 'qnnpack':
|
if qengine != "qnnpack":
|
||||||
continue
|
continue
|
||||||
with override_quantized_engine(qengine):
|
with override_quantized_engine(qengine):
|
||||||
model = ManualConvLinearSymmQATModel()
|
model = ManualConvLinearSymmQATModel()
|
||||||
|
|
@ -373,17 +459,20 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
model = ManualConvLinearSymmQATModel()
|
model = ManualConvLinearSymmQATModel()
|
||||||
model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train])
|
model = quantize_qat(
|
||||||
|
model, test_only_train_fn, [self.img_data_2d_train]
|
||||||
|
)
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
def test_dynamic_qat_linear(self):
|
def test_dynamic_qat_linear(self):
|
||||||
for qengine in supported_qengines:
|
for qengine in supported_qengines:
|
||||||
with override_quantized_engine(qengine):
|
with override_quantized_engine(qengine):
|
||||||
# Dynamic QAT without memoryless observers should fail
|
# Dynamic QAT without memoryless observers should fail
|
||||||
with self.assertRaisesRegex(ValueError,
|
with self.assertRaisesRegex(
|
||||||
"Dynamic QAT requires a memoryless observer." +
|
ValueError,
|
||||||
"This means a MovingAverage observer with averaging constant equal to 1"
|
"Dynamic QAT requires a memoryless observer."
|
||||||
):
|
+ "This means a MovingAverage observer with averaging constant equal to 1",
|
||||||
|
):
|
||||||
model = ManualLinearDynamicQATModel(default_qat_qconfig)
|
model = ManualLinearDynamicQATModel(default_qat_qconfig)
|
||||||
model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})
|
model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})
|
||||||
|
|
||||||
|
|
@ -409,14 +498,23 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
|
|
||||||
test_only_train_fn(model, self.embed_linear_data_train)
|
test_only_train_fn(model, self.embed_linear_data_train)
|
||||||
# make sure activation_post_process is inserted after Linear.
|
# make sure activation_post_process is inserted after Linear.
|
||||||
self.assertEqual(type(model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize)
|
self.assertEqual(
|
||||||
|
type(model.linear.activation_post_process),
|
||||||
|
FusedMovingAvgObsFakeQuantize,
|
||||||
|
)
|
||||||
# make sure that Embedding has a noop for activation.
|
# make sure that Embedding has a noop for activation.
|
||||||
self.assertEqual(type(model.emb.activation_post_process), NoopObserver)
|
self.assertEqual(type(model.emb.activation_post_process), NoopObserver)
|
||||||
# make sure that FakeQuant zero_points are correct dtype
|
# make sure that FakeQuant zero_points are correct dtype
|
||||||
self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
|
self.assertEqual(
|
||||||
self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32)
|
model.emb.weight_fake_quant.zero_point.dtype, torch.float32
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
model.linear.weight_fake_quant.zero_point.dtype, torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
model = convert(model, mapping=get_embedding_static_quant_module_mappings())
|
model = convert(
|
||||||
|
model, mapping=get_embedding_static_quant_module_mappings()
|
||||||
|
)
|
||||||
|
|
||||||
def checkQuantized(model):
|
def checkQuantized(model):
|
||||||
# make sure Embedding is now a QuantizedEmbedding
|
# make sure Embedding is now a QuantizedEmbedding
|
||||||
|
|
@ -430,7 +528,6 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
|
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_bag_linear(self):
|
def test_embedding_bag_linear(self):
|
||||||
for qengine in supported_qengines:
|
for qengine in supported_qengines:
|
||||||
with override_quantized_engine(qengine):
|
with override_quantized_engine(qengine):
|
||||||
|
|
@ -442,9 +539,15 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
# make sure not activation_post_process is inserted for EmbeddingBag
|
# make sure not activation_post_process is inserted for EmbeddingBag
|
||||||
self.assertFalse(hasattr(model, "activation_post_process"))
|
self.assertFalse(hasattr(model, "activation_post_process"))
|
||||||
# make sure that FakeQuant zero_points are correct dtype
|
# make sure that FakeQuant zero_points are correct dtype
|
||||||
self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
|
self.assertEqual(
|
||||||
self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32)
|
model.emb.weight_fake_quant.zero_point.dtype, torch.float32
|
||||||
model = convert(model, mapping=get_embedding_static_quant_module_mappings())
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
model.linear.weight_fake_quant.zero_point.dtype, torch.int32
|
||||||
|
)
|
||||||
|
model = convert(
|
||||||
|
model, mapping=get_embedding_static_quant_module_mappings()
|
||||||
|
)
|
||||||
|
|
||||||
def checkQuantized(model):
|
def checkQuantized(model):
|
||||||
# Make sure EmbeddingBag is now a quantized EmbeddingBag.
|
# Make sure EmbeddingBag is now a quantized EmbeddingBag.
|
||||||
|
|
@ -505,7 +608,9 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
||||||
torch.ao.quantization.prepare(model, inplace=True)
|
torch.ao.quantization.prepare(model, inplace=True)
|
||||||
torch.ao.quantization.convert(model, inplace=True)
|
torch.ao.quantization.convert(model, inplace=True)
|
||||||
self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys()))
|
self.assertEqual(
|
||||||
|
set(model.state_dict().keys()), set(quant_state_dict.keys())
|
||||||
|
)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.load_state_dict(quant_state_dict)
|
model.load_state_dict(quant_state_dict)
|
||||||
out = model(x)
|
out = model(x)
|
||||||
|
|
@ -513,20 +618,19 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
|
|
||||||
@override_qengines
|
@override_qengines
|
||||||
def test_forward_hooks_preserved(self):
|
def test_forward_hooks_preserved(self):
|
||||||
r"""Test QAT on preserving pre forward and post forward hooks of original model
|
r"""Test QAT on preserving pre forward and post forward hooks of original model"""
|
||||||
"""
|
|
||||||
qengine = torch.backends.quantized.engine
|
qengine = torch.backends.quantized.engine
|
||||||
model = QuantStubModel()
|
model = QuantStubModel()
|
||||||
counter = {
|
counter = {
|
||||||
'pre_forwards': 0,
|
"pre_forwards": 0,
|
||||||
'forwards': 0,
|
"forwards": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
def fw_pre_hook(h_module, input):
|
def fw_pre_hook(h_module, input):
|
||||||
counter['pre_forwards'] += 1
|
counter["pre_forwards"] += 1
|
||||||
|
|
||||||
def fw_hook(h_module, input, output):
|
def fw_hook(h_module, input, output):
|
||||||
counter['forwards'] += 1
|
counter["forwards"] += 1
|
||||||
|
|
||||||
model.fc.register_forward_pre_hook(fw_pre_hook)
|
model.fc.register_forward_pre_hook(fw_pre_hook)
|
||||||
model.fc.register_forward_hook(fw_hook)
|
model.fc.register_forward_hook(fw_hook)
|
||||||
|
|
@ -537,15 +641,24 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
def checkHooksIsPresent(model, before_convert=True):
|
def checkHooksIsPresent(model, before_convert=True):
|
||||||
forward_hooks = 1
|
forward_hooks = 1
|
||||||
if before_convert:
|
if before_convert:
|
||||||
self.assertEqual(len(model.quant._forward_hooks.values()), 1,
|
self.assertEqual(
|
||||||
"Quantization observer hook has disappeared")
|
len(model.quant._forward_hooks.values()),
|
||||||
|
1,
|
||||||
|
"Quantization observer hook has disappeared",
|
||||||
|
)
|
||||||
forward_hooks = 2
|
forward_hooks = 2
|
||||||
self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values())
|
self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values())
|
||||||
self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
|
self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
|
||||||
self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1,
|
self.assertEqual(
|
||||||
"Extra pre forward hooks have appeared on a layer")
|
len(model.fc._forward_pre_hooks.values()),
|
||||||
self.assertEqual(len(model.fc._forward_hooks.values()), forward_hooks,
|
1,
|
||||||
"Extra post forward hooks have appeared on a layer")
|
"Extra pre forward hooks have appeared on a layer",
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
len(model.fc._forward_hooks.values()),
|
||||||
|
forward_hooks,
|
||||||
|
"Extra post forward hooks have appeared on a layer",
|
||||||
|
)
|
||||||
|
|
||||||
checkHooksIsPresent(model, True)
|
checkHooksIsPresent(model, True)
|
||||||
x = torch.rand(2, 5, dtype=torch.float)
|
x = torch.rand(2, 5, dtype=torch.float)
|
||||||
|
|
@ -600,32 +713,40 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine)
|
default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine)
|
||||||
|
|
||||||
# Test constructor parameters checks here.
|
# Test constructor parameters checks here.
|
||||||
with self.assertRaisesRegex(AssertionError,
|
with self.assertRaisesRegex(
|
||||||
"qconfig must be provided for QAT module"):
|
AssertionError, "qconfig must be provided for QAT module"
|
||||||
|
):
|
||||||
nnqat.EmbeddingBag(10, 5, qconfig=None)
|
nnqat.EmbeddingBag(10, 5, qconfig=None)
|
||||||
|
|
||||||
with self.assertRaisesRegex(AssertionError,
|
with self.assertRaisesRegex(
|
||||||
"Embedding Bag weights requires a qscheme of " +
|
AssertionError,
|
||||||
"torch.per_channel_affine_float_qparams"):
|
"Embedding Bag weights requires a qscheme of "
|
||||||
|
+ "torch.per_channel_affine_float_qparams",
|
||||||
|
):
|
||||||
nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig)
|
nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig)
|
||||||
|
|
||||||
# Test from_float checks here.
|
# Test from_float checks here.
|
||||||
embed = nn.Embedding(10, 5)
|
embed = nn.Embedding(10, 5)
|
||||||
with self.assertRaisesRegex(AssertionError,
|
with self.assertRaisesRegex(
|
||||||
"qat.EmbeddingBag.from_float only works for EmbeddingBag"):
|
AssertionError, "qat.EmbeddingBag.from_float only works for EmbeddingBag"
|
||||||
|
):
|
||||||
nnqat.EmbeddingBag.from_float(embed)
|
nnqat.EmbeddingBag.from_float(embed)
|
||||||
embed_bag = nn.EmbeddingBag(10, 5)
|
embed_bag = nn.EmbeddingBag(10, 5)
|
||||||
with self.assertRaisesRegex(AssertionError,
|
with self.assertRaisesRegex(
|
||||||
"Input float module must have qconfig defined"):
|
AssertionError, "Input float module must have qconfig defined"
|
||||||
|
):
|
||||||
nnqat.EmbeddingBag.from_float(embed_bag)
|
nnqat.EmbeddingBag.from_float(embed_bag)
|
||||||
embed_bag.qconfig = None
|
embed_bag.qconfig = None
|
||||||
with self.assertRaisesRegex(AssertionError,
|
with self.assertRaisesRegex(
|
||||||
"Input float module must have a valid qconfig"):
|
AssertionError, "Input float module must have a valid qconfig"
|
||||||
|
):
|
||||||
nnqat.EmbeddingBag.from_float(embed_bag)
|
nnqat.EmbeddingBag.from_float(embed_bag)
|
||||||
embed_bag.qconfig = default_qat_qconfig
|
embed_bag.qconfig = default_qat_qconfig
|
||||||
with self.assertRaisesRegex(AssertionError,
|
with self.assertRaisesRegex(
|
||||||
"Embedding Bag weights requires a qscheme of " +
|
AssertionError,
|
||||||
"torch.per_channel_affine_float_qparams"):
|
"Embedding Bag weights requires a qscheme of "
|
||||||
|
+ "torch.per_channel_affine_float_qparams",
|
||||||
|
):
|
||||||
nnqat.EmbeddingBag.from_float(embed_bag)
|
nnqat.EmbeddingBag.from_float(embed_bag)
|
||||||
|
|
||||||
def test_embedding_qat_qconfig_equal(self):
|
def test_embedding_qat_qconfig_equal(self):
|
||||||
|
|
@ -636,8 +757,10 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
|
||||||
model = ManualEmbeddingBagLinear().train()
|
model = ManualEmbeddingBagLinear().train()
|
||||||
model = prepare_qat(model)
|
model = prepare_qat(model)
|
||||||
|
|
||||||
self.assertTrue(qconfig_equals(model.emb.qconfig,
|
self.assertTrue(
|
||||||
default_embedding_qat_qconfig))
|
qconfig_equals(model.emb.qconfig, default_embedding_qat_qconfig)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
def _test_activation_convert_numerics_impl(self, Act, data):
|
def _test_activation_convert_numerics_impl(self, Act, data):
|
||||||
|
|
@ -683,24 +806,26 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
m = M().train()
|
m = M().train()
|
||||||
m.qconfig = default_qat_qconfig
|
m.qconfig = default_qat_qconfig
|
||||||
m = prepare_qat(m)
|
m = prepare_qat(m)
|
||||||
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
|
for attr in ["sigmoid", "hardsigmoid", "tanh"]:
|
||||||
self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize)
|
self.assertEqual(
|
||||||
|
type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize
|
||||||
|
)
|
||||||
data = torch.randn(1, 3, 2, 4)
|
data = torch.randn(1, 3, 2, 4)
|
||||||
before_convert = m(data)
|
before_convert = m(data)
|
||||||
m = convert(m)
|
m = convert(m)
|
||||||
after_convert = m(data)
|
after_convert = m(data)
|
||||||
self.assertEqual(before_convert, after_convert)
|
self.assertEqual(before_convert, after_convert)
|
||||||
# make sure activation post process is removed
|
# make sure activation post process is removed
|
||||||
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
|
for attr in ["sigmoid", "hardsigmoid", "tanh"]:
|
||||||
# verify fake quant module is removd
|
# verify fake quant module is removd
|
||||||
self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process'))
|
self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
|
||||||
# verify that hooks are removed
|
# verify that hooks are removed
|
||||||
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
|
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
|
||||||
|
|
||||||
# make sure no fake quantize module is inserted for eval mode
|
# make sure no fake quantize module is inserted for eval mode
|
||||||
|
|
||||||
def checkNoFQModule(m):
|
def checkNoFQModule(m):
|
||||||
for attr in ['sigmoid', 'hardsigmoid', 'tanh']:
|
for attr in ["sigmoid", "hardsigmoid", "tanh"]:
|
||||||
self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
|
self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
|
||||||
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
|
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
|
||||||
|
|
||||||
|
|
@ -734,50 +859,52 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
# make sure ReLU module is not changed
|
# make sure ReLU module is not changed
|
||||||
self.assertTrue(type(m.relu), nn.ReLU)
|
self.assertTrue(type(m.relu), nn.ReLU)
|
||||||
|
|
||||||
@given(batch_size=st.integers(2, 4),
|
@given(
|
||||||
input_channels_per_group=st.sampled_from([2, 3, 4]),
|
batch_size=st.integers(2, 4),
|
||||||
height=st.integers(5, 10),
|
input_channels_per_group=st.sampled_from([2, 3, 4]),
|
||||||
width=st.integers(5, 10),
|
height=st.integers(5, 10),
|
||||||
output_channels_per_group=st.sampled_from([2, 3]),
|
width=st.integers(5, 10),
|
||||||
groups=st.integers(1, 3),
|
output_channels_per_group=st.sampled_from([2, 3]),
|
||||||
kernel_h=st.integers(1, 3),
|
groups=st.integers(1, 3),
|
||||||
kernel_w=st.integers(1, 3),
|
kernel_h=st.integers(1, 3),
|
||||||
stride_h=st.integers(1, 2),
|
kernel_w=st.integers(1, 3),
|
||||||
stride_w=st.integers(1, 2),
|
stride_h=st.integers(1, 2),
|
||||||
pad_h=st.integers(0, 2),
|
stride_w=st.integers(1, 2),
|
||||||
pad_w=st.integers(0, 2),
|
pad_h=st.integers(0, 2),
|
||||||
dilation=st.integers(1, 1),
|
pad_w=st.integers(0, 2),
|
||||||
padding_mode=st.sampled_from(['zeros', 'circular']),
|
dilation=st.integers(1, 1),
|
||||||
use_relu=st.booleans(),
|
padding_mode=st.sampled_from(["zeros", "circular"]),
|
||||||
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
|
use_relu=st.booleans(),
|
||||||
momentum=st.sampled_from([0.1, 0.2, 0.3]),
|
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
|
||||||
freeze_bn=st.booleans(),
|
momentum=st.sampled_from([0.1, 0.2, 0.3]),
|
||||||
zero_gamma=st.booleans(),
|
freeze_bn=st.booleans(),
|
||||||
has_bias=st.booleans(),
|
zero_gamma=st.booleans(),
|
||||||
use_slow_fusion=st.booleans())
|
has_bias=st.booleans(),
|
||||||
|
use_slow_fusion=st.booleans(),
|
||||||
|
)
|
||||||
def test_conv_bn_relu(
|
def test_conv_bn_relu(
|
||||||
self,
|
self,
|
||||||
batch_size,
|
batch_size,
|
||||||
input_channels_per_group,
|
input_channels_per_group,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
output_channels_per_group,
|
output_channels_per_group,
|
||||||
groups,
|
groups,
|
||||||
kernel_h,
|
kernel_h,
|
||||||
kernel_w,
|
kernel_w,
|
||||||
stride_h,
|
stride_h,
|
||||||
stride_w,
|
stride_w,
|
||||||
pad_h,
|
pad_h,
|
||||||
pad_w,
|
pad_w,
|
||||||
dilation,
|
dilation,
|
||||||
padding_mode,
|
padding_mode,
|
||||||
use_relu,
|
use_relu,
|
||||||
eps,
|
eps,
|
||||||
momentum,
|
momentum,
|
||||||
freeze_bn,
|
freeze_bn,
|
||||||
zero_gamma,
|
zero_gamma,
|
||||||
has_bias,
|
has_bias,
|
||||||
use_slow_fusion,
|
use_slow_fusion,
|
||||||
):
|
):
|
||||||
input_channels = input_channels_per_group * groups
|
input_channels = input_channels_per_group * groups
|
||||||
output_channels = output_channels_per_group * groups
|
output_channels = output_channels_per_group * groups
|
||||||
|
|
@ -792,7 +919,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
(dilation_h, dilation_w),
|
(dilation_h, dilation_w),
|
||||||
groups,
|
groups,
|
||||||
has_bias,
|
has_bias,
|
||||||
padding_mode
|
padding_mode,
|
||||||
).to(dtype=torch.double)
|
).to(dtype=torch.double)
|
||||||
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double)
|
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double)
|
||||||
relu_op = ReLU()
|
relu_op = ReLU()
|
||||||
|
|
@ -811,7 +938,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
eps,
|
eps,
|
||||||
momentum,
|
momentum,
|
||||||
freeze_bn=True,
|
freeze_bn=True,
|
||||||
qconfig=default_qat_qconfig
|
qconfig=default_qat_qconfig,
|
||||||
).to(dtype=torch.double)
|
).to(dtype=torch.double)
|
||||||
qat_op._enable_slow_path_for_better_numerical_stability = use_slow_fusion
|
qat_op._enable_slow_path_for_better_numerical_stability = use_slow_fusion
|
||||||
|
|
||||||
|
|
@ -826,7 +953,14 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats)
|
qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats)
|
||||||
|
|
||||||
# align inputs and internal parameters
|
# align inputs and internal parameters
|
||||||
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
|
input = torch.randn(
|
||||||
|
batch_size,
|
||||||
|
input_channels,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype=torch.double,
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
|
conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
|
||||||
if has_bias:
|
if has_bias:
|
||||||
conv_op.bias = torch.nn.Parameter(qat_op.bias.detach())
|
conv_op.bias = torch.nn.Parameter(qat_op.bias.detach())
|
||||||
|
|
@ -840,17 +974,20 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
|
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
|
||||||
|
|
||||||
if not use_relu:
|
if not use_relu:
|
||||||
|
|
||||||
def relu_op(x): # noqa: F811
|
def relu_op(x): # noqa: F811
|
||||||
return x
|
return x
|
||||||
|
|
||||||
if freeze_bn:
|
if freeze_bn:
|
||||||
|
|
||||||
def ref_op(x):
|
def ref_op(x):
|
||||||
x = conv_op(x)
|
x = conv_op(x)
|
||||||
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
|
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * (
|
||||||
(bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
|
bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)
|
||||||
.reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
|
).reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
|
||||||
x = relu_op(x)
|
x = relu_op(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
else:
|
else:
|
||||||
ref_op = compose([conv_op, bn_op, relu_op])
|
ref_op = compose([conv_op, bn_op, relu_op])
|
||||||
|
|
||||||
|
|
@ -882,51 +1019,64 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
num_batches_tracked_actual = qat_op.bn.num_batches_tracked
|
num_batches_tracked_actual = qat_op.bn.num_batches_tracked
|
||||||
precision = 1e-10
|
precision = 1e-10
|
||||||
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
|
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
|
||||||
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
|
self.assertEqual(
|
||||||
|
weight_grad_ref, weight_grad_actual, atol=precision, rtol=0
|
||||||
|
)
|
||||||
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
|
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
|
||||||
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
|
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
|
||||||
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
|
self.assertEqual(
|
||||||
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
|
num_batches_tracked_ref,
|
||||||
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
|
num_batches_tracked_actual,
|
||||||
|
atol=precision,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
running_mean_ref, running_mean_actual, atol=precision, rtol=0
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
running_var_ref, running_var_actual, atol=precision, rtol=0
|
||||||
|
)
|
||||||
|
|
||||||
@given(batch_size=st.integers(2, 4),
|
@given(
|
||||||
input_channels_per_group=st.sampled_from([2, 3, 4]),
|
batch_size=st.integers(2, 4),
|
||||||
height=st.integers(5, 10),
|
input_channels_per_group=st.sampled_from([2, 3, 4]),
|
||||||
width=st.integers(5, 10),
|
height=st.integers(5, 10),
|
||||||
output_channels_per_group=st.sampled_from([2, 3]),
|
width=st.integers(5, 10),
|
||||||
groups=st.integers(1, 3),
|
output_channels_per_group=st.sampled_from([2, 3]),
|
||||||
kernel_h=st.integers(1, 3),
|
groups=st.integers(1, 3),
|
||||||
kernel_w=st.integers(1, 3),
|
kernel_h=st.integers(1, 3),
|
||||||
stride_h=st.integers(1, 2),
|
kernel_w=st.integers(1, 3),
|
||||||
stride_w=st.integers(1, 2),
|
stride_h=st.integers(1, 2),
|
||||||
pad_h=st.integers(0, 2),
|
stride_w=st.integers(1, 2),
|
||||||
pad_w=st.integers(0, 2),
|
pad_h=st.integers(0, 2),
|
||||||
dilation=st.integers(1, 1),
|
pad_w=st.integers(0, 2),
|
||||||
padding_mode=st.sampled_from(['zeros', 'circular']),
|
dilation=st.integers(1, 1),
|
||||||
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
|
padding_mode=st.sampled_from(["zeros", "circular"]),
|
||||||
momentum=st.sampled_from([0.1, 0.2, 0.3]),
|
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
|
||||||
freeze_bn=st.booleans(),
|
momentum=st.sampled_from([0.1, 0.2, 0.3]),
|
||||||
bias=st.booleans())
|
freeze_bn=st.booleans(),
|
||||||
|
bias=st.booleans(),
|
||||||
|
)
|
||||||
def test_conv_bn_folded_vs_unfolded(
|
def test_conv_bn_folded_vs_unfolded(
|
||||||
self,
|
self,
|
||||||
batch_size,
|
batch_size,
|
||||||
input_channels_per_group,
|
input_channels_per_group,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
output_channels_per_group,
|
output_channels_per_group,
|
||||||
groups,
|
groups,
|
||||||
kernel_h,
|
kernel_h,
|
||||||
kernel_w,
|
kernel_w,
|
||||||
stride_h,
|
stride_h,
|
||||||
stride_w,
|
stride_w,
|
||||||
pad_h,
|
pad_h,
|
||||||
pad_w,
|
pad_w,
|
||||||
dilation,
|
dilation,
|
||||||
padding_mode,
|
padding_mode,
|
||||||
eps,
|
eps,
|
||||||
momentum,
|
momentum,
|
||||||
freeze_bn,
|
freeze_bn,
|
||||||
bias,
|
bias,
|
||||||
):
|
):
|
||||||
input_channels = input_channels_per_group * groups
|
input_channels = input_channels_per_group * groups
|
||||||
output_channels = output_channels_per_group * groups
|
output_channels = output_channels_per_group * groups
|
||||||
|
|
@ -945,7 +1095,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
eps,
|
eps,
|
||||||
momentum,
|
momentum,
|
||||||
freeze_bn=freeze_bn,
|
freeze_bn=freeze_bn,
|
||||||
qconfig=default_qat_qconfig
|
qconfig=default_qat_qconfig,
|
||||||
).to(dtype=torch.double)
|
).to(dtype=torch.double)
|
||||||
|
|
||||||
qat_ref_op = _ReferenceConvBn2d(
|
qat_ref_op = _ReferenceConvBn2d(
|
||||||
|
|
@ -961,7 +1111,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
eps,
|
eps,
|
||||||
momentum,
|
momentum,
|
||||||
freeze_bn=freeze_bn,
|
freeze_bn=freeze_bn,
|
||||||
qconfig=default_qat_qconfig
|
qconfig=default_qat_qconfig,
|
||||||
).to(dtype=torch.double)
|
).to(dtype=torch.double)
|
||||||
|
|
||||||
qat_op.apply(torch.ao.quantization.disable_fake_quant)
|
qat_op.apply(torch.ao.quantization.disable_fake_quant)
|
||||||
|
|
@ -981,7 +1131,6 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr)
|
qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
|
|
||||||
# make sure that calling model.train() does not override the
|
# make sure that calling model.train() does not override the
|
||||||
# bn freeze setting
|
# bn freeze setting
|
||||||
qat_op.train()
|
qat_op.train()
|
||||||
|
|
@ -990,7 +1139,14 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
qat_op_optim.zero_grad()
|
qat_op_optim.zero_grad()
|
||||||
qat_ref_op_optim.zero_grad()
|
qat_ref_op_optim.zero_grad()
|
||||||
|
|
||||||
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True)
|
input = torch.randn(
|
||||||
|
batch_size,
|
||||||
|
input_channels,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype=torch.double,
|
||||||
|
requires_grad=True,
|
||||||
|
)
|
||||||
input_clone = input.detach().clone().requires_grad_()
|
input_clone = input.detach().clone().requires_grad_()
|
||||||
|
|
||||||
if i > 2:
|
if i > 2:
|
||||||
|
|
@ -1030,12 +1186,23 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
|
|
||||||
precision = 1e-5
|
precision = 1e-5
|
||||||
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
|
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
|
||||||
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0)
|
self.assertEqual(
|
||||||
|
weight_grad_ref, weight_grad_actual, atol=precision, rtol=0
|
||||||
|
)
|
||||||
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
|
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
|
||||||
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
|
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
|
||||||
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0)
|
self.assertEqual(
|
||||||
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0)
|
num_batches_tracked_ref,
|
||||||
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0)
|
num_batches_tracked_actual,
|
||||||
|
atol=precision,
|
||||||
|
rtol=0,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
running_mean_ref, running_mean_actual, atol=precision, rtol=0
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
running_var_ref, running_var_actual, atol=precision, rtol=0
|
||||||
|
)
|
||||||
|
|
||||||
qat_op_optim.step()
|
qat_op_optim.step()
|
||||||
qat_ref_op_optim.step()
|
qat_ref_op_optim.step()
|
||||||
|
|
@ -1048,7 +1215,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
nn.BatchNorm1d(4),
|
nn.BatchNorm1d(4),
|
||||||
)
|
)
|
||||||
m_ref_copy = copy.deepcopy(m_ref)
|
m_ref_copy = copy.deepcopy(m_ref)
|
||||||
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
|
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [["0", "1"]])
|
||||||
qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
||||||
m_ref_copy[0].qconfig = qconfig
|
m_ref_copy[0].qconfig = qconfig
|
||||||
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
|
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
|
||||||
|
|
@ -1071,7 +1238,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
nn.BatchNorm1d(4),
|
nn.BatchNorm1d(4),
|
||||||
)
|
)
|
||||||
m_ref_copy = copy.deepcopy(m_ref)
|
m_ref_copy = copy.deepcopy(m_ref)
|
||||||
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']])
|
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [["0", "1"]])
|
||||||
qconfig = default_symmetric_qnnpack_qat_qconfig
|
qconfig = default_symmetric_qnnpack_qat_qconfig
|
||||||
m_ref_copy[0].qconfig = qconfig
|
m_ref_copy[0].qconfig = qconfig
|
||||||
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
|
m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
|
||||||
|
|
@ -1093,14 +1260,13 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
)
|
)
|
||||||
data = torch.randn(4, 4)
|
data = torch.randn(4, 4)
|
||||||
m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
||||||
m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']])
|
m = torch.ao.quantization.fuse_modules_qat(m, [["1", "2"]])
|
||||||
mp = prepare_qat(m)
|
mp = prepare_qat(m)
|
||||||
mp(data)
|
mp(data)
|
||||||
mq = convert(mp)
|
mq = convert(mp)
|
||||||
self.assertTrue(type(mq[1]) == nnq.Linear)
|
self.assertTrue(type(mq[1]) == nnq.Linear)
|
||||||
self.assertTrue(type(mq[2]) == nn.Identity)
|
self.assertTrue(type(mq[2]) == nn.Identity)
|
||||||
|
|
||||||
|
|
||||||
@skipIfNoXNNPACK
|
@skipIfNoXNNPACK
|
||||||
@override_qengines
|
@override_qengines
|
||||||
def test_linear_precomputed_fake_quant(self):
|
def test_linear_precomputed_fake_quant(self):
|
||||||
|
|
@ -1124,10 +1290,14 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
m_ref.activation_post_process = activation
|
m_ref.activation_post_process = activation
|
||||||
m_ref.qconfig = qconfig
|
m_ref.qconfig = qconfig
|
||||||
m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True)
|
m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True)
|
||||||
self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale)
|
self.assertTrue(
|
||||||
|
m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
raise RuntimeError(
|
||||||
"\tpython test/test_quantization.py TESTNAME\n\n"
|
"This test file is not meant to be run directly, use:\n\n"
|
||||||
"instead.")
|
"\tpython test/test_quantization.py TESTNAME\n\n"
|
||||||
|
"instead."
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user