mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Update on "Starter code for resnext quantization"
Differential Revision: [D16616819](https://our.internmc.facebook.com/intern/diff/D16616819/)
This commit is contained in:
commit
2fbffd4bc6
|
|
@ -487,11 +487,9 @@ class ResNetBase(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
print(x.size())
|
|
||||||
out = self.conv1(x)
|
out = self.conv1(x)
|
||||||
out = self.bn1(out)
|
out = self.bn1(out)
|
||||||
out = self.relu1(out)
|
out = self.relu1(out)
|
||||||
|
|
||||||
identity = self.downsample(x)
|
identity = self.downsample(x)
|
||||||
out = self.myop.add(out, identity)
|
out = self.myop.add(out, identity)
|
||||||
out = self.relu2(out)
|
out = self.relu2(out)
|
||||||
|
|
|
||||||
|
|
@ -258,15 +258,12 @@ class PostTrainingQuantTest(QuantizationTestCase):
|
||||||
|
|
||||||
def test_resnet_base(self):
|
def test_resnet_base(self):
|
||||||
r"""Test quantization for bottleneck topology used in resnet/resnext
|
r"""Test quantization for bottleneck topology used in resnet/resnext
|
||||||
and add coverage for conversion of average pool operator
|
and add coverage for conversion of average pool and float functional
|
||||||
"""
|
"""
|
||||||
model = ResNetBase().float().eval()
|
model = ResNetBase().float().eval()
|
||||||
model = QuantWrapper(model)
|
model = QuantWrapper(model)
|
||||||
model.qconfig = default_qconfig
|
model.qconfig = default_qconfig
|
||||||
print(model)
|
|
||||||
fuse_list = [['module.conv1', 'module.bn1', 'module.relu1']]
|
fuse_list = [['module.conv1', 'module.bn1', 'module.relu1']]
|
||||||
print(model.module)
|
|
||||||
print(model.module.conv1)
|
|
||||||
fuse_modules(model, fuse_list)
|
fuse_modules(model, fuse_list)
|
||||||
prepare(model)
|
prepare(model)
|
||||||
self.checkObservers(model)
|
self.checkObservers(model)
|
||||||
|
|
@ -276,6 +273,7 @@ class PostTrainingQuantTest(QuantizationTestCase):
|
||||||
def checkQuantized(model):
|
def checkQuantized(model):
|
||||||
self.assertEqual(type(model.module.conv1), nn._intrinsic.quantized.ConvReLU2d)
|
self.assertEqual(type(model.module.conv1), nn._intrinsic.quantized.ConvReLU2d)
|
||||||
self.assertEqual(type(model.module.myop), nn.quantized.QFunctional)
|
self.assertEqual(type(model.module.myop), nn.quantized.QFunctional)
|
||||||
|
self.assertEqual(type(model.module.avgpool), nn.AdaptiveAvgPool2d)
|
||||||
test_only_eval_fn(model, self.img_data)
|
test_only_eval_fn(model, self.img_data)
|
||||||
|
|
||||||
checkQuantized(model)
|
checkQuantized(model)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user