mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
All modifications are done through tools, the detailed commands are as follows: ```bash lintrunner -a --take "PYFMT" --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150761 Approved by: https://github.com/jerryzh168
602 lines
23 KiB
Python
602 lines
23 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
|
|
import copy
|
|
|
|
import torch
|
|
import torch.ao.nn.intrinsic as nni
|
|
import torch.ao.nn.intrinsic.qat as nniqat
|
|
import torch.ao.nn.intrinsic.quantized as nniq
|
|
import torch.ao.nn.quantized as nnq
|
|
import torch.nn as nn
|
|
from torch.ao.quantization import (
|
|
convert,
|
|
default_qat_qconfig,
|
|
default_qconfig,
|
|
fuse_modules,
|
|
fuse_modules_qat,
|
|
prepare,
|
|
prepare_qat,
|
|
QConfig,
|
|
quantize,
|
|
quantize_qat,
|
|
)
|
|
from torch.testing._internal.common_quantization import (
|
|
ModelForConvTransposeBNFusion,
|
|
ModelForFusion,
|
|
ModelForFusionWithBias,
|
|
ModelForLinearBNFusion,
|
|
ModelWithSequentialFusion,
|
|
QuantizationTestCase,
|
|
SingleLayerLinearModel,
|
|
skipIfNoFBGEMM,
|
|
test_only_eval_fn,
|
|
test_only_train_fn,
|
|
)
|
|
from torch.testing._internal.common_quantized import (
|
|
override_quantized_engine,
|
|
supported_qengines,
|
|
)
|
|
|
|
|
|
@skipIfNoFBGEMM
|
|
class TestFuseEager(QuantizationTestCase):
|
|
def test_fuse_module_train(self):
|
|
model = ModelForFusion(default_qat_qconfig).train()
|
|
# Test step by step fusion
|
|
model = fuse_modules_qat(model, ["conv1", "bn1", "relu1"])
|
|
model = fuse_modules_qat(model, ["sub1.conv", "sub1.bn"])
|
|
self.assertEqual(
|
|
type(model.conv1),
|
|
nni.ConvBnReLU2d,
|
|
msg="Fused Conv + BN + Relu first layer",
|
|
)
|
|
self.assertEqual(
|
|
type(model.bn1),
|
|
torch.nn.Identity,
|
|
msg="Fused Conv + BN + Relu (skipped BN)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.relu1),
|
|
torch.nn.Identity,
|
|
msg="Fused Conv + BN + Relu (skipped Relu)",
|
|
)
|
|
|
|
self.assertEqual(
|
|
type(model.sub1.conv), nni.ConvBn2d, msg="Fused submodule Conv + BN"
|
|
)
|
|
self.assertEqual(
|
|
type(model.sub1.bn),
|
|
torch.nn.Identity,
|
|
msg="Fused submodule Conv + BN (skipped BN)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.sub2.conv), torch.nn.Conv2d, msg="Non-fused submodule Conv"
|
|
)
|
|
self.assertEqual(
|
|
type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU"
|
|
)
|
|
model = prepare_qat(model)
|
|
self.checkObservers(model)
|
|
|
|
def checkQAT(model):
|
|
self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d)
|
|
self.assertEqual(type(model.sub1.bn), nn.Identity)
|
|
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
|
|
self.assertEqual(type(model.sub2.relu), nn.ReLU)
|
|
|
|
checkQAT(model)
|
|
test_only_train_fn(model, self.img_data_1d_train)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
|
|
self.assertEqual(type(model.sub1.bn), nn.Identity)
|
|
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
|
|
self.assertEqual(type(model.sub2.relu), nn.ReLU)
|
|
test_only_eval_fn(model, self.img_data_1d)
|
|
self.checkNoQconfig(model)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'",
|
|
):
|
|
checkQuantized(model)
|
|
|
|
model = ModelForFusion(default_qat_qconfig).train()
|
|
model = fuse_modules_qat(
|
|
model, [["conv1", "bn1", "relu1"], ["sub1.conv", "sub1.bn"]]
|
|
)
|
|
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'",
|
|
):
|
|
checkQuantized(model)
|
|
|
|
def test_fuse_module_eval(self):
|
|
model = ModelForFusion(default_qconfig)
|
|
model.eval()
|
|
model = fuse_modules(
|
|
model,
|
|
[
|
|
["conv3", "bn3", "relu4"],
|
|
["conv1", "bn1", "relu1"],
|
|
["conv2", "relu2"],
|
|
["bn2", "relu3"],
|
|
["sub1.conv", "sub1.bn"],
|
|
],
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv1),
|
|
nni.ConvReLU2d,
|
|
msg="Fused Conv + BN + Relu first layer (BN is folded)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv1[0]),
|
|
nn.Conv2d,
|
|
msg="Fused Conv + BN + Relu (Conv + folded BN only)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv1[1]),
|
|
nn.ReLU,
|
|
msg="Fused Conv + BN + Relu second layer (Relu only)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.bn1),
|
|
nn.Identity,
|
|
msg="Fused Conv + BN + Relu second layer (Skipped BN)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.relu1),
|
|
nn.Identity,
|
|
msg="Fused Conv + BN + Relu second layer (Skipped Relu)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv2),
|
|
nni.ConvReLU3d,
|
|
msg="Fused Conv + BN + Relu first layer (BN is folded)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.bn2),
|
|
nni.BNReLU3d,
|
|
msg="Fused BN + Relu first layer (Relu is folded))",
|
|
)
|
|
self.assertEqual(
|
|
type(model.relu3),
|
|
nn.Identity,
|
|
msg="Fused BN + Relu second layer (Skipped Relu)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv2[0]),
|
|
nn.Conv3d,
|
|
msg="Fused Conv + BN + Relu (Conv + folded BN only)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv2[1]),
|
|
nn.ReLU,
|
|
msg="Fused Conv + BN + Relu second layer (Relu only)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.relu2),
|
|
nn.Identity,
|
|
msg="Fused Conv + BN + Relu second layer (Skipped Relu)",
|
|
)
|
|
|
|
self.assertEqual(
|
|
type(model.conv3),
|
|
nni.ConvReLU1d,
|
|
msg="Fused Conv + Relu for Conv1d (folded BN)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv3[0]), nn.Conv1d, msg="Fused Conv + Relu for Conv1d "
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv3[1]), nn.ReLU, msg="Fused Conv + Relu for Conv1d"
|
|
)
|
|
self.assertEqual(
|
|
type(model.bn3),
|
|
nn.Identity,
|
|
msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)",
|
|
)
|
|
|
|
self.assertEqual(
|
|
type(model.sub1.conv), nn.Conv2d, msg="Fused submodule Conv + folded BN"
|
|
)
|
|
self.assertEqual(
|
|
type(model.sub1.bn), nn.Identity, msg="Fused submodule (skipped BN)"
|
|
)
|
|
self.assertEqual(
|
|
type(model.sub2.conv), nn.Conv2d, msg="Non-fused submodule Conv"
|
|
)
|
|
self.assertEqual(
|
|
type(model.sub2.relu), torch.nn.ReLU, msg="Non-fused submodule ReLU"
|
|
)
|
|
|
|
model = prepare(model)
|
|
self.checkObservers(model)
|
|
test_only_eval_fn(model, self.img_data_1d)
|
|
model = convert(model)
|
|
|
|
def checkQuantized(model):
|
|
self.assertEqual(type(model.conv3), nniq.ConvReLU1d)
|
|
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
|
|
self.assertEqual(type(model.sub1.bn), nn.Identity)
|
|
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
|
|
self.assertEqual(type(model.sub2.relu), nn.ReLU)
|
|
self.assertEqual(type(model.bn2), nniq.BNReLU3d)
|
|
test_only_eval_fn(model, self.img_data_1d)
|
|
self.checkNoQconfig(model)
|
|
|
|
checkQuantized(model)
|
|
|
|
model = ModelForFusion(default_qconfig).eval()
|
|
model = fuse_modules(
|
|
model,
|
|
[
|
|
["conv1", "bn1", "relu1"],
|
|
["conv2", "relu2"],
|
|
["bn2", "relu3"],
|
|
["sub1.conv", "sub1.bn"],
|
|
["conv3", "bn3", "relu4"],
|
|
],
|
|
)
|
|
model = quantize(model, test_only_eval_fn, [self.img_data_1d])
|
|
checkQuantized(model)
|
|
|
|
def test_fusion_sequential_model_train(self):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = ModelWithSequentialFusion().train()
|
|
model.to(torch.float)
|
|
fuse_modules_qat(
|
|
model,
|
|
[
|
|
["conv1", "relu1"],
|
|
["features.0.0", "features.0.1", "features.0.2"],
|
|
["features.1.0", "features.1.1", "features.1.2"],
|
|
["features.2.0", "features.2.1", "features.2.2"],
|
|
["classifier.0", "classifier.1"],
|
|
],
|
|
inplace=True,
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv1),
|
|
nni.ConvReLU2d,
|
|
msg="Fused Conv + Relu: nni.ConvReLU2d",
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d"
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu"
|
|
)
|
|
self.assertEqual(
|
|
type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity"
|
|
)
|
|
for i in range(3):
|
|
self.assertEqual(
|
|
type(model.features[i][0]),
|
|
nni.ConvBnReLU2d,
|
|
msg="Fused submodule Conv + folded BN",
|
|
)
|
|
self.assertEqual(
|
|
type(model.features[i][1]),
|
|
nn.Identity,
|
|
msg="Fused submodule (skipped BN)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.features[i][2]),
|
|
nn.Identity,
|
|
msg="Non-fused submodule Conv",
|
|
)
|
|
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
|
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
|
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
|
|
prepare_qat(model, inplace=True)
|
|
self.checkObservers(model)
|
|
model(self.img_data_2d[0][0])
|
|
|
|
def checkQAT(model):
|
|
self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
|
|
for i in range(3):
|
|
self.assertEqual(
|
|
type(model.features[i][0]),
|
|
nniqat.ConvBnReLU2d,
|
|
msg="Fused submodule Conv + folded BN",
|
|
)
|
|
self.assertEqual(
|
|
type(model.features[i][1]),
|
|
nn.Identity,
|
|
msg="Fused submodule (skipped BN)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.features[i][2]),
|
|
nn.Identity,
|
|
msg="Non-fused submodule Conv",
|
|
)
|
|
self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
|
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
|
|
|
checkQAT(model)
|
|
model(self.img_data_2d[1][0])
|
|
convert(model, inplace=True)
|
|
model(self.img_data_2d[1][0])
|
|
self.checkModelWithSequentialQuantized(model)
|
|
|
|
def test_fusion_sequential_model_eval(self):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = ModelWithSequentialFusion().eval()
|
|
model.to(torch.float)
|
|
fuse_modules(
|
|
model,
|
|
[
|
|
["conv1", "relu1"],
|
|
["features.0.0", "features.0.1", "features.0.2"],
|
|
["features.1.0", "features.1.1", "features.1.2"],
|
|
["features.2.0", "features.2.1", "features.2.2"],
|
|
["classifier.0", "classifier.1"],
|
|
],
|
|
inplace=True,
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv1),
|
|
nni.ConvReLU2d,
|
|
msg="Fused Conv + Relu: nni.ConvReLU2d",
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv1[0]), nn.Conv2d, msg="Fused Conv + Relu: Conv2d"
|
|
)
|
|
self.assertEqual(
|
|
type(model.conv1[1]), nn.ReLU, msg="Fused Conv + Relu: Relu"
|
|
)
|
|
self.assertEqual(
|
|
type(model.relu1), nn.Identity, msg="Fused Conv + Relu: Identity"
|
|
)
|
|
for i in range(3):
|
|
self.assertEqual(
|
|
type(model.features[i][0]),
|
|
nni.ConvReLU2d,
|
|
msg="Fused submodule Conv + folded BN",
|
|
)
|
|
self.assertEqual(
|
|
type(model.features[i][1]),
|
|
nn.Identity,
|
|
msg="Fused submodule (skipped BN)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.features[i][2]),
|
|
nn.Identity,
|
|
msg="Non-fused submodule Conv",
|
|
)
|
|
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
|
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
|
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
|
prepare(model, inplace=True)
|
|
self.checkObservers(model)
|
|
model(self.img_data_2d[0][0])
|
|
convert(model, inplace=True)
|
|
model(self.img_data_2d[1][0])
|
|
self.checkModelWithSequentialQuantized(model)
|
|
|
|
def checkModelWithSequentialQuantized(self, model):
|
|
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
for i in range(3):
|
|
self.assertEqual(type(model.features[i][0]), nniq.ConvReLU2d)
|
|
self.assertEqual(type(model.features[i][1]), nn.Identity)
|
|
self.assertEqual(type(model.features[i][2]), nn.Identity)
|
|
self.assertEqual(type(model.classifier[0]), nniq.LinearReLU)
|
|
self.assertEqual(type(model.classifier[1]), nn.Identity)
|
|
|
|
def test_fusion_conv_with_bias(self):
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model_orig = ModelForFusionWithBias().train()
|
|
|
|
# reference model
|
|
model_ref = copy.deepcopy(model_orig)
|
|
# output with no fusion.
|
|
out_ref = model_ref(self.img_data_2d[0][0])
|
|
|
|
# fused model
|
|
model_orig.qconfig = QConfig(
|
|
activation=torch.nn.Identity, weight=torch.nn.Identity
|
|
)
|
|
model = fuse_modules_qat(
|
|
model_orig, [["conv1", "bn1", "relu1"], ["conv2", "bn2"]]
|
|
)
|
|
prep_model = prepare_qat(model, inplace=False)
|
|
# output with fusion but no observers.
|
|
out_fused = prep_model(self.img_data_2d[0][0])
|
|
|
|
self.assertEqual(out_ref, out_fused)
|
|
|
|
def checkBN(bn_ref, bn):
|
|
self.assertEqual(bn_ref.weight, bn.weight)
|
|
self.assertEqual(bn_ref.bias, bn.bias)
|
|
self.assertEqual(bn_ref.running_mean, bn.running_mean)
|
|
self.assertEqual(bn_ref.running_var, bn.running_var)
|
|
|
|
checkBN(model_ref.bn1, prep_model.conv1.bn)
|
|
checkBN(model_ref.bn2, prep_model.conv2.bn)
|
|
|
|
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
|
prepare_qat(model, inplace=True)
|
|
|
|
model(self.img_data_2d[0][0])
|
|
|
|
def checkQAT(model):
|
|
self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.relu1), nn.Identity)
|
|
self.assertEqual(type(model.conv2), nniqat.ConvBn2d)
|
|
self.assertEqual(type(model.bn2), nn.Identity)
|
|
|
|
checkQAT(model)
|
|
|
|
def test_fusion_linear_bn_eval(self):
|
|
model = ModelForLinearBNFusion().train()
|
|
inp1 = torch.randn(8, 20)
|
|
inp2 = torch.randn(8, 20)
|
|
|
|
# Get some interesting values into the running mean and variance.
|
|
model(inp1)
|
|
model.eval()
|
|
golden = model(inp2)
|
|
|
|
model = fuse_modules(model, [["fc", "bn"]])
|
|
self.assertEqual(type(model.bn), nn.Identity)
|
|
self.assertEqual(golden, model(inp2))
|
|
|
|
def test_fusion_convtranspose_bn_eval(self):
|
|
model = ModelForConvTransposeBNFusion().train()
|
|
inp1 = torch.randn(8, 3, 16)
|
|
inp2 = torch.randn(8, 3, 16)
|
|
|
|
# Get some interesting values into the running mean and variance.
|
|
model(inp1)
|
|
model.eval()
|
|
golden = model(inp2)
|
|
|
|
model = fuse_modules(
|
|
model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]]
|
|
)
|
|
self.assertEqual(type(model.bn1), nn.Identity)
|
|
self.assertEqual(type(model.bn2), nn.Identity)
|
|
self.assertEqual(type(model.bn3), nn.Identity)
|
|
|
|
self.assertEqual(golden, model(inp2))
|
|
|
|
def test_fuse_function_customization(self):
|
|
dummy_model = SingleLayerLinearModel().train()
|
|
dummy_model.eval()
|
|
|
|
# A custom fuse funct
|
|
def custom_fuse_func(module, is_qat, add_fuser_mapping):
|
|
return [torch.nn.Identity()]
|
|
|
|
dummy_model = fuse_modules(dummy_model, [["fc1"]], fuser_func=custom_fuse_func)
|
|
self.assertEqual(type(dummy_model.fc1), nn.Identity)
|
|
|
|
def test_forward_hooks_preserved(self):
|
|
r"""Test case that checks whether forward pre hooks of the first module and
|
|
post forward hooks of the last module in modules list passed to fusion function preserved.
|
|
(e.g. before fusion: [nn.Conv2d (with pre forward hooks), nn.BatchNorm2d, nn.ReLU (with post forward hooks)]
|
|
after fusion: [nni.ConvBnReLU2d (with pre and post hooks), nn.Identity, nn.Identity])
|
|
"""
|
|
model = ModelForFusion(default_qat_qconfig).train()
|
|
|
|
counter = {
|
|
"pre_forwards": 0,
|
|
"forwards": 0,
|
|
}
|
|
fused = False
|
|
|
|
def fw_pre_hook(fused_module_class, h_module, input):
|
|
if fused:
|
|
self.assertEqual(
|
|
type(h_module),
|
|
fused_module_class,
|
|
"After fusion owner of the first module's forward pre hook is not a fused module",
|
|
)
|
|
counter["pre_forwards"] += 1
|
|
|
|
def fw_hook(fused_module_class, h_module, input, output):
|
|
if fused:
|
|
self.assertEqual(
|
|
type(h_module),
|
|
fused_module_class,
|
|
"After fusion owner of the last module's forward hook is not a fused module",
|
|
)
|
|
counter["forwards"] += 1
|
|
|
|
# Registering two pre and two post forward hooks, thus expecting counter increment by two each inference
|
|
model.conv1.register_forward_pre_hook(
|
|
lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args)
|
|
)
|
|
model.sub1.conv.register_forward_pre_hook(
|
|
lambda *args: fw_pre_hook(nni.ConvBn2d, *args)
|
|
)
|
|
model.relu1.register_forward_hook(
|
|
lambda *args: fw_hook(nni.ConvBnReLU2d, *args)
|
|
)
|
|
model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args))
|
|
|
|
test_only_eval_fn(model, self.img_data_1d)
|
|
self.assertEqual(counter["pre_forwards"], 2 * len(self.img_data_1d))
|
|
self.assertEqual(counter["forwards"], 2 * len(self.img_data_1d))
|
|
|
|
model = fuse_modules_qat(model, ["conv1", "bn1", "relu1"])
|
|
model = fuse_modules_qat(model, ["sub1.conv", "sub1.bn"])
|
|
|
|
fused = True
|
|
before_fusion_pre_count = counter["pre_forwards"]
|
|
before_fusion_post_count = counter["forwards"]
|
|
test_only_eval_fn(model, self.img_data_1d)
|
|
self.assertEqual(
|
|
counter["pre_forwards"] - before_fusion_pre_count, 2 * len(self.img_data_1d)
|
|
)
|
|
self.assertEqual(
|
|
counter["forwards"] - before_fusion_post_count, 2 * len(self.img_data_1d)
|
|
)
|
|
|
|
def test_fuse_modules_with_nested_hooks(self):
|
|
r"""Test case that checks whether a nested module with sub-sub modules registered with hooks
|
|
can be safely fused. Safeguard for issues similar to https://github.com/pytorch/pytorch/issues/105063
|
|
in the future.
|
|
"""
|
|
|
|
def myhook(*x):
|
|
return ""
|
|
|
|
for qengine in supported_qengines:
|
|
with override_quantized_engine(qengine):
|
|
model = ModelWithSequentialFusion().eval()
|
|
|
|
for sub_model in model.modules():
|
|
if isinstance(sub_model, nn.Sequential):
|
|
for layer in sub_model:
|
|
if hasattr(layer, "register_forward_hook"):
|
|
layer.register_forward_hook(myhook)
|
|
|
|
fuse_modules(
|
|
model,
|
|
[["features.0.0", "features.0.1", "features.0.2"]],
|
|
inplace=True,
|
|
)
|
|
self.assertEqual(
|
|
type(model.features[0][0]),
|
|
nni.ConvReLU2d,
|
|
msg="Fused submodule Conv + folded BN",
|
|
)
|
|
self.assertEqual(
|
|
type(model.features[0][1]),
|
|
nn.Identity,
|
|
msg="Fused submodule (skipped BN)",
|
|
)
|
|
self.assertEqual(
|
|
type(model.features[0][2]),
|
|
nn.Identity,
|
|
msg="Non-fused submodule Conv",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_quantization.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|