mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
pytorch quantization: document the custom module APIs (#67449)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67449 Adds a description of what the current custom module API does and API examples for Eager mode and FX graph mode to the main PyTorch quantization documentation page. Test Plan: ``` cd docs make html python -m http.server // check the docs page, it renders correctly ``` Reviewed By: jbschlosser Differential Revision: D31994641 Pulled By: vkuzo fbshipit-source-id: d35a62947dd06e71276eb6a0e37950d3cc5abfc1
This commit is contained in:
parent
acdc754918
commit
99282126dc
|
|
@ -578,6 +578,138 @@ Quantization workflows work by adding (e.g. adding observers as
|
|||
means that the model stays a regular ``nn.Module``-based instance throughout the
|
||||
process and thus can work with the rest of PyTorch APIs.
|
||||
|
||||
Quantization Custom Module API
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Both Eager mode and FX graph mode quantization APIs provide a hook for the user
|
||||
to specify module quantized in a custom way, with user defined logic for
|
||||
observation and quantization. The user needs to specify:
|
||||
|
||||
1. The Python type of the source fp32 module (existing in the model)
|
||||
2. The Python type of the observed module (provided by user). This module needs
|
||||
to define a `from_float` function which defines how the observed module is
|
||||
created from the original fp32 module.
|
||||
3. The Python type of the quantized module (provided by user). This module needs
|
||||
to define a `from_observed` function which defines how the quantized module is
|
||||
created from the obseved module.
|
||||
4. A configuration describing (1), (2), (3) above, passed to the quantization APIs.
|
||||
|
||||
|
||||
The framework will then do the following:
|
||||
|
||||
1. during the `prepare` module swaps, it will convert every module of type
|
||||
specified in (1) to the type specified in (2), using the `from_float` function of
|
||||
the class in (2).
|
||||
2. during the `convert` module swaps, it will convert every module of type
|
||||
specified in (2) to the type specified in (3), using the `from_observed` function
|
||||
of the class in (3).
|
||||
|
||||
Currently, there is a requirement that `ObservedCustomModule` will have a single
|
||||
Tensor output, and an observer will be added by the framework (not by the user)
|
||||
on that output. The observer will be stored under the `activation_post_process` key
|
||||
as an attribute of the custom module instance. Relaxing these restrictions may
|
||||
be done at a future time.
|
||||
|
||||
Example::
|
||||
|
||||
import torch
|
||||
import torch.nn.quantized as nnq
|
||||
import torch.quantization.quantize_fx
|
||||
|
||||
# original fp32 module to replace
|
||||
class CustomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
# custom observed module, provided by user
|
||||
class ObservedCustomModule(torch.nn.Module):
|
||||
def __init__(self, linear):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, float_module):
|
||||
assert hasattr(float_module, 'qconfig')
|
||||
observed = cls(float_module.linear)
|
||||
observed.qconfig = float_module.qconfig
|
||||
return observed
|
||||
|
||||
# custom quantized module, provided by user
|
||||
class StaticQuantCustomModule(torch.nn.Module):
|
||||
def __init__(self, linear):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
@classmethod
|
||||
def from_observed(cls, observed_module):
|
||||
assert hasattr(observed_module, 'qconfig')
|
||||
assert hasattr(observed_module, 'activation_post_process')
|
||||
observed_module.linear.activation_post_process = \
|
||||
observed_module.activation_post_process
|
||||
quantized = cls(nnq.Linear.from_float(observed_module.linear))
|
||||
return quantized
|
||||
|
||||
#
|
||||
# example API call (Eager mode quantization)
|
||||
#
|
||||
|
||||
m = torch.nn.Sequential(CustomModule()).eval()
|
||||
|
||||
prepare_custom_config_dict = {
|
||||
"float_to_observed_custom_module_class": {
|
||||
CustomModule: ObservedCustomModule
|
||||
}
|
||||
}
|
||||
convert_custom_config_dict = {
|
||||
"observed_to_quantized_custom_module_class": {
|
||||
ObservedCustomModule: StaticQuantCustomModule
|
||||
}
|
||||
}
|
||||
|
||||
m.qconfig = torch.quantization.default_qconfig
|
||||
mp = torch.quantization.prepare(
|
||||
m, prepare_custom_config_dict=prepare_custom_config_dict)
|
||||
# calibration (not shown)
|
||||
mq = torch.quantization.convert(
|
||||
mp, convert_custom_config_dict=convert_custom_config_dict)
|
||||
|
||||
#
|
||||
# example API call (FX graph mode quantization)
|
||||
#
|
||||
|
||||
m = torch.nn.Sequential(CustomModule()).eval()
|
||||
|
||||
qconfig_dict = {'': torch.quantization.default_qconfig}
|
||||
prepare_custom_config_dict = {
|
||||
"float_to_observed_custom_module_class": {
|
||||
"static": {
|
||||
CustomModule: ObservedCustomModule,
|
||||
}
|
||||
}
|
||||
}
|
||||
convert_custom_config_dict = {
|
||||
"observed_to_quantized_custom_module_class": {
|
||||
"static": {
|
||||
ObservedCustomModule: StaticQuantCustomModule,
|
||||
}
|
||||
}
|
||||
}
|
||||
mp = torch.quantization.quantize_fx.prepare_fx(
|
||||
m, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict)
|
||||
# calibration (not shown)
|
||||
mq = torch.quantization.quantize_fx.convert_fx(
|
||||
mp, convert_custom_config_dict=convert_custom_config_dict)
|
||||
|
||||
Model Preparation for Quantization (Eager Mode)
|
||||
-----------------------------------------------
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user