pytorch/torch/ao/quantization/stubs.py

75 lines
2.2 KiB
Python

from typing import Any, Optional
import torch
from torch import nn
from torch.ao.quantization import QConfig
__all__ = ["QuantStub", "DeQuantStub", "QuantWrapper"]
class QuantStub(nn.Module):
r"""Quantize stub module, before calibration, this is same as an observer,
it will be swapped as `nnq.Quantize` in `convert`.
Args:
qconfig: quantization configuration for the tensor,
if qconfig is not provided, we will get qconfig from parent modules
"""
def __init__(self, qconfig: Optional[QConfig] = None):
super().__init__()
if qconfig:
self.qconfig = qconfig
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
class DeQuantStub(nn.Module):
r"""Dequantize stub module, before calibration, this is same as identity,
this will be swapped as `nnq.DeQuantize` in `convert`.
Args:
qconfig: quantization configuration for the tensor,
if qconfig is not provided, we will get qconfig from parent modules
"""
def __init__(self, qconfig: Optional[Any] = None):
super().__init__()
if qconfig:
self.qconfig = qconfig
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
class QuantWrapper(nn.Module):
r"""A wrapper class that wraps the input module, adds QuantStub and
DeQuantStub and surround the call to module with call to quant and dequant
modules.
This is used by the `quantization` utility functions to add the quant and
dequant modules, before `convert` function `QuantStub` will just be observer,
it observes the input tensor, after `convert`, `QuantStub`
will be swapped to `nnq.Quantize` which does actual quantization. Similarly
for `DeQuantStub`.
"""
quant: QuantStub
dequant: DeQuantStub
module: nn.Module
def __init__(self, module: nn.Module):
super().__init__()
qconfig = getattr(module, "qconfig", None)
self.add_module("quant", QuantStub(qconfig))
self.add_module("dequant", DeQuantStub(qconfig))
self.add_module("module", module)
self.train(module.training)
def forward(self, X: torch.Tensor) -> torch.Tensor:
X = self.quant(X)
X = self.module(X)
return self.dequant(X)