mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This pull request includes fixes for six quantization doc bugs. https://github.com/pytorch/pytorch/issues/30283 - Rendering issue on QConfig https://github.com/pytorch/pytorch/issues/26305 - Minor doc issue on fuse_modules() https://github.com/pytorch/pytorch/issues/27451 - Issues with ConvReLU2d, ConvReLU3d, and LinearReLU doc issues https://github.com/pytorch/pytorch/issues/26899 - Missing docstrings in torch.nn.intrinsic fused functions https://github.com/pytorch/pytorch/issues/29735 - add discussion of QNNPack to quantization doc page https://github.com/pytorch/pytorch/issues/27938 - some of the quantized functions lack documentation Pull Request resolved: https://github.com/pytorch/pytorch/pull/30288 Differential Revision: D18653368 Pulled By: gottbrath fbshipit-source-id: 410b3dd81ff10909a7f1a7736ca42d7cabf0beb1
105 lines
5.2 KiB
Python
105 lines
5.2 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
from collections import namedtuple
|
|
from .observer import *
|
|
from .fake_quantize import *
|
|
import torch.nn as nn
|
|
|
|
class QConfig(namedtuple('QConfig', ['activation', 'weight'])):
|
|
"""
|
|
Describes how to quantize a layer or a part of the network by providing
|
|
settings (observer classes) for activations and weights respectively.
|
|
|
|
|
|
Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns
|
|
instances on invocation, not the concrete observer instances themselves.
|
|
Quantization preparation function will instantiate observers multiple times for each of the layers.
|
|
|
|
|
|
Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args`
|
|
method (that behaves like functools.partial):
|
|
|
|
my_qconfig = QConfig(activation=MinMaxObserver.with_args(dtype=torch.qint8),
|
|
weight=default_observer.with_args(dtype=torch.qint8))
|
|
"""
|
|
def __new__(cls, activation, weight):
|
|
# catch common mistakes
|
|
if isinstance(activation, nn.Module) or isinstance(weight, nn.Module):
|
|
raise ValueError("QConfig received observer instance, please pass observer class instead. " +
|
|
"Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
|
|
return super(QConfig, cls).__new__(cls, activation, weight)
|
|
|
|
|
|
default_qconfig = QConfig(activation=default_observer,
|
|
weight=default_weight_observer)
|
|
|
|
default_debug_qconfig = QConfig(weight=default_weight_observer,
|
|
activation=default_debug_observer)
|
|
|
|
default_per_channel_qconfig = QConfig(activation=default_observer,
|
|
weight=default_per_channel_weight_observer)
|
|
|
|
class QConfigDynamic(namedtuple('QConfigDynamic', ['weight'])):
|
|
"""
|
|
Describes how to dynamically quantize a layer or a part of the network by providing
|
|
settings (observer classe) for weights.
|
|
|
|
It's like QConfig, but for dynamic quantization.
|
|
|
|
Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns
|
|
instances on invocation, not the concrete observer instances themselves.
|
|
Quantization function will instantiate observers multiple times for each of the layers.
|
|
|
|
Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args`
|
|
method (that behaves like functools.partial):
|
|
|
|
my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
|
|
"""
|
|
def __new__(cls, weight):
|
|
# catch common mistakes
|
|
if isinstance(weight, nn.Module):
|
|
raise ValueError("QConfigDynamic received observer instance, please pass observer class instead. " +
|
|
"Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
|
|
return super(QConfigDynamic, cls).__new__(cls, weight)
|
|
|
|
default_dynamic_qconfig = QConfigDynamic(weight=default_weight_observer)
|
|
float16_dynamic_qconfig = QConfigDynamic(weight=NoopObserver.with_args(dtype=torch.float16))
|
|
per_channel_dynamic_qconfig = QConfigDynamic(weight=default_per_channel_weight_observer)
|
|
|
|
default_qat_qconfig = QConfig(activation=default_fake_quant,
|
|
weight=default_weight_fake_quant)
|
|
|
|
default_weight_only_qconfig = QConfig(activation=torch.nn.Identity,
|
|
weight=default_weight_fake_quant)
|
|
default_activation_only_qconfig = QConfig(activation=default_fake_quant,
|
|
weight=torch.nn.Identity)
|
|
|
|
def get_default_qconfig(backend='fbgemm'):
|
|
if backend == 'fbgemm':
|
|
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
|
|
weight=default_per_channel_weight_observer)
|
|
elif backend == 'qnnpack':
|
|
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
|
|
weight=default_weight_observer)
|
|
else:
|
|
raise ValueError("Unknown backend, please specify qconfig manually")
|
|
return qconfig
|
|
|
|
def get_default_qat_qconfig(backend='fbgemm'):
|
|
# Histogram observer is too slow for quantization aware training
|
|
if backend == 'fbgemm':
|
|
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
reduce_range=True),
|
|
weight=default_per_channel_weight_fake_quant)
|
|
elif backend == 'qnnpack':
|
|
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
reduce_range=False),
|
|
weight=default_weight_fake_quant)
|
|
else:
|
|
raise ValueError("Unknown backend, please specify qconfig manually")
|
|
|
|
return qconfig
|