mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ao][docs] tests for quantization docs (#79923)
Summary: per https://github.com/pytorch/pytorch/issues/79135 the code snippets in the docs don't run. This is a recurring problem since previously there was no unit test to check that these code snippets actually ran. This PR adds support for such a test, importing the snippet as a string and evaluating it to make sure that it actually runs if the code snippet has user defined code, you can pass in dummy versions using global_inputs. Sometimes the imports of the code snippets behave oddly but you can pass them in as in test_quantization_doc_custom where nnq is passed in. Test Plan: python test/test_quantization.py TestQuantizationDocs also see https://github.com/pytorch/pytorch/pull/79994 to see what shows up in CI when the docs get broken Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/79923 Approved by: https://github.com/z-a-f, https://github.com/vspenubarthi
This commit is contained in:
parent
da33c93169
commit
ffdc5eebc7
|
|
@ -69,31 +69,31 @@ Diagram::
|
|||
/
|
||||
linear_weight_int8
|
||||
|
||||
API example::
|
||||
PTDQ API Example::
|
||||
|
||||
import torch
|
||||
import torch
|
||||
|
||||
# define a floating point model
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
self.fc = torch.nn.Linear(4, 4)
|
||||
# define a floating point model
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc(x)
|
||||
return x
|
||||
def forward(self, x):
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
# create a model instance
|
||||
model_fp32 = M()
|
||||
# create a quantized model instance
|
||||
model_int8 = torch.quantization.quantize_dynamic(
|
||||
model_fp32, # the original model
|
||||
{torch.nn.Linear}, # a set of layers to dynamically quantize
|
||||
dtype=torch.qint8) # the target dtype for quantized weights
|
||||
# create a model instance
|
||||
model_fp32 = M()
|
||||
# create a quantized model instance
|
||||
model_int8 = torch.quantization.quantize_dynamic(
|
||||
model_fp32, # the original model
|
||||
{torch.nn.Linear}, # a set of layers to dynamically quantize
|
||||
dtype=torch.qint8) # the target dtype for quantized weights
|
||||
|
||||
# run the model
|
||||
input_fp32 = torch.randn(4, 4, 4, 4)
|
||||
res = model_int8(input_fp32)
|
||||
# run the model
|
||||
input_fp32 = torch.randn(4, 4, 4, 4)
|
||||
res = model_int8(input_fp32)
|
||||
|
||||
To learn more about dynamic quantization please see our `dynamic quantization tutorial
|
||||
<https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html>`_.
|
||||
|
|
@ -124,14 +124,14 @@ Diagram::
|
|||
/
|
||||
linear_weight_int8
|
||||
|
||||
API Example::
|
||||
PTSQ API Example::
|
||||
|
||||
import torch
|
||||
|
||||
# define a floating point model where some layers could be statically quantized
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
super().__init__()
|
||||
# QuantStub converts tensors from floating point to quantized
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.conv = torch.nn.Conv2d(1, 1, 1)
|
||||
|
|
@ -222,14 +222,14 @@ Diagram::
|
|||
/
|
||||
linear_weight_int8
|
||||
|
||||
API Example::
|
||||
QAT API Example::
|
||||
|
||||
import torch
|
||||
|
||||
# define a floating point model where some layers could benefit from QAT
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(M, self).__init__()
|
||||
super().__init__()
|
||||
# QuantStub converts tensors from floating point to quantized
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.conv = torch.nn.Conv2d(1, 1, 1)
|
||||
|
|
@ -249,8 +249,8 @@ API Example::
|
|||
# create a model instance
|
||||
model_fp32 = M()
|
||||
|
||||
# model must be set to train mode for QAT logic to work
|
||||
model_fp32.train()
|
||||
# model must be set to eval for fusion to work
|
||||
model_fp32.eval()
|
||||
|
||||
# attach a global qconfig, which contains information about what kind
|
||||
# of observers to attach. Use 'fbgemm' for server inference and
|
||||
|
|
@ -265,8 +265,9 @@ API Example::
|
|||
[['conv', 'bn', 'relu']])
|
||||
|
||||
# Prepare the model for QAT. This inserts observers and fake_quants in
|
||||
# the model needs to be set to train for QAT logic to work
|
||||
# the model that will observe weight and activation tensors during calibration.
|
||||
model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused)
|
||||
model_fp32_prepared = torch.quantization.prepare_qat(model_fp32_fused.train())
|
||||
|
||||
# run the training loop (not shown)
|
||||
training_loop(model_fp32_prepared)
|
||||
|
|
@ -324,13 +325,14 @@ to do the following in addition:
|
|||
|
||||
There are multiple quantization types in post training quantization (weight only, dynamic and static) and the configuration is done through `qconfig_mapping` (an argument of the `prepare_fx` function).
|
||||
|
||||
API Example::
|
||||
FXPTQ API Example::
|
||||
|
||||
from torch.quantization import QConfigMapping
|
||||
import torch
|
||||
from torch.ao.quantization import QConfigMapping
|
||||
import torch.quantization.quantize_fx as quantize_fx
|
||||
import copy
|
||||
|
||||
model_fp = UserModel(...)
|
||||
model_fp = UserModel()
|
||||
|
||||
#
|
||||
# post training dynamic/weight_only quantization
|
||||
|
|
@ -340,9 +342,11 @@ API Example::
|
|||
model_to_quantize = copy.deepcopy(model_fp)
|
||||
model_to_quantize.eval()
|
||||
qconfig_mapping = QConfigMapping().set_global(torch.quantization.default_dynamic_qconfig)
|
||||
# a tuple of one or more example inputs are needed to trace the model
|
||||
example_inputs = (input_fp32)
|
||||
# prepare
|
||||
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping)
|
||||
# no calibration needed when we only have dynamici/weight_only quantization
|
||||
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
|
||||
# no calibration needed when we only have dynamic/weight_only quantization
|
||||
# quantize
|
||||
model_quantized = quantize_fx.convert_fx(model_prepared)
|
||||
|
||||
|
|
@ -354,7 +358,7 @@ API Example::
|
|||
qconfig_mapping = QConfigMapping().set_global(torch.quantization.get_default_qconfig('qnnpack'))
|
||||
model_to_quantize.eval()
|
||||
# prepare
|
||||
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping)
|
||||
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
|
||||
# calibrate (not shown)
|
||||
# quantize
|
||||
model_quantized = quantize_fx.convert_fx(model_prepared)
|
||||
|
|
@ -367,7 +371,7 @@ API Example::
|
|||
qconfig_mapping = QConfigMapping().set_global(torch.quantization.get_default_qat_qconfig('qnnpack'))
|
||||
model_to_quantize.train()
|
||||
# prepare
|
||||
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping)
|
||||
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
|
||||
# training loop (not shown)
|
||||
# quantize
|
||||
model_quantized = quantize_fx.convert_fx(model_prepared)
|
||||
|
|
@ -790,106 +794,101 @@ on that output. The observer will be stored under the `activation_post_process`
|
|||
as an attribute of the custom module instance. Relaxing these restrictions may
|
||||
be done at a future time.
|
||||
|
||||
Example::
|
||||
Custom API Example::
|
||||
|
||||
import torch
|
||||
import torch.nn.quantized as nnq
|
||||
from torch.quantization import QConfigMapping
|
||||
import torch.quantization.quantize_fx
|
||||
import torch
|
||||
import torch.nn.quantized as nnq
|
||||
from torch.ao.quantization import QConfigMapping
|
||||
import torch.ao.quantization.quantize_fx
|
||||
|
||||
# original fp32 module to replace
|
||||
class CustomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
# original fp32 module to replace
|
||||
class CustomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
# custom observed module, provided by user
|
||||
class ObservedCustomModule(torch.nn.Module):
|
||||
def __init__(self, linear):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
# custom observed module, provided by user
|
||||
class ObservedCustomModule(torch.nn.Module):
|
||||
def __init__(self, linear):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, float_module):
|
||||
assert hasattr(float_module, 'qconfig')
|
||||
observed = cls(float_module.linear)
|
||||
observed.qconfig = float_module.qconfig
|
||||
return observed
|
||||
@classmethod
|
||||
def from_float(cls, float_module):
|
||||
assert hasattr(float_module, 'qconfig')
|
||||
observed = cls(float_module.linear)
|
||||
observed.qconfig = float_module.qconfig
|
||||
return observed
|
||||
|
||||
# custom quantized module, provided by user
|
||||
class StaticQuantCustomModule(torch.nn.Module):
|
||||
def __init__(self, linear):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
# custom quantized module, provided by user
|
||||
class StaticQuantCustomModule(torch.nn.Module):
|
||||
def __init__(self, linear):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
@classmethod
|
||||
def from_observed(cls, observed_module):
|
||||
assert hasattr(observed_module, 'qconfig')
|
||||
assert hasattr(observed_module, 'activation_post_process')
|
||||
observed_module.linear.activation_post_process = \
|
||||
observed_module.activation_post_process
|
||||
quantized = cls(nnq.Linear.from_float(observed_module.linear))
|
||||
return quantized
|
||||
@classmethod
|
||||
def from_observed(cls, observed_module):
|
||||
assert hasattr(observed_module, 'qconfig')
|
||||
assert hasattr(observed_module, 'activation_post_process')
|
||||
observed_module.linear.activation_post_process = \
|
||||
observed_module.activation_post_process
|
||||
quantized = cls(nnq.Linear.from_float(observed_module.linear))
|
||||
return quantized
|
||||
|
||||
#
|
||||
# example API call (Eager mode quantization)
|
||||
#
|
||||
#
|
||||
# example API call (Eager mode quantization)
|
||||
#
|
||||
|
||||
m = torch.nn.Sequential(CustomModule()).eval()
|
||||
|
||||
prepare_custom_config_dict = {
|
||||
"float_to_observed_custom_module_class": {
|
||||
CustomModule: ObservedCustomModule
|
||||
}
|
||||
}
|
||||
convert_custom_config_dict = {
|
||||
"observed_to_quantized_custom_module_class": {
|
||||
ObservedCustomModule: StaticQuantCustomModule
|
||||
}
|
||||
}
|
||||
|
||||
m.qconfig = torch.quantization.default_qconfig
|
||||
mp = torch.quantization.prepare(
|
||||
m, prepare_custom_config_dict=prepare_custom_config_dict)
|
||||
# calibration (not shown)
|
||||
mq = torch.quantization.convert(
|
||||
mp, convert_custom_config_dict=convert_custom_config_dict)
|
||||
|
||||
#
|
||||
# example API call (FX graph mode quantization)
|
||||
#
|
||||
|
||||
m = torch.nn.Sequential(CustomModule()).eval()
|
||||
|
||||
qconfig_mapping = QConfigMapping().set_global(torch.quantization.default_qconfig)
|
||||
prepare_custom_config_dict = {
|
||||
"float_to_observed_custom_module_class": {
|
||||
"static": {
|
||||
CustomModule: ObservedCustomModule,
|
||||
}
|
||||
}
|
||||
}
|
||||
convert_custom_config_dict = {
|
||||
"observed_to_quantized_custom_module_class": {
|
||||
"static": {
|
||||
ObservedCustomModule: StaticQuantCustomModule,
|
||||
}
|
||||
}
|
||||
}
|
||||
mp = torch.quantization.quantize_fx.prepare_fx(
|
||||
m, qconfig_mapping, prepare_custom_config_dict=prepare_custom_config_dict)
|
||||
# calibration (not shown)
|
||||
mq = torch.quantization.quantize_fx.convert_fx(
|
||||
mp, convert_custom_config_dict=convert_custom_config_dict)
|
||||
m = torch.nn.Sequential(CustomModule()).eval()
|
||||
prepare_custom_config_dict = {
|
||||
"float_to_observed_custom_module_class": {
|
||||
CustomModule: ObservedCustomModule
|
||||
}
|
||||
}
|
||||
convert_custom_config_dict = {
|
||||
"observed_to_quantized_custom_module_class": {
|
||||
ObservedCustomModule: StaticQuantCustomModule
|
||||
}
|
||||
}
|
||||
m.qconfig = torch.ao.quantization.default_qconfig
|
||||
mp = torch.ao.quantization.prepare(
|
||||
m, prepare_custom_config_dict=prepare_custom_config_dict)
|
||||
# calibration (not shown)
|
||||
mq = torch.ao.quantization.convert(
|
||||
mp, convert_custom_config_dict=convert_custom_config_dict)
|
||||
#
|
||||
# example API call (FX graph mode quantization)
|
||||
#
|
||||
m = torch.nn.Sequential(CustomModule()).eval()
|
||||
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_qconfig)
|
||||
prepare_custom_config_dict = {
|
||||
"float_to_observed_custom_module_class": {
|
||||
"static": {
|
||||
CustomModule: ObservedCustomModule,
|
||||
}
|
||||
}
|
||||
}
|
||||
convert_custom_config_dict = {
|
||||
"observed_to_quantized_custom_module_class": {
|
||||
"static": {
|
||||
ObservedCustomModule: StaticQuantCustomModule,
|
||||
}
|
||||
}
|
||||
}
|
||||
mp = torch.ao.quantization.quantize_fx.prepare_fx(
|
||||
m, qconfig_mapping, torch.randn(3,3), prepare_custom_config=prepare_custom_config_dict)
|
||||
# calibration (not shown)
|
||||
mq = torch.ao.quantization.quantize_fx.convert_fx(
|
||||
mp, convert_custom_config=convert_custom_config_dict)
|
||||
|
||||
Best Practices
|
||||
--------------
|
||||
|
|
|
|||
151
test/quantization/core/test_docs.py
Normal file
151
test/quantization/core/test_docs.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
# Owner(s): ["oncall: quantization"]
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
# import torch.nn.quantized as nnq
|
||||
from torch.testing._internal.common_quantization import (
|
||||
QuantizationTestCase,
|
||||
SingleLayerLinearModel,
|
||||
)
|
||||
|
||||
|
||||
class TestQuantizationDocs(QuantizationTestCase):
|
||||
r"""
|
||||
The tests in this section import code from the quantization docs and check that
|
||||
they actually run without errors. In cases where objects are undefined in the code snippet,
|
||||
they must be provided in the test. The imports seem to behave a bit inconsistently,
|
||||
they can be imported either in the test file or passed as a global input
|
||||
"""
|
||||
|
||||
def _get_code(
|
||||
self, path_from_pytorch, unique_identifier, offset=2, short_snippet=False
|
||||
):
|
||||
r"""
|
||||
This function reads in the code from the docs given a unique identifier.
|
||||
Most code snippets have a 2 space indentation, for other indentation levels,
|
||||
change the offset `arg`. the `short_snippet` arg can be set to allow for testing
|
||||
of smaller snippets, the check that this arg controls is used to make sure that
|
||||
we are not accidentally only importing a blank line or something.
|
||||
"""
|
||||
|
||||
def get_correct_path(path_from_pytorch):
|
||||
r"""
|
||||
Current working directory when CI is running test seems to vary, this function
|
||||
looks for docs and if it finds it looks for the path to the
|
||||
file and if the file exists returns that path, otherwise keeps looking. Will
|
||||
only work if cwd contains pytorch or docs or a parent contains docs.
|
||||
"""
|
||||
# get cwd
|
||||
cur_dir_path = Path(".").resolve()
|
||||
|
||||
# check if cwd contains pytorch, use that if it does
|
||||
if (cur_dir_path / "pytorch").is_dir():
|
||||
cur_dir_path = (cur_dir_path / "pytorch").resolve()
|
||||
|
||||
# need to find the file, so we check current directory
|
||||
# and all parent directories to see if the path leads to it
|
||||
check_dir = cur_dir_path
|
||||
while not check_dir == check_dir.parent:
|
||||
file_path = (check_dir / path_from_pytorch).resolve()
|
||||
if file_path.is_file():
|
||||
return file_path
|
||||
check_dir = check_dir.parent.resolve()
|
||||
|
||||
# no longer passing when file not found
|
||||
raise FileNotFoundError("could not find {}".format(path_from_pytorch))
|
||||
|
||||
path_to_file = get_correct_path(path_from_pytorch)
|
||||
if path_to_file:
|
||||
file = open(path_to_file)
|
||||
content = file.readlines()
|
||||
|
||||
# it will register as having a newline at the end in python
|
||||
if "\n" not in unique_identifier:
|
||||
unique_identifier += "\n"
|
||||
|
||||
assert unique_identifier in content, "could not find {} in {}".format(
|
||||
unique_identifier, path_to_file
|
||||
)
|
||||
|
||||
# get index of first line of code
|
||||
line_num_start = content.index(unique_identifier) + 1
|
||||
|
||||
# next find where the code chunk ends.
|
||||
# this regex will match lines that don't start
|
||||
# with a \n or " " with number of spaces=offset
|
||||
r = r = re.compile("^[^\n," + " " * offset + "]")
|
||||
# this will return the line of first line that matches regex
|
||||
line_after_code = next(filter(r.match, content[line_num_start:]))
|
||||
last_line_num = content.index(line_after_code)
|
||||
|
||||
# remove the first `offset` chars of each line and gather it all together
|
||||
code = "".join(
|
||||
[x[offset:] for x in content[line_num_start + 1 : last_line_num]]
|
||||
)
|
||||
|
||||
# want to make sure we are actually getting some code,
|
||||
assert last_line_num - line_num_start > 3 or short_snippet, (
|
||||
"The code in {} identified by {} seems suspiciously short:"
|
||||
"\n\n###code-start####\n{}###code-end####".format(
|
||||
path_to_file, unique_identifier, code
|
||||
)
|
||||
)
|
||||
return code
|
||||
|
||||
return None
|
||||
|
||||
def _test_code(self, code, global_inputs=None):
|
||||
r"""
|
||||
This function runs `code` using any vars in `global_inputs`
|
||||
"""
|
||||
# if couldn't find the
|
||||
if code is not None:
|
||||
expr = compile(code, "test", "exec")
|
||||
exec(expr, global_inputs)
|
||||
|
||||
def test_quantization_doc_ptdq(self):
|
||||
path_from_pytorch = "docs/source/quantization.rst"
|
||||
unique_identifier = "PTDQ API Example::"
|
||||
code = self._get_code(path_from_pytorch, unique_identifier)
|
||||
self._test_code(code)
|
||||
|
||||
def test_quantization_doc_ptsq(self):
|
||||
path_from_pytorch = "docs/source/quantization.rst"
|
||||
unique_identifier = "PTSQ API Example::"
|
||||
code = self._get_code(path_from_pytorch, unique_identifier)
|
||||
self._test_code(code)
|
||||
|
||||
def test_quantization_doc_qat(self):
|
||||
path_from_pytorch = "docs/source/quantization.rst"
|
||||
unique_identifier = "QAT API Example::"
|
||||
|
||||
def _dummy_func(*args, **kwargs):
|
||||
return None
|
||||
|
||||
input_fp32 = torch.randn(1, 1, 1, 1)
|
||||
global_inputs = {"training_loop": _dummy_func, "input_fp32": input_fp32}
|
||||
|
||||
code = self._get_code(path_from_pytorch, unique_identifier)
|
||||
self._test_code(code, global_inputs)
|
||||
|
||||
def test_quantization_doc_fx(self):
|
||||
path_from_pytorch = "docs/source/quantization.rst"
|
||||
unique_identifier = "FXPTQ API Example::"
|
||||
|
||||
input_fp32 = SingleLayerLinearModel().get_example_inputs()
|
||||
global_inputs = {"UserModel": SingleLayerLinearModel, "input_fp32": input_fp32}
|
||||
|
||||
code = self._get_code(path_from_pytorch, unique_identifier)
|
||||
self._test_code(code, global_inputs)
|
||||
|
||||
def test_quantization_doc_custom(self):
|
||||
path_from_pytorch = "docs/source/quantization.rst"
|
||||
unique_identifier = "Custom API Example::"
|
||||
|
||||
global_inputs = {"nnq": torch.nn.quantized}
|
||||
|
||||
code = self._get_code(path_from_pytorch, unique_identifier)
|
||||
self._test_code(code, global_inputs)
|
||||
|
|
@ -37,7 +37,7 @@ from quantization.core.test_workflow_module import TestHistogramObserver # noqa
|
|||
from quantization.core.test_workflow_module import TestDistributed # noqa: F401
|
||||
from quantization.core.test_workflow_module import TestFusedObsFakeQuantModule # noqa: F401
|
||||
from quantization.core.test_utils import TestUtils # noqa: F401
|
||||
|
||||
from quantization.core.test_docs import TestQuantizationDocs # noqa: F401
|
||||
|
||||
# Eager Mode Workflow. Tests for the functionality of APIs and different features implemented
|
||||
# using eager mode.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user