pytorch/torch/ao/quantization/stubs.py
Supriya Rao 9d52651d4e torch.ao migration: stubs.py phase 1 (#64861)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64861

1. move the file
  ```
  hg mv caffe2/torch/quantization/stubs.py caffe2/torch/ao/quantization/
  ```

  2. create a new file in the old location and copy the imports
  3. fix all call sites inside `torch`
ghstack-source-id: 137885365

Test Plan: buck test mode/dev //caffe2/test:quantization

Reviewed By: jerryzh168

Differential Revision: D30879678

fbshipit-source-id: a2d24f25d01064212aca15e94e8c78240ba48953
2021-09-13 08:40:29 -07:00

59 lines
1.8 KiB
Python

from torch import nn
class QuantStub(nn.Module):
r"""Quantize stub module, before calibration, this is same as an observer,
it will be swapped as `nnq.Quantize` in `convert`.
Args:
qconfig: quantization configuration for the tensor,
if qconfig is not provided, we will get qconfig from parent modules
"""
def __init__(self, qconfig=None):
super(QuantStub, self).__init__()
if qconfig:
self.qconfig = qconfig
def forward(self, x):
return x
class DeQuantStub(nn.Module):
r"""Dequantize stub module, before calibration, this is same as identity,
this will be swapped as `nnq.DeQuantize` in `convert`.
"""
def __init__(self):
super(DeQuantStub, self).__init__()
def forward(self, x):
return x
class QuantWrapper(nn.Module):
r"""A wrapper class that wraps the input module, adds QuantStub and
DeQuantStub and surround the call to module with call to quant and dequant
modules.
This is used by the `quantization` utility functions to add the quant and
dequant modules, before `convert` function `QuantStub` will just be observer,
it observes the input tensor, after `convert`, `QuantStub`
will be swapped to `nnq.Quantize` which does actual quantization. Similarly
for `DeQuantStub`.
"""
quant: QuantStub
dequant: DeQuantStub
module: nn.Module
def __init__(self, module):
super(QuantWrapper, self).__init__()
qconfig = module.qconfig if hasattr(module, 'qconfig') else None
self.add_module('quant', QuantStub(qconfig))
self.add_module('dequant', DeQuantStub())
self.add_module('module', module)
self.train(module.training)
def forward(self, X):
X = self.quant(X)
X = self.module(X)
return self.dequant(X)