mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22694 Move quantization and quantized utility functions for testing to common_quantized.py and common_quantization.py. Addditionally, add a quantized test case base class which contains common methods for checking the results of quantization on modules. As a consequence of the move, fixed the import at the top of test_quantized.py, and test_quantization to use the new utility Reviewed By: jerryzh168 Differential Revision: D16172012 fbshipit-source-id: 329166af5555fc829f26bf1383d682c25c01a7d9
55 lines
2.1 KiB
Python
55 lines
2.1 KiB
Python
r"""Importing this file includes common utility methods and base clases for
|
|
checking quantization api and properties of resulting modules.
|
|
"""
|
|
import torch
|
|
import torch.nn.quantized as nnq
|
|
from common_utils import TestCase
|
|
|
|
# QuantizationTestCase used as a base class for testing quantization on modules
|
|
class QuantizationTestCase(TestCase):
|
|
|
|
def checkNoPrepModules(self, module):
|
|
r"""Checks the module does not contain child
|
|
modules for quantization prepration, e.g.
|
|
quant, dequant and observer
|
|
"""
|
|
self.assertFalse(hasattr(module, 'quant'))
|
|
self.assertFalse(hasattr(module, 'dequant'))
|
|
|
|
def checkHasPrepModules(self, module):
|
|
r"""Checks the module contains child
|
|
modules for quantization prepration, e.g.
|
|
quant, dequant and observer
|
|
"""
|
|
self.assertTrue(hasattr(module, 'module'))
|
|
self.assertTrue(hasattr(module, 'quant'))
|
|
self.assertTrue(hasattr(module, 'dequant'))
|
|
|
|
def checkObservers(self, module):
|
|
r"""Checks the module or module's leaf descendants
|
|
have observers in preperation for quantization
|
|
"""
|
|
if hasattr(module, 'qconfig') and module.qconfig is not None and len(module._modules) == 0:
|
|
self.assertTrue(hasattr(module, 'observer'))
|
|
for child in module.children():
|
|
self.checkObservers(child)
|
|
|
|
def checkQuantDequant(self, mod):
|
|
r"""Checks that mod has nn.Quantize and
|
|
nn.DeQuantize submodules inserted
|
|
"""
|
|
self.assertEqual(type(mod.quant), nnq.Quantize)
|
|
self.assertEqual(type(mod.dequant), nnq.DeQuantize)
|
|
|
|
def checkQuantizedLinear(self, mod):
|
|
r"""Checks that mod has been swapped for an nnq.Linear
|
|
module, the bias is qint32, and that the module
|
|
has Quantize and DeQuantize submodules
|
|
"""
|
|
self.assertEqual(type(mod.module), nnq.Linear)
|
|
self.assertEqual(mod.module.bias.dtype, torch.qint32)
|
|
self.checkQuantDequant(mod)
|
|
|
|
def checkLinear(self, mod):
|
|
self.assertEqual(type(mod), torch.nn.Linear)
|