From 5c6c2cf8767599d674a60e5a1d018d199a523995 Mon Sep 17 00:00:00 2001 From: Raghu Krishnamoorthi Date: Thu, 22 Aug 2019 17:05:32 -0700 Subject: [PATCH] Update on "Update mapping dictionary to support functionalmodules and pooling operations" Differential Revision: [D16879132](https://our.internmc.facebook.com/intern/diff/D16879132/) --- test/common_quantization.py | 2 -- test/test_quantization.py | 6 ++---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/test/common_quantization.py b/test/common_quantization.py index f0f44e08d05..7868fc2e78e 100644 --- a/test/common_quantization.py +++ b/test/common_quantization.py @@ -487,11 +487,9 @@ class ResNetBase(torch.nn.Module): def forward(self, x): - print(x.size()) out = self.conv1(x) out = self.bn1(out) out = self.relu1(out) - identity = self.downsample(x) out = self.myop.add(out, identity) out = self.relu2(out) diff --git a/test/test_quantization.py b/test/test_quantization.py index 0d8b0533692..05f3acfeb6d 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -258,15 +258,12 @@ class PostTrainingQuantTest(QuantizationTestCase): def test_resnet_base(self): r"""Test quantization for bottleneck topology used in resnet/resnext - and add coverage for conversion of average pool operator + and add coverage for conversion of average pool and float functional """ model = ResNetBase().float().eval() model = QuantWrapper(model) model.qconfig = default_qconfig - print(model) fuse_list = [['module.conv1', 'module.bn1', 'module.relu1']] - print(model.module) - print(model.module.conv1) fuse_modules(model, fuse_list) prepare(model) self.checkObservers(model) @@ -276,6 +273,7 @@ class PostTrainingQuantTest(QuantizationTestCase): def checkQuantized(model): self.assertEqual(type(model.module.conv1), nn._intrinsic.quantized.ConvReLU2d) self.assertEqual(type(model.module.myop), nn.quantized.QFunctional) + self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d) test_only_eval_fn(model, self.img_data) checkQuantized(model)