mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: The pattern ``` X.Y if hasattr(X, "Y") else Z ``` can be replaced with ``` getattr(X, "Y", Z) ``` The [getattr](https://www.w3schools.com/python/ref_func_getattr.asp) function gives more succinct code than the [hasattr](https://www.w3schools.com/python/ref_func_hasattr.asp) function. Please use it when appropriate. **This diff is very low risk. Green tests indicate that you can safely Accept & Ship.** Test Plan: Sandcastle Reviewed By: vkuzo Differential Revision: D44886422 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100597 Approved by: https://github.com/Skylion007
65 lines
2.0 KiB
Python
65 lines
2.0 KiB
Python
|
|
from torch import nn
|
|
|
|
class QuantStub(nn.Module):
|
|
r"""Quantize stub module, before calibration, this is same as an observer,
|
|
it will be swapped as `nnq.Quantize` in `convert`.
|
|
|
|
Args:
|
|
qconfig: quantization configuration for the tensor,
|
|
if qconfig is not provided, we will get qconfig from parent modules
|
|
"""
|
|
def __init__(self, qconfig=None):
|
|
super().__init__()
|
|
if qconfig:
|
|
self.qconfig = qconfig
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
|
|
class DeQuantStub(nn.Module):
|
|
r"""Dequantize stub module, before calibration, this is same as identity,
|
|
this will be swapped as `nnq.DeQuantize` in `convert`.
|
|
|
|
Args:
|
|
qconfig: quantization configuration for the tensor,
|
|
if qconfig is not provided, we will get qconfig from parent modules
|
|
"""
|
|
def __init__(self, qconfig=None):
|
|
super().__init__()
|
|
if qconfig:
|
|
self.qconfig = qconfig
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
|
|
class QuantWrapper(nn.Module):
|
|
r"""A wrapper class that wraps the input module, adds QuantStub and
|
|
DeQuantStub and surround the call to module with call to quant and dequant
|
|
modules.
|
|
|
|
This is used by the `quantization` utility functions to add the quant and
|
|
dequant modules, before `convert` function `QuantStub` will just be observer,
|
|
it observes the input tensor, after `convert`, `QuantStub`
|
|
will be swapped to `nnq.Quantize` which does actual quantization. Similarly
|
|
for `DeQuantStub`.
|
|
"""
|
|
quant: QuantStub
|
|
dequant: DeQuantStub
|
|
module: nn.Module
|
|
|
|
def __init__(self, module):
|
|
super().__init__()
|
|
qconfig = getattr(module, "qconfig", None)
|
|
self.add_module('quant', QuantStub(qconfig))
|
|
self.add_module('dequant', DeQuantStub(qconfig))
|
|
self.add_module('module', module)
|
|
self.train(module.training)
|
|
|
|
def forward(self, X):
|
|
X = self.quant(X)
|
|
X = self.module(X)
|
|
return self.dequant(X)
|