mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: This diff fixes the issue reported in https://github.com/pytorch/pytorch/issues/105063 and also related to internal caffe2 bug (reproduced error in internal fb pytorch: N3945540) Test Plan: Wait for sandcastle with the added unit test in caffe2/torch/ao/quantization/eager/test_fuse_eager Differential Revision: D47402357 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105069 Approved by: https://github.com/jerryzh168
465 lines
21 KiB
Python
465 lines
21 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
|
|
import copy
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.ao.nn.quantized as nnq
|
|
import torch.ao.nn.intrinsic as nni
|
|
import torch.ao.nn.intrinsic.quantized as nniq
|
|
import torch.ao.nn.intrinsic.qat as nniqat
|
|
from torch.ao.quantization import (
|
|
quantize,
|
|
prepare,
|
|
convert,
|
|
prepare_qat,
|
|
quantize_qat,
|
|
fuse_modules,
|
|
fuse_modules_qat,
|
|
QConfig,
|
|
default_qconfig,
|
|
default_qat_qconfig,
|
|
)
|
|
|
|
from torch.testing._internal.common_quantization import (
|
|
QuantizationTestCase,
|
|
ModelForFusion,
|
|
ModelWithSequentialFusion,
|
|
ModelForLinearBNFusion,
|
|
ModelForFusionWithBias,
|
|
ModelForConvTransposeBNFusion,
|
|
SingleLayerLinearModel,
|
|
test_only_eval_fn,
|
|
test_only_train_fn,
|
|
skipIfNoFBGEMM,
|
|
)
|
|
|
|
from torch.testing._internal.common_quantized import (
|
|
override_quantized_engine,
|
|
supported_qengines,
|
|
)
|
|
|
|
|
|
@skipIfNoFBGEMM
|
|
class TestFuseEager(QuantizationTestCase):
|
|
def test_fuse_module_train(self):
|
|
model = ModelForFusion(default_qat_qconfig).train()
|
|
# Test step by step fusion
|
|
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
|
|
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
|
|
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
|
|
msg="Fused Conv + BN + Relu first layer")
|
|
self.assertEqual(type(model.bn1), torch.nn.Identity,
|
|
msg="Fused Conv + BN + Relu (skipped BN)")
|
|
self.assertEqual(type(model.relu1), torch.nn.Identity,
|
|
msg="Fused Conv + BN + Relu (skipped Relu)")
|
|
|
|
self.assertEqual(type(model.sub1.conv), nni.ConvBn2d,
|
|
msg="Fused submodule Conv + BN")
|
|
self.assertEqual(type(model.sub1.bn), torch.nn.Identity,
|
|
msg="Fused submodule Conv + BN (skipped BN)")
|
|
self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d,
|
|
msg="Non-fused submodule Conv")
|
|
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
|
|
msg="Non-fused submodule ReLU")
|
|
model = prepare_qat(model)
|
|
self.checkObservers(model)
|
|
|
|
def checkQAT(model):
|
|
self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d)
|
|
self.assertEqual(type(model.sub1.bn), nn.Identity)
|
|
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
|
|
self.assertEqual(type(model.sub2.relu), nn.ReLU)
|
|
|
|
checkQAT(model)
|
|
test_only_train_fn(model, self.img_data_1d_train)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
|
|
self.assertEqual(type(model.sub1.bn), nn.Identity)
|
|
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
|
|
self.assertEqual(type(model.sub2.relu), nn.ReLU)
|
|
test_only_eval_fn(model, self.img_data_1d)
|
|
self.checkNoQconfig(model)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
|
|
checkQuantized(model)
|
|
|
|
model = ModelForFusion(default_qat_qconfig).train()
|
|
model = fuse_modules_qat(
|
|
model,
|
|
[['conv1', 'bn1', 'relu1'],
|
|
['sub1.conv', 'sub1.bn']])
|
|
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
|
|
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
|
|
checkQuantized(model)
|
|
|
|
|
|
def test_fuse_module_eval(self):
|
|
model = ModelForFusion(default_qconfig)
|
|
model.eval()
|
|
model = fuse_modules(
|
|
model,
|
|
[['conv3', 'bn3', 'relu4'],
|
|
['conv1', 'bn1', 'relu1'],
|
|
['conv2', 'relu2'],
|
|
['bn2', 'relu3'],
|
|
['sub1.conv', 'sub1.bn']])
|
|
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
|
msg="Fused Conv + BN + Relu first layer (BN is folded)")
|
|
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
|
msg="Fused Conv + BN + Relu (Conv + folded BN only)")
|
|
self.assertEqual(type(model.conv1[1]), nn.ReLU,
|
|
msg="Fused Conv + BN + Relu second layer (Relu only)")
|
|
self.assertEqual(type(model.bn1), nn.Identity,
|
|
msg="Fused Conv + BN + Relu second layer (Skipped BN)")
|
|
self.assertEqual(type(model.relu1), nn.Identity,
|
|
msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
|
|
self.assertEqual(type(model.conv2), nni.ConvReLU3d,
|
|
msg="Fused Conv + BN + Relu first layer (BN is folded)")
|
|
self.assertEqual(type(model.bn2), nni.BNReLU3d,
|
|
msg="Fused BN + Relu first layer (Relu is folded))")
|
|
self.assertEqual(type(model.relu3), nn.Identity,
|
|
msg="Fused BN + Relu second layer (Skipped Relu)")
|
|
self.assertEqual(type(model.conv2[0]), nn.Conv3d,
|
|
msg="Fused Conv + BN + Relu (Conv + folded BN only)")
|
|
self.assertEqual(type(model.conv2[1]), nn.ReLU,
|
|
msg="Fused Conv + BN + Relu second layer (Relu only)")
|
|
self.assertEqual(type(model.relu2), nn.Identity,
|
|
msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
|
|
|
|
self.assertEqual(type(model.conv3), nni.ConvReLU1d,
|
|
msg="Fused Conv + Relu for Conv1d (folded BN)")
|
|
self.assertEqual(type(model.conv3[0]), nn.Conv1d,
|
|
msg="Fused Conv + Relu for Conv1d ")
|
|
self.assertEqual(type(model.conv3[1]), nn.ReLU,
|
|
msg="Fused Conv + Relu for Conv1d")
|
|
self.assertEqual(type(model.bn3), nn.Identity,
|
|
msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)")
|
|
|
|
self.assertEqual(type(model.sub1.conv), nn.Conv2d,
|
|
msg="Fused submodule Conv + folded BN")
|
|
self.assertEqual(type(model.sub1.bn), nn.Identity,
|
|
msg="Fused submodule (skipped BN)")
|
|
self.assertEqual(type(model.sub2.conv), nn.Conv2d,
|
|
msg="Non-fused submodule Conv")
|
|
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
|
|
msg="Non-fused submodule ReLU")
|
|
|
|
model = prepare(model)
|
|
self.checkObservers(model)
|
|
test_only_eval_fn(model, self.img_data_1d)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.conv3), nniq.ConvReLU1d)
|
|
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
|
|
self.assertEqual(type(model.sub1.bn), nn.Identity)
|
|
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
|
|
self.assertEqual(type(model.sub2.relu), nn.ReLU)
|
|
self.assertEqual(type(model.bn2), nniq.BNReLU3d)
|
|
test_only_eval_fn(model, self.img_data_1d)
|
|
self.checkNoQconfig(model)
|
|
|
|
checkQuantized(model)
|
|
|
|
model = ModelForFusion(default_qconfig).eval()
|
|
model = fuse_modules(
|
|
model,
|
|
[['conv1', 'bn1', 'relu1'],
|
|
['conv2', 'relu2'],
|
|
['bn2', 'relu3'],
|
|
['sub1.conv', 'sub1.bn'],
|
|
['conv3', 'bn3', 'relu4']])
|
|
model = quantize(model, test_only_eval_fn, [self.img_data_1d])
|
|
checkQuantized(model)
|
|
|
|
def test_fusion_sequential_model_train(self):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = ModelWithSequentialFusion().train()
|
|
model.to(torch.float)
|
|
fuse_modules_qat(
|
|
model, [['conv1', 'relu1'] ,
|
|
['features.0.0', 'features.0.1', 'features.0.2'],
|
|
['features.1.0', 'features.1.1', 'features.1.2'],
|
|
['features.2.0', 'features.2.1', 'features.2.2'],
|
|
['classifier.0', 'classifier.1']],
|
|
inplace=True)
|
|
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
|
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
|
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
|
msg="Fused Conv + Relu: Conv2d")
|
|
self.assertEqual(type(model.conv1[1]), nn.ReLU,
|
|
msg="Fused Conv + Relu: Relu")
|
|
self.assertEqual(type(model.relu1), nn.Identity,
|
|
msg="Fused Conv + Relu: Identity")
|
|
for i in range(3):
|
|
self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d,
|
|
msg="Fused submodule Conv + folded BN")
|
|
self.assertEqual(type(model.features[i][1]), nn.Identity,
|
|
msg="Fused submodule (skipped BN)")
|
|
self.assertEqual(type(model.features[i][2]), nn.Identity,
|
|
msg="Non-fused submodule Conv")
|
|
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
|
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
|
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
|
prepare_qat(model, inplace=True)
|
|
self.checkObservers(model)
|
|
model(self.img_data_2d[0][0])
|
|
|
|
|
|
def checkQAT(model):
|
|
self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
for i in range(3):
|
|
self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d,
|
|
msg="Fused submodule Conv + folded BN")
|
|
self.assertEqual(type(model.features[i][1]), nn.Identity,
|
|
msg="Fused submodule (skipped BN)")
|
|
self.assertEqual(type(model.features[i][2]), nn.Identity,
|
|
msg="Non-fused submodule Conv")
|
|
self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
|
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
|
|
|
checkQAT(model)
|
|
model(self.img_data_2d[1][0])
|
|
convert(model, inplace=True)
|
|
model(self.img_data_2d[1][0])
|
|
self.checkModelWithSequentialQuantized(model)
|
|
|
|
def test_fusion_sequential_model_eval(self):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = ModelWithSequentialFusion().eval()
|
|
model.to(torch.float)
|
|
fuse_modules(
|
|
model,
|
|
[['conv1', 'relu1'],
|
|
['features.0.0', 'features.0.1', 'features.0.2'],
|
|
['features.1.0', 'features.1.1', 'features.1.2'],
|
|
['features.2.0', 'features.2.1', 'features.2.2'],
|
|
['classifier.0', 'classifier.1']],
|
|
inplace=True)
|
|
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
|
|
msg="Fused Conv + Relu: nni.ConvReLU2d")
|
|
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
|
|
msg="Fused Conv + Relu: Conv2d")
|
|
self.assertEqual(type(model.conv1[1]), nn.ReLU,
|
|
msg="Fused Conv + Relu: Relu")
|
|
self.assertEqual(type(model.relu1), nn.Identity,
|
|
msg="Fused Conv + Relu: Identity")
|
|
for i in range(3):
|
|
self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d,
|
|
msg="Fused submodule Conv + folded BN")
|
|
self.assertEqual(type(model.features[i][1]), nn.Identity,
|
|
msg="Fused submodule (skipped BN)")
|
|
self.assertEqual(type(model.features[i][2]), nn.Identity,
|
|
msg="Non-fused submodule Conv")
|
|
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
|
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
|
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
|
prepare(model, inplace=True)
|
|
self.checkObservers(model)
|
|
model(self.img_data_2d[0][0])
|
|
convert(model, inplace=True)
|
|
model(self.img_data_2d[1][0])
|
|
self.checkModelWithSequentialQuantized(model)
|
|
|
|
def checkModelWithSequentialQuantized(self, model):
|
|
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
for i in range(3):
|
|
self.assertEqual(type(model.features[i][0]), nniq.ConvReLU2d)
|
|
self.assertEqual(type(model.features[i][1]), nn.Identity)
|
|
self.assertEqual(type(model.features[i][2]), nn.Identity)
|
|
self.assertEqual(type(model.classifier[0]), nniq.LinearReLU)
|
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
|
|
|
def test_fusion_conv_with_bias(self):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model_orig = ModelForFusionWithBias().train()
|
|
|
|
# reference model
|
|
model_ref = copy.deepcopy(model_orig)
|
|
# output with no fusion.
|
|
out_ref = model_ref(self.img_data_2d[0][0])
|
|
|
|
# fused model
|
|
model_orig.qconfig = QConfig(activation=torch.nn.Identity,
|
|
weight=torch.nn.Identity)
|
|
model = fuse_modules_qat(
|
|
model_orig,
|
|
[["conv1", "bn1", "relu1"],
|
|
["conv2", "bn2"]])
|
|
prep_model = prepare_qat(model, inplace=False)
|
|
# output with fusion but no observers.
|
|
out_fused = prep_model(self.img_data_2d[0][0])
|
|
|
|
self.assertEqual(out_ref, out_fused)
|
|
|
|
def checkBN(bn_ref, bn):
|
|
self.assertEqual(bn_ref.weight, bn.weight)
|
|
self.assertEqual(bn_ref.bias, bn.bias)
|
|
self.assertEqual(bn_ref.running_mean, bn.running_mean)
|
|
self.assertEqual(bn_ref.running_var, bn.running_var)
|
|
|
|
checkBN(model_ref.bn1, prep_model.conv1.bn)
|
|
checkBN(model_ref.bn2, prep_model.conv2.bn)
|
|
|
|
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
|
prepare_qat(model, inplace=True)
|
|
|
|
model(self.img_data_2d[0][0])
|
|
|
|
def checkQAT(model):
|
|
self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
self.assertEqual(type(model.conv2), nniqat.ConvBn2d)
|
|
self.assertEqual(type(model.bn2), nn.Identity)
|
|
|
|
checkQAT(model)
|
|
|
|
|
|
def test_fusion_linear_bn_eval(self):
|
|
model = ModelForLinearBNFusion().train()
|
|
inp1 = torch.randn(8, 20)
|
|
inp2 = torch.randn(8, 20)
|
|
|
|
# Get some interesting values into the running mean and variance.
|
|
model(inp1)
|
|
model.eval()
|
|
golden = model(inp2)
|
|
|
|
model = fuse_modules(model, [["fc", "bn"]])
|
|
self.assertEqual(type(model.bn), nn.Identity)
|
|
self.assertEqual(golden, model(inp2))
|
|
|
|
def test_fusion_convtranspose_bn_eval(self):
|
|
model = ModelForConvTransposeBNFusion().train()
|
|
inp1 = torch.randn(8, 3, 16)
|
|
inp2 = torch.randn(8, 3, 16)
|
|
|
|
# Get some interesting values into the running mean and variance.
|
|
model(inp1)
|
|
model.eval()
|
|
golden = model(inp2)
|
|
|
|
model = fuse_modules(model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]])
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.bn2), nn.Identity)
|
|
self.assertEqual(type(model.bn3), nn.Identity)
|
|
|
|
self.assertEqual(golden, model(inp2))
|
|
|
|
def test_fuse_function_customization(self):
|
|
dummy_model = SingleLayerLinearModel().train()
|
|
dummy_model.eval()
|
|
|
|
# A custom fuse funct
|
|
def custom_fuse_func(module, is_qat, add_fuser_mapping):
|
|
return [torch.nn.Identity()]
|
|
|
|
dummy_model = fuse_modules(dummy_model, [["fc1"]], fuser_func=custom_fuse_func)
|
|
self.assertEqual(type(dummy_model.fc1), nn.Identity)
|
|
|
|
def test_forward_hooks_preserved(self):
|
|
r"""Test case that checks whether forward pre hooks of the first module and
|
|
post forward hooks of the last module in modules list passed to fusion function preserved.
|
|
(e.g. before fusion: [nn.Conv2d (with pre forward hooks), nn.BatchNorm2d, nn.ReLU (with post forward hooks)]
|
|
after fusion: [nni.ConvBnReLU2d (with pre and post hooks), nn.Identity, nn.Identity])
|
|
"""
|
|
model = ModelForFusion(default_qat_qconfig).train()
|
|
|
|
counter = {
|
|
'pre_forwards': 0,
|
|
'forwards': 0,
|
|
}
|
|
fused = False
|
|
|
|
def fw_pre_hook(fused_module_class, h_module, input):
|
|
if fused:
|
|
self.assertEqual(type(h_module), fused_module_class,
|
|
"After fusion owner of the first module's forward pre hook is not a fused module")
|
|
counter['pre_forwards'] += 1
|
|
|
|
def fw_hook(fused_module_class, h_module, input, output):
|
|
if fused:
|
|
self.assertEqual(type(h_module), fused_module_class,
|
|
"After fusion owner of the last module's forward hook is not a fused module")
|
|
counter['forwards'] += 1
|
|
|
|
# Registering two pre and two post forward hooks, thus expecting counter increment by two each inference
|
|
model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args))
|
|
model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args))
|
|
model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args))
|
|
model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args))
|
|
|
|
test_only_eval_fn(model, self.img_data_1d)
|
|
self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
|
|
self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))
|
|
|
|
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
|
|
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
|
|
|
|
fused = True
|
|
before_fusion_pre_count = counter['pre_forwards']
|
|
before_fusion_post_count = counter['forwards']
|
|
test_only_eval_fn(model, self.img_data_1d)
|
|
self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d))
|
|
self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d))
|
|
|
|
def test_fuse_modules_with_nested_hooks(self):
|
|
r"""Test case that checks whether a nested module with sub-sub modules registered with hooks
|
|
can be safely fused. Safeguard for issues similar to https://github.com/pytorch/pytorch/issues/105063
|
|
in the future.
|
|
"""
|
|
def myhook(*x):
|
|
return ""
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = ModelWithSequentialFusion().eval()
|
|
|
|
for sub_model in model.modules():
|
|
if isinstance(sub_model, nn.Sequential):
|
|
for layer in sub_model:
|
|
if hasattr(layer, 'register_forward_hook'):
|
|
layer.register_forward_hook(myhook)
|
|
|
|
fuse_modules(model, [['features.0.0', 'features.0.1', 'features.0.2']], inplace=True)
|
|
self.assertEqual(
|
|
type(model.features[0][0]),
|
|
nni.ConvReLU2d,
|
|
msg="Fused submodule Conv + folded BN"
|
|
)
|
|
self.assertEqual(
|
|
type(model.features[0][1]),
|
|
nn.Identity,
|
|
msg="Fused submodule (skipped BN)"
|
|
)
|
|
self.assertEqual(
|
|
type(model.features[0][2]),
|
|
nn.Identity,
|
|
msg="Non-fused submodule Conv"
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_quantization.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|