pytorch quantization: document the custom module APIs (#67449)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67449

Adds a description of what the current custom module API does
and API examples for Eager mode and FX graph mode to the main
PyTorch quantization documentation page.

Test Plan:
```
cd docs
make html
python -m http.server
// check the docs page, it renders correctly
```

Reviewed By: jbschlosser

Differential Revision: D31994641

Pulled By: vkuzo

fbshipit-source-id: d35a62947dd06e71276eb6a0e37950d3cc5abfc1
This commit is contained in:
Vasiliy Kuznetsov 2021-10-29 05:20:29 -07:00 committed by Facebook GitHub Bot
parent acdc754918
commit 99282126dc

View File

@ -578,6 +578,138 @@ Quantization workflows work by adding (e.g. adding observers as
means that the model stays a regular ``nn.Module``-based instance throughout the
process and thus can work with the rest of PyTorch APIs.
Quantization Custom Module API
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Both Eager mode and FX graph mode quantization APIs provide a hook for the user
to specify module quantized in a custom way, with user defined logic for
observation and quantization. The user needs to specify:
1. The Python type of the source fp32 module (existing in the model)
2. The Python type of the observed module (provided by user). This module needs
to define a `from_float` function which defines how the observed module is
created from the original fp32 module.
3. The Python type of the quantized module (provided by user). This module needs
to define a `from_observed` function which defines how the quantized module is
created from the obseved module.
4. A configuration describing (1), (2), (3) above, passed to the quantization APIs.
The framework will then do the following:
1. during the `prepare` module swaps, it will convert every module of type
specified in (1) to the type specified in (2), using the `from_float` function of
the class in (2).
2. during the `convert` module swaps, it will convert every module of type
specified in (2) to the type specified in (3), using the `from_observed` function
of the class in (3).
Currently, there is a requirement that `ObservedCustomModule` will have a single
Tensor output, and an observer will be added by the framework (not by the user)
on that output. The observer will be stored under the `activation_post_process` key
as an attribute of the custom module instance. Relaxing these restrictions may
be done at a future time.
Example::
import torch
import torch.nn.quantized as nnq
import torch.quantization.quantize_fx
# original fp32 module to replace
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
return self.linear(x)
# custom observed module, provided by user
class ObservedCustomModule(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
def forward(self, x):
return self.linear(x)
@classmethod
def from_float(cls, float_module):
assert hasattr(float_module, 'qconfig')
observed = cls(float_module.linear)
observed.qconfig = float_module.qconfig
return observed
# custom quantized module, provided by user
class StaticQuantCustomModule(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
def forward(self, x):
return self.linear(x)
@classmethod
def from_observed(cls, observed_module):
assert hasattr(observed_module, 'qconfig')
assert hasattr(observed_module, 'activation_post_process')
observed_module.linear.activation_post_process = \
observed_module.activation_post_process
quantized = cls(nnq.Linear.from_float(observed_module.linear))
return quantized
#
# example API call (Eager mode quantization)
#
m = torch.nn.Sequential(CustomModule()).eval()
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
ObservedCustomModule: StaticQuantCustomModule
}
}
m.qconfig = torch.quantization.default_qconfig
mp = torch.quantization.prepare(
m, prepare_custom_config_dict=prepare_custom_config_dict)
# calibration (not shown)
mq = torch.quantization.convert(
mp, convert_custom_config_dict=convert_custom_config_dict)
#
# example API call (FX graph mode quantization)
#
m = torch.nn.Sequential(CustomModule()).eval()
qconfig_dict = {'': torch.quantization.default_qconfig}
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
"static": {
CustomModule: ObservedCustomModule,
}
}
}
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
"static": {
ObservedCustomModule: StaticQuantCustomModule,
}
}
}
mp = torch.quantization.quantize_fx.prepare_fx(
m, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict)
# calibration (not shown)
mq = torch.quantization.quantize_fx.convert_fx(
mp, convert_custom_config_dict=convert_custom_config_dict)
Model Preparation for Quantization (Eager Mode)
-----------------------------------------------