pytorch/torch/quantization/qconfig.py
Charles David Hernandez 32d0c3e8ee Support for reference convert_fx working on gpu
Summary:
This PR enables gpu only quantization, best used with is_reference since
there are not many gpu kernels for ops as of now.

This PR mainly changes how qconfigs and their obs constructors operate once they
on modules qconfig. The function add_module_to_qconfig_obs_ctr takes the obs constructors on the original
qconfig, and configures them so that when invoked, the created obs will
be on whatever device the module occupies. (Once observers are created,
module.to(device) is already setup so that it moves any observers). To do this,
a new method and a few small chanegs were added to the _PartialWrapper class that
our observers already use to create constructors (without changing the
existing functionality). These changes work in
concert with changes to the prepare flow such that when the qconfigs are
propagated to the moduels (in quantize.py and qconfig_utils.py) they are configured using add_module_to_qconfig_obs_ctr.

Ideally this would work on other models but the is_reference support for
a lot of modules isn't there yet, those tests should be added in a
future PR

Test Plan:
python test/test_quantization.py TestQuantizeFxModels.test_static_gpu_convert_basic

python test/test_quantization.py TestQuantizeFxModels.test_switch_device_prepare_convert

python test/test_quantization.py TestQuantizeFxModels.test_prepare_serialize_switch_device_convert

python test/test_quantization.py TestQuantizeFx.test_qconfig_precedence

Reviewed By: vkuzo

Differential Revision: D29684114

fbshipit-source-id: 19fefb8e1998eaf212723e836276ccf39467f2e7
2021-07-23 10:30:38 -07:00

190 lines
9.2 KiB
Python

from collections import namedtuple
from .observer import (HistogramObserver, MovingAverageMinMaxObserver,
PlaceholderObserver, default_debug_observer,
default_dynamic_quant_observer,
default_float_qparams_observer, default_observer,
default_per_channel_weight_observer,
default_placeholder_observer, default_weight_observer)
from .fake_quantize import (FakeQuantize, default_fake_quant,
default_per_channel_weight_fake_quant,
default_weight_fake_quant)
import torch
import torch.nn as nn
from typing import Union, Optional, Any
class QConfig(namedtuple('QConfig', ['activation', 'weight'])):
"""
Describes how to quantize a layer or a part of the network by providing
settings (observer classes) for activations and weights respectively.
Note that QConfig needs to contain observer **classes** (like MinMaxObserver) or a callable that returns
instances on invocation, not the concrete observer instances themselves.
Quantization preparation function will instantiate observers multiple times for each of the layers.
Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args`
method (that behaves like functools.partial):
my_qconfig = QConfig(activation=MinMaxObserver.with_args(dtype=torch.qint8),
weight=default_observer.with_args(dtype=torch.qint8))
"""
def __new__(cls, activation, weight):
# catch common mistakes
if isinstance(activation, nn.Module) or isinstance(weight, nn.Module):
raise ValueError("QConfig received observer instance, please pass observer class instead. " +
"Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
return super(QConfig, cls).__new__(cls, activation, weight)
default_qconfig = QConfig(activation=default_observer,
weight=default_weight_observer)
default_debug_qconfig = QConfig(weight=default_weight_observer,
activation=default_debug_observer)
default_per_channel_qconfig = QConfig(activation=default_observer,
weight=default_per_channel_weight_observer)
class QConfigDynamic(namedtuple('QConfigDynamic', ['activation', 'weight'])):
"""
Describes how to dynamically quantize a layer or a part of the network by providing
settings (observer classes) for weights.
It's like QConfig, but for dynamic quantization.
Note that QConfigDynamic needs to contain observer **classes** (like MinMaxObserver) or a callable that returns
instances on invocation, not the concrete observer instances themselves.
Quantization function will instantiate observers multiple times for each of the layers.
Observer classes have usually reasonable default arguments, but they can be overwritten with `with_args`
method (that behaves like functools.partial):
my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
"""
def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
# catch common mistakes
if isinstance(weight, nn.Module):
raise ValueError("QConfigDynamic received observer instance, please pass observer class instead. " +
"Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
return super(QConfigDynamic, cls).__new__(cls, activation, weight)
default_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
weight=default_weight_observer)
float16_dynamic_qconfig = QConfigDynamic(activation=PlaceholderObserver.with_args(dtype=torch.float32),
weight=PlaceholderObserver.with_args(dtype=torch.float16))
float16_static_qconfig = QConfigDynamic(activation=PlaceholderObserver.with_args(dtype=torch.float16),
weight=PlaceholderObserver.with_args(dtype=torch.float16))
per_channel_dynamic_qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
weight=default_per_channel_weight_observer)
# TODO: this is weight only quant, change this to QConfigWeightOnly
# or remove the QConfigDynamic later
float_qparams_weight_only_qconfig = QConfigDynamic(
activation=default_placeholder_observer,
weight=default_float_qparams_observer)
default_qat_qconfig = QConfig(activation=default_fake_quant,
weight=default_weight_fake_quant)
default_weight_only_qconfig = QConfig(activation=torch.nn.Identity,
weight=default_weight_fake_quant)
default_activation_only_qconfig = QConfig(activation=default_fake_quant,
weight=torch.nn.Identity)
def get_default_qconfig(backend='fbgemm'):
if backend == 'fbgemm':
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
weight=default_per_channel_weight_observer)
elif backend == 'qnnpack':
qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
weight=default_weight_observer)
else:
qconfig = default_qconfig
return qconfig
def get_default_qat_qconfig(backend='fbgemm'):
# Histogram observer is too slow for quantization aware training
if backend == 'fbgemm':
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
reduce_range=True),
weight=default_per_channel_weight_fake_quant)
elif backend == 'qnnpack':
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
reduce_range=False),
weight=default_weight_fake_quant)
else:
qconfig = default_qat_qconfig
return qconfig
def assert_valid_qconfig(qconfig: Optional[Union[QConfig, QConfigDynamic]],
mod: torch.nn.Module) -> None:
if qconfig is None:
return
is_conv_transpose_mod = (
isinstance(mod, torch.nn.ConvTranspose1d) or
isinstance(mod, torch.nn.ConvTranspose2d) or
isinstance(mod, torch.nn.ConvTranspose3d))
if is_conv_transpose_mod:
example_observer = qconfig.weight()
is_per_channel = (
isinstance(example_observer, torch.quantization.PerChannelMinMaxObserver) or
isinstance(example_observer, torch.quantization.MovingAveragePerChannelMinMaxObserver)
)
assert not is_per_channel, \
'Per channel weight observer is not supported yet for ConvTranspose{n}d.'
QConfigAny = Union[QConfig,
QConfigDynamic, None]
def add_module_to_qconfig_obs_ctr(
qconfig: QConfigAny,
module: Union[nn.Module, None]) -> Any:
r"""This is a helper function for use in quantization prepare that updates a qconfig so that
the constructors stored in the qconfig will create observers on the same device that
'module' is on. This is intended to be used when the qconfigs are propagated to each
module in order to avoid potential device alignment issues.
Args:
qconfig: QConfig or QConfigDynamic with obs constructors stored in activation and weight
module: module which the qconfig is related to
Return:
qconfig: configured so that obs constructors set to construct on the same device as module
"""
if module is None or qconfig is None or qconfig._fields != ('activation', 'weight'):
return qconfig
def get_factory_kwargs_based_on_module_device():
assert isinstance(module, torch.nn.Module)
devices = {p.device for p in module.parameters()} | \
{p.device for p in module.buffers()}
device = next(iter(devices)) if len(devices) > 0 else None
return None if device is None else {'device': device}
def configure_constructor_to_put_obs_on_module_device(original_constructor):
try:
# check if constructor can accept factory_kwargs
check = original_constructor.with_args(factory_kwargs=None)
check()
return original_constructor.with_callable_args(factory_kwargs=get_factory_kwargs_based_on_module_device)
except AttributeError: # qconfig doesn't have activation or weight
return original_constructor
except TypeError: # the class doesn't accept factory_kwargs argument
return original_constructor
activation = configure_constructor_to_put_obs_on_module_device(qconfig.activation)
weight = configure_constructor_to_put_obs_on_module_device(qconfig.weight)
if isinstance(qconfig, QConfig):
return QConfig(activation, weight)
else:
return QConfigDynamic(activation, weight)