Update on "Starter code for resnext quantization"

Differential Revision: [D16616819](https://our.internmc.facebook.com/intern/diff/D16616819/)
This commit is contained in:
Raghu Krishnamoorthi 2019-08-22 17:05:36 -07:00
commit 2fbffd4bc6
2 changed files with 2 additions and 6 deletions

View File

@ -487,11 +487,9 @@ class ResNetBase(torch.nn.Module):
def forward(self, x):
print(x.size())
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
identity = self.downsample(x)
out = self.myop.add(out, identity)
out = self.relu2(out)

View File

@ -258,15 +258,12 @@ class PostTrainingQuantTest(QuantizationTestCase):
def test_resnet_base(self):
r"""Test quantization for bottleneck topology used in resnet/resnext
and add coverage for conversion of average pool operator
and add coverage for conversion of average pool and float functional
"""
model = ResNetBase().float().eval()
model = QuantWrapper(model)
model.qconfig = default_qconfig
print(model)
fuse_list = [['module.conv1', 'module.bn1', 'module.relu1']]
print(model.module)
print(model.module.conv1)
fuse_modules(model, fuse_list)
prepare(model)
self.checkObservers(model)
@ -276,6 +273,7 @@ class PostTrainingQuantTest(QuantizationTestCase):
def checkQuantized(model):
self.assertEqual(type(model.module.conv1), nn._intrinsic.quantized.ConvReLU2d)
self.assertEqual(type(model.module.myop), nn.quantized.QFunctional)
self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d)
test_only_eval_fn(model, self.img_data)
checkQuantized(model)