[reland][quant][graphmode][fx] Add top level APIs (#43581) (#43901)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43901

Add similar APIs like eager and graph mode on torchscript
- fuse_fx
- quantize_fx (for both post training static and qat)
- quantize_dynamic_fx (for post training dynamic)
- prepare_fx (for both post training static and qat)
- prepare_dynamic_fx (for post training dynamic)
- convert_fx (for all modes)

Test Plan:
Imported from OSS

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D23432430

fbshipit-source-id: fc99eb75cbecd6ee7a3aa6c8ec71cd499ff7e3c1
This commit is contained in:
Jerry Zhang 2020-08-31 18:22:24 -07:00 committed by Facebook GitHub Bot
parent deb5fde51c
commit 7db7da7151
11 changed files with 285 additions and 62 deletions

View File

@ -74,7 +74,7 @@ ignore_errors = True
[mypy-torch.quantization._numeric_suite]
ignore_errors = True
[mypy-torch.quantization._quantize_fx]
[mypy-torch.quantization.quantize_fx]
ignore_errors = True
[mypy-torch.quantization.fx.*]

View File

@ -10,10 +10,11 @@ import torch.multiprocessing as mp
from torch.fx import symbolic_trace
# graph mode quantization based on fx
from torch.quantization._quantize_fx import (
Quantizer,
fuse,
from torch.quantization import (
QuantType,
fuse_fx,
prepare_fx,
convert_fx,
)
from torch.quantization import (
@ -654,11 +655,10 @@ class TestQuantizeFxOps(QuantizationTestCase):
m = M()
original = symbolic_trace(m)
# nothing to fuse so skipping the fuse step
quantizer = Quantizer()
qconfig_dict = {'': default_qconfig}
prepared = quantizer.prepare(original, qconfig_dict)
prepared = prepare_fx(original, qconfig_dict)
# not runnable
quantized = quantizer.convert(prepared)
quantized = convert_fx(prepared)
# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
@ -750,11 +750,10 @@ class TestQuantizeFxOps(QuantizationTestCase):
m = M()
original = symbolic_trace(m)
# nothing to fuse so skipping the fuse step
quantizer = Quantizer()
qconfig_dict = {'': default_qconfig}
prepared = quantizer.prepare(original, qconfig_dict)
prepared = prepare_fx(original, qconfig_dict)
# not runnable
quantized = quantizer.convert(prepared)
quantized = convert_fx(prepared)
# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
@ -817,9 +816,8 @@ class TestQuantizeFxModels(QuantizationTestCase):
if mode != 'static':
model.train()
graph_module = fuse(graph_module)
quantizer = Quantizer()
prepared = quantizer.prepare(graph_module, qconfig_dict)
graph_module = fuse_fx(graph_module)
prepared = prepare_fx(graph_module, qconfig_dict)
if mode == 'ddp':
mp.spawn(run_ddp,
@ -837,7 +835,7 @@ class TestQuantizeFxModels(QuantizationTestCase):
# print('after observation root:', prepared.root)
qgraph = quantizer.convert(prepared)
qgraph = convert_fx(prepared)
# print('after quantization root:', qgraph.root)
# print('after quantization code:', qgraph.src)
qgraph.eval()

View File

@ -5,7 +5,9 @@ from .qconfig import *
from .fake_quantize import *
from .fuse_modules import fuse_modules
from .stubs import *
from .quant_type import *
from .quantize_jit import *
from .quantize_fx import *
def default_eval_fn(model, calib_data):
r"""
@ -20,8 +22,12 @@ _all__ = [
# Top level API for eager mode quantization
'quantize', 'quantize_dynamic', 'quantize_qat',
'prepare', 'convert', 'prepare_qat',
# Top level API for graph mode quantization
# Top level API for graph mode quantization on TorchScript
'quantize_jit', 'quantize_dynamic_jit',
# Top level API for graph mode quantization on GraphModule(torch.fx)
'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
'QuantType', # quantization type
# Sub functions for `prepare` and `swap_module`
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
'default_eval_fn', 'get_observer_dict',

View File

@ -1,3 +0,0 @@
from .fx import Quantizer # noqa: F401
from .fx import QuantType # noqa: F401
from .fx import fuse # noqa: F401

View File

@ -1,3 +1,3 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from .quantize import Quantizer, QuantType
from .fuse import fuse
from .quantize import Quantizer
from .fuse import Fuser

View File

@ -124,7 +124,7 @@ class ModuleReLUFusion():
return quantizer.fused_graph.node_copy(self.module_node, load_arg)
class Fuser:
def fuse_conv_bn(self, model, inplace=False):
def fuse(self, model, inplace=False):
input_root = model.root
if not inplace:
input_root = copy.deepcopy(input_root)
@ -173,7 +173,3 @@ class Fuser:
apply_match(pattern, node, (node, value(self, node)))
return match_map
def fuse(graph_module, inplace=False):
fuser = Fuser()
return fuser.fuse_conv_bn(graph_module, inplace)

View File

@ -33,16 +33,8 @@ from .utils import _parent_name
from abc import ABC, abstractmethod
import copy
import enum
import operator
# Quantization type (dynamic quantization, static quantization).
# Should match the c++ enum in quantization_type.h
class QuantType(enum.IntEnum):
DYNAMIC = 0
STATIC = 1
QAT = 2
# ------------------------
# Helper Functions
# ------------------------
@ -383,8 +375,8 @@ class BatchNorm(QuantizeHandler):
load_arg(quantized=False)(self.bn_node.kwargs))
ARGS_TO_SKIP = {
torch.ops.quantized.hardswish: ['inplace'],
torch.ops.quantized.instance_norm:
torch._ops.ops.quantized.hardswish: ['inplace'],
torch._ops.ops.quantized.instance_norm:
['running_mean', 'running_var', 'use_input_stats', 'momentum'],
}
@register_quant_pattern(torch.nn.ELU)
@ -621,15 +613,16 @@ class Quantizer:
elif node.op == 'call_module':
self.qconfig_map[node.name] = get_qconfig(self.modules[node.target])
def _prepare(self, model, qconfig_dict, inplace, quant_type):
def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant):
assert not inplace, 'inplace prepare is not supported yet'
input_root = model.root
if not inplace:
input_root = copy.deepcopy(input_root)
input_graph = model.graph
self.quant_type = quant_type
self.is_dynamic_quant = is_dynamic_quant
# TODO: allow user specified patterns
if self.quant_type == QuantType.DYNAMIC:
if self.is_dynamic_quant:
self.patterns = get_dynamic_quant_patterns()
else:
self.patterns = get_quant_patterns()
@ -688,7 +681,7 @@ class Quantizer:
observed.add(node.name)
# don't need to insert observer for output in dynamic quantization
if self.quant_type == QuantType.DYNAMIC:
if self.is_dynamic_quant:
continue
if isinstance(obj, CopyNode):
@ -725,22 +718,44 @@ class Quantizer:
observed.add(node.name)
observed_graph.output(load_arg(input_graph.result))
return GraphModule(input_root, observed_graph)
observed = GraphModule(input_root, observed_graph)
self.save_state(observed)
return observed
def save_state(self, observed):
observed._activation_post_process_map = self.activation_post_process_map
observed._patterns = self.patterns
observed._qconfig_map = self.qconfig_map
def restore_state(self, observed):
err_msg = 'please make sure the model is produced by prepare'
assert hasattr(observed, '_activation_post_process_map'), 'did not found ' + \
'_activation_post_process attribute ' + err_msg
assert hasattr(observed, '_patterns'), 'did not found ' + \
'_patterns attribute ' + err_msg
assert hasattr(observed, '_qconfig_map'), 'did not found ' + \
'_qconfig_map attribute ' + err_msg
self.activation_post_process_map = observed._activation_post_process_map
self.patterns = observed._patterns
self.qconfig_map = observed._qconfig_map
def prepare(self, model, qconfig_dict, inplace=False):
return self._prepare(model, qconfig_dict, inplace, quant_type=QuantType.STATIC)
return self._prepare(model, qconfig_dict, inplace, is_dynamic_quant=False)
def prepare_dynamic(self, model, qconfig_dict, inplace=False):
return self._prepare(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC)
return self._prepare(model, qconfig_dict, inplace, is_dynamic_quant=True)
def convert(self, observed, inplace=False, debug=False):
assert self.activation_post_process_map is not None
def convert(self, observed, inplace=False, debug=False, is_dynamic_quant=False):
assert not inplace, 'inplace convert is not supported yet'
self.restore_state(observed)
self.is_dynamic_quant = is_dynamic_quant
# move to cpu since we only have quantized cpu kernels
observed.eval().cpu()
observed_root = observed.root
observed_graph = observed.graph
if not inplace:
observed_root = copy.deepcopy(observed_root)
self.modules = dict(observed_root.named_modules())
matches = self._find_matches(observed.graph, self.modules, self.patterns)
@ -833,7 +848,8 @@ class Quantizer:
'CopyNode of type ' + node.op + ' is not handled'
quantized = is_quantized(node.args[0])
if self.quant_type == QuantType.DYNAMIC:
# output of dynamic quantization is not quantized
if self.is_dynamic_quant:
quantized = False
if quantized:
@ -951,7 +967,7 @@ class Quantizer:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i in WEIGHT_INDEX_DICT[node.target]:
is_weight = True
if self.quant_type != QuantType.DYNAMIC or is_weight:
if (not self.is_dynamic_quant) or is_weight:
# overwrite previous quant config
quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight)
return visit_arg

View File

@ -0,0 +1,10 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import enum
# Quantization type (dynamic quantization, static quantization).
# Should match the c++ enum in quantization_type.h
class QuantType(enum.IntEnum):
DYNAMIC = 0
STATIC = 1
QAT = 2

View File

@ -0,0 +1,201 @@
from .fx import Fuser # noqa: F401
from .fx import Quantizer # noqa: F401
from torch.fx import GraphModule
def _check_is_graph_module(model):
if not isinstance(model, GraphModule):
raise ValueError(
'input model must be a GraphModule, ' +
'please run torch.fx.symbolic_trace on your model before using ' +
'quantize_fx. Got type:' + str(type(model)))
def fuse_fx(graph_module, inplace=False):
r""" Fuse modules in preparation for quantization
Args:
graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
"""
_check_is_graph_module(graph_module)
fuser = Fuser()
return fuser.fuse(graph_module, inplace)
def _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant):
_check_is_graph_module(graph_module)
quantizer = Quantizer()
prepare = quantizer.prepare_dynamic if is_dynamic_quant else quantizer.prepare
prepared = prepare(graph_module, qconfig_dict, inplace)
return prepared
def prepare_fx(graph_module, qconfig_dict, inplace=False):
r""" Prepare a model for post training static quantization or
qantization aware training, not for public use.
Args:
graph_module: model from symbolic_tracing (torch.fx.symbolic_trace), must be
an eval model
qconfig_dict: see :func:`~torch.quantization.quantize_fx`
Return:
A GraphModule with observer or fake quant modules, ready for
calibration or quantization aware training
"""
return _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant=False)
def prepare_static_fx(graph_module, qconfig_dict, inplace=False):
assert not graph_module.training, 'prepare_static_fx only works for models in ' + \
'eval mode'
return prepare_fx(graph_module, qconfig_dict, inplace)
def prepare_qat_fx(graph_module, qconfig_dict, inplace=False):
r""" Prepare a model for quantization aware training
Args:
graph_module: model from symbolic_tracing (torch.fx.symbolic_trace), must be
a train model
qconfig_dict: see :func:`~torch.quantization.quantize_fx`
Return:
A GraphModule with observer or fake quant modules, ready for
calibration or quantization aware training
"""
assert graph_module.training, 'prepare_qat_fx only works for models in ' + \
'train mode'
return prepare_fx(graph_module, qconfig_dict, inplace)
def prepare_dynamic_fx(graph_module, qconfig_dict, inplace=False):
r""" Prepare a model for post training dynamic quantization
"""
return _prepare_fx(graph_module, qconfig_dict, inplace, True)
def _convert_fx(graph_module, inplace=False, debug=False, is_dynamic_quant=False):
_check_is_graph_module(graph_module)
quantizer = Quantizer()
return quantizer.convert(graph_module, inplace, debug, is_dynamic_quant)
def convert_fx(graph_module, inplace=False, debug=False):
r""" Convert a calibrated or trained model to a quantized model
"""
return _convert_fx(graph_module, inplace, debug, is_dynamic_quant=False)
convert_static_fx = convert_fx
convert_qat_fx = convert_fx
def convert_dynamic_fx(graph_module, inplace=False, debug=False):
return _convert_fx(graph_module, inplace, debug, is_dynamic_quant=True)
def _quantize_fx(model, qconfig_dict, run_fn=None, run_args=None, inplace=False,
debug=False, is_dynamic_quant=False):
assert not model.training, 'quantize_fx is only used for post training ' + \
'quantization(eval mode), for quantization aware training please use ' + \
'prepare_qat_fx and convert_qat_fx.'
if is_dynamic_quant:
model = prepare_dynamic_fx(model, qconfig_dict, inplace)
# TODO: change inplace to True since the model is already copied in
# prepare
model = convert_dynamic_fx(model, False, debug)
else:
assert run_fn, "Must provide calibration function for post training static quantization"
assert run_args, "Must provide calibration dataset for post training static quantization"
model = prepare_fx(model, qconfig_dict, inplace)
run_fn(model, *run_args)
# TODO: change inplace to True since the model is already copied in
# prepare
model = convert_fx(model, False, debug)
return model
def quantize_fx(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False):
r"""Quantize the input float symbolically traced GraphModule model with
post training static quantization
First it will prepare the model for calibration, then it calls
`run_fn` which will run the calibration step, after that we will
convert the model to a quantized model.
Args:
`model`: input float TorchScript model
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
qconfig for that module as value, empty key means the qconfig will be applied
to whole model unless its overwritten by more specific configurations, the
qconfig for each module is either found in the dictionary or fallback to
the qconfig of parent module.
Right now qconfig_dict is the only way to configure how the model is quantized,
and it is done in the granularity of module, that is, we only support one type
of qconfig for each torch.nn.Module, and the qconfig for sub module will
override the qconfig for parent module, empty string means global configuration.
`run_fn`: a calibration function for calibrating the prepared model
`run_args`: positional arguments for `run_fn`
`inplace`: carry out model transformations in-place, the original module is
mutated
`debug`: flag for producing a debug friendly model (preserve weight attribute)
Return:
Quantized TorchSciprt model.
Example:
```python
import torch
from torch.quantization import get_default_qconfig
from torch.quantization import quantize_fx
graph_module = torch.fx.symbolic_trace(float_model.eval())
qconfig = get_default_qconfig('fbgemm')
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
quantized_model = quantize_fx(
graph_module,
{'': qconfig},
calibrate,
[data_loader_test])
```
"""
return _quantize_fx(
model, qconfig_dict, run_fn, run_args, inplace, debug, is_dynamic_quant=False)
def quantize_dynamic_fx(model, qconfig_dict, inplace=False, debug=False):
r"""Quantize the input float symbolically traced GraphModule model with
post training dynamic quantization.
Currently only qint8 quantization of torch.nn.Linear is supported.
Args:
`model`: input float TorchScript model
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
qconfig for that module as value, please see detailed
descriptions in :func:`~torch.quantization.quantize_fx`
`inplace`: carry out model transformations in-place, the original module is
mutated
`debug`: flag for producing a debug friendly model (preserve weight attribute)
Return:
Quantized TorchSciprt model.
Example:
```python
import torch
from torch.quantization import per_channel_dynamic_qconfig
from torch.quantization import quantize_dynmiac_fx
graph_module = torch.fx.symbolic_trace(float_model.eval())
qconfig = get_default_qconfig('fbgemm')
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
quantized_model = quantize_dynamic_fx(
graph_module,
{'': qconfig},
calibrate,
[data_loader_test])
```
"""
return _quantize_fx(
model, qconfig_dict, inplace=inplace, debug=debug, is_dynamic_quant=True)

View File

@ -1,16 +1,10 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import enum
import torch
from .qconfig import QConfig
from .quant_type import QuantType
from torch.jit._recursive import wrap_cpp_module
# Quantization type (dynamic quantization, static quantization).
# Should match the c++ enum in quantization_type.h
class QuantType(enum.IntEnum):
DYNAMIC = 0
STATIC = 1
def _check_is_script_module(model):
if not isinstance(model, torch.jit.ScriptModule):
raise ValueError('input must be a script module, got: ' + str(type(model)))

View File

@ -26,10 +26,13 @@ from torch.quantization.default_mappings import (
from torch.fx import symbolic_trace
# graph mode quantization based on fx
from torch.quantization._quantize_fx import (
Quantizer,
from torch.quantization import (
QuantType,
fuse,
fuse_fx,
prepare_fx,
prepare_dynamic_fx,
convert_fx,
convert_dynamic_fx,
)
import copy
@ -611,18 +614,20 @@ class QuantizationTestCase(TestCase):
else:
model.eval()
original = symbolic_trace(model)
fused = fuse(original)
fused = fuse_fx(original)
quantizer = Quantizer()
qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
if quant_type == QuantType.DYNAMIC:
prepared = quantizer.prepare_dynamic(fused, qconfig_dict)
prepare = prepare_dynamic_fx
convert = convert_dynamic_fx
else:
prepared = quantizer.prepare(fused, qconfig_dict)
prepare = prepare_fx
convert = convert_fx
prepared = prepare(fused, qconfig_dict)
prepared(*inputs)
qgraph = quantizer.convert(prepared)
qgraph_debug = quantizer.convert(prepared, debug=True)
qgraph = convert(prepared)
qgraph_debug = convert(prepared, debug=True)
result = qgraph(*inputs)
result_debug = qgraph_debug(*inputs)