pytorch/torch/ao/quantization/stubs.py
Richard Barnes 6370ac0251 [codemod] Replace hasattr with getattr in caffe2/torch/ao/quantization/stubs.py (#100597)
Summary:
The pattern
```
X.Y if hasattr(X, "Y") else Z
```
can be replaced with
```
getattr(X, "Y", Z)
```

The [getattr](https://www.w3schools.com/python/ref_func_getattr.asp) function gives more succinct code than the [hasattr](https://www.w3schools.com/python/ref_func_hasattr.asp) function. Please use it when appropriate.

**This diff is very low risk. Green tests indicate that you can safely Accept & Ship.**

Test Plan: Sandcastle

Reviewed By: vkuzo

Differential Revision: D44886422

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100597
Approved by: https://github.com/Skylion007
2023-05-04 16:36:23 +00:00

65 lines
2.0 KiB
Python

from torch import nn
class QuantStub(nn.Module):
r"""Quantize stub module, before calibration, this is same as an observer,
it will be swapped as `nnq.Quantize` in `convert`.
Args:
qconfig: quantization configuration for the tensor,
if qconfig is not provided, we will get qconfig from parent modules
"""
def __init__(self, qconfig=None):
super().__init__()
if qconfig:
self.qconfig = qconfig
def forward(self, x):
return x
class DeQuantStub(nn.Module):
r"""Dequantize stub module, before calibration, this is same as identity,
this will be swapped as `nnq.DeQuantize` in `convert`.
Args:
qconfig: quantization configuration for the tensor,
if qconfig is not provided, we will get qconfig from parent modules
"""
def __init__(self, qconfig=None):
super().__init__()
if qconfig:
self.qconfig = qconfig
def forward(self, x):
return x
class QuantWrapper(nn.Module):
r"""A wrapper class that wraps the input module, adds QuantStub and
DeQuantStub and surround the call to module with call to quant and dequant
modules.
This is used by the `quantization` utility functions to add the quant and
dequant modules, before `convert` function `QuantStub` will just be observer,
it observes the input tensor, after `convert`, `QuantStub`
will be swapped to `nnq.Quantize` which does actual quantization. Similarly
for `DeQuantStub`.
"""
quant: QuantStub
dequant: DeQuantStub
module: nn.Module
def __init__(self, module):
super().__init__()
qconfig = getattr(module, "qconfig", None)
self.add_module('quant', QuantStub(qconfig))
self.add_module('dequant', DeQuantStub(qconfig))
self.add_module('module', module)
self.train(module.training)
def forward(self, X):
X = self.quant(X)
X = self.module(X)
return self.dequant(X)