mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Update on "Starter code for resnext quantization"
Differential Revision: [D16616819](https://our.internmc.facebook.com/intern/diff/D16616819/)
This commit is contained in:
commit
2fbffd4bc6
|
|
@ -487,11 +487,9 @@ class ResNetBase(torch.nn.Module):
|
|||
|
||||
|
||||
def forward(self, x):
|
||||
print(x.size())
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu1(out)
|
||||
|
||||
identity = self.downsample(x)
|
||||
out = self.myop.add(out, identity)
|
||||
out = self.relu2(out)
|
||||
|
|
|
|||
|
|
@ -258,15 +258,12 @@ class PostTrainingQuantTest(QuantizationTestCase):
|
|||
|
||||
def test_resnet_base(self):
|
||||
r"""Test quantization for bottleneck topology used in resnet/resnext
|
||||
and add coverage for conversion of average pool operator
|
||||
and add coverage for conversion of average pool and float functional
|
||||
"""
|
||||
model = ResNetBase().float().eval()
|
||||
model = QuantWrapper(model)
|
||||
model.qconfig = default_qconfig
|
||||
print(model)
|
||||
fuse_list = [['module.conv1', 'module.bn1', 'module.relu1']]
|
||||
print(model.module)
|
||||
print(model.module.conv1)
|
||||
fuse_modules(model, fuse_list)
|
||||
prepare(model)
|
||||
self.checkObservers(model)
|
||||
|
|
@ -276,6 +273,7 @@ class PostTrainingQuantTest(QuantizationTestCase):
|
|||
def checkQuantized(model):
|
||||
self.assertEqual(type(model.module.conv1), nn._intrinsic.quantized.ConvReLU2d)
|
||||
self.assertEqual(type(model.module.myop), nn.quantized.QFunctional)
|
||||
self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d)
|
||||
test_only_eval_fn(model, self.img_data)
|
||||
|
||||
checkQuantized(model)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user