mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43901 Add similar APIs like eager and graph mode on torchscript - fuse_fx - quantize_fx (for both post training static and qat) - quantize_dynamic_fx (for post training dynamic) - prepare_fx (for both post training static and qat) - prepare_dynamic_fx (for post training dynamic) - convert_fx (for all modes) Test Plan: Imported from OSS Imported from OSS Reviewed By: vkuzo Differential Revision: D23432430 fbshipit-source-id: fc99eb75cbecd6ee7a3aa6c8ec71cd499ff7e3c1
This commit is contained in:
parent
deb5fde51c
commit
7db7da7151
2
mypy.ini
2
mypy.ini
|
|
@ -74,7 +74,7 @@ ignore_errors = True
|
|||
[mypy-torch.quantization._numeric_suite]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.quantization._quantize_fx]
|
||||
[mypy-torch.quantization.quantize_fx]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.quantization.fx.*]
|
||||
|
|
|
|||
|
|
@ -10,10 +10,11 @@ import torch.multiprocessing as mp
|
|||
from torch.fx import symbolic_trace
|
||||
|
||||
# graph mode quantization based on fx
|
||||
from torch.quantization._quantize_fx import (
|
||||
Quantizer,
|
||||
fuse,
|
||||
from torch.quantization import (
|
||||
QuantType,
|
||||
fuse_fx,
|
||||
prepare_fx,
|
||||
convert_fx,
|
||||
)
|
||||
|
||||
from torch.quantization import (
|
||||
|
|
@ -654,11 +655,10 @@ class TestQuantizeFxOps(QuantizationTestCase):
|
|||
m = M()
|
||||
original = symbolic_trace(m)
|
||||
# nothing to fuse so skipping the fuse step
|
||||
quantizer = Quantizer()
|
||||
qconfig_dict = {'': default_qconfig}
|
||||
prepared = quantizer.prepare(original, qconfig_dict)
|
||||
prepared = prepare_fx(original, qconfig_dict)
|
||||
# not runnable
|
||||
quantized = quantizer.convert(prepared)
|
||||
quantized = convert_fx(prepared)
|
||||
|
||||
# This checks that the dequantize from the output of first conv
|
||||
# is being propagated to the end, so that we don't insert extra
|
||||
|
|
@ -750,11 +750,10 @@ class TestQuantizeFxOps(QuantizationTestCase):
|
|||
m = M()
|
||||
original = symbolic_trace(m)
|
||||
# nothing to fuse so skipping the fuse step
|
||||
quantizer = Quantizer()
|
||||
qconfig_dict = {'': default_qconfig}
|
||||
prepared = quantizer.prepare(original, qconfig_dict)
|
||||
prepared = prepare_fx(original, qconfig_dict)
|
||||
# not runnable
|
||||
quantized = quantizer.convert(prepared)
|
||||
quantized = convert_fx(prepared)
|
||||
|
||||
# This checks that the dequantize from the output of first conv
|
||||
# is being propagated to the end, so that we don't insert extra
|
||||
|
|
@ -817,9 +816,8 @@ class TestQuantizeFxModels(QuantizationTestCase):
|
|||
if mode != 'static':
|
||||
model.train()
|
||||
|
||||
graph_module = fuse(graph_module)
|
||||
quantizer = Quantizer()
|
||||
prepared = quantizer.prepare(graph_module, qconfig_dict)
|
||||
graph_module = fuse_fx(graph_module)
|
||||
prepared = prepare_fx(graph_module, qconfig_dict)
|
||||
|
||||
if mode == 'ddp':
|
||||
mp.spawn(run_ddp,
|
||||
|
|
@ -837,7 +835,7 @@ class TestQuantizeFxModels(QuantizationTestCase):
|
|||
|
||||
# print('after observation root:', prepared.root)
|
||||
|
||||
qgraph = quantizer.convert(prepared)
|
||||
qgraph = convert_fx(prepared)
|
||||
# print('after quantization root:', qgraph.root)
|
||||
# print('after quantization code:', qgraph.src)
|
||||
qgraph.eval()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@ from .qconfig import *
|
|||
from .fake_quantize import *
|
||||
from .fuse_modules import fuse_modules
|
||||
from .stubs import *
|
||||
from .quant_type import *
|
||||
from .quantize_jit import *
|
||||
from .quantize_fx import *
|
||||
|
||||
def default_eval_fn(model, calib_data):
|
||||
r"""
|
||||
|
|
@ -20,8 +22,12 @@ _all__ = [
|
|||
# Top level API for eager mode quantization
|
||||
'quantize', 'quantize_dynamic', 'quantize_qat',
|
||||
'prepare', 'convert', 'prepare_qat',
|
||||
# Top level API for graph mode quantization
|
||||
# Top level API for graph mode quantization on TorchScript
|
||||
'quantize_jit', 'quantize_dynamic_jit',
|
||||
# Top level API for graph mode quantization on GraphModule(torch.fx)
|
||||
'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx
|
||||
'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',
|
||||
'QuantType', # quantization type
|
||||
# Sub functions for `prepare` and `swap_module`
|
||||
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
|
||||
'default_eval_fn', 'get_observer_dict',
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
from .fx import Quantizer # noqa: F401
|
||||
from .fx import QuantType # noqa: F401
|
||||
from .fx import fuse # noqa: F401
|
||||
|
|
@ -1,3 +1,3 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
from .quantize import Quantizer, QuantType
|
||||
from .fuse import fuse
|
||||
from .quantize import Quantizer
|
||||
from .fuse import Fuser
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ class ModuleReLUFusion():
|
|||
return quantizer.fused_graph.node_copy(self.module_node, load_arg)
|
||||
|
||||
class Fuser:
|
||||
def fuse_conv_bn(self, model, inplace=False):
|
||||
def fuse(self, model, inplace=False):
|
||||
input_root = model.root
|
||||
if not inplace:
|
||||
input_root = copy.deepcopy(input_root)
|
||||
|
|
@ -173,7 +173,3 @@ class Fuser:
|
|||
apply_match(pattern, node, (node, value(self, node)))
|
||||
|
||||
return match_map
|
||||
|
||||
def fuse(graph_module, inplace=False):
|
||||
fuser = Fuser()
|
||||
return fuser.fuse_conv_bn(graph_module, inplace)
|
||||
|
|
|
|||
|
|
@ -33,16 +33,8 @@ from .utils import _parent_name
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
import copy
|
||||
import enum
|
||||
import operator
|
||||
|
||||
# Quantization type (dynamic quantization, static quantization).
|
||||
# Should match the c++ enum in quantization_type.h
|
||||
class QuantType(enum.IntEnum):
|
||||
DYNAMIC = 0
|
||||
STATIC = 1
|
||||
QAT = 2
|
||||
|
||||
# ------------------------
|
||||
# Helper Functions
|
||||
# ------------------------
|
||||
|
|
@ -383,8 +375,8 @@ class BatchNorm(QuantizeHandler):
|
|||
load_arg(quantized=False)(self.bn_node.kwargs))
|
||||
|
||||
ARGS_TO_SKIP = {
|
||||
torch.ops.quantized.hardswish: ['inplace'],
|
||||
torch.ops.quantized.instance_norm:
|
||||
torch._ops.ops.quantized.hardswish: ['inplace'],
|
||||
torch._ops.ops.quantized.instance_norm:
|
||||
['running_mean', 'running_var', 'use_input_stats', 'momentum'],
|
||||
}
|
||||
@register_quant_pattern(torch.nn.ELU)
|
||||
|
|
@ -621,15 +613,16 @@ class Quantizer:
|
|||
elif node.op == 'call_module':
|
||||
self.qconfig_map[node.name] = get_qconfig(self.modules[node.target])
|
||||
|
||||
def _prepare(self, model, qconfig_dict, inplace, quant_type):
|
||||
def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant):
|
||||
assert not inplace, 'inplace prepare is not supported yet'
|
||||
input_root = model.root
|
||||
if not inplace:
|
||||
input_root = copy.deepcopy(input_root)
|
||||
|
||||
input_graph = model.graph
|
||||
self.quant_type = quant_type
|
||||
self.is_dynamic_quant = is_dynamic_quant
|
||||
# TODO: allow user specified patterns
|
||||
if self.quant_type == QuantType.DYNAMIC:
|
||||
if self.is_dynamic_quant:
|
||||
self.patterns = get_dynamic_quant_patterns()
|
||||
else:
|
||||
self.patterns = get_quant_patterns()
|
||||
|
|
@ -688,7 +681,7 @@ class Quantizer:
|
|||
observed.add(node.name)
|
||||
|
||||
# don't need to insert observer for output in dynamic quantization
|
||||
if self.quant_type == QuantType.DYNAMIC:
|
||||
if self.is_dynamic_quant:
|
||||
continue
|
||||
|
||||
if isinstance(obj, CopyNode):
|
||||
|
|
@ -725,22 +718,44 @@ class Quantizer:
|
|||
observed.add(node.name)
|
||||
observed_graph.output(load_arg(input_graph.result))
|
||||
|
||||
return GraphModule(input_root, observed_graph)
|
||||
observed = GraphModule(input_root, observed_graph)
|
||||
self.save_state(observed)
|
||||
return observed
|
||||
|
||||
def save_state(self, observed):
|
||||
observed._activation_post_process_map = self.activation_post_process_map
|
||||
observed._patterns = self.patterns
|
||||
observed._qconfig_map = self.qconfig_map
|
||||
|
||||
def restore_state(self, observed):
|
||||
err_msg = 'please make sure the model is produced by prepare'
|
||||
assert hasattr(observed, '_activation_post_process_map'), 'did not found ' + \
|
||||
'_activation_post_process attribute ' + err_msg
|
||||
assert hasattr(observed, '_patterns'), 'did not found ' + \
|
||||
'_patterns attribute ' + err_msg
|
||||
assert hasattr(observed, '_qconfig_map'), 'did not found ' + \
|
||||
'_qconfig_map attribute ' + err_msg
|
||||
self.activation_post_process_map = observed._activation_post_process_map
|
||||
self.patterns = observed._patterns
|
||||
self.qconfig_map = observed._qconfig_map
|
||||
|
||||
def prepare(self, model, qconfig_dict, inplace=False):
|
||||
return self._prepare(model, qconfig_dict, inplace, quant_type=QuantType.STATIC)
|
||||
return self._prepare(model, qconfig_dict, inplace, is_dynamic_quant=False)
|
||||
|
||||
def prepare_dynamic(self, model, qconfig_dict, inplace=False):
|
||||
return self._prepare(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC)
|
||||
return self._prepare(model, qconfig_dict, inplace, is_dynamic_quant=True)
|
||||
|
||||
def convert(self, observed, inplace=False, debug=False):
|
||||
assert self.activation_post_process_map is not None
|
||||
def convert(self, observed, inplace=False, debug=False, is_dynamic_quant=False):
|
||||
assert not inplace, 'inplace convert is not supported yet'
|
||||
self.restore_state(observed)
|
||||
self.is_dynamic_quant = is_dynamic_quant
|
||||
# move to cpu since we only have quantized cpu kernels
|
||||
observed.eval().cpu()
|
||||
observed_root = observed.root
|
||||
observed_graph = observed.graph
|
||||
if not inplace:
|
||||
observed_root = copy.deepcopy(observed_root)
|
||||
|
||||
self.modules = dict(observed_root.named_modules())
|
||||
|
||||
matches = self._find_matches(observed.graph, self.modules, self.patterns)
|
||||
|
|
@ -833,7 +848,8 @@ class Quantizer:
|
|||
'CopyNode of type ' + node.op + ' is not handled'
|
||||
quantized = is_quantized(node.args[0])
|
||||
|
||||
if self.quant_type == QuantType.DYNAMIC:
|
||||
# output of dynamic quantization is not quantized
|
||||
if self.is_dynamic_quant:
|
||||
quantized = False
|
||||
|
||||
if quantized:
|
||||
|
|
@ -951,7 +967,7 @@ class Quantizer:
|
|||
for i, node_arg in enumerate(node.args):
|
||||
if arg is node_arg and i in WEIGHT_INDEX_DICT[node.target]:
|
||||
is_weight = True
|
||||
if self.quant_type != QuantType.DYNAMIC or is_weight:
|
||||
if (not self.is_dynamic_quant) or is_weight:
|
||||
# overwrite previous quant config
|
||||
quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight)
|
||||
return visit_arg
|
||||
|
|
|
|||
10
torch/quantization/quant_type.py
Normal file
10
torch/quantization/quant_type.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import enum
|
||||
|
||||
# Quantization type (dynamic quantization, static quantization).
|
||||
# Should match the c++ enum in quantization_type.h
|
||||
class QuantType(enum.IntEnum):
|
||||
DYNAMIC = 0
|
||||
STATIC = 1
|
||||
QAT = 2
|
||||
201
torch/quantization/quantize_fx.py
Normal file
201
torch/quantization/quantize_fx.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
from .fx import Fuser # noqa: F401
|
||||
from .fx import Quantizer # noqa: F401
|
||||
from torch.fx import GraphModule
|
||||
|
||||
def _check_is_graph_module(model):
|
||||
if not isinstance(model, GraphModule):
|
||||
raise ValueError(
|
||||
'input model must be a GraphModule, ' +
|
||||
'please run torch.fx.symbolic_trace on your model before using ' +
|
||||
'quantize_fx. Got type:' + str(type(model)))
|
||||
|
||||
def fuse_fx(graph_module, inplace=False):
|
||||
r""" Fuse modules in preparation for quantization
|
||||
|
||||
Args:
|
||||
graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
|
||||
"""
|
||||
_check_is_graph_module(graph_module)
|
||||
fuser = Fuser()
|
||||
return fuser.fuse(graph_module, inplace)
|
||||
|
||||
def _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant):
|
||||
_check_is_graph_module(graph_module)
|
||||
|
||||
quantizer = Quantizer()
|
||||
prepare = quantizer.prepare_dynamic if is_dynamic_quant else quantizer.prepare
|
||||
prepared = prepare(graph_module, qconfig_dict, inplace)
|
||||
return prepared
|
||||
|
||||
def prepare_fx(graph_module, qconfig_dict, inplace=False):
|
||||
r""" Prepare a model for post training static quantization or
|
||||
qantization aware training, not for public use.
|
||||
|
||||
Args:
|
||||
graph_module: model from symbolic_tracing (torch.fx.symbolic_trace), must be
|
||||
an eval model
|
||||
qconfig_dict: see :func:`~torch.quantization.quantize_fx`
|
||||
|
||||
Return:
|
||||
A GraphModule with observer or fake quant modules, ready for
|
||||
calibration or quantization aware training
|
||||
"""
|
||||
return _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant=False)
|
||||
|
||||
def prepare_static_fx(graph_module, qconfig_dict, inplace=False):
|
||||
assert not graph_module.training, 'prepare_static_fx only works for models in ' + \
|
||||
'eval mode'
|
||||
return prepare_fx(graph_module, qconfig_dict, inplace)
|
||||
|
||||
def prepare_qat_fx(graph_module, qconfig_dict, inplace=False):
|
||||
r""" Prepare a model for quantization aware training
|
||||
Args:
|
||||
graph_module: model from symbolic_tracing (torch.fx.symbolic_trace), must be
|
||||
a train model
|
||||
qconfig_dict: see :func:`~torch.quantization.quantize_fx`
|
||||
|
||||
Return:
|
||||
A GraphModule with observer or fake quant modules, ready for
|
||||
calibration or quantization aware training
|
||||
"""
|
||||
assert graph_module.training, 'prepare_qat_fx only works for models in ' + \
|
||||
'train mode'
|
||||
return prepare_fx(graph_module, qconfig_dict, inplace)
|
||||
|
||||
def prepare_dynamic_fx(graph_module, qconfig_dict, inplace=False):
|
||||
r""" Prepare a model for post training dynamic quantization
|
||||
"""
|
||||
return _prepare_fx(graph_module, qconfig_dict, inplace, True)
|
||||
|
||||
def _convert_fx(graph_module, inplace=False, debug=False, is_dynamic_quant=False):
|
||||
_check_is_graph_module(graph_module)
|
||||
quantizer = Quantizer()
|
||||
return quantizer.convert(graph_module, inplace, debug, is_dynamic_quant)
|
||||
|
||||
def convert_fx(graph_module, inplace=False, debug=False):
|
||||
r""" Convert a calibrated or trained model to a quantized model
|
||||
"""
|
||||
return _convert_fx(graph_module, inplace, debug, is_dynamic_quant=False)
|
||||
|
||||
convert_static_fx = convert_fx
|
||||
convert_qat_fx = convert_fx
|
||||
|
||||
def convert_dynamic_fx(graph_module, inplace=False, debug=False):
|
||||
return _convert_fx(graph_module, inplace, debug, is_dynamic_quant=True)
|
||||
|
||||
def _quantize_fx(model, qconfig_dict, run_fn=None, run_args=None, inplace=False,
|
||||
debug=False, is_dynamic_quant=False):
|
||||
assert not model.training, 'quantize_fx is only used for post training ' + \
|
||||
'quantization(eval mode), for quantization aware training please use ' + \
|
||||
'prepare_qat_fx and convert_qat_fx.'
|
||||
|
||||
if is_dynamic_quant:
|
||||
model = prepare_dynamic_fx(model, qconfig_dict, inplace)
|
||||
# TODO: change inplace to True since the model is already copied in
|
||||
# prepare
|
||||
model = convert_dynamic_fx(model, False, debug)
|
||||
else:
|
||||
assert run_fn, "Must provide calibration function for post training static quantization"
|
||||
assert run_args, "Must provide calibration dataset for post training static quantization"
|
||||
model = prepare_fx(model, qconfig_dict, inplace)
|
||||
run_fn(model, *run_args)
|
||||
# TODO: change inplace to True since the model is already copied in
|
||||
# prepare
|
||||
model = convert_fx(model, False, debug)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def quantize_fx(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False):
|
||||
r"""Quantize the input float symbolically traced GraphModule model with
|
||||
post training static quantization
|
||||
|
||||
First it will prepare the model for calibration, then it calls
|
||||
`run_fn` which will run the calibration step, after that we will
|
||||
convert the model to a quantized model.
|
||||
|
||||
Args:
|
||||
`model`: input float TorchScript model
|
||||
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
|
||||
qconfig for that module as value, empty key means the qconfig will be applied
|
||||
to whole model unless it’s overwritten by more specific configurations, the
|
||||
qconfig for each module is either found in the dictionary or fallback to
|
||||
the qconfig of parent module.
|
||||
|
||||
Right now qconfig_dict is the only way to configure how the model is quantized,
|
||||
and it is done in the granularity of module, that is, we only support one type
|
||||
of qconfig for each torch.nn.Module, and the qconfig for sub module will
|
||||
override the qconfig for parent module, empty string means global configuration.
|
||||
`run_fn`: a calibration function for calibrating the prepared model
|
||||
`run_args`: positional arguments for `run_fn`
|
||||
`inplace`: carry out model transformations in-place, the original module is
|
||||
mutated
|
||||
`debug`: flag for producing a debug friendly model (preserve weight attribute)
|
||||
|
||||
Return:
|
||||
Quantized TorchSciprt model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import torch
|
||||
from torch.quantization import get_default_qconfig
|
||||
from torch.quantization import quantize_fx
|
||||
|
||||
graph_module = torch.fx.symbolic_trace(float_model.eval())
|
||||
qconfig = get_default_qconfig('fbgemm')
|
||||
def calibrate(model, data_loader):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for image, target in data_loader:
|
||||
model(image)
|
||||
|
||||
quantized_model = quantize_fx(
|
||||
graph_module,
|
||||
{'': qconfig},
|
||||
calibrate,
|
||||
[data_loader_test])
|
||||
```
|
||||
"""
|
||||
return _quantize_fx(
|
||||
model, qconfig_dict, run_fn, run_args, inplace, debug, is_dynamic_quant=False)
|
||||
|
||||
def quantize_dynamic_fx(model, qconfig_dict, inplace=False, debug=False):
|
||||
r"""Quantize the input float symbolically traced GraphModule model with
|
||||
post training dynamic quantization.
|
||||
Currently only qint8 quantization of torch.nn.Linear is supported.
|
||||
|
||||
Args:
|
||||
`model`: input float TorchScript model
|
||||
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
|
||||
qconfig for that module as value, please see detailed
|
||||
descriptions in :func:`~torch.quantization.quantize_fx`
|
||||
`inplace`: carry out model transformations in-place, the original module is
|
||||
mutated
|
||||
`debug`: flag for producing a debug friendly model (preserve weight attribute)
|
||||
|
||||
Return:
|
||||
Quantized TorchSciprt model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import torch
|
||||
from torch.quantization import per_channel_dynamic_qconfig
|
||||
from torch.quantization import quantize_dynmiac_fx
|
||||
|
||||
graph_module = torch.fx.symbolic_trace(float_model.eval())
|
||||
qconfig = get_default_qconfig('fbgemm')
|
||||
def calibrate(model, data_loader):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for image, target in data_loader:
|
||||
model(image)
|
||||
|
||||
quantized_model = quantize_dynamic_fx(
|
||||
graph_module,
|
||||
{'': qconfig},
|
||||
calibrate,
|
||||
[data_loader_test])
|
||||
```
|
||||
"""
|
||||
return _quantize_fx(
|
||||
model, qconfig_dict, inplace=inplace, debug=debug, is_dynamic_quant=True)
|
||||
|
|
@ -1,16 +1,10 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import enum
|
||||
import torch
|
||||
from .qconfig import QConfig
|
||||
from .quant_type import QuantType
|
||||
from torch.jit._recursive import wrap_cpp_module
|
||||
|
||||
# Quantization type (dynamic quantization, static quantization).
|
||||
# Should match the c++ enum in quantization_type.h
|
||||
class QuantType(enum.IntEnum):
|
||||
DYNAMIC = 0
|
||||
STATIC = 1
|
||||
|
||||
def _check_is_script_module(model):
|
||||
if not isinstance(model, torch.jit.ScriptModule):
|
||||
raise ValueError('input must be a script module, got: ' + str(type(model)))
|
||||
|
|
|
|||
|
|
@ -26,10 +26,13 @@ from torch.quantization.default_mappings import (
|
|||
from torch.fx import symbolic_trace
|
||||
|
||||
# graph mode quantization based on fx
|
||||
from torch.quantization._quantize_fx import (
|
||||
Quantizer,
|
||||
from torch.quantization import (
|
||||
QuantType,
|
||||
fuse,
|
||||
fuse_fx,
|
||||
prepare_fx,
|
||||
prepare_dynamic_fx,
|
||||
convert_fx,
|
||||
convert_dynamic_fx,
|
||||
)
|
||||
|
||||
import copy
|
||||
|
|
@ -611,18 +614,20 @@ class QuantizationTestCase(TestCase):
|
|||
else:
|
||||
model.eval()
|
||||
original = symbolic_trace(model)
|
||||
fused = fuse(original)
|
||||
fused = fuse_fx(original)
|
||||
|
||||
quantizer = Quantizer()
|
||||
qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
|
||||
if quant_type == QuantType.DYNAMIC:
|
||||
prepared = quantizer.prepare_dynamic(fused, qconfig_dict)
|
||||
prepare = prepare_dynamic_fx
|
||||
convert = convert_dynamic_fx
|
||||
else:
|
||||
prepared = quantizer.prepare(fused, qconfig_dict)
|
||||
prepare = prepare_fx
|
||||
convert = convert_fx
|
||||
|
||||
prepared = prepare(fused, qconfig_dict)
|
||||
prepared(*inputs)
|
||||
qgraph = quantizer.convert(prepared)
|
||||
qgraph_debug = quantizer.convert(prepared, debug=True)
|
||||
qgraph = convert(prepared)
|
||||
qgraph_debug = convert(prepared, debug=True)
|
||||
|
||||
result = qgraph(*inputs)
|
||||
result_debug = qgraph_debug(*inputs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user