[Easy] enable PYFMT for torch/quantization/eager (#150761)

All modifications are done through tools, the detailed commands are as follows:

```bash
lintrunner -a --take "PYFMT" --all-files
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150761
Approved by: https://github.com/jerryzh168
This commit is contained in:
FFFrog 2025-04-07 17:02:16 +08:00 committed by PyTorch MergeBot
parent 91b090c912
commit 8895c290f4
8 changed files with 1473 additions and 809 deletions

View File

@ -1165,14 +1165,6 @@ exclude_patterns = [
'test/quantization/core/test_utils.py', 'test/quantization/core/test_utils.py',
'test/quantization/core/test_workflow_module.py', 'test/quantization/core/test_workflow_module.py',
'test/quantization/core/test_workflow_ops.py', 'test/quantization/core/test_workflow_ops.py',
'test/quantization/eager/__init__.py',
'test/quantization/eager/test_bias_correction_eager.py',
'test/quantization/eager/test_equalize_eager.py',
'test/quantization/eager/test_fuse_eager.py',
'test/quantization/eager/test_model_numerics.py',
'test/quantization/eager/test_numeric_suite_eager.py',
'test/quantization/eager/test_quantize_eager_ptq.py',
'test/quantization/eager/test_quantize_eager_qat.py',
'test/quantization/fx/__init__.py', 'test/quantization/fx/__init__.py',
'test/quantization/fx/test_equalize_fx.py', 'test/quantization/fx/test_equalize_fx.py',
'test/quantization/fx/test_model_report_fx.py', 'test/quantization/fx/test_model_report_fx.py',

View File

@ -1,24 +1,23 @@
# Owner(s): ["oncall: quantization"] # Owner(s): ["oncall: quantization"]
import copy
import torch import torch
import torch.nn as nn
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.ao.quantization import default_qconfig
from torch.ao.quantization import QuantWrapper
import torch.ao.ns._numeric_suite as ns import torch.ao.ns._numeric_suite as ns
import torch.nn as nn
from torch.ao.quantization import default_qconfig, QuantWrapper
from torch.ao.quantization._correct_bias import ( from torch.ao.quantization._correct_bias import (
_supported_modules, _supported_modules,
_supported_modules_quantized, _supported_modules_quantized,
bias_correction, bias_correction,
get_module, get_module,
get_param, get_param,
parent_child_names parent_child_names,
)
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skipIfNoFBGEMM,
) )
import copy
class TestBiasCorrectionEager(QuantizationTestCase): class TestBiasCorrectionEager(QuantizationTestCase):
@ -28,9 +27,9 @@ class TestBiasCorrectionEager(QuantizationTestCase):
return 20 * torch.log10(Ps / Pn) return 20 * torch.log10(Ps / Pn)
def correct_artificial_bias_quantize(self, float_model, img_data): def correct_artificial_bias_quantize(self, float_model, img_data):
''' Adding artificial bias and testing if bias persists after bias """Adding artificial bias and testing if bias persists after bias
correction. This test case changes the bias of a quantized submodule correction. This test case changes the bias of a quantized submodule
''' """
artificial_model = copy.deepcopy(float_model) artificial_model = copy.deepcopy(float_model)
artificial_model.qconfig = default_qconfig artificial_model.qconfig = default_qconfig
torch.ao.quantization.prepare(artificial_model, inplace=True) torch.ao.quantization.prepare(artificial_model, inplace=True)
@ -41,12 +40,17 @@ class TestBiasCorrectionEager(QuantizationTestCase):
# manually changing bias # manually changing bias
for name, submodule in artificial_model.named_modules(): for name, submodule in artificial_model.named_modules():
if type(submodule) in _supported_modules: if type(submodule) in _supported_modules:
x = get_param(submodule, 'bias') x = get_param(submodule, "bias")
weight = get_param(submodule, 'weight') weight = get_param(submodule, "weight")
if x is not None: if x is not None:
submodule.set_weight_bias(weight, x.data * 3) submodule.set_weight_bias(weight, x.data * 3)
bias_correction(float_model, artificial_model, img_data, target_modules=_supported_modules_quantized) bias_correction(
float_model,
artificial_model,
img_data,
target_modules=_supported_modules_quantized,
)
# Trims off the shadow module, # Trims off the shadow module,
for name, submodule in artificial_model.named_modules(): for name, submodule in artificial_model.named_modules():
@ -58,11 +62,13 @@ class TestBiasCorrectionEager(QuantizationTestCase):
for name, artificial_submodule in artificial_model.named_modules(): for name, artificial_submodule in artificial_model.named_modules():
if type(artificial_submodule) in _supported_modules_quantized: if type(artificial_submodule) in _supported_modules_quantized:
submodule = get_module(float_model, name) submodule = get_module(float_model, name)
float_bias = get_param(submodule, 'bias') float_bias = get_param(submodule, "bias")
artificial_bias = get_param(artificial_submodule, 'bias') artificial_bias = get_param(artificial_submodule, "bias")
self.assertTrue(self.compute_sqnr(float_bias, artificial_bias) > 30, self.assertTrue(
"Correcting quantized bias produced too much noise, sqnr score too low") self.compute_sqnr(float_bias, artificial_bias) > 30,
"Correcting quantized bias produced too much noise, sqnr score too low",
)
@skipIfNoFBGEMM @skipIfNoFBGEMM
def test_linear_chain(self): def test_linear_chain(self):
@ -78,9 +84,15 @@ class TestBiasCorrectionEager(QuantizationTestCase):
x = self.linear2(x) x = self.linear2(x)
x = self.linear3(x) x = self.linear3(x)
return x return x
float_model = QuantWrapper(LinearChain()) float_model = QuantWrapper(LinearChain())
img_data = [(torch.rand(10, 3, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) img_data = [
for _ in range(50)] (
torch.rand(10, 3, dtype=torch.float),
torch.randint(0, 1, (2,), dtype=torch.long),
)
for _ in range(50)
]
self.correct_artificial_bias_quantize(float_model, img_data) self.correct_artificial_bias_quantize(float_model, img_data)
@skipIfNoFBGEMM @skipIfNoFBGEMM
@ -97,7 +109,13 @@ class TestBiasCorrectionEager(QuantizationTestCase):
x = self.conv2d2(x) x = self.conv2d2(x)
x = self.conv2d3(x) x = self.conv2d3(x)
return x return x
float_model = QuantWrapper(ConvChain()) float_model = QuantWrapper(ConvChain())
img_data = [(torch.rand(10, 3, 125, 125, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) img_data = [
for _ in range(50)] (
torch.rand(10, 3, 125, 125, dtype=torch.float),
torch.randint(0, 1, (2,), dtype=torch.long),
)
for _ in range(50)
]
self.correct_artificial_bias_quantize(float_model, img_data) self.correct_artificial_bias_quantize(float_model, img_data)

View File

@ -1,20 +1,19 @@
# Owner(s): ["oncall: quantization"] # Owner(s): ["oncall: quantization"]
import torch
import torch.nn as nn
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.ao.quantization.fuse_modules import fuse_modules
import torch.ao.quantization._equalize as _equalize
import copy import copy
import torch
import torch.ao.quantization._equalize as _equalize
import torch.nn as nn
from torch.ao.quantization.fuse_modules import fuse_modules
from torch.testing._internal.common_quantization import QuantizationTestCase
class TestEqualizeEager(QuantizationTestCase): class TestEqualizeEager(QuantizationTestCase):
def checkChannelsEqualized(self, tensor1, tensor2, output_axis, input_axis): def checkChannelsEqualized(self, tensor1, tensor2, output_axis, input_axis):
''' Checks the channel ranges of tensor1, tensor2 are the same, """Checks the channel ranges of tensor1, tensor2 are the same,
which is an indication that equalization has been applied correctly which is an indication that equalization has been applied correctly
''' """
output_channel_tensor1 = _equalize.channel_range(tensor1, output_axis) output_channel_tensor1 = _equalize.channel_range(tensor1, output_axis)
input_channel_tensor2 = _equalize.channel_range(tensor2, input_axis) input_channel_tensor2 = _equalize.channel_range(tensor2, input_axis)
@ -23,18 +22,17 @@ class TestEqualizeEager(QuantizationTestCase):
self.assertEqual(output_channel_tensor1, input_channel_tensor2) self.assertEqual(output_channel_tensor1, input_channel_tensor2)
def getModule(self, model, name): def getModule(self, model, name):
''' Given the name is a submodule to a model, return the submodule """Given the name is a submodule to a model, return the submodule"""
'''
curr = model curr = model
name = name.split('.') name = name.split(".")
for subname in name: for subname in name:
curr = curr._modules[subname] curr = curr._modules[subname]
return curr return curr
def test_cross_layer_equalization(self): def test_cross_layer_equalization(self):
''' applies _equalize.cross_layer_equalization on two modules and checks """applies _equalize.cross_layer_equalization on two modules and checks
to make sure channels ranges are equivalent to make sure channels ranges are equivalent
''' """
module1 = nn.Conv2d(3, 4, 2) module1 = nn.Conv2d(3, 4, 2)
module2 = nn.Linear(4, 4) module2 = nn.Linear(4, 4)
@ -45,13 +43,18 @@ class TestEqualizeEager(QuantizationTestCase):
mod_tensor1, mod_tensor2 = module1.weight, module2.weight mod_tensor1, mod_tensor2 = module1.weight, module2.weight
self.checkChannelsEqualized(mod_tensor1, mod_tensor2, module1_output_channel_axis, module2_input_channel_axis) self.checkChannelsEqualized(
mod_tensor1,
mod_tensor2,
module1_output_channel_axis,
module2_input_channel_axis,
)
def test_converged(self): def test_converged(self):
''' Sanity checks on _equalize.converged working """Sanity checks on _equalize.converged working
identical modules should return true identical modules should return true
modules with high difference in weights should return false modules with high difference in weights should return false
''' """
module1 = nn.Linear(3, 3) module1 = nn.Linear(3, 3)
module2 = nn.Linear(3, 3) module2 = nn.Linear(3, 3)
@ -59,18 +62,19 @@ class TestEqualizeEager(QuantizationTestCase):
module2.weight = nn.parameter.Parameter(torch.zeros(module1.weight.size())) module2.weight = nn.parameter.Parameter(torch.zeros(module1.weight.size()))
# input is a dictionary # input is a dictionary
dictionary_1 = {'linear1': module1} dictionary_1 = {"linear1": module1}
dictionary_2 = {'linear1': module2} dictionary_2 = {"linear1": module2}
self.assertTrue(_equalize.converged(dictionary_1, dictionary_1, 1e-6)) self.assertTrue(_equalize.converged(dictionary_1, dictionary_1, 1e-6))
self.assertFalse(_equalize.converged(dictionary_1, dictionary_2, 1e-6)) self.assertFalse(_equalize.converged(dictionary_1, dictionary_2, 1e-6))
def test_equalize(self): def test_equalize(self):
''' First checks to see if _equalize.equalize can handle multiple """First checks to see if _equalize.equalize can handle multiple
pair modules as input pair modules as input
then checks correctness of the function by ensuring the equalized then checks correctness of the function by ensuring the equalized
and unequalized versions of the model yield the same output and unequalized versions of the model yield the same output
given the same input given the same input
''' """
class ChainModule(nn.Module): class ChainModule(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -83,13 +87,16 @@ class TestEqualizeEager(QuantizationTestCase):
x = self.linear2(x) x = self.linear2(x)
x = self.linear3(x) x = self.linear3(x)
return x return x
chain1 = ChainModule() chain1 = ChainModule()
chain2 = copy.deepcopy(chain1) chain2 = copy.deepcopy(chain1)
_equalize.equalize(chain1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6) _equalize.equalize(
linear1 = self.getModule(chain1, 'linear1') chain1, [["linear1", "linear2"], ["linear2", "linear3"]], 1e-6
linear2 = self.getModule(chain1, 'linear2') )
linear3 = self.getModule(chain1, 'linear3') linear1 = self.getModule(chain1, "linear1")
linear2 = self.getModule(chain1, "linear2")
linear3 = self.getModule(chain1, "linear3")
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1) self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1) self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
@ -98,7 +105,7 @@ class TestEqualizeEager(QuantizationTestCase):
self.assertEqual(chain1(input), chain2(input)) self.assertEqual(chain1(input), chain2(input))
def test_equalize_fused_convrelu(self): def test_equalize_fused_convrelu(self):
''' Checks to see if eager mode equalization supports fused """Checks to see if eager mode equalization supports fused
ConvReLU2d models ConvReLU2d models
A model with 3 ConvReLU2d is constructed. Next, the conv2d and relu A model with 3 ConvReLU2d is constructed. Next, the conv2d and relu
@ -106,7 +113,8 @@ class TestEqualizeEager(QuantizationTestCase):
equalization applied. Finally, we ensure that the channels have been equalization applied. Finally, we ensure that the channels have been
equalized and that the equalized and unequalized versions of the model equalized and that the equalized and unequalized versions of the model
yield the same output given the same input yield the same output given the same input
''' """
class M(nn.Module): class M(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -128,13 +136,15 @@ class TestEqualizeEager(QuantizationTestCase):
model = M() model = M()
fused_model1 = fuse_modules(model, [['conv1', 'relu1'], ['conv2', 'relu2'], ['conv3', 'relu3']]) fused_model1 = fuse_modules(
model, [["conv1", "relu1"], ["conv2", "relu2"], ["conv3", "relu3"]]
)
fused_model2 = copy.deepcopy(fused_model1) fused_model2 = copy.deepcopy(fused_model1)
_equalize.equalize(fused_model1, [['conv1', 'conv2'], ['conv2', 'conv3']], 1e-6) _equalize.equalize(fused_model1, [["conv1", "conv2"], ["conv2", "conv3"]], 1e-6)
conv1 = self.getModule(fused_model1, 'conv1')[0] conv1 = self.getModule(fused_model1, "conv1")[0]
conv2 = self.getModule(fused_model1, 'conv2')[0] conv2 = self.getModule(fused_model1, "conv2")[0]
conv3 = self.getModule(fused_model1, 'conv3')[0] conv3 = self.getModule(fused_model1, "conv3")[0]
self.checkChannelsEqualized(conv1.weight, conv2.weight, 0, 1) self.checkChannelsEqualized(conv1.weight, conv2.weight, 0, 1)
self.checkChannelsEqualized(conv2.weight, conv3.weight, 0, 1) self.checkChannelsEqualized(conv2.weight, conv3.weight, 0, 1)
@ -144,7 +154,7 @@ class TestEqualizeEager(QuantizationTestCase):
self.assertEqual(fused_model1(input), model(input)) self.assertEqual(fused_model1(input), model(input))
def test_equalize_fused_linearrelu(self): def test_equalize_fused_linearrelu(self):
''' Checks to see if eager mode equalization supports fused """Checks to see if eager mode equalization supports fused
LinearReLU models LinearReLU models
A model with 3 LinearReLU is constructed. Next, the linear and relu A model with 3 LinearReLU is constructed. Next, the linear and relu
@ -152,7 +162,8 @@ class TestEqualizeEager(QuantizationTestCase):
equalization applied. Finally, we ensure that the channels have been equalization applied. Finally, we ensure that the channels have been
equalized and that the equalized and unequalized versions of the model equalized and that the equalized and unequalized versions of the model
yield the same output given the same input yield the same output given the same input
''' """
class M(nn.Module): class M(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -174,13 +185,17 @@ class TestEqualizeEager(QuantizationTestCase):
model = M() model = M()
fused_model1 = fuse_modules(model, [['linear1', 'relu1'], ['linear2', 'relu2'], ['linear3', 'relu3']]) fused_model1 = fuse_modules(
model, [["linear1", "relu1"], ["linear2", "relu2"], ["linear3", "relu3"]]
)
fused_model2 = copy.deepcopy(fused_model1) fused_model2 = copy.deepcopy(fused_model1)
_equalize.equalize(fused_model1, [['linear1', 'linear2'], ['linear2', 'linear3']], 1e-6) _equalize.equalize(
linear1 = self.getModule(fused_model1, 'linear1')[0] fused_model1, [["linear1", "linear2"], ["linear2", "linear3"]], 1e-6
linear2 = self.getModule(fused_model1, 'linear2')[0] )
linear3 = self.getModule(fused_model1, 'linear3')[0] linear1 = self.getModule(fused_model1, "linear1")[0]
linear2 = self.getModule(fused_model1, "linear2")[0]
linear3 = self.getModule(fused_model1, "linear3")[0]
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1) self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1) self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)

View File

@ -3,37 +3,35 @@
import copy import copy
import torch import torch
import torch.nn as nn
import torch.ao.nn.quantized as nnq
import torch.ao.nn.intrinsic as nni import torch.ao.nn.intrinsic as nni
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.intrinsic.qat as nniqat import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.quantized as nnq
import torch.nn as nn
from torch.ao.quantization import ( from torch.ao.quantization import (
quantize,
prepare,
convert, convert,
prepare_qat, default_qat_qconfig,
quantize_qat, default_qconfig,
fuse_modules, fuse_modules,
fuse_modules_qat, fuse_modules_qat,
prepare,
prepare_qat,
QConfig, QConfig,
default_qconfig, quantize,
default_qat_qconfig, quantize_qat,
) )
from torch.testing._internal.common_quantization import ( from torch.testing._internal.common_quantization import (
QuantizationTestCase,
ModelForFusion,
ModelWithSequentialFusion,
ModelForLinearBNFusion,
ModelForFusionWithBias,
ModelForConvTransposeBNFusion, ModelForConvTransposeBNFusion,
ModelForFusion,
ModelForFusionWithBias,
ModelForLinearBNFusion,
ModelWithSequentialFusion,
QuantizationTestCase,
SingleLayerLinearModel, SingleLayerLinearModel,
skipIfNoFBGEMM,
test_only_eval_fn, test_only_eval_fn,
test_only_train_fn, test_only_train_fn,
skipIfNoFBGEMM,
) )
from torch.testing._internal.common_quantized import ( from torch.testing._internal.common_quantized import (
override_quantized_engine, override_quantized_engine,
supported_qengines, supported_qengines,
@ -45,23 +43,38 @@ class TestFuseEager(QuantizationTestCase):
def test_fuse_module_train(self): def test_fuse_module_train(self):
model = ModelForFusion(default_qat_qconfig).train() model = ModelForFusion(default_qat_qconfig).train()
# Test step by step fusion # Test step by step fusion
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1']) model = fuse_modules_qat(model, ["conv1", "bn1", "relu1"])
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn']) model = fuse_modules_qat(model, ["sub1.conv", "sub1.bn"])
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d, self.assertEqual(
msg="Fused Conv + BN + Relu first layer") type(model.conv1),
self.assertEqual(type(model.bn1), torch.nn.Identity, nni.ConvBnReLU2d,
msg="Fused Conv + BN + Relu (skipped BN)") msg="Fused Conv + BN + Relu first layer",
self.assertEqual(type(model.relu1), torch.nn.Identity, )
msg="Fused Conv + BN + Relu (skipped Relu)") self.assertEqual(
type(model.bn1),
torch.nn.Identity,
msg="Fused Conv + BN + Relu (skipped BN)",
)
self.assertEqual(
type(model.relu1),
torch.nn.Identity,
msg="Fused Conv + BN + Relu (skipped Relu)",
)
self.assertEqual(type(model.sub1.conv), nni.ConvBn2d, self.assertEqual(
msg="Fused submodule Conv + BN") type(model.sub1.conv), nni.ConvBn2d, msg="Fused submodule Conv + BN"
self.assertEqual(type(model.sub1.bn), torch.nn.Identity, )
msg="Fused submodule Conv + BN (skipped BN)") self.assertEqual(
self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d, type(model.sub1.bn),
msg="Non-fused submodule Conv") torch.nn.Identity,
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, msg="Fused submodule Conv + BN (skipped BN)",
msg="Non-fused submodule ReLU") )
self.assertEqual(
type(model.sub2.conv), torch.nn.Conv2d, msg="Non-fused submodule Conv"
)
self.assertEqual(
type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU"
)
model = prepare_qat(model) model = prepare_qat(model)
self.checkObservers(model) self.checkObservers(model)
@ -89,69 +102,121 @@ class TestFuseEager(QuantizationTestCase):
test_only_eval_fn(model, self.img_data_1d) test_only_eval_fn(model, self.img_data_1d)
self.checkNoQconfig(model) self.checkNoQconfig(model)
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): with self.assertRaisesRegex(
RuntimeError,
"Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'",
):
checkQuantized(model) checkQuantized(model)
model = ModelForFusion(default_qat_qconfig).train() model = ModelForFusion(default_qat_qconfig).train()
model = fuse_modules_qat( model = fuse_modules_qat(
model, model, [["conv1", "bn1", "relu1"], ["sub1.conv", "sub1.bn"]]
[['conv1', 'bn1', 'relu1'], )
['sub1.conv', 'sub1.bn']])
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train]) model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"): with self.assertRaisesRegex(
RuntimeError,
"Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'",
):
checkQuantized(model) checkQuantized(model)
def test_fuse_module_eval(self): def test_fuse_module_eval(self):
model = ModelForFusion(default_qconfig) model = ModelForFusion(default_qconfig)
model.eval() model.eval()
model = fuse_modules( model = fuse_modules(
model, model,
[['conv3', 'bn3', 'relu4'], [
['conv1', 'bn1', 'relu1'], ["conv3", "bn3", "relu4"],
['conv2', 'relu2'], ["conv1", "bn1", "relu1"],
['bn2', 'relu3'], ["conv2", "relu2"],
['sub1.conv', 'sub1.bn']]) ["bn2", "relu3"],
self.assertEqual(type(model.conv1), nni.ConvReLU2d, ["sub1.conv", "sub1.bn"],
msg="Fused Conv + BN + Relu first layer (BN is folded)") ],
self.assertEqual(type(model.conv1[0]), nn.Conv2d, )
msg="Fused Conv + BN + Relu (Conv + folded BN only)") self.assertEqual(
self.assertEqual(type(model.conv1[1]), nn.ReLU, type(model.conv1),
msg="Fused Conv + BN + Relu second layer (Relu only)") nni.ConvReLU2d,
self.assertEqual(type(model.bn1), nn.Identity, msg="Fused Conv + BN + Relu first layer (BN is folded)",
msg="Fused Conv + BN + Relu second layer (Skipped BN)") )
self.assertEqual(type(model.relu1), nn.Identity, self.assertEqual(
msg="Fused Conv + BN + Relu second layer (Skipped Relu)") type(model.conv1[0]),
self.assertEqual(type(model.conv2), nni.ConvReLU3d, nn.Conv2d,
msg="Fused Conv + BN + Relu first layer (BN is folded)") msg="Fused Conv + BN + Relu (Conv + folded BN only)",
self.assertEqual(type(model.bn2), nni.BNReLU3d, )
msg="Fused BN + Relu first layer (Relu is folded))") self.assertEqual(
self.assertEqual(type(model.relu3), nn.Identity, type(model.conv1[1]),
msg="Fused BN + Relu second layer (Skipped Relu)") nn.ReLU,
self.assertEqual(type(model.conv2[0]), nn.Conv3d, msg="Fused Conv + BN + Relu second layer (Relu only)",
msg="Fused Conv + BN + Relu (Conv + folded BN only)") )
self.assertEqual(type(model.conv2[1]), nn.ReLU, self.assertEqual(
msg="Fused Conv + BN + Relu second layer (Relu only)") type(model.bn1),
self.assertEqual(type(model.relu2), nn.Identity, nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped Relu)") msg="Fused Conv + BN + Relu second layer (Skipped BN)",
)
self.assertEqual(
type(model.relu1),
nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped Relu)",
)
self.assertEqual(
type(model.conv2),
nni.ConvReLU3d,
msg="Fused Conv + BN + Relu first layer (BN is folded)",
)
self.assertEqual(
type(model.bn2),
nni.BNReLU3d,
msg="Fused BN + Relu first layer (Relu is folded))",
)
self.assertEqual(
type(model.relu3),
nn.Identity,
msg="Fused BN + Relu second layer (Skipped Relu)",
)
self.assertEqual(
type(model.conv2[0]),
nn.Conv3d,
msg="Fused Conv + BN + Relu (Conv + folded BN only)",
)
self.assertEqual(
type(model.conv2[1]),
nn.ReLU,
msg="Fused Conv + BN + Relu second layer (Relu only)",
)
self.assertEqual(
type(model.relu2),
nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped Relu)",
)
self.assertEqual(type(model.conv3), nni.ConvReLU1d, self.assertEqual(
msg="Fused Conv + Relu for Conv1d (folded BN)") type(model.conv3),
self.assertEqual(type(model.conv3[0]), nn.Conv1d, nni.ConvReLU1d,
msg="Fused Conv + Relu for Conv1d ") msg="Fused Conv + Relu for Conv1d (folded BN)",
self.assertEqual(type(model.conv3[1]), nn.ReLU, )
msg="Fused Conv + Relu for Conv1d") self.assertEqual(
self.assertEqual(type(model.bn3), nn.Identity, type(model.conv3[0]), nn.Conv1d, msg="Fused Conv + Relu for Conv1d "
msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)") )
self.assertEqual(
type(model.conv3[1]), nn.ReLU, msg="Fused Conv + Relu for Conv1d"
)
self.assertEqual(
type(model.bn3),
nn.Identity,
msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)",
)
self.assertEqual(type(model.sub1.conv), nn.Conv2d, self.assertEqual(
msg="Fused submodule Conv + folded BN") type(model.sub1.conv), nn.Conv2d, msg="Fused submodule Conv + folded BN"
self.assertEqual(type(model.sub1.bn), nn.Identity, )
msg="Fused submodule (skipped BN)") self.assertEqual(
self.assertEqual(type(model.sub2.conv), nn.Conv2d, type(model.sub1.bn), nn.Identity, msg="Fused submodule (skipped BN)"
msg="Non-fused submodule Conv") )
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU, self.assertEqual(
msg="Non-fused submodule ReLU") type(model.sub2.conv), nn.Conv2d, msg="Non-fused submodule Conv"
)
self.assertEqual(
type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU"
)
model = prepare(model) model = prepare(model)
self.checkObservers(model) self.checkObservers(model)
@ -176,11 +241,14 @@ class TestFuseEager(QuantizationTestCase):
model = ModelForFusion(default_qconfig).eval() model = ModelForFusion(default_qconfig).eval()
model = fuse_modules( model = fuse_modules(
model, model,
[['conv1', 'bn1', 'relu1'], [
['conv2', 'relu2'], ["conv1", "bn1", "relu1"],
['bn2', 'relu3'], ["conv2", "relu2"],
['sub1.conv', 'sub1.bn'], ["bn2", "relu3"],
['conv3', 'bn3', 'relu4']]) ["sub1.conv", "sub1.bn"],
["conv3", "bn3", "relu4"],
],
)
model = quantize(model, test_only_eval_fn, [self.img_data_1d]) model = quantize(model, test_only_eval_fn, [self.img_data_1d])
checkQuantized(model) checkQuantized(model)
@ -190,27 +258,46 @@ class TestFuseEager(QuantizationTestCase):
model = ModelWithSequentialFusion().train() model = ModelWithSequentialFusion().train()
model.to(torch.float) model.to(torch.float)
fuse_modules_qat( fuse_modules_qat(
model, [['conv1', 'relu1'] , model,
['features.0.0', 'features.0.1', 'features.0.2'], [
['features.1.0', 'features.1.1', 'features.1.2'], ["conv1", "relu1"],
['features.2.0', 'features.2.1', 'features.2.2'], ["features.0.0", "features.0.1", "features.0.2"],
['classifier.0', 'classifier.1']], ["features.1.0", "features.1.1", "features.1.2"],
inplace=True) ["features.2.0", "features.2.1", "features.2.2"],
self.assertEqual(type(model.conv1), nni.ConvReLU2d, ["classifier.0", "classifier.1"],
msg="Fused Conv + Relu: nni.ConvReLU2d") ],
self.assertEqual(type(model.conv1[0]), nn.Conv2d, inplace=True,
msg="Fused Conv + Relu: Conv2d") )
self.assertEqual(type(model.conv1[1]), nn.ReLU, self.assertEqual(
msg="Fused Conv + Relu: Relu") type(model.conv1),
self.assertEqual(type(model.relu1), nn.Identity, nni.ConvReLU2d,
msg="Fused Conv + Relu: Identity") msg="Fused Conv + Relu: nni.ConvReLU2d",
)
self.assertEqual(
type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d"
)
self.assertEqual(
type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu"
)
self.assertEqual(
type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity"
)
for i in range(3): for i in range(3):
self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d, self.assertEqual(
msg="Fused submodule Conv + folded BN") type(model.features[i][0]),
self.assertEqual(type(model.features[i][1]), nn.Identity, nni.ConvBnReLU2d,
msg="Fused submodule (skipped BN)") msg="Fused submodule Conv + folded BN",
self.assertEqual(type(model.features[i][2]), nn.Identity, )
msg="Non-fused submodule Conv") self.assertEqual(
type(model.features[i][1]),
nn.Identity,
msg="Fused submodule (skipped BN)",
)
self.assertEqual(
type(model.features[i][2]),
nn.Identity,
msg="Non-fused submodule Conv",
)
self.assertEqual(type(model.classifier[0]), nni.LinearReLU) self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity) self.assertEqual(type(model.classifier[1]), nn.Identity)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
@ -218,17 +305,26 @@ class TestFuseEager(QuantizationTestCase):
self.checkObservers(model) self.checkObservers(model)
model(self.img_data_2d[0][0]) model(self.img_data_2d[0][0])
def checkQAT(model): def checkQAT(model):
self.assertEqual(type(model.conv1), nniqat.ConvReLU2d) self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
self.assertEqual(type(model.relu1), nn.Identity) self.assertEqual(type(model.relu1), nn.Identity)
for i in range(3): for i in range(3):
self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d, self.assertEqual(
msg="Fused submodule Conv + folded BN") type(model.features[i][0]),
self.assertEqual(type(model.features[i][1]), nn.Identity, nniqat.ConvBnReLU2d,
msg="Fused submodule (skipped BN)") msg="Fused submodule Conv + folded BN",
self.assertEqual(type(model.features[i][2]), nn.Identity, )
msg="Non-fused submodule Conv") self.assertEqual(
type(model.features[i][1]),
nn.Identity,
msg="Fused submodule (skipped BN)",
)
self.assertEqual(
type(model.features[i][2]),
nn.Identity,
msg="Non-fused submodule Conv",
)
self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU) self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity) self.assertEqual(type(model.classifier[1]), nn.Identity)
@ -245,27 +341,45 @@ class TestFuseEager(QuantizationTestCase):
model.to(torch.float) model.to(torch.float)
fuse_modules( fuse_modules(
model, model,
[['conv1', 'relu1'], [
['features.0.0', 'features.0.1', 'features.0.2'], ["conv1", "relu1"],
['features.1.0', 'features.1.1', 'features.1.2'], ["features.0.0", "features.0.1", "features.0.2"],
['features.2.0', 'features.2.1', 'features.2.2'], ["features.1.0", "features.1.1", "features.1.2"],
['classifier.0', 'classifier.1']], ["features.2.0", "features.2.1", "features.2.2"],
inplace=True) ["classifier.0", "classifier.1"],
self.assertEqual(type(model.conv1), nni.ConvReLU2d, ],
msg="Fused Conv + Relu: nni.ConvReLU2d") inplace=True,
self.assertEqual(type(model.conv1[0]), nn.Conv2d, )
msg="Fused Conv + Relu: Conv2d") self.assertEqual(
self.assertEqual(type(model.conv1[1]), nn.ReLU, type(model.conv1),
msg="Fused Conv + Relu: Relu") nni.ConvReLU2d,
self.assertEqual(type(model.relu1), nn.Identity, msg="Fused Conv + Relu: nni.ConvReLU2d",
msg="Fused Conv + Relu: Identity") )
self.assertEqual(
type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d"
)
self.assertEqual(
type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu"
)
self.assertEqual(
type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity"
)
for i in range(3): for i in range(3):
self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d, self.assertEqual(
msg="Fused submodule Conv + folded BN") type(model.features[i][0]),
self.assertEqual(type(model.features[i][1]), nn.Identity, nni.ConvReLU2d,
msg="Fused submodule (skipped BN)") msg="Fused submodule Conv + folded BN",
self.assertEqual(type(model.features[i][2]), nn.Identity, )
msg="Non-fused submodule Conv") self.assertEqual(
type(model.features[i][1]),
nn.Identity,
msg="Fused submodule (skipped BN)",
)
self.assertEqual(
type(model.features[i][2]),
nn.Identity,
msg="Non-fused submodule Conv",
)
self.assertEqual(type(model.classifier[0]), nni.LinearReLU) self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity) self.assertEqual(type(model.classifier[1]), nn.Identity)
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
@ -297,12 +411,12 @@ class TestFuseEager(QuantizationTestCase):
out_ref = model_ref(self.img_data_2d[0][0]) out_ref = model_ref(self.img_data_2d[0][0])
# fused model # fused model
model_orig.qconfig = QConfig(activation=torch.nn.Identity, model_orig.qconfig = QConfig(
weight=torch.nn.Identity) activation=torch.nn.Identity, weight=torch.nn.Identity
)
model = fuse_modules_qat( model = fuse_modules_qat(
model_orig, model_orig, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]]
[["conv1", "bn1", "relu1"], )
["conv2", "bn2"]])
prep_model = prepare_qat(model, inplace=False) prep_model = prepare_qat(model, inplace=False)
# output with fusion but no observers. # output with fusion but no observers.
out_fused = prep_model(self.img_data_2d[0][0]) out_fused = prep_model(self.img_data_2d[0][0])
@ -332,7 +446,6 @@ class TestFuseEager(QuantizationTestCase):
checkQAT(model) checkQAT(model)
def test_fusion_linear_bn_eval(self): def test_fusion_linear_bn_eval(self):
model = ModelForLinearBNFusion().train() model = ModelForLinearBNFusion().train()
inp1 = torch.randn(8, 20) inp1 = torch.randn(8, 20)
@ -357,7 +470,9 @@ class TestFuseEager(QuantizationTestCase):
model.eval() model.eval()
golden = model(inp2) golden = model(inp2)
model = fuse_modules(model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]]) model = fuse_modules(
model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]]
)
self.assertEqual(type(model.bn1), nn.Identity) self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.bn2), nn.Identity) self.assertEqual(type(model.bn2), nn.Identity)
self.assertEqual(type(model.bn3), nn.Identity) self.assertEqual(type(model.bn3), nn.Identity)
@ -384,50 +499,68 @@ class TestFuseEager(QuantizationTestCase):
model = ModelForFusion(default_qat_qconfig).train() model = ModelForFusion(default_qat_qconfig).train()
counter = { counter = {
'pre_forwards': 0, "pre_forwards": 0,
'forwards': 0, "forwards": 0,
} }
fused = False fused = False
def fw_pre_hook(fused_module_class, h_module, input): def fw_pre_hook(fused_module_class, h_module, input):
if fused: if fused:
self.assertEqual(type(h_module), fused_module_class, self.assertEqual(
"After fusion owner of the first module's forward pre hook is not a fused module") type(h_module),
counter['pre_forwards'] += 1 fused_module_class,
"After fusion owner of the first module's forward pre hook is not a fused module",
)
counter["pre_forwards"] += 1
def fw_hook(fused_module_class, h_module, input, output): def fw_hook(fused_module_class, h_module, input, output):
if fused: if fused:
self.assertEqual(type(h_module), fused_module_class, self.assertEqual(
"After fusion owner of the last module's forward hook is not a fused module") type(h_module),
counter['forwards'] += 1 fused_module_class,
"After fusion owner of the last module's forward hook is not a fused module",
)
counter["forwards"] += 1
# Registering two pre and two post forward hooks, thus expecting counter increment by two each inference # Registering two pre and two post forward hooks, thus expecting counter increment by two each inference
model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args)) model.conv1.register_forward_pre_hook(
model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args)) lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args)
model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args)) )
model.sub1.conv.register_forward_pre_hook(
lambda *args: fw_pre_hook(nni.ConvBn2d, *args)
)
model.relu1.register_forward_hook(
lambda *args: fw_hook(nni.ConvBnReLU2d, *args)
)
model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args)) model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args))
test_only_eval_fn(model, self.img_data_1d) test_only_eval_fn(model, self.img_data_1d)
self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d)) self.assertEqual(counter["pre_forwards"], 2 * len(self.img_data_1d))
self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d)) self.assertEqual(counter["forwards"], 2 * len(self.img_data_1d))
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1']) model = fuse_modules_qat(model, ["conv1", "bn1", "relu1"])
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn']) model = fuse_modules_qat(model, ["sub1.conv", "sub1.bn"])
fused = True fused = True
before_fusion_pre_count = counter['pre_forwards'] before_fusion_pre_count = counter["pre_forwards"]
before_fusion_post_count = counter['forwards'] before_fusion_post_count = counter["forwards"]
test_only_eval_fn(model, self.img_data_1d) test_only_eval_fn(model, self.img_data_1d)
self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d)) self.assertEqual(
self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d)) counter["pre_forwards"] - before_fusion_pre_count, 2 * len(self.img_data_1d)
)
self.assertEqual(
counter["forwards"] - before_fusion_post_count, 2 * len(self.img_data_1d)
)
def test_fuse_modules_with_nested_hooks(self): def test_fuse_modules_with_nested_hooks(self):
r"""Test case that checks whether a nested module with sub-sub modules registered with hooks r"""Test case that checks whether a nested module with sub-sub modules registered with hooks
can be safely fused. Safeguard for issues similar to https://github.com/pytorch/pytorch/issues/105063 can be safely fused. Safeguard for issues similar to https://github.com/pytorch/pytorch/issues/105063
in the future. in the future.
""" """
def myhook(*x): def myhook(*x):
return "" return ""
for qengine in supported_qengines: for qengine in supported_qengines:
with override_quantized_engine(qengine): with override_quantized_engine(qengine):
model = ModelWithSequentialFusion().eval() model = ModelWithSequentialFusion().eval()
@ -435,28 +568,32 @@ class TestFuseEager(QuantizationTestCase):
for sub_model in model.modules(): for sub_model in model.modules():
if isinstance(sub_model, nn.Sequential): if isinstance(sub_model, nn.Sequential):
for layer in sub_model: for layer in sub_model:
if hasattr(layer, 'register_forward_hook'): if hasattr(layer, "register_forward_hook"):
layer.register_forward_hook(myhook) layer.register_forward_hook(myhook)
fuse_modules(model, [['features.0.0', 'features.0.1', 'features.0.2']], inplace=True) fuse_modules(
model,
[["features.0.0", "features.0.1", "features.0.2"]],
inplace=True,
)
self.assertEqual( self.assertEqual(
type(model.features[0][0]), type(model.features[0][0]),
nni.ConvReLU2d, nni.ConvReLU2d,
msg="Fused submodule Conv + folded BN" msg="Fused submodule Conv + folded BN",
) )
self.assertEqual( self.assertEqual(
type(model.features[0][1]), type(model.features[0][1]),
nn.Identity, nn.Identity,
msg="Fused submodule (skipped BN)" msg="Fused submodule (skipped BN)",
) )
self.assertEqual( self.assertEqual(
type(model.features[0][2]), type(model.features[0][2]),
nn.Identity, nn.Identity,
msg="Non-fused submodule Conv" msg="Non-fused submodule Conv",
) )
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError( raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n" "This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n" "\tpython test/test_quantization.py TESTNAME\n\n"

View File

@ -1,17 +1,17 @@
# Owner(s): ["oncall: quantization"] # Owner(s): ["oncall: quantization"]
import torch import torch
from torch.testing._internal.common_quantization import ( from torch.testing._internal.common_quantization import (
QuantizationTestCase,
ModelMultipleOps, ModelMultipleOps,
ModelMultipleOpsNoAvgPool, ModelMultipleOpsNoAvgPool,
QuantizationTestCase,
) )
from torch.testing._internal.common_quantized import ( from torch.testing._internal.common_quantized import (
override_quantized_engine, override_quantized_engine,
supported_qengines, supported_qengines,
) )
class TestModelNumericsEager(QuantizationTestCase): class TestModelNumericsEager(QuantizationTestCase):
def test_float_quant_compare_per_tensor(self): def test_float_quant_compare_per_tensor(self):
for qengine in supported_qengines: for qengine in supported_qengines:
@ -25,16 +25,24 @@ class TestModelNumericsEager(QuantizationTestCase):
qModel = torch.ao.quantization.QuantWrapper(my_model) qModel = torch.ao.quantization.QuantWrapper(my_model)
qModel.eval() qModel.eval()
qModel.qconfig = torch.ao.quantization.default_qconfig qModel.qconfig = torch.ao.quantization.default_qconfig
torch.ao.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.ao.quantization.fuse_modules(
qModel.module, [["conv1", "bn1", "relu1"]], inplace=True
)
torch.ao.quantization.prepare(qModel, inplace=True) torch.ao.quantization.prepare(qModel, inplace=True)
qModel(calib_data) qModel(calib_data)
torch.ao.quantization.convert(qModel, inplace=True) torch.ao.quantization.convert(qModel, inplace=True)
out_q = qModel(eval_data) out_q = qModel(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q)) SQNRdB = 20 * torch.log10(
torch.norm(out_ref) / torch.norm(out_ref - out_q)
)
# Quantized model output should be close to floating point model output numerically # Quantized model output should be close to floating point model output numerically
# Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired # Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired
# output # output
self.assertGreater(SQNRdB, 30, msg='Quantized model numerics diverge from float, expect SQNR > 30 dB') self.assertGreater(
SQNRdB,
30,
msg="Quantized model numerics diverge from float, expect SQNR > 30 dB",
)
def test_float_quant_compare_per_channel(self): def test_float_quant_compare_per_channel(self):
# Test for per-channel Quant # Test for per-channel Quant
@ -47,7 +55,9 @@ class TestModelNumericsEager(QuantizationTestCase):
q_model = torch.ao.quantization.QuantWrapper(my_model) q_model = torch.ao.quantization.QuantWrapper(my_model)
q_model.eval() q_model.eval()
q_model.qconfig = torch.ao.quantization.default_per_channel_qconfig q_model.qconfig = torch.ao.quantization.default_per_channel_qconfig
torch.ao.quantization.fuse_modules(q_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.ao.quantization.fuse_modules(
q_model.module, [["conv1", "bn1", "relu1"]], inplace=True
)
torch.ao.quantization.prepare(q_model) torch.ao.quantization.prepare(q_model)
q_model(calib_data) q_model(calib_data)
torch.ao.quantization.convert(q_model) torch.ao.quantization.convert(q_model)
@ -55,7 +65,11 @@ class TestModelNumericsEager(QuantizationTestCase):
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q)) SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q))
# Quantized model output should be close to floating point model output numerically # Quantized model output should be close to floating point model output numerically
# Setting target SQNR to be 35 dB # Setting target SQNR to be 35 dB
self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB') self.assertGreater(
SQNRdB,
35,
msg="Quantized model numerics diverge from float, expect SQNR > 35 dB",
)
def test_fake_quant_true_quant_compare(self): def test_fake_quant_true_quant_compare(self):
for qengine in supported_qengines: for qengine in supported_qengines:
@ -69,7 +83,9 @@ class TestModelNumericsEager(QuantizationTestCase):
fq_model = torch.ao.quantization.QuantWrapper(my_model) fq_model = torch.ao.quantization.QuantWrapper(my_model)
fq_model.train() fq_model.train()
fq_model.qconfig = torch.ao.quantization.default_qat_qconfig fq_model.qconfig = torch.ao.quantization.default_qat_qconfig
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.ao.quantization.fuse_modules_qat(
fq_model.module, [["conv1", "bn1", "relu1"]], inplace=True
)
torch.ao.quantization.prepare_qat(fq_model) torch.ao.quantization.prepare_qat(fq_model)
fq_model.eval() fq_model.eval()
fq_model.apply(torch.ao.quantization.disable_fake_quant) fq_model.apply(torch.ao.quantization.disable_fake_quant)
@ -78,14 +94,26 @@ class TestModelNumericsEager(QuantizationTestCase):
fq_model.apply(torch.ao.quantization.enable_fake_quant) fq_model.apply(torch.ao.quantization.enable_fake_quant)
fq_model.apply(torch.ao.quantization.disable_observer) fq_model.apply(torch.ao.quantization.disable_observer)
out_fq = fq_model(eval_data) out_fq = fq_model(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq)) SQNRdB = 20 * torch.log10(
torch.norm(out_ref) / torch.norm(out_ref - out_fq)
)
# Quantized model output should be close to floating point model output numerically # Quantized model output should be close to floating point model output numerically
# Setting target SQNR to be 35 dB # Setting target SQNR to be 35 dB
self.assertGreater(SQNRdB, 35, msg='Quantized model numerics diverge from float, expect SQNR > 35 dB') self.assertGreater(
SQNRdB,
35,
msg="Quantized model numerics diverge from float, expect SQNR > 35 dB",
)
torch.ao.quantization.convert(fq_model) torch.ao.quantization.convert(fq_model)
out_q = fq_model(eval_data) out_q = fq_model(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10)) SQNRdB = 20 * torch.log10(
self.assertGreater(SQNRdB, 60, msg='Fake quant and true quant numerics diverge, expect SQNR > 60 dB') torch.norm(out_fq) / (torch.norm(out_fq - out_q) + 1e-10)
)
self.assertGreater(
SQNRdB,
60,
msg="Fake quant and true quant numerics diverge, expect SQNR > 60 dB",
)
# Test to compare weight only quantized model numerics and # Test to compare weight only quantized model numerics and
# activation only quantized model numerics with float # activation only quantized model numerics with float
@ -95,8 +123,10 @@ class TestModelNumericsEager(QuantizationTestCase):
torch.manual_seed(67) torch.manual_seed(67)
calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32) calib_data = torch.rand(2048, 3, 15, 15, dtype=torch.float32)
eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32) eval_data = torch.rand(10, 3, 15, 15, dtype=torch.float32)
qconfigset = {torch.ao.quantization.default_weight_only_qconfig, qconfigset = {
torch.ao.quantization.default_activation_only_qconfig} torch.ao.quantization.default_weight_only_qconfig,
torch.ao.quantization.default_activation_only_qconfig,
}
SQNRTarget = [35, 45] SQNRTarget = [35, 45]
for idx, qconfig in enumerate(qconfigset): for idx, qconfig in enumerate(qconfigset):
my_model = ModelMultipleOpsNoAvgPool().to(torch.float32) my_model = ModelMultipleOpsNoAvgPool().to(torch.float32)
@ -105,7 +135,9 @@ class TestModelNumericsEager(QuantizationTestCase):
fq_model = torch.ao.quantization.QuantWrapper(my_model) fq_model = torch.ao.quantization.QuantWrapper(my_model)
fq_model.train() fq_model.train()
fq_model.qconfig = qconfig fq_model.qconfig = qconfig
torch.ao.quantization.fuse_modules_qat(fq_model.module, [['conv1', 'bn1', 'relu1']], inplace=True) torch.ao.quantization.fuse_modules_qat(
fq_model.module, [["conv1", "bn1", "relu1"]], inplace=True
)
torch.ao.quantization.prepare_qat(fq_model) torch.ao.quantization.prepare_qat(fq_model)
fq_model.eval() fq_model.eval()
fq_model.apply(torch.ao.quantization.disable_fake_quant) fq_model.apply(torch.ao.quantization.disable_fake_quant)
@ -114,11 +146,19 @@ class TestModelNumericsEager(QuantizationTestCase):
fq_model.apply(torch.ao.quantization.enable_fake_quant) fq_model.apply(torch.ao.quantization.enable_fake_quant)
fq_model.apply(torch.ao.quantization.disable_observer) fq_model.apply(torch.ao.quantization.disable_observer)
out_fq = fq_model(eval_data) out_fq = fq_model(eval_data)
SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_fq)) SQNRdB = 20 * torch.log10(
self.assertGreater(SQNRdB, SQNRTarget[idx], msg='Quantized model numerics diverge from float') torch.norm(out_ref) / torch.norm(out_ref - out_fq)
)
self.assertGreater(
SQNRdB,
SQNRTarget[idx],
msg="Quantized model numerics diverge from float",
)
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_quantization.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_quantization.py TESTNAME\n\n"
"instead."
)

View File

@ -2,43 +2,45 @@
# ruff: noqa: F841 # ruff: noqa: F841
import unittest import unittest
import torch import torch
import torch.nn as nn
import torch.ao.nn.quantized as nnq import torch.ao.nn.quantized as nnq
from torch.ao.quantization import ( import torch.nn as nn
DeQuantStub,
QuantStub,
convert,
default_qconfig,
prepare,
quantize,
quantize_dynamic,
)
from torch.ao.ns._numeric_suite import ( from torch.ao.ns._numeric_suite import (
OutputLogger,
Shadow,
ShadowLogger,
compare_model_outputs, compare_model_outputs,
compare_model_stub, compare_model_stub,
compare_weights, compare_weights,
prepare_model_outputs,
get_matching_activations, get_matching_activations,
OutputLogger,
prepare_model_outputs,
Shadow,
ShadowLogger,
)
from torch.ao.quantization import (
convert,
default_qconfig,
DeQuantStub,
prepare,
quantize,
quantize_dynamic,
QuantStub,
) )
from torch.testing._internal.common_quantization import ( from torch.testing._internal.common_quantization import (
AnnotatedConvBnReLUModel, AnnotatedConvBnReLUModel,
AnnotatedConvModel, AnnotatedConvModel,
AnnotatedConvTransposeModel, AnnotatedConvTransposeModel,
AnnotatedSingleLayerLinearModel, AnnotatedSingleLayerLinearModel,
LSTMwithHiddenDynamicModel,
AnnotatedTwoLayerLinearModel, AnnotatedTwoLayerLinearModel,
LSTMwithHiddenDynamicModel,
QuantizationTestCase, QuantizationTestCase,
SingleLayerLinearDynamicModel, SingleLayerLinearDynamicModel,
test_only_eval_fn,
skip_if_no_torchvision, skip_if_no_torchvision,
test_only_eval_fn,
) )
from torch.testing._internal.common_quantized import override_qengines from torch.testing._internal.common_quantized import override_qengines
from torch.testing._internal.common_utils import IS_ARM64 from torch.testing._internal.common_utils import IS_ARM64
class SubModule(torch.nn.Module): class SubModule(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -200,10 +202,18 @@ class TestNumericSuiteEager(QuantizationTestCase):
for i, val in enumerate(v["quantized"]): for i, val in enumerate(v["quantized"]):
self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) self.assertTrue(v["float"][i].shape == v["quantized"][i].shape)
model_list = [AnnotatedConvModel(qengine), model_list = [
AnnotatedConvTransposeModel("qnnpack"), # ConvT cannot use per channel weights AnnotatedConvModel(qengine),
AnnotatedConvBnReLUModel(qengine)] AnnotatedConvTransposeModel(
module_swap_list = [nn.Conv2d, nn.intrinsic.modules.fused.ConvReLU2d, nn.ConvTranspose2d] "qnnpack"
), # ConvT cannot use per channel weights
AnnotatedConvBnReLUModel(qengine),
]
module_swap_list = [
nn.Conv2d,
nn.intrinsic.modules.fused.ConvReLU2d,
nn.ConvTranspose2d,
]
for model in model_list: for model in model_list:
model.eval() model.eval()
if hasattr(model, "fuse_model"): if hasattr(model, "fuse_model"):
@ -279,7 +289,6 @@ class TestNumericSuiteEager(QuantizationTestCase):
self.assertTrue(isinstance(q_model.mod1, Shadow)) self.assertTrue(isinstance(q_model.mod1, Shadow))
self.assertFalse(isinstance(q_model.conv, Shadow)) self.assertFalse(isinstance(q_model.conv, Shadow))
@override_qengines @override_qengines
def test_compare_model_stub_functional_static(self): def test_compare_model_stub_functional_static(self):
r"""Compare the output of static quantized functional layer and its float shadow module""" r"""Compare the output of static quantized functional layer and its float shadow module"""
@ -486,7 +495,9 @@ class TestNumericSuiteEager(QuantizationTestCase):
for i, val in enumerate(v["quantized"]): for i, val in enumerate(v["quantized"]):
self.assertTrue(len(v["float"][i]) == len(v["quantized"][i])) self.assertTrue(len(v["float"][i]) == len(v["quantized"][i]))
if i == 0: if i == 0:
self.assertTrue(v["float"][i][0].shape == v["quantized"][i][0].shape) self.assertTrue(
v["float"][i][0].shape == v["quantized"][i][0].shape
)
else: else:
self.assertTrue( self.assertTrue(
v["float"][i][0].shape == v["quantized"][i][0].shape v["float"][i][0].shape == v["quantized"][i][0].shape
@ -540,12 +551,23 @@ class TestNumericSuiteEager(QuantizationTestCase):
@skip_if_no_torchvision @skip_if_no_torchvision
def _test_vision_model(self, float_model): def _test_vision_model(self, float_model):
float_model.to('cpu') float_model.to("cpu")
float_model.eval() float_model.eval()
float_model.fuse_model() float_model.fuse_model()
float_model.qconfig = torch.ao.quantization.default_qconfig float_model.qconfig = torch.ao.quantization.default_qconfig
img_data = [(torch.rand(2, 3, 224, 224, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)] img_data = [
qmodel = quantize(float_model, torch.ao.quantization.default_eval_fn, [img_data], inplace=False) (
torch.rand(2, 3, 224, 224, dtype=torch.float),
torch.randint(0, 1, (2,), dtype=torch.long),
)
for _ in range(2)
]
qmodel = quantize(
float_model,
torch.ao.quantization.default_eval_fn,
[img_data],
inplace=False,
)
wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict()) wt_compare_dict = compare_weights(float_model.state_dict(), qmodel.state_dict())
@ -560,9 +582,11 @@ class TestNumericSuiteEager(QuantizationTestCase):
# 'quantized', containing the activations of floating point and quantized model at matching locations. # 'quantized', containing the activations of floating point and quantized model at matching locations.
act_compare_dict = compare_model_outputs(float_model, qmodel, data) act_compare_dict = compare_model_outputs(float_model, qmodel, data)
for key in act_compare_dict: for key in act_compare_dict:
compute_error(act_compare_dict[key]['float'][0], act_compare_dict[key]['quantized'][0].dequantize()) compute_error(
act_compare_dict[key]["float"][0],
act_compare_dict[key]["quantized"][0].dequantize(),
)
prepare_model_outputs(float_model, qmodel) prepare_model_outputs(float_model, qmodel)
@ -579,10 +603,12 @@ class TestNumericSuiteEager(QuantizationTestCase):
@unittest.skipIf(IS_ARM64, "Not working on arm right now") @unittest.skipIf(IS_ARM64, "Not working on arm right now")
def test_mobilenet_v2(self): def test_mobilenet_v2(self):
from torchvision.models.quantization import mobilenet_v2 from torchvision.models.quantization import mobilenet_v2
self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False)) self._test_vision_model(mobilenet_v2(pretrained=True, quantize=False))
@skip_if_no_torchvision @skip_if_no_torchvision
@unittest.skipIf(IS_ARM64, "Not working on arm right now") @unittest.skipIf(IS_ARM64, "Not working on arm right now")
def test_mobilenet_v3(self): def test_mobilenet_v3(self):
from torchvision.models.quantization import mobilenet_v3_large from torchvision.models.quantization import mobilenet_v3_large
self._test_vision_model(mobilenet_v3_large(pretrained=True, quantize=False)) self._test_vision_model(mobilenet_v3_large(pretrained=True, quantize=False))

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,8 @@
import copy import copy
import math import math
from hypothesis import given, strategies as st
import torch import torch
import torch.ao.nn.intrinsic.qat as nniqat import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.qat as nnqat import torch.ao.nn.qat as nnqat
@ -12,8 +14,6 @@ import torch.ao.nn.quantized.dynamic as nnqd
import torch.backends.mkldnn import torch.backends.mkldnn
import torch.nn as nn import torch.nn as nn
import torch.testing._internal.hypothesis_utils as hu import torch.testing._internal.hypothesis_utils as hu
from hypothesis import given, strategies as st
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.ao.quantization import ( from torch.ao.quantization import (
convert, convert,
@ -50,42 +50,63 @@ from torch.testing._internal.common_quantization import (
test_only_train_fn, test_only_train_fn,
TwoLayerLinearModel, TwoLayerLinearModel,
) )
from torch.testing._internal.common_quantized import ( from torch.testing._internal.common_quantized import (
override_qengines, override_qengines,
override_quantized_engine, override_quantized_engine,
supported_qengines, supported_qengines,
) )
from torch.testing._internal.common_utils import skipIfNoXNNPACK from torch.testing._internal.common_utils import skipIfNoXNNPACK
hu.assert_deadline_disabled() hu.assert_deadline_disabled()
from functools import reduce from functools import reduce
class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd): class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
""" """
Conv-BN fusion implemented with explicit folding. Useful Conv-BN fusion implemented with explicit folding. Useful
to verify numerical equivalency with non-folded version. to verify numerical equivalency with non-folded version.
""" """
def __init__(self,
# ConvNd args def __init__(
in_channels, out_channels, kernel_size, stride, self,
padding, dilation, transposed, output_padding, # ConvNd args
groups, in_channels,
bias, out_channels,
padding_mode, kernel_size,
# BatchNormNd args stride,
# num_features: out_channels padding,
eps=1e-05, momentum=0.1, dilation,
# affine: True transposed,
# track_running_stats: True output_padding,
# Args for this module groups,
freeze_bn=False, bias,
qconfig=None): padding_mode,
nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size, # BatchNormNd args
stride, padding, dilation, transposed, # num_features: out_channels
output_padding, groups, False, padding_mode) eps=1e-05,
assert qconfig, 'qconfig must be provided for QAT module' momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None,
):
nn.modules.conv._ConvNd.__init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
False,
padding_mode,
)
assert qconfig, "qconfig must be provided for QAT module"
self.qconfig = qconfig self.qconfig = qconfig
self.eps = eps self.eps = eps
self.momentum = momentum self.momentum = momentum
@ -103,7 +124,7 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
if bias: if bias:
self.bias = nn.Parameter(torch.empty(out_channels)) self.bias = nn.Parameter(torch.empty(out_channels))
else: else:
self.register_parameter('bias', None) self.register_parameter("bias", None)
self.reset_bn_parameters() self.reset_bn_parameters()
def reset_running_stats(self): def reset_running_stats(self):
@ -123,7 +144,7 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
def reset_parameters(self): def reset_parameters(self):
super().reset_parameters() super().reset_parameters()
# A hack to avoid resetting on undefined parameters # A hack to avoid resetting on undefined parameters
if hasattr(self, 'gamma'): if hasattr(self, "gamma"):
self.reset_bn_parameters() self.reset_bn_parameters()
def update_bn_stats(self): def update_bn_stats(self):
@ -161,33 +182,50 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
if self.bias is not None: if self.bias is not None:
zero_bias = torch.zeros_like(self.bias, dtype=input.dtype) zero_bias = torch.zeros_like(self.bias, dtype=input.dtype)
else: else:
zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device, dtype=input.dtype) zero_bias = torch.zeros(
conv = self._conv_forward(input, self.weight_fake_quant(scaled_weight), zero_bias) self.out_channels, device=scaled_weight.device, dtype=input.dtype
)
conv = self._conv_forward(
input, self.weight_fake_quant(scaled_weight), zero_bias
)
if self.training and not self.freeze_bn: if self.training and not self.freeze_bn:
# recovering original conv to get original batch_mean and batch_var # recovering original conv to get original batch_mean and batch_var
if self.bias is not None: if self.bias is not None:
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1]) conv_orig = conv / scale_factor.reshape(
[1, -1, 1, 1]
) + self.bias.reshape([1, -1, 1, 1])
else: else:
conv_orig = conv / scale_factor.reshape([1, -1, 1, 1]) conv_orig = conv / scale_factor.reshape([1, -1, 1, 1])
batch_mean = torch.mean(conv_orig, dim=[0, 2, 3]) batch_mean = torch.mean(conv_orig, dim=[0, 2, 3])
batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False) batch_var = torch.var(conv_orig, dim=[0, 2, 3], unbiased=False)
n = float(conv_orig.numel() / conv_orig.size()[1]) n = float(conv_orig.numel() / conv_orig.size()[1])
unbiased_batch_var = batch_var * (n / (n - 1)) unbiased_batch_var = batch_var * (n / (n - 1))
batch_rstd = torch.ones_like(batch_var, memory_format=torch.contiguous_format) / torch.sqrt(batch_var + self.eps) batch_rstd = torch.ones_like(
batch_var, memory_format=torch.contiguous_format
) / torch.sqrt(batch_var + self.eps)
conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + \ conv = (self.gamma * batch_rstd).reshape([1, -1, 1, 1]) * conv_orig + (
(self.beta - self.gamma * batch_rstd * batch_mean).reshape([1, -1, 1, 1]) self.beta - self.gamma * batch_rstd * batch_mean
self.running_mean = exponential_average_factor * batch_mean.detach() + \ ).reshape([1, -1, 1, 1])
(1 - exponential_average_factor) * self.running_mean self.running_mean = (
self.running_var = exponential_average_factor * unbiased_batch_var.detach() + \ exponential_average_factor * batch_mean.detach()
(1 - exponential_average_factor) * self.running_var + (1 - exponential_average_factor) * self.running_mean
)
self.running_var = (
exponential_average_factor * unbiased_batch_var.detach()
+ (1 - exponential_average_factor) * self.running_var
)
else: else:
if self.bias is None: if self.bias is None:
conv = conv + (self.beta - self.gamma * self.running_mean / conv = conv + (
running_std).reshape([1, -1, 1, 1]) self.beta - self.gamma * self.running_mean / running_std
).reshape([1, -1, 1, 1])
else: else:
conv = conv + (self.gamma * (self.bias - self.running_mean) / running_std + self.beta).reshape([1, -1, 1, 1]) conv = conv + (
self.gamma * (self.bias - self.running_mean) / running_std
+ self.beta
).reshape([1, -1, 1, 1])
return conv return conv
def extra_repr(self): def extra_repr(self):
@ -200,23 +238,37 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
@classmethod @classmethod
def from_float(cls, mod, qconfig=None): def from_float(cls, mod, qconfig=None):
r"""Create a qat module from a float module or qparams_dict r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user or directly from user
""" """
assert type(mod) == cls._FLOAT_MODULE, 'qat.' + cls.__name__ + '.from_float only works for ' + \ assert type(mod) == cls._FLOAT_MODULE, (
cls._FLOAT_MODULE.__name__ "qat."
+ cls.__name__
+ ".from_float only works for "
+ cls._FLOAT_MODULE.__name__
)
if not qconfig: if not qconfig:
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert hasattr(
assert mod.qconfig, 'Input float module must have a valid qconfig' mod, "qconfig"
), "Input float module must have qconfig defined"
assert mod.qconfig, "Input float module must have a valid qconfig"
qconfig = mod.qconfig qconfig = mod.qconfig
conv, bn = mod[0], mod[1] conv, bn = mod[0], mod[1]
qat_convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size, qat_convbn = cls(
conv.stride, conv.padding, conv.dilation, conv.in_channels,
conv.groups, conv.bias is not None, conv.out_channels,
conv.padding_mode, conv.kernel_size,
bn.eps, bn.momentum, conv.stride,
False, conv.padding,
qconfig) conv.dilation,
conv.groups,
conv.bias is not None,
conv.padding_mode,
bn.eps,
bn.momentum,
False,
qconfig,
)
qat_convbn.weight = conv.weight qat_convbn.weight = conv.weight
qat_convbn.bias = conv.bias qat_convbn.bias = conv.bias
qat_convbn.gamma = bn.weight qat_convbn.gamma = bn.weight
@ -226,41 +278,69 @@ class _ReferenceConvBnNd(torch.nn.Conv2d, torch.nn.modules.conv._ConvNd):
qat_convbn.num_batches_tracked = bn.num_batches_tracked qat_convbn.num_batches_tracked = bn.num_batches_tracked
return qat_convbn return qat_convbn
class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d): class _ReferenceConvBn2d(_ReferenceConvBnNd, nn.Conv2d):
_FLOAT_MODULE = torch.ao.nn.intrinsic.ConvBn2d _FLOAT_MODULE = torch.ao.nn.intrinsic.ConvBn2d
def __init__(self, def __init__(
# ConvNd args self,
in_channels, out_channels, kernel_size, stride=1, # ConvNd args
padding=0, dilation=1, groups=1, in_channels,
bias=None, out_channels,
padding_mode='zeros', kernel_size,
# BatchNorm2d args stride=1,
# num_features: out_channels padding=0,
eps=1e-05, momentum=0.1, dilation=1,
# affine: True groups=1,
# track_running_stats: True bias=None,
# Args for this module padding_mode="zeros",
freeze_bn=False, # BatchNorm2d args
qconfig=None): # num_features: out_channels
eps=1e-05,
momentum=0.1,
# affine: True
# track_running_stats: True
# Args for this module
freeze_bn=False,
qconfig=None,
):
kernel_size = _pair(kernel_size) kernel_size = _pair(kernel_size)
stride = _pair(stride) stride = _pair(stride)
padding = _pair(padding) padding = _pair(padding)
dilation = _pair(dilation) dilation = _pair(dilation)
_ReferenceConvBnNd.__init__(self, in_channels, out_channels, kernel_size, stride, _ReferenceConvBnNd.__init__(
padding, dilation, False, _pair(0), groups, bias, padding_mode, self,
eps, momentum, freeze_bn, qconfig) in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
False,
_pair(0),
groups,
bias,
padding_mode,
eps,
momentum,
freeze_bn,
qconfig,
)
class TestQuantizeEagerQAT(QuantizationTestCase): class TestQuantizeEagerQAT(QuantizationTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.embed_linear_data_train = [[torch.randint(0, 10, (12, 12), dtype=torch.long), self.embed_linear_data_train = [
torch.randn((12, 1), dtype=torch.float)] [
for _ in range(2)] torch.randint(0, 10, (12, 12), dtype=torch.long),
torch.randn((12, 1), dtype=torch.float),
]
for _ in range(2)
]
self.embed_data = [[torch.randint(0, 10, (12, 1))]] self.embed_data = [[torch.randint(0, 10, (12, 1))]]
def test_manual(self): def test_manual(self):
for qengine in supported_qengines: for qengine in supported_qengines:
with override_quantized_engine(qengine): with override_quantized_engine(qengine):
@ -279,8 +359,9 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
checkQuantized(model) checkQuantized(model)
model = quantize_qat(ManualLinearQATModel(qengine), test_only_train_fn, model = quantize_qat(
[self.train_data]) ManualLinearQATModel(qengine), test_only_train_fn, [self.train_data]
)
checkQuantized(model) checkQuantized(model)
def test_dropout(self): def test_dropout(self):
@ -301,8 +382,11 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
checkQuantized(model) checkQuantized(model)
model = quantize_qat(ManualDropoutQATModel(qengine), test_only_train_fn, model = quantize_qat(
[self.train_data]) ManualDropoutQATModel(qengine),
test_only_train_fn,
[self.train_data],
)
checkQuantized(model) checkQuantized(model)
def test_eval_only_fake_quant(self): def test_eval_only_fake_quant(self):
@ -342,7 +426,9 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
checkQuantized(model) checkQuantized(model)
model = ManualConvLinearQATModel() model = ManualConvLinearQATModel()
model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train]) model = quantize_qat(
model, test_only_train_fn, [self.img_data_2d_train]
)
checkQuantized(model) checkQuantized(model)
@skipIfNoXNNPACK @skipIfNoXNNPACK
@ -351,7 +437,7 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
Supported only with qengine=qnnpack, which uses symmetric Supported only with qengine=qnnpack, which uses symmetric
kernels from xnnpack library.""" kernels from xnnpack library."""
for qengine in supported_qengines: for qengine in supported_qengines:
if qengine != 'qnnpack': if qengine != "qnnpack":
continue continue
with override_quantized_engine(qengine): with override_quantized_engine(qengine):
model = ManualConvLinearSymmQATModel() model = ManualConvLinearSymmQATModel()
@ -373,17 +459,20 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
checkQuantized(model) checkQuantized(model)
model = ManualConvLinearSymmQATModel() model = ManualConvLinearSymmQATModel()
model = quantize_qat(model, test_only_train_fn, [self.img_data_2d_train]) model = quantize_qat(
model, test_only_train_fn, [self.img_data_2d_train]
)
checkQuantized(model) checkQuantized(model)
def test_dynamic_qat_linear(self): def test_dynamic_qat_linear(self):
for qengine in supported_qengines: for qengine in supported_qengines:
with override_quantized_engine(qengine): with override_quantized_engine(qengine):
# Dynamic QAT without memoryless observers should fail # Dynamic QAT without memoryless observers should fail
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
"Dynamic QAT requires a memoryless observer." + ValueError,
"This means a MovingAverage observer with averaging constant equal to 1" "Dynamic QAT requires a memoryless observer."
): + "This means a MovingAverage observer with averaging constant equal to 1",
):
model = ManualLinearDynamicQATModel(default_qat_qconfig) model = ManualLinearDynamicQATModel(default_qat_qconfig)
model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear}) model = prepare_qat(model, mapping={torch.nn.Linear: nnqatd.Linear})
@ -409,14 +498,23 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
test_only_train_fn(model, self.embed_linear_data_train) test_only_train_fn(model, self.embed_linear_data_train)
# make sure activation_post_process is inserted after Linear. # make sure activation_post_process is inserted after Linear.
self.assertEqual(type(model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize) self.assertEqual(
type(model.linear.activation_post_process),
FusedMovingAvgObsFakeQuantize,
)
# make sure that Embedding has a noop for activation. # make sure that Embedding has a noop for activation.
self.assertEqual(type(model.emb.activation_post_process), NoopObserver) self.assertEqual(type(model.emb.activation_post_process), NoopObserver)
# make sure that FakeQuant zero_points are correct dtype # make sure that FakeQuant zero_points are correct dtype
self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32) self.assertEqual(
self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32) model.emb.weight_fake_quant.zero_point.dtype, torch.float32
)
self.assertEqual(
model.linear.weight_fake_quant.zero_point.dtype, torch.int32
)
model = convert(model, mapping=get_embedding_static_quant_module_mappings()) model = convert(
model, mapping=get_embedding_static_quant_module_mappings()
)
def checkQuantized(model): def checkQuantized(model):
# make sure Embedding is now a QuantizedEmbedding # make sure Embedding is now a QuantizedEmbedding
@ -430,7 +528,6 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
checkQuantized(model) checkQuantized(model)
def test_embedding_bag_linear(self): def test_embedding_bag_linear(self):
for qengine in supported_qengines: for qengine in supported_qengines:
with override_quantized_engine(qengine): with override_quantized_engine(qengine):
@ -442,9 +539,15 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
# make sure not activation_post_process is inserted for EmbeddingBag # make sure not activation_post_process is inserted for EmbeddingBag
self.assertFalse(hasattr(model, "activation_post_process")) self.assertFalse(hasattr(model, "activation_post_process"))
# make sure that FakeQuant zero_points are correct dtype # make sure that FakeQuant zero_points are correct dtype
self.assertEqual(model.emb.weight_fake_quant.zero_point.dtype, torch.float32) self.assertEqual(
self.assertEqual(model.linear.weight_fake_quant.zero_point.dtype, torch.int32) model.emb.weight_fake_quant.zero_point.dtype, torch.float32
model = convert(model, mapping=get_embedding_static_quant_module_mappings()) )
self.assertEqual(
model.linear.weight_fake_quant.zero_point.dtype, torch.int32
)
model = convert(
model, mapping=get_embedding_static_quant_module_mappings()
)
def checkQuantized(model): def checkQuantized(model):
# Make sure EmbeddingBag is now a quantized EmbeddingBag. # Make sure EmbeddingBag is now a quantized EmbeddingBag.
@ -505,7 +608,9 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine) model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
torch.ao.quantization.prepare(model, inplace=True) torch.ao.quantization.prepare(model, inplace=True)
torch.ao.quantization.convert(model, inplace=True) torch.ao.quantization.convert(model, inplace=True)
self.assertEqual(set(model.state_dict().keys()), set(quant_state_dict.keys())) self.assertEqual(
set(model.state_dict().keys()), set(quant_state_dict.keys())
)
model.eval() model.eval()
model.load_state_dict(quant_state_dict) model.load_state_dict(quant_state_dict)
out = model(x) out = model(x)
@ -513,20 +618,19 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
@override_qengines @override_qengines
def test_forward_hooks_preserved(self): def test_forward_hooks_preserved(self):
r"""Test QAT on preserving pre forward and post forward hooks of original model r"""Test QAT on preserving pre forward and post forward hooks of original model"""
"""
qengine = torch.backends.quantized.engine qengine = torch.backends.quantized.engine
model = QuantStubModel() model = QuantStubModel()
counter = { counter = {
'pre_forwards': 0, "pre_forwards": 0,
'forwards': 0, "forwards": 0,
} }
def fw_pre_hook(h_module, input): def fw_pre_hook(h_module, input):
counter['pre_forwards'] += 1 counter["pre_forwards"] += 1
def fw_hook(h_module, input, output): def fw_hook(h_module, input, output):
counter['forwards'] += 1 counter["forwards"] += 1
model.fc.register_forward_pre_hook(fw_pre_hook) model.fc.register_forward_pre_hook(fw_pre_hook)
model.fc.register_forward_hook(fw_hook) model.fc.register_forward_hook(fw_hook)
@ -537,15 +641,24 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
def checkHooksIsPresent(model, before_convert=True): def checkHooksIsPresent(model, before_convert=True):
forward_hooks = 1 forward_hooks = 1
if before_convert: if before_convert:
self.assertEqual(len(model.quant._forward_hooks.values()), 1, self.assertEqual(
"Quantization observer hook has disappeared") len(model.quant._forward_hooks.values()),
1,
"Quantization observer hook has disappeared",
)
forward_hooks = 2 forward_hooks = 2
self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values()) self.assertObjectIn(fw_pre_hook, model.fc._forward_pre_hooks.values())
self.assertObjectIn(fw_hook, model.fc._forward_hooks.values()) self.assertObjectIn(fw_hook, model.fc._forward_hooks.values())
self.assertEqual(len(model.fc._forward_pre_hooks.values()), 1, self.assertEqual(
"Extra pre forward hooks have appeared on a layer") len(model.fc._forward_pre_hooks.values()),
self.assertEqual(len(model.fc._forward_hooks.values()), forward_hooks, 1,
"Extra post forward hooks have appeared on a layer") "Extra pre forward hooks have appeared on a layer",
)
self.assertEqual(
len(model.fc._forward_hooks.values()),
forward_hooks,
"Extra post forward hooks have appeared on a layer",
)
checkHooksIsPresent(model, True) checkHooksIsPresent(model, True)
x = torch.rand(2, 5, dtype=torch.float) x = torch.rand(2, 5, dtype=torch.float)
@ -600,32 +713,40 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine) default_qat_qconfig = get_default_qat_qconfig(torch.backends.quantized.engine)
# Test constructor parameters checks here. # Test constructor parameters checks here.
with self.assertRaisesRegex(AssertionError, with self.assertRaisesRegex(
"qconfig must be provided for QAT module"): AssertionError, "qconfig must be provided for QAT module"
):
nnqat.EmbeddingBag(10, 5, qconfig=None) nnqat.EmbeddingBag(10, 5, qconfig=None)
with self.assertRaisesRegex(AssertionError, with self.assertRaisesRegex(
"Embedding Bag weights requires a qscheme of " + AssertionError,
"torch.per_channel_affine_float_qparams"): "Embedding Bag weights requires a qscheme of "
+ "torch.per_channel_affine_float_qparams",
):
nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig) nnqat.EmbeddingBag(10, 5, qconfig=default_qat_qconfig)
# Test from_float checks here. # Test from_float checks here.
embed = nn.Embedding(10, 5) embed = nn.Embedding(10, 5)
with self.assertRaisesRegex(AssertionError, with self.assertRaisesRegex(
"qat.EmbeddingBag.from_float only works for EmbeddingBag"): AssertionError, "qat.EmbeddingBag.from_float only works for EmbeddingBag"
):
nnqat.EmbeddingBag.from_float(embed) nnqat.EmbeddingBag.from_float(embed)
embed_bag = nn.EmbeddingBag(10, 5) embed_bag = nn.EmbeddingBag(10, 5)
with self.assertRaisesRegex(AssertionError, with self.assertRaisesRegex(
"Input float module must have qconfig defined"): AssertionError, "Input float module must have qconfig defined"
):
nnqat.EmbeddingBag.from_float(embed_bag) nnqat.EmbeddingBag.from_float(embed_bag)
embed_bag.qconfig = None embed_bag.qconfig = None
with self.assertRaisesRegex(AssertionError, with self.assertRaisesRegex(
"Input float module must have a valid qconfig"): AssertionError, "Input float module must have a valid qconfig"
):
nnqat.EmbeddingBag.from_float(embed_bag) nnqat.EmbeddingBag.from_float(embed_bag)
embed_bag.qconfig = default_qat_qconfig embed_bag.qconfig = default_qat_qconfig
with self.assertRaisesRegex(AssertionError, with self.assertRaisesRegex(
"Embedding Bag weights requires a qscheme of " + AssertionError,
"torch.per_channel_affine_float_qparams"): "Embedding Bag weights requires a qscheme of "
+ "torch.per_channel_affine_float_qparams",
):
nnqat.EmbeddingBag.from_float(embed_bag) nnqat.EmbeddingBag.from_float(embed_bag)
def test_embedding_qat_qconfig_equal(self): def test_embedding_qat_qconfig_equal(self):
@ -636,8 +757,10 @@ class TestQuantizeEagerQAT(QuantizationTestCase):
model = ManualEmbeddingBagLinear().train() model = ManualEmbeddingBagLinear().train()
model = prepare_qat(model) model = prepare_qat(model)
self.assertTrue(qconfig_equals(model.emb.qconfig, self.assertTrue(
default_embedding_qat_qconfig)) qconfig_equals(model.emb.qconfig, default_embedding_qat_qconfig)
)
class TestQuantizeEagerQATNumerics(QuantizationTestCase): class TestQuantizeEagerQATNumerics(QuantizationTestCase):
def _test_activation_convert_numerics_impl(self, Act, data): def _test_activation_convert_numerics_impl(self, Act, data):
@ -683,24 +806,26 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
m = M().train() m = M().train()
m.qconfig = default_qat_qconfig m.qconfig = default_qat_qconfig
m = prepare_qat(m) m = prepare_qat(m)
for attr in ['sigmoid', 'hardsigmoid', 'tanh']: for attr in ["sigmoid", "hardsigmoid", "tanh"]:
self.assertEqual(type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize) self.assertEqual(
type(getattr(m, attr).activation_post_process), FixedQParamsFakeQuantize
)
data = torch.randn(1, 3, 2, 4) data = torch.randn(1, 3, 2, 4)
before_convert = m(data) before_convert = m(data)
m = convert(m) m = convert(m)
after_convert = m(data) after_convert = m(data)
self.assertEqual(before_convert, after_convert) self.assertEqual(before_convert, after_convert)
# make sure activation post process is removed # make sure activation post process is removed
for attr in ['sigmoid', 'hardsigmoid', 'tanh']: for attr in ["sigmoid", "hardsigmoid", "tanh"]:
# verify fake quant module is removd # verify fake quant module is removd
self.assertFalse(hasattr(getattr(m, attr), 'activation_post_process')) self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
# verify that hooks are removed # verify that hooks are removed
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
# make sure no fake quantize module is inserted for eval mode # make sure no fake quantize module is inserted for eval mode
def checkNoFQModule(m): def checkNoFQModule(m):
for attr in ['sigmoid', 'hardsigmoid', 'tanh']: for attr in ["sigmoid", "hardsigmoid", "tanh"]:
self.assertFalse(hasattr(getattr(m, attr), "activation_post_process")) self.assertFalse(hasattr(getattr(m, attr), "activation_post_process"))
self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0) self.assertTrue(len(getattr(m, attr)._forward_hooks.items()) == 0)
@ -734,50 +859,52 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
# make sure ReLU module is not changed # make sure ReLU module is not changed
self.assertTrue(type(m.relu), nn.ReLU) self.assertTrue(type(m.relu), nn.ReLU)
@given(batch_size=st.integers(2, 4), @given(
input_channels_per_group=st.sampled_from([2, 3, 4]), batch_size=st.integers(2, 4),
height=st.integers(5, 10), input_channels_per_group=st.sampled_from([2, 3, 4]),
width=st.integers(5, 10), height=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]), width=st.integers(5, 10),
groups=st.integers(1, 3), output_channels_per_group=st.sampled_from([2, 3]),
kernel_h=st.integers(1, 3), groups=st.integers(1, 3),
kernel_w=st.integers(1, 3), kernel_h=st.integers(1, 3),
stride_h=st.integers(1, 2), kernel_w=st.integers(1, 3),
stride_w=st.integers(1, 2), stride_h=st.integers(1, 2),
pad_h=st.integers(0, 2), stride_w=st.integers(1, 2),
pad_w=st.integers(0, 2), pad_h=st.integers(0, 2),
dilation=st.integers(1, 1), pad_w=st.integers(0, 2),
padding_mode=st.sampled_from(['zeros', 'circular']), dilation=st.integers(1, 1),
use_relu=st.booleans(), padding_mode=st.sampled_from(["zeros", "circular"]),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]), use_relu=st.booleans(),
momentum=st.sampled_from([0.1, 0.2, 0.3]), eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
freeze_bn=st.booleans(), momentum=st.sampled_from([0.1, 0.2, 0.3]),
zero_gamma=st.booleans(), freeze_bn=st.booleans(),
has_bias=st.booleans(), zero_gamma=st.booleans(),
use_slow_fusion=st.booleans()) has_bias=st.booleans(),
use_slow_fusion=st.booleans(),
)
def test_conv_bn_relu( def test_conv_bn_relu(
self, self,
batch_size, batch_size,
input_channels_per_group, input_channels_per_group,
height, height,
width, width,
output_channels_per_group, output_channels_per_group,
groups, groups,
kernel_h, kernel_h,
kernel_w, kernel_w,
stride_h, stride_h,
stride_w, stride_w,
pad_h, pad_h,
pad_w, pad_w,
dilation, dilation,
padding_mode, padding_mode,
use_relu, use_relu,
eps, eps,
momentum, momentum,
freeze_bn, freeze_bn,
zero_gamma, zero_gamma,
has_bias, has_bias,
use_slow_fusion, use_slow_fusion,
): ):
input_channels = input_channels_per_group * groups input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups output_channels = output_channels_per_group * groups
@ -792,7 +919,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
(dilation_h, dilation_w), (dilation_h, dilation_w),
groups, groups,
has_bias, has_bias,
padding_mode padding_mode,
).to(dtype=torch.double) ).to(dtype=torch.double)
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double) bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.double)
relu_op = ReLU() relu_op = ReLU()
@ -811,7 +938,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
eps, eps,
momentum, momentum,
freeze_bn=True, freeze_bn=True,
qconfig=default_qat_qconfig qconfig=default_qat_qconfig,
).to(dtype=torch.double) ).to(dtype=torch.double)
qat_op._enable_slow_path_for_better_numerical_stability = use_slow_fusion qat_op._enable_slow_path_for_better_numerical_stability = use_slow_fusion
@ -826,7 +953,14 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats) qat_op.apply(torch.ao.nn.intrinsic.qat.update_bn_stats)
# align inputs and internal parameters # align inputs and internal parameters
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True) input = torch.randn(
batch_size,
input_channels,
height,
width,
dtype=torch.double,
requires_grad=True,
)
conv_op.weight = torch.nn.Parameter(qat_op.weight.detach()) conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
if has_bias: if has_bias:
conv_op.bias = torch.nn.Parameter(qat_op.bias.detach()) conv_op.bias = torch.nn.Parameter(qat_op.bias.detach())
@ -840,17 +974,20 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x) return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
if not use_relu: if not use_relu:
def relu_op(x): # noqa: F811 def relu_op(x): # noqa: F811
return x return x
if freeze_bn: if freeze_bn:
def ref_op(x): def ref_op(x):
x = conv_op(x) x = conv_op(x)
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \ x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * (
(bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \ bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)
.reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1]) ).reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
x = relu_op(x) x = relu_op(x)
return x return x
else: else:
ref_op = compose([conv_op, bn_op, relu_op]) ref_op = compose([conv_op, bn_op, relu_op])
@ -882,51 +1019,64 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
num_batches_tracked_actual = qat_op.bn.num_batches_tracked num_batches_tracked_actual = qat_op.bn.num_batches_tracked
precision = 1e-10 precision = 1e-10
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0) self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0) self.assertEqual(
weight_grad_ref, weight_grad_actual, atol=precision, rtol=0
)
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0) self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0) self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0) self.assertEqual(
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0) num_batches_tracked_ref,
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0) num_batches_tracked_actual,
atol=precision,
rtol=0,
)
self.assertEqual(
running_mean_ref, running_mean_actual, atol=precision, rtol=0
)
self.assertEqual(
running_var_ref, running_var_actual, atol=precision, rtol=0
)
@given(batch_size=st.integers(2, 4), @given(
input_channels_per_group=st.sampled_from([2, 3, 4]), batch_size=st.integers(2, 4),
height=st.integers(5, 10), input_channels_per_group=st.sampled_from([2, 3, 4]),
width=st.integers(5, 10), height=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]), width=st.integers(5, 10),
groups=st.integers(1, 3), output_channels_per_group=st.sampled_from([2, 3]),
kernel_h=st.integers(1, 3), groups=st.integers(1, 3),
kernel_w=st.integers(1, 3), kernel_h=st.integers(1, 3),
stride_h=st.integers(1, 2), kernel_w=st.integers(1, 3),
stride_w=st.integers(1, 2), stride_h=st.integers(1, 2),
pad_h=st.integers(0, 2), stride_w=st.integers(1, 2),
pad_w=st.integers(0, 2), pad_h=st.integers(0, 2),
dilation=st.integers(1, 1), pad_w=st.integers(0, 2),
padding_mode=st.sampled_from(['zeros', 'circular']), dilation=st.integers(1, 1),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]), padding_mode=st.sampled_from(["zeros", "circular"]),
momentum=st.sampled_from([0.1, 0.2, 0.3]), eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
freeze_bn=st.booleans(), momentum=st.sampled_from([0.1, 0.2, 0.3]),
bias=st.booleans()) freeze_bn=st.booleans(),
bias=st.booleans(),
)
def test_conv_bn_folded_vs_unfolded( def test_conv_bn_folded_vs_unfolded(
self, self,
batch_size, batch_size,
input_channels_per_group, input_channels_per_group,
height, height,
width, width,
output_channels_per_group, output_channels_per_group,
groups, groups,
kernel_h, kernel_h,
kernel_w, kernel_w,
stride_h, stride_h,
stride_w, stride_w,
pad_h, pad_h,
pad_w, pad_w,
dilation, dilation,
padding_mode, padding_mode,
eps, eps,
momentum, momentum,
freeze_bn, freeze_bn,
bias, bias,
): ):
input_channels = input_channels_per_group * groups input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups output_channels = output_channels_per_group * groups
@ -945,7 +1095,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
eps, eps,
momentum, momentum,
freeze_bn=freeze_bn, freeze_bn=freeze_bn,
qconfig=default_qat_qconfig qconfig=default_qat_qconfig,
).to(dtype=torch.double) ).to(dtype=torch.double)
qat_ref_op = _ReferenceConvBn2d( qat_ref_op = _ReferenceConvBn2d(
@ -961,7 +1111,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
eps, eps,
momentum, momentum,
freeze_bn=freeze_bn, freeze_bn=freeze_bn,
qconfig=default_qat_qconfig qconfig=default_qat_qconfig,
).to(dtype=torch.double) ).to(dtype=torch.double)
qat_op.apply(torch.ao.quantization.disable_fake_quant) qat_op.apply(torch.ao.quantization.disable_fake_quant)
@ -981,7 +1131,6 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr) qat_ref_op_optim = torch.optim.SGD(qat_ref_op.parameters(), lr=lr)
for i in range(5): for i in range(5):
# make sure that calling model.train() does not override the # make sure that calling model.train() does not override the
# bn freeze setting # bn freeze setting
qat_op.train() qat_op.train()
@ -990,7 +1139,14 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
qat_op_optim.zero_grad() qat_op_optim.zero_grad()
qat_ref_op_optim.zero_grad() qat_ref_op_optim.zero_grad()
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.double, requires_grad=True) input = torch.randn(
batch_size,
input_channels,
height,
width,
dtype=torch.double,
requires_grad=True,
)
input_clone = input.detach().clone().requires_grad_() input_clone = input.detach().clone().requires_grad_()
if i > 2: if i > 2:
@ -1030,12 +1186,23 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
precision = 1e-5 precision = 1e-5
self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0) self.assertEqual(input_grad_ref, input_grad_actual, atol=precision, rtol=0)
self.assertEqual(weight_grad_ref, weight_grad_actual, atol=precision, rtol=0) self.assertEqual(
weight_grad_ref, weight_grad_actual, atol=precision, rtol=0
)
self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0) self.assertEqual(gamma_grad_ref, gamma_grad_actual, atol=precision, rtol=0)
self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0) self.assertEqual(beta_grad_ref, beta_grad_actual, atol=precision, rtol=0)
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual, atol=precision, rtol=0) self.assertEqual(
self.assertEqual(running_mean_ref, running_mean_actual, atol=precision, rtol=0) num_batches_tracked_ref,
self.assertEqual(running_var_ref, running_var_actual, atol=precision, rtol=0) num_batches_tracked_actual,
atol=precision,
rtol=0,
)
self.assertEqual(
running_mean_ref, running_mean_actual, atol=precision, rtol=0
)
self.assertEqual(
running_var_ref, running_var_actual, atol=precision, rtol=0
)
qat_op_optim.step() qat_op_optim.step()
qat_ref_op_optim.step() qat_ref_op_optim.step()
@ -1048,7 +1215,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
nn.BatchNorm1d(4), nn.BatchNorm1d(4),
) )
m_ref_copy = copy.deepcopy(m_ref) m_ref_copy = copy.deepcopy(m_ref)
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']]) m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [["0", "1"]])
qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
m_ref_copy[0].qconfig = qconfig m_ref_copy[0].qconfig = qconfig
m = nniqat.LinearBn1d.from_float(m_ref_copy[0]) m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
@ -1071,7 +1238,7 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
nn.BatchNorm1d(4), nn.BatchNorm1d(4),
) )
m_ref_copy = copy.deepcopy(m_ref) m_ref_copy = copy.deepcopy(m_ref)
m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [['0', '1']]) m_ref_copy = torch.ao.quantization.fuse_modules_qat(m_ref_copy, [["0", "1"]])
qconfig = default_symmetric_qnnpack_qat_qconfig qconfig = default_symmetric_qnnpack_qat_qconfig
m_ref_copy[0].qconfig = qconfig m_ref_copy[0].qconfig = qconfig
m = nniqat.LinearBn1d.from_float(m_ref_copy[0]) m = nniqat.LinearBn1d.from_float(m_ref_copy[0])
@ -1093,14 +1260,13 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
) )
data = torch.randn(4, 4) data = torch.randn(4, 4)
m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) m.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
m = torch.ao.quantization.fuse_modules_qat(m, [['1', '2']]) m = torch.ao.quantization.fuse_modules_qat(m, [["1", "2"]])
mp = prepare_qat(m) mp = prepare_qat(m)
mp(data) mp(data)
mq = convert(mp) mq = convert(mp)
self.assertTrue(type(mq[1]) == nnq.Linear) self.assertTrue(type(mq[1]) == nnq.Linear)
self.assertTrue(type(mq[2]) == nn.Identity) self.assertTrue(type(mq[2]) == nn.Identity)
@skipIfNoXNNPACK @skipIfNoXNNPACK
@override_qengines @override_qengines
def test_linear_precomputed_fake_quant(self): def test_linear_precomputed_fake_quant(self):
@ -1124,10 +1290,14 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
m_ref.activation_post_process = activation m_ref.activation_post_process = activation
m_ref.qconfig = qconfig m_ref.qconfig = qconfig
m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True) m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True)
self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale) self.assertTrue(
m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale
)
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_quantization.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_quantization.py TESTNAME\n\n"
"instead."
)