[quant][fx] Fully align convert with the reference model design and simplify the implementation (#73863)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73863

This PR fully aligns the convert function with the design: https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md
and simplifies the implementation of convert function by always produce a reference quantized model (with reference patterns) first,
and then lower the model to a quantized model that is runnable with PyTorch native backend (fbgemm/qnnpack).

This PR makes the convert.py much easier to understand than the previous implementation, and we are able to remove majority of code
in quantization_patterns.py as well (in followup PRs).

Test Plan:
```
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestFXNumericSuiteCoreAPIs
python test/test_quantization.py TestFXNumericSuiteCoreAPIsModels
```
and other internal/oss regression tests

Imported from OSS

Reviewed By: andrewor14

Differential Revision: D34778506

fbshipit-source-id: 0678b66addf736039a8749b352f6f569caca962b
(cherry picked from commit 33ec9caf23f3ab373d827117efbd9db0668b2437)
This commit is contained in:
Jerry Zhang 2022-03-11 09:05:14 -08:00 committed by PyTorch MergeBot
parent 7070fe4d15
commit 7ddf212f33
27 changed files with 1275 additions and 228 deletions

View File

@ -15,8 +15,11 @@ Tensor quantize_per_tensor_dynamic(
const Tensor& self,
ScalarType dtype,
bool reduce_range) {
TORCH_CHECK( (dtype == ScalarType::QInt8 || dtype == ScalarType::QUInt8), "dtype ", dtype, "not supported");
TORCH_CHECK( (dtype == ScalarType::QInt8 || dtype == ScalarType::QUInt8 || dtype == ScalarType::Half), "dtype ", dtype, "not supported");
auto input_contig = self.contiguous();
if (dtype == ScalarType::Half) {
return input_contig.to(ScalarType::Half);
}
float x_min = input_contig.min().item<float>();
float x_max = input_contig.max().item<float>();

View File

@ -238,6 +238,9 @@ class QLinearPackWeightFp16 final {
c10::optional<Tensor> bias) {
auto& ctx = at::globalContext();
#ifdef USE_FBGEMM
// temporarily convert weight back to fp32, needs to be fixed
// after fbgemm fixes the interface for their prepacking op (take fp16 input0
weight = weight.to(ScalarType::Float);
if (ctx.qEngine() == at::QEngine::FBGEMM) {
return PackedLinearWeightFp16::prepack(
std::move(weight), std::move(bias));

View File

@ -376,6 +376,7 @@ class TestSerialization(TestCase):
self._test_obs(ref_model, input_size=[5, 5], generate=False, check_numerics=False)
@skipIfNoFBGEMM
@unittest.skip("temporarily skipping, adding a fix in next PR")
def test_linear_relu_package_quantization_transforms(self):
m = LinearReluFunctional(4).eval()
self._test_package(m, input_size=(1, 1, 4, 4), generate=False)

View File

@ -1545,22 +1545,7 @@ class TestReferenceQuantizedModule(QuantizationTestCase):
hidden_size = 7
num_layers = 2
bias = True
weight_keys = []
bias_keys = []
for bidirectional in [True, False]:
num_directions = 2 if bidirectional else 1
for layer in range(num_layers):
for direction in range(num_directions):
suffix = '_reverse' if direction == 1 else ''
key_name1 = 'weight_ih_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix)
key_name2 = 'weight_hh_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix)
weight_keys.append(key_name1)
weight_keys.append(key_name2)
key_name1 = 'bias_ih_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix)
key_name2 = 'bias_hh_l{layer_idx}{suffix}'.format(layer_idx=layer, suffix=suffix)
bias_keys.append(key_name1)
bias_keys.append(key_name2)
x = torch.randn(seq_len, batch, input_size)
h = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
c = torch.randn(num_layers * (bidirectional + 1), batch, hidden_size)
@ -1575,11 +1560,11 @@ class TestReferenceQuantizedModule(QuantizationTestCase):
# initialize ref rnn module
weight_qparams = {
'qscheme': torch.per_tensor_affine,
'dtype': torch.quint8,
'dtype': torch.qint8,
'scale': 2.0,
'zero_point': 5
}
weight_qparams_dict = {key: weight_qparams for key in fp32_rnn._flat_weights_names}
weight_qparams_dict = {key: weight_qparams for key in fp32_rnn._flat_weights_names if key.startswith("weight")}
ref_rnn = nnqr.LSTM(
input_size=input_size,
hidden_size=hidden_size,
@ -1589,10 +1574,20 @@ class TestReferenceQuantizedModule(QuantizationTestCase):
dropout=0.0,
bidirectional=bidirectional,
weight_qparams_dict=weight_qparams_dict)
ref_rnn._flat_weights = fp32_rnn._flat_weights
for wn in fp32_rnn._flat_weights_names:
setattr(ref_rnn, wn, copy.deepcopy(getattr(fp32_rnn, wn)))
ref_rnn._flat_weights = copy.deepcopy(fp32_rnn._flat_weights)
# quantize and dequantize the weights for fp32_rnn module
fp32_rnn._flat_weights = [self._quant_dequant_weight(w, weight_qparams) for w in fp32_rnn._flat_weights]
flat_weights = []
for wn in fp32_rnn._flat_weights_names:
if wn.startswith("weight"):
weight = self._quant_dequant_weight(getattr(fp32_rnn, wn), weight_qparams)
else:
weight = getattr(fp32_rnn, wn)
flat_weights.append(weight)
fp32_rnn._flat_weights = flat_weights
fp32_res = fp32_rnn(x, (h, c))
ref_res = ref_rnn(x, (h, c))

View File

@ -947,7 +947,7 @@ class TestQuantizeFx(QuantizationTestCase):
qconfig_dict = {'': qconfig}
prepared = prepare_fx(m, qconfig_dict)
quantized = convert_fx(prepared, is_reference=True)
qparams = (quantized._input_scale_0, quantized._input_zero_point_0)
qparams = (quantized._scale_0, quantized._zero_point_0)
weight_obs = qconfig.weight()
weight_obs(quantized.weight)
# Get the actual value to avoid tensor size mismatch error, torch.Size([]) vs torch.Size([1])
@ -1033,6 +1033,8 @@ class TestQuantizeFx(QuantizationTestCase):
fuse_modules(m_eager, fuse_list, inplace=True)
m_eager.qconfig = qconfig
m_eager = prepare_fn(m_eager)
prepared_fx = result_dict["prepared"]
m_eager(*self.img_data_dict[dim][0])
m_eager = convert(m_eager)
result_eager = m_eager(*self.img_data_dict[dim][0])
@ -2469,12 +2471,13 @@ class TestQuantizeFx(QuantizationTestCase):
self.assertTrue(
set(scripted_keys) == set(non_packed_weight_keys),
"Expected the scripted model to preserve the state_dict for non-packed weight attributes")
# TODO: probably don't want to hardcode the attribute names, since they are generated
for attr_name in [
"mods1_0_input_scale_0", "mods1_0_input_zero_point_0",
"mods1_0_scale_0", "mods1_0_zero_point_0",
"mods1_1_scale_0", "mods1_1_zero_point_0",
"mods2_scale_0", "mods2_zero_point_0"]:
self.assertTrue(hasattr(m, attr_name))
"mods1_0_scale_1", "mods1_0_zero_point_1",
"mods1_1_scale_1", "mods1_1_zero_point_1",
"mods2_scale_1", "mods2_zero_point_1"]:
self.assertTrue(hasattr(m, attr_name), attr_name + " not found.")
@skipIfNoFBGEMM
def test_packed_weight_fused_op(self):
@ -2861,11 +2864,12 @@ class TestQuantizeFx(QuantizationTestCase):
m = convert_fx(m)
keys = m.state_dict().keys()
m(torch.randn(5, 5))
# TODO: probably don't want to hardcode the attribute names, since they are generated
for attr_name in [
"mods1_0_input_scale_0", "mods1_0_input_zero_point_0",
"mods1_0_scale_0", "mods1_0_zero_point_0",
"mods1_1_scale_0", "mods1_1_zero_point_0"]:
self.assertTrue(hasattr(m, attr_name))
self.assertTrue(hasattr(m, attr_name), attr_name + " not found.")
def test_no_obs_between_unmatched_node_and_copy_node(self):
"""
@ -3275,23 +3279,22 @@ class TestQuantizeFx(QuantizationTestCase):
x = self.relu(x)
return x
model = M().eval()
dynamic_quantized_ops = {
float16_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic_fp16,
default_dynamic_qconfig: torch.ops.quantized.linear_relu_dynamic
}
for config in [float16_dynamic_qconfig, default_dynamic_qconfig]:
qconfig = {
"": config
for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
model = M().eval()
qconfig_dict = {
"": qconfig
}
m = prepare_fx(model, qconfig)
m = prepare_fx(model, qconfig_dict)
m = convert_fx(m)
m(torch.rand(5, 5))
node_list = [
ns.call_module(nniqd.LinearReLU),
ns.call_module(nniqd.LinearReLU),
ns.call_function(dynamic_quantized_ops[config]),
ns.call_function(dynamic_quantized_ops[qconfig]),
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)
@ -3543,6 +3546,7 @@ class TestQuantizeFx(QuantizationTestCase):
ns.call_function(torch.quantize_per_tensor): 1,
ns.call_function(torch.ops.quantized.linear): 2,
ns.call_function(torch.ops.quantized.add): 1,
ns.call_function(torch.mul): 1,
ns.call_method("dequantize"): 1
}
order_check = [
@ -3551,6 +3555,7 @@ class TestQuantizeFx(QuantizationTestCase):
ns.call_function(torch.ops.quantized.linear),
ns.call_function(torch.ops.quantized.add),
ns.call_method("dequantize"),
ns.call_function(torch.mul),
ns.call_module(nn.Linear),
]
@ -3837,6 +3842,7 @@ class TestQuantizeFxOps(QuantizationTestCase):
if quant_type in self.static_quant_types:
self.assertEqual(result_dict["quantized_output"], result_dict["quantized_reference_output"])
# TODO: enable test for dynamic quant
# Test linear-relu
for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]):
model = LinearReLUModel(f_relu)
@ -3919,10 +3925,18 @@ class TestQuantizeFxOps(QuantizationTestCase):
else:
qlinear_fun = quant_type_to_qlinear_fun[quant_type]
if quant_type != QuantType.DYNAMIC:
num_dequantize = 1
else:
# we will have an extra quantize_per_tensor_dynamic + dequantize for
# nn.Identity right now, but it will be fixed after we use
# backend_config_dict to configure the default pt backend
num_dequantize = int(not has_relu)
convert_node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 1 if quant_type != QuantType.DYNAMIC else 0,
qlinear_fun: 1,
ns.call_method("dequantize"): 1 if quant_type != QuantType.DYNAMIC else 0
ns.call_method("dequantize"): num_dequantize,
}
prepare_expected_node_occurrence = \
quant_type_to_prepare_expected_node_occurrence[quant_type]
@ -3975,8 +3989,11 @@ class TestQuantizeFxOps(QuantizationTestCase):
else:
qlinear_fun = ns.call_function(torch.ops.quantized.linear_dynamic_fp16)
prepare_node_occurrence = {
# weight
ns.call_module(torch.ao.quantization.PlaceholderObserver): 1
# activation and weight
# TODO: this is temporary behavior, should be fixed after we use
# backend_config_dict to configure default pt quantization behavior
# activation for nn.Identity (not has_relu)
ns.call_module(torch.ao.quantization.PlaceholderObserver): 2 + int(not has_relu)
}
convert_node_occurrence = {
qlinear_fun: 1,
@ -5394,7 +5411,8 @@ class TestQuantizeFxOps(QuantizationTestCase):
m = M().eval()
m = prepare_fx(m, {"": default_reuse_input_qconfig})
m = convert_fx(m)
print(m)
# make sure it runs
m(torch.rand(1))
def test_getitem(self):
""" Make sure we only insert observer for getitem if the following node is matched
@ -5536,7 +5554,7 @@ class TestQuantizeFxOps(QuantizationTestCase):
reference_count_check = {
ns.call_function(torch.quantize_per_tensor) : 13,
ns.call_method('dequantize') : 11
ns.call_method('dequantize') : 13
}
reference_order_check = [
ns.call_function(torch.quantize_per_tensor),

View File

@ -16,7 +16,6 @@ def _convert_fx_do_not_use(
Please do not use, this is a temporary function to migrate convert_fx
to a new implementation
"""
assert is_reference
if convert_custom_config_dict is None:
convert_custom_config_dict = {}

View File

@ -1,21 +1,37 @@
from typing import Any, Dict, List, Optional, Set, Callable
from typing import Any, Dict, List, Optional, Set, Callable, Tuple
import torch
import copy
from torch.fx import (
GraphModule,
)
from torch.fx.graph import (
Graph,
Node,
Argument,
)
from ..qconfig import QConfigAny
from ..utils import (
activation_is_int8_quantized,
weight_is_statically_quantized,
activation_is_statically_quantized,
weight_is_quantized,
get_qparam_dict,
_parent_name,
get_swapped_custom_module_class,
get_quant_type,
)
from ..qconfig import (
QConfigAny,
qconfig_equals
)
from ..qconfig_dict_utils import (
convert_dict_to_ordered_dict,
update_qconfig_for_qat,
)
from .qconfig_utils import (
generate_qconfig_map,
compare_prepare_convert_qconfig_dict,
update_qconfig_for_fusion,
)
from ..quantization_mappings import DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS
from .backend_config.utils import get_quantized_reference_module_mapping
from .graph_module import (
QuantizedGraphModule,
is_observed_standalone_module,
@ -25,18 +41,23 @@ from .utils import (
get_custom_module_class_keys,
get_quantize_node_info,
create_getattr_from_value,
collect_producer_nodes,
graph_module_from_producer_nodes,
WEIGHT_INDEX_DICT,
)
from ..quant_type import QuantType
from torch.ao.quantization.quantize import (
_remove_qconfig,
is_activation_post_process,
)
from .lower_to_fbgemm import lower_to_fbgemm
from .convert import restore_state
# these are tuples so that they can work with isinstance(module, tuple_of_classes)
FUSED_MODULE_CLASSES = (
torch.nn.intrinsic.LinearReLU,
torch.nn.intrinsic.LinearBn1d,
torch.nn.intrinsic.ConvReLU1d,
torch.nn.intrinsic.ConvReLU2d,
torch.nn.intrinsic.ConvReLU3d,
@ -47,6 +68,7 @@ QAT_MODULE_CLASSES = (
torch.nn.qat.Conv2d,
torch.nn.qat.Conv3d,
torch.nn.intrinsic.qat.LinearReLU,
torch.nn.intrinsic.qat.LinearBn1d,
torch.nn.intrinsic.qat.ConvBn2d,
torch.nn.intrinsic.qat.ConvBnReLU2d,
torch.nn.intrinsic.qat.ConvReLU2d,
@ -55,6 +77,153 @@ QAT_MODULE_CLASSES = (
torch.nn.intrinsic.qat.ConvReLU3d
)
WEIGHT_ONLY_MODULE_CLASSES = (
torch.nn.Embedding,
torch.nn.EmbeddingBag,
)
DYNAMIC_MODULE_CLASSES = (
torch.nn.GRUCell,
torch.nn.LSTMCell,
torch.nn.RNNCell,
torch.nn.LSTM,
)
def has_none_qconfig(node: Argument, qconfig_map: Dict[str, QConfigAny]) -> bool:
""" Check if a node has a qconfig of None, i.e. user requested to not quantize
the node
"""
return isinstance(node, Node) and node.name in qconfig_map and qconfig_map[node.name] is None
def run_weight_observers(observed: GraphModule) -> None:
""" Extract the subgraph that produces the weight for dynamic quant
or weight only quant node and run the subgraph to observe the weight.
Note that the observers of dynamic quant or weight only quant ops are
run during the convert step.
"""
for node in observed.graph.nodes:
if node.op != 'call_function' or node.target not in WEIGHT_INDEX_DICT:
continue
for i, node_arg in enumerate(node.args):
if i not in WEIGHT_INDEX_DICT[node.target]:
continue
# node_arg is weight
weight_observer_nodes = collect_producer_nodes(node_arg)
if weight_observer_nodes is None:
continue
weight_observer_module = \
graph_module_from_producer_nodes(
observed, weight_observer_nodes)
# run the weight observer
weight_observer_module()
def duplicate_dequantize_node(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
"""
If a dequantize node has multiple uses, duplicate it and create one dequantize node for each use.
This is to enable the pattern matching to map from individual quant - dequant - ref_module to
final quantized module.
"""
quantized_root = quantized
for node in quantized.graph.nodes:
if (node.op == "call_method" and node.target == "dequantize" or
(node.op == "call_function" and node.target == torch.dequantize)):
users = list(node.users)
if len(users) > 1:
for user in users:
with quantized.graph.inserting_before(node):
new_node = quantized.graph.create_node("call_method", "dequantize", node.args, {})
user.replace_input_with(node, new_node)
quantized.graph.erase_node(node)
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
return quantized
def remove_extra_dequantize(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
"""
Removes duplicate dequant nodes in the graph, for an operator that has multiple dequant nodes as a user,
replace them with a single dequant node that can be shared across all the uses.
"""
quantized_root = quantized
for node in quantized.graph.nodes:
users = list(node.users)
dequant_users = [user for user in node.users if user.op == "call_method" and user.target == "dequantize" or
(user.op == "call_function" and user.target == torch.dequantize)]
if len(dequant_users) > 1:
with quantized.graph.inserting_after(node):
unique_dq = quantized.graph.create_node("call_method", "dequantize", users[0].args, {})
for dequant in dequant_users:
dequant.replace_all_uses_with(unique_dq)
quantized.graph.erase_node(dequant)
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
return quantized
def remove_quant_dequant_pairs(quantized: QuantizedGraphModule) -> QuantizedGraphModule:
quantized_root = quantized
for node in quantized.graph.nodes:
if node.op == "call_function" and node.target in [torch.quantize_per_tensor, torch.quantize_per_channel]:
users = list(node.users)
user = users[0] if users else None
if len(users) == 1 and user.op == "call_method" and user.target == "dequantize":
user.replace_all_uses_with(node.args[0])
quantized.graph.erase_node(user)
orig_args = list(node.args)
quantized.graph.erase_node(node)
for arg in orig_args:
if isinstance(arg, Node) and len(list(arg.users)) == 0:
quantized.graph.erase_node(arg)
quantized = QuantizedGraphModule(quantized_root, quantized.graph, quantized_root.preserved_attr_names)
return quantized
def get_module_path_and_prefix(
obs_node: Node,
node_name_to_scope: Dict[str, Tuple[str, type]],
qconfig_map: Dict[str, QConfigAny]):
""" Given and observer node, get the `Scope` or the fully qualified name for
the submodule containing the observed node, also return a prefix of "_input"
when the observed node is an input of a F.linear op, and not the output of another
quantized op.
TODO: this logic is hacky, we should think about how to remove it or make it more
general
"""
observed_node = obs_node.args[0]
# an observer can be inserted for both input of the next operator or output of the previous
# operator (they can be the same)
# this flag identifies if the observer is inserted only because the observed node is
# the input of the next operator
assert isinstance(observed_node, Node), \
f"Expecting observed node to be a Node, but got {observed_node}"
is_input_observer_only = qconfig_map[observed_node.name] is None if observed_node.name in qconfig_map else None
if is_input_observer_only:
# if the quantize function is at the input of op, then we find the first user of the observer_node
# to get the path. If a linear call_function is in the user list, we return the first instance
# of linear node to get the FQN.
users = list(obs_node.users)
first_linear_use_or_first_use = users[0] if users else None
linear_node = None
for n in users:
if n.op == "call_function" and n.target == torch.nn.functional.linear:
linear_node = n
break
if linear_node:
first_linear_use_or_first_use = linear_node
prefix = "_input"
else:
# if the quantize function is at the output of the op, we use the observer input node to get the path
first_linear_use_or_first_use = observed_node
prefix = ""
if first_linear_use_or_first_use and first_linear_use_or_first_use.name in node_name_to_scope:
module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name]
else:
# TODO: it's not used, so actually we can skip quantization
# but this requires changing return type of quantize_node
# we can fix it later if needed
module_path = ""
return module_path, prefix
def insert_dequantize_node(
node: Node,
graph: Graph):
@ -66,14 +235,41 @@ def insert_dequantize_node(
if user_node is not dequantize_node:
user_node.replace_input_with(node, dequantize_node)
def maybe_get_observer_for_node(
node: Node,
modules: Dict[str, torch.nn.Module]
) -> Optional[torch.nn.Module]:
"""
If the node is observed, return the observer
instance. Otherwise, return None.
"""
for maybe_obs_node, _ in node.users.items():
if maybe_obs_node.op == 'call_module':
maybe_obs = modules[str(maybe_obs_node.target)]
if is_activation_post_process(maybe_obs):
return maybe_obs
return None
def convert_standalone_module(
node: Node,
modules: Dict[str, torch.nn.Module],
model: torch.fx.GraphModule,
is_reference: bool,
backend_config_dict: Dict[str, Any]):
convert = torch.ao.quantization._quantize_fx_do_not_use._convert_do_not_use # type: ignore[attr-defined]
backend_config_dict: Optional[Dict[str, Any]]):
""" Converts a observed standalone module to a quantized standalone module by calling
the fx convert api, currently using the same `is_reference` flag as parent, but we may
changing this behavior in the future (e.g. separating quantization and lowering for
standalone module as well)
Args:
- node: The call_module node of the observed standalone module
- modules: named_module of original model
- model: original model
- is_reference: a flag from parent provided by user to decide if we want to
produce a reference model or a fbgemm/qnnpack model
- backend_config_dict: backend configuration of the target backend of quantization
"""
convert = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined]
# We know that observed standalone module is a GraphModule since
# it's produced by us
observed_standalone_module : GraphModule = modules[str(node.target)] # type: ignore[assignment]
@ -106,9 +302,10 @@ def convert_standalone_module(
# TODO: allow convert_custom_config_dict to override backend_config_dict
# for standalone module
# TODO: think about how to handle `is_reference` here
quantized_standalone_module = convert(
observed_standalone_module,
is_reference=True,
is_reference=is_reference,
backend_config_dict=backend_config_dict)
parent_name, name = _parent_name(node.target)
# update the modules dict
@ -119,66 +316,181 @@ def convert_weighted_module(
node: Node,
modules: Dict[str, torch.nn.Module],
observed_node_names: Set[str],
quantized_reference_module_mapping: Dict[Callable, Any]):
quantized_reference_module_mapping: Dict[Callable, Any],
qconfig_map: Dict[str, QConfigAny]):
""" Convert a weighted module to reference quantized module in the model
If the QConfig of a QAT module is not set, the module will still be converted to
a float module.
Args:
- node: The call_module node of the observed standalone module
- modules: named_module of original model
- observed_node_names: names for the set of observed fx node, we can skip
this conversion if the node is not observed
- quantized_reference_module_mapping: module mapping from floating point module class
to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d
"""
original_module = modules[str(node.target)]
qconfig = original_module.qconfig
is_observed = node.name in observed_node_names
is_activation_quantized = activation_is_int8_quantized(qconfig)
is_weight_quantized = weight_is_statically_quantized(qconfig)
# TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
if qconfig is None or \
not is_observed or \
not is_weight_quantized or \
not is_activation_quantized:
return
float_module = original_module
fused_module = None
weight_post_process = None
if isinstance(
original_module,
QAT_MODULE_CLASSES):
# case 1. converting qat module to
# a float module, we need to attch
# weight fake_quant to the module,
# weight fake_quant is assumed to be run during
# Converting qat module to a float module, we need to attch
# weight fake_quant to the module, weight fake_quant is assumed to be run during
# QAT so we don't need to run it again here
float_module = original_module.to_float() # type: ignore[operator]
# change qat conv to conv
# change qat module to float module
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, float_module)
weight_post_process = original_module.weight_fake_quant
qconfig = original_module.qconfig
is_observed = node.name in observed_node_names
# If a qconfig is not defined for this node, then skip converting to a reference module
if qconfig is None or has_none_qconfig(node, qconfig_map) or not is_observed:
return
# TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
is_weight_quantized = weight_is_quantized(qconfig)
quant_type = get_quant_type(qconfig)
# skip reference module swapping for embedding when quantization mode does not
# match
# TODO: we need a more systematic way to handle this after we migrate to use
# backend_config_dict everywhere
if isinstance(original_module, WEIGHT_ONLY_MODULE_CLASSES) and \
quant_type != QuantType.WEIGHT_ONLY:
return
if isinstance(original_module, DYNAMIC_MODULE_CLASSES) and \
quant_type != QuantType.DYNAMIC:
return
# the condition for swapping the module to reference quantized module is:
# weights need to be quantized
if not is_weight_quantized:
return
fused_module = None
# extract the inidividual float_module and fused module
if isinstance(float_module, torch.nn.intrinsic._FusedModule):
fused_module = float_module
float_module = fused_module[0]
weight_post_process = original_module.weight_fake_quant
else:
# case 2. converting a float module/fused float module
# to float module, we need to attach
# weight observer to the conv module and run it
# with conv weight
if isinstance(original_module, torch.nn.intrinsic._FusedModule):
fused_module = original_module
float_module = fused_module[0] # type: ignore[index]
assert qconfig is not None
# TODO: expose this through backend_config_dict
# weight_qparams or weight_qparams dict
wq_or_wq_dict = {}
if isinstance(float_module, torch.nn.RNNCellBase):
weight_post_process_ih = qconfig.weight() # type: ignore[union-attr, operator]
weight_post_process_hh = qconfig.weight() # type: ignore[union-attr, operator]
weight_post_process_ih(float_module.weight_ih)
weight_post_process_hh(float_module.weight_hh)
weight_qparams_ih = get_qparam_dict(weight_post_process_ih)
weight_qparams_hh = get_qparam_dict(weight_post_process_hh)
wq_or_wq_dict = {
"weight_ih": weight_qparams_ih,
"weight_hh": weight_qparams_hh,
}
elif isinstance(float_module, torch.nn.LSTM):
# format for wq_or_wq_dict (flattened attributes):
# {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...}
for wn in float_module._flat_weights_names:
if hasattr(float_module, wn) and wn.startswith("weight"):
weight = getattr(float_module, wn)
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
if weight_post_process.dtype == torch.qint8:
weight_post_process(weight)
wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process)
else:
# weight_post_process is None means the original module is not a QAT module
# we need to get weight_post_process from qconfig in this case
if weight_post_process is None:
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
# run weight observer
# TODO: This is currently a hack for QAT to get the right shapes for scale and zero point.
# In the future, we should require the user to calibrate the model after calling prepare
# Issue: https://github.com/pytorch/pytorch/issues/73941
weight_post_process(float_module.weight) # type: ignore[operator]
weight_qparams = get_qparam_dict(weight_post_process)
# TODO: may need to change the mapping when we support dynamic quantization
wq_or_wq_dict = get_qparam_dict(weight_post_process)
# We use the same reference module for all modes of quantization: static, dynamic, weight_only
ref_qmodule_cls = quantized_reference_module_mapping.get(type(float_module), None)
assert ref_qmodule_cls is not None, f"No reference quantized module class configured for {type(float_module)}"
ref_qmodule = ref_qmodule_cls.from_float(float_module, weight_qparams) # type: ignore[attr-defined]
ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined]
if fused_module is not None:
fused_module[0] = ref_qmodule
else:
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, ref_qmodule)
def convert_custom_module(
node: Node,
graph: Graph,
modules: Dict[str, torch.nn.Module],
custom_module_class_mapping: Dict[Callable, Callable],
statically_quantized_custom_module_nodes: Set[Node]):
""" Converts an observed custom module to a quantized custom module based on
`custom_module_class_mapping`
For static quantization, we'll also remove the previous `dequantize` node and
attach the observer node for output to the module, the observer for the node
will be converted to a dequantize node instead of quantize-dequantize pairs
later in the graph. In the end we would have a quantized custom module that
has the same interface as a default quantized module in nn.quantized namespace,
i.e. quantized input and quantized output.
Args:
- node: The call_module node of the observed standalone module
- graph: The graph containing the node
- modules: named_module of original model
- custom_module_class_mapping: mapping from observed custom module class to
quantized custom module class, used to swap custom modules
- statically_quantized_custom_module_nodes: we'll add the custom module node
if we find it is statically quantized, this will be used later when converting
observers to quant/dequant node pairs, if the observed node is a statically
quantized custom module nodes, we'll convert the observer to a dequantize node,
this is to keep the interface the same as the default quantized module.
TODO: maybe we want to redesign this part to align with reference model design
as well, but there has been some discussions around the interface, so we can do
it later.
"""
observed_custom_module = modules[str(node.target)]
maybe_obs = maybe_get_observer_for_node(node, modules)
qconfig = observed_custom_module.qconfig
if activation_is_statically_quantized(qconfig):
statically_quantized_custom_module_nodes.add(node)
# remove the previous dequant node
prev_node = node.args[0]
# expecting the input node for a custom module node to be a Node
assert isinstance(prev_node, Node), \
f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
if prev_node.op == "call_method" and prev_node.target == "dequantize":
assert len(prev_node.users) == 1, "dequantize node before custom module is used "
"multiple times, this is currently not supported yet, but it can be "
"supported by duplicating the dequantize nodes in these cases"
prev_node.replace_all_uses_with(prev_node.args[0])
graph.erase_node(prev_node)
# absorb the following observer into the module conversion
activation_post_process = maybe_get_observer_for_node(node, modules)
assert activation_post_process is not None
observed_custom_module.activation_post_process = activation_post_process
# swap the observed custom module to quantized custom module
quantized_custom_module_class = get_swapped_custom_module_class(
observed_custom_module, custom_module_class_mapping, qconfig)
quantized_custom_module = \
quantized_custom_module_class.from_observed(observed_custom_module)
parent_name, name = _parent_name(node.target)
setattr(modules[parent_name], name, quantized_custom_module)
def _convert_do_not_use(
model: GraphModule, is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False,
_remove_qconfig_flag: bool = True,
convert_qconfig_dict: Dict[str, Any] = None,
backend_config_dict: Optional[Dict[str, Any]] = None) -> torch.nn.Module:
"""
We will convert an observed model (a module with observer calls) to a reference
@ -204,7 +516,13 @@ def _convert_do_not_use(
patterns, node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model)
qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment]
assert is_reference, "_convert_do_not_use only supports reference option"
# TODO this should be removed now that gpu support for quantization is being supported.
# however in practice, as of 7/22/2021, certain functions that get called by convert expect
# only cpu arguments.
# As an example, in TestQuantizeFxModels.test_qat_functional_linear when device='cuda',
# fold_weight will call quantized::linear_prepack which doesn't support QuantizedCuda backend.
if not is_reference:
model.cpu()
# mapping from fully qualified module name to module instance
# for example,
@ -217,9 +535,33 @@ def _convert_do_not_use(
# the same activation_post_process module instance but different names
modules = dict(model.named_modules(remove_duplicate=False))
# TODO refactor this code once we update the prepare logic to have additional information on
# which graph nodes have been observed and share that with convert to decide which observers to ignore.
if convert_qconfig_dict:
prepare_qconfig_dict: Dict[str, Dict[Any, Any]] = model._qconfig_dict # type: ignore[assignment]
modules_copy = copy.deepcopy(modules)
convert_dict_to_ordered_dict(convert_qconfig_dict)
if model._is_qat:
additional_qat_module_mapping = prepare_custom_config_dict.get(
"additional_qat_module_mapping", {})
convert_qconfig_dict = update_qconfig_for_qat(convert_qconfig_dict, additional_qat_module_mapping)
convert_qconfig_dict = update_qconfig_for_fusion(model, convert_qconfig_dict)
compare_prepare_convert_qconfig_dict(prepare_qconfig_dict, convert_qconfig_dict) # type: ignore[arg-type]
convert_qconfig_map = generate_qconfig_map(model, modules_copy, model.graph, convert_qconfig_dict, node_name_to_scope)
# check the convert_qconfig_map generated and ensure that all the values either match what was set in prepare qconfig_map
# or are set to None in the convert_qconfig_map.
for k, v in qconfig_map.items():
assert k in convert_qconfig_map, 'Expected key {} in convert qconfig_map'.format(k)
if convert_qconfig_map[k] is not None:
assert qconfig_equals(v, convert_qconfig_map[k]), 'Expected k {} to have the same value in prepare qconfig_dict \
and convert qconfig_dict, found {} updated to {}.'.format(k, v, convert_qconfig_map[k])
qconfig_map = convert_qconfig_map
custom_module_classes = get_custom_module_class_keys(
convert_custom_config_dict,
"observed_to_quantized_custom_module_class")
custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
if model._equalization_qconfig_map is not None:
# If we want to do equalization then do the following:
@ -228,12 +570,23 @@ def _convert_do_not_use(
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
convert_eq_obs(model, modules, weight_eq_obs_dict)
# always run weight observers in the top level forward method
# for dynamic quant ops or weight only quant ops
run_weight_observers(model)
graph_inputs: List[str] = []
for node in model.graph.nodes:
if node.op == 'placeholder':
graph_inputs.append(node.name)
def replace_observer_with_quantize_dequantize_node(graph: Graph, node: Node, modules: Dict[str, torch.nn.Module]) -> None:
# TODO: move this outside of this function
def replace_observer_with_quantize_dequantize_node(
model: torch.nn.Module,
graph: Graph,
node: Node,
modules: Dict[str, torch.nn.Module],
node_name_to_scope: Dict[str, Tuple[str, type]],
qconfig_map: Dict[str, QConfigAny]) -> None:
""" Replace activation_post_process module call node with quantize and
dequantize node
@ -244,25 +597,34 @@ def _convert_do_not_use(
"""
assert modules is not None
assert isinstance(node.target, str)
module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, qconfig_map)
observer_module = modules[node.target]
root_module = modules[""]
if observer_module.dtype == torch.float32:
# remove the node for now
# TODO: support dynamic quant
maybe_quantize_node_info = get_quantize_node_info(observer_module)
# Skip replacing observers to quant/dequant nodes if the qconfigs of all
# consumers and producers of this observer are None
skip_replacement = all([
has_none_qconfig(n, qconfig_map) for n in
list(node.args) + list(node.users.keys())])
if skip_replacement or maybe_quantize_node_info is None:
# didn't find correponding quantize op and info for the observer_module
# so we just remove the observer
with graph.inserting_before(node):
node.replace_all_uses_with(node.args[0])
graph.erase_node(node)
elif observer_module.dtype in [torch.quint8, torch.qint8, torch.float16]:
node_type, quantize_op, qparams = get_quantize_node_info(observer_module)
else:
# otherwise, we can convert the observer moduel call to quantize/dequantize node
node_type, quantize_op, qparams = maybe_quantize_node_info
# replace observer node with quant - dequant node
with graph.inserting_before(node):
input_node = node.args[0]
inputs = [input_node]
for key, value in qparams.items():
# TODO: we can add the information of whether a value needs to
# be registered as an attribute in qparams dict itself
if key in ['_scale_', '_zero_point_']:
# For scale and zero_point values we register them as buffers in the root module.
# TODO: maybe need more complex attr name here
qparam_node = create_getattr_from_value(root_module, graph, key, value)
qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
inputs.append(qparam_node)
else:
# for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
@ -273,6 +635,15 @@ def _convert_do_not_use(
node.replace_all_uses_with(dequantized_node)
graph.erase_node(node)
# this is a temporary hack for custom module, we may want to implement
# this properly after the custom module class design is finalized
def replace_observer_with_dequantize_node(node: Node, graph: Graph):
call_custom_module_node = node.args[0]
assert isinstance(call_custom_module_node, Node), \
f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
node.replace_all_uses_with(call_custom_module_node)
graph.erase_node(node)
insert_dequantize_node(call_custom_module_node, graph)
# additional state to override inputs to be quantized, if specified
# by the user
@ -284,10 +655,12 @@ def _convert_do_not_use(
"output_quantized_idxs", [])
if backend_config_dict is None:
backend_config_dict = {}
quantized_reference_module_mapping = copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS)
else:
quantized_reference_module_mapping = get_quantized_reference_module_mapping(backend_config_dict)
# convert tuples so that it can work with isinstance(module, tuple_of_classes)
weighted_module_classes = tuple(quantized_reference_module_mapping.keys())
statically_quantized_custom_module_nodes: Set[Node] = set()
for node in list(model.graph.nodes):
if node.op == 'placeholder':
@ -315,18 +688,36 @@ def _convert_do_not_use(
model.graph.erase_node(maybe_dequantize_node)
elif node.op == "call_module":
if is_activation_post_process(modules[node.target]):
replace_observer_with_quantize_dequantize_node(model.graph, node, modules)
observed_node = node.args[0]
if observed_node in statically_quantized_custom_module_nodes:
replace_observer_with_dequantize_node(node, model.graph)
else:
replace_observer_with_quantize_dequantize_node(
model, model.graph, node, modules, node_name_to_scope,
qconfig_map)
elif is_observed_standalone_module(modules[node.target]):
# TODO: move this to a separate function
convert_standalone_module(node, modules, model, is_reference, backend_config_dict)
convert_standalone_module(
node, modules, model, is_reference, backend_config_dict)
elif type(modules[node.target]) in set(
weighted_module_classes).union(QAT_MODULE_CLASSES).union(FUSED_MODULE_CLASSES):
convert_weighted_module(node, modules, observed_node_names, quantized_reference_module_mapping)
convert_weighted_module(
node, modules, observed_node_names, quantized_reference_module_mapping, qconfig_map)
elif type(modules[node.target]) in custom_module_classes:
convert_custom_module(
node, model.graph, modules, custom_module_class_mapping,
statically_quantized_custom_module_nodes)
preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", []))
model = QuantizedGraphModule(model, model.graph, preserved_attributes)
# TODO: maybe move this to quantize_fx.py
if not is_reference:
model = duplicate_dequantize_node(model)
model = lower_to_fbgemm(model, qconfig_map, node_name_to_scope)
model = remove_quant_dequant_pairs(model)
model = remove_extra_dequantize(model)
# TODO: this looks hacky, we want to check why we need this and see if we can
# remove this
# removes qconfig and activation_post_process modules
if _remove_qconfig_flag:
_remove_qconfig(model)
preserved_attributes = set(convert_custom_config_dict.get("preserved_attributes", []))
model = QuantizedGraphModule(model, model.graph, preserved_attributes)
return model

View File

@ -5,7 +5,9 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.quantized.dynamic as nniqd
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.quantized._reference as nnqr
from torch.nn.quantized.modules.utils import WeightedQuantizedModule
from . import subgraph_rewriter_FORKED_DO_NOT_USE
@ -24,7 +26,7 @@ from ..utils import _parent_name
from ..qconfig import QConfigAny
from ..quantization_mappings import get_quantized_operator
from .utils import create_node_from_old_node_preserve_meta
from typing import Dict, Tuple, Type, List, Callable, Any, Union, Set
from typing import Dict, Tuple, Type, List, Callable, Any, Union, Set, Optional
from torch.fx import Node
import operator
@ -85,6 +87,10 @@ def is_default_node(node, modules):
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.Dropout,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.intrinsic.BNReLU2d,
torch.nn.intrinsic.BNReLU3d,
]
return _is_node_in_list(node, modules, func_list, method_list, module_type_list)
@ -179,10 +185,14 @@ def is_special_pattern_node(node, modules):
res_module = res_module or is_call_module
return res_function, res_method, res_module
def is_dequantize_node(node):
return isinstance(node, Node) and node.op == 'call_method' and node.target == 'dequantize'
def is_getattr_tensor_metadata_node(node):
return node.op == "call_function" and \
node.target == getattr and \
node.args[1] in ["shape"]
def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: Dict[str, QConfigAny]):
"""
Return True if the op is configured with a None qconfig, False otherwise.
@ -192,16 +202,32 @@ def should_skip_lowering(op: torch.fx.node.Node, qconfig_map: Dict[str, QConfigA
"""
return op.name in qconfig_map and qconfig_map[op.name] is None
# Mapping from reference module class to the replacement quantized module class for lowering
LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[WeightedQuantizedModule]] = {
# Mapping from reference module class to the replacement static quantized module class for lowering
STATIC_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[WeightedQuantizedModule]] = {
nnqr.Linear: nnq.Linear,
nnqr.Conv1d: nnq.Conv1d,
nnqr.Conv2d: nnq.Conv2d,
nnqr.Conv3d: nnq.Conv3d,
}
# TODO: merge with LOWER_MODULE_MAP after we merge
# _lower_weighted_ref_module and special_pattern_replacement
# Mapping from reference module class to the replacement dynamic quantized module class for lowering
DYNAMIC_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[nn.Module]] = {
nnqr.Linear: nnqd.Linear,
nnqr.GRUCell: nnqd.GRUCell,
nnqr.LSTMCell: nnqd.LSTMCell,
nnqr.RNNCell: nnqd.RNNCell,
nnqr.LSTM: nnqd.LSTM,
}
# Mapping from reference module class to the replacement weight only quantized module class for lowering
# TODO: correct the namespace for these modules
WEIGHT_ONLY_LOWER_MODULE_MAP: Dict[Type[nn.Module], Type[nn.Module]] = {
nnqr.Embedding: nnq.Embedding,
nnqr.EmbeddingBag: nnq.EmbeddingBag,
}
# TODO: merge with STATIC_LOWER_MODULE_MAP after we merge
# _lower_static_weighted_ref_module and special_pattern_replacement
SPECIAL_PATTERN_LOWER_MODULE_MAP = {
nn.BatchNorm2d: nnq.BatchNorm2d,
nn.BatchNorm3d: nnq.BatchNorm3d,
@ -215,22 +241,31 @@ SPECIAL_PATTERN_LOWER_MODULE_MAP = {
nn.InstanceNorm3d: nnq.InstanceNorm3d,
nn.LayerNorm: nnq.LayerNorm,
nn.Dropout: nnq.Dropout,
nni.BNReLU2d: nniq.BNReLU2d,
nni.BNReLU3d: nniq.BNReLU3d,
}
# Mapping from fused module class to a 2-tuple of:
# 1) The inner reference module class
# 2) The replacement quantized module class for lowering
LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]] = {
# 2) The replacement static quantized module class for lowering
STATIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[WeightedQuantizedModule]]] = {
nni.LinearReLU: (nnqr.Linear, nniq.LinearReLU),
nni.ConvReLU1d: (nnqr.Conv1d, nniq.ConvReLU1d),
nni.ConvReLU2d: (nnqr.Conv2d, nniq.ConvReLU2d),
nni.ConvReLU3d: (nnqr.Conv3d, nniq.ConvReLU3d),
}
# Mapping from fused module class to a 2-tuple of:
# 1) The inner reference module class
# 2) The replacement dynamic quantized module class for lowering
DYNAMIC_LOWER_FUSED_MODULE_MAP: Dict[Type[nn.Module], Tuple[Type[nn.Module], Type[nn.Module]]] = {
nni.LinearReLU: (nnqr.Linear, nniqd.LinearReLU),
}
# Mapping from a functional to lower to a 2-tuple of
# 1) The quantized version of the op
# 2) The quantized version of the op fused with relu, if it exists, else None
LOWER_FUNCTIONAL_MAP: Dict[Callable, Tuple[Callable, Callable]] = {
STATIC_LOWER_FUNCTIONAL_MAP: Dict[Callable, Tuple[Callable, Callable]] = {
F.linear: (torch.ops.quantized.linear, torch.ops.quantized.linear_relu),
F.conv1d: (torch.ops.quantized.conv1d, torch.ops.quantized.conv1d_relu),
F.conv2d: (torch.ops.quantized.conv2d, torch.ops.quantized.conv2d_relu),
@ -245,6 +280,29 @@ WEIGHT_PREPACK_OPS: Set[Callable] = {
torch._ops.ops.quantized.conv3d_prepack,
}
# Mapping from a functional to a dictionary, where the key is a 2-tuple of
# (activation_compute_dtype, weight_dtype) and the value is a 2-tuple of
# 1) The dynamically quantized version of the op
# 2) The dynamically quantized version of the op fused with relu, if it exists, else None
DYNAMIC_LOWER_FUNCTIONAL_MAP: Dict[Callable, Dict[Tuple[torch.dtype, torch.dtype], Tuple[Callable, Optional[Callable]]]] = {
F.linear: {
(torch.quint8, torch.qint8): (torch.ops.quantized.linear_dynamic,
torch.ops.quantized.linear_relu_dynamic),
(torch.float16, torch.float16): (torch.ops.quantized.linear_dynamic_fp16,
torch.ops.quantized.linear_relu_dynamic_fp16)
},
# dynamic conv + relu is not available yet
F.conv1d: {
(torch.quint8, torch.qint8): (torch.ops.quantized.conv1d_dynamic, None),
},
F.conv2d: {
(torch.quint8, torch.qint8): (torch.ops.quantized.conv2d_dynamic, None),
},
F.conv3d: {
(torch.quint8, torch.qint8): (torch.ops.quantized.conv3d_dynamic, None),
},
}
CONV_FUNCTIONAL_OPS: Set[Callable] = {
F.conv1d,
F.conv2d,
@ -307,12 +365,12 @@ def fold_weight(
quantized = QuantizedGraphModule(quantized_root, folded_graph, quantized_root.preserved_attr_names)
return quantized
def _lower_weighted_ref_module(model: QuantizedGraphModule) -> QuantizedGraphModule:
def _lower_static_weighted_ref_module(model: QuantizedGraphModule) -> QuantizedGraphModule:
"""
Traverse the graph and find dequantize - ref module - quantize patterns
and replace them with the quantized version of the ref module.
"""
for ref_class in list(LOWER_MODULE_MAP.keys()) + list(LOWER_FUSED_MODULE_MAP.keys()):
for ref_class in list(STATIC_LOWER_MODULE_MAP.keys()) + list(STATIC_LOWER_FUSED_MODULE_MAP.keys()):
pattern = (torch.quantize_per_tensor,
(ref_class, "dequantize"),
MatchAllNode, MatchAllNode, MatchAllNode)
@ -348,12 +406,12 @@ def _lower_weighted_ref_module(model: QuantizedGraphModule) -> QuantizedGraphMod
output_zero_point = getattr(model, zero_point_node.target)
# For fused modules, we also check whether the inner module is a reference module
# If so, we replace the entire fused module with the corresponding quantized module
if ref_class in LOWER_FUSED_MODULE_MAP:
inner_ref_class, q_class = LOWER_FUSED_MODULE_MAP[ref_class]
if ref_class in STATIC_LOWER_FUSED_MODULE_MAP:
inner_ref_class, q_class = STATIC_LOWER_FUSED_MODULE_MAP[ref_class]
if type(ref_module[0]) != inner_ref_class:
continue
else:
q_class = LOWER_MODULE_MAP[type(ref_module)]
q_class = STATIC_LOWER_MODULE_MAP[type(ref_module)]
assert issubclass(q_class, WeightedQuantizedModule) # suppress mypy warnings
q_module = q_class.from_reference(ref_module, output_scale, output_zero_point)
@ -373,7 +431,94 @@ def _lower_weighted_ref_module(model: QuantizedGraphModule) -> QuantizedGraphMod
model.graph.erase_node(zero_point_node)
return model
def _lower_weighted_ref_functional(
def _lower_dynamic_weighted_ref_module(model: QuantizedGraphModule) -> QuantizedGraphModule:
"""
Traverse the graph and find quantize_per_tensor_dynamic - dequantize - ref_module patterns
and replace them with the dynamically quantized version of the ref module.
"""
named_modules = dict(model.named_modules(remove_duplicate=False))
for n in model.graph.nodes:
if n.op != "call_module" or \
type(named_modules[str(n.target)]) not in \
set(DYNAMIC_LOWER_MODULE_MAP.keys()).union(
set(DYNAMIC_LOWER_FUSED_MODULE_MAP.keys())):
continue
ref_node = n
dq_node = ref_node.args[0]
if dq_node.op != "call_method" or dq_node.target != "dequantize":
continue
# don't support lowering the pattern when the result of dequantize is used by
# multiple nodes
if len(dq_node.users) > 1:
continue
input_dynamic_q_node = dq_node.args[0]
# don't support lowering the pattern when the result of quantize is used by
# multiple nodes
if len(input_dynamic_q_node.users) > 1:
continue
if input_dynamic_q_node.op != "call_function" or \
input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic:
continue
activation_compute_dtype = input_dynamic_q_node.args[1]
is_fp16 = activation_compute_dtype == torch.float16
is_int8 = activation_compute_dtype in [torch.quint8, torch.qint8]
if not is_int8 and not is_fp16:
continue
ref_module = named_modules[str(ref_node.target)]
ref_class = type(ref_module)
if ref_class in DYNAMIC_LOWER_FUSED_MODULE_MAP:
inner_ref_class, q_class = DYNAMIC_LOWER_FUSED_MODULE_MAP[ref_class]
if type(ref_module[0]) != inner_ref_class:
continue
else:
q_class = DYNAMIC_LOWER_MODULE_MAP.get(ref_class) # type: ignore[assignment]
# TODO: maybe define a WeightedDynamicallyQuantizedModule
q_module = q_class.from_reference(ref_module) # type: ignore[attr-defined]
# replace reference moduel with dynamically quantized module
parent_name, module_name = _parent_name(ref_node.target)
setattr(named_modules[parent_name], module_name, q_module)
# remove q - dq node
dq_node.replace_all_uses_with(input_dynamic_q_node)
model.graph.erase_node(dq_node)
input_dynamic_q_node.replace_all_uses_with(input_dynamic_q_node.args[0])
model.graph.erase_node(input_dynamic_q_node)
return model
def _lower_weight_only_weighted_ref_module(model: QuantizedGraphModule) -> QuantizedGraphModule:
"""
Traverse the graph and find ref_module patterns
and replace them with the weight only quantized version of the ref module.
"""
named_modules = dict(model.named_modules(remove_duplicate=False))
for n in model.graph.nodes:
if n.op != "call_module" or \
type(named_modules[str(n.target)]) not in \
set(WEIGHT_ONLY_LOWER_MODULE_MAP.keys()):
continue
ref_node = n
ref_module = named_modules[str(ref_node.target)]
ref_class = type(ref_module)
q_class = WEIGHT_ONLY_LOWER_MODULE_MAP.get(ref_class)
# TODO: WeightedQuantizedModule is currently assuming static quant apis
# with output_scale, output_zero_point in from_reference, we may want to
# relax that, or rename this
# TODO: maybe define a WeightedWeightOnlyQuantizedModule
q_module = q_class.from_reference(ref_module) # type: ignore[union-attr]
# replace reference moduel with dynamically quantized module
parent_name, module_name = _parent_name(ref_node.target)
setattr(named_modules[parent_name], module_name, q_module)
return model
def _lower_static_weighted_ref_functional(
model: QuantizedGraphModule,
qconfig_map: Dict[str, QConfigAny]
) -> QuantizedGraphModule:
@ -392,7 +537,9 @@ def _lower_weighted_ref_functional(
q_node = n
(func_node, output_scale_node, output_zp_node, _) = q_node.args
# Handle cases where the functional op is wrapped in a ReLU
if func_node.target == F.relu:
if func_node.op == "call_function" and func_node.target == F.relu or \
func_node.op == "call_module" and \
type(modules[str(func_node.target)]) == torch.nn.ReLU:
relu_node = func_node
func_node = relu_node.args[0]
else:
@ -401,7 +548,7 @@ def _lower_weighted_ref_functional(
continue
# Linear args: (dequantized inputs, dequantized weights[, bias])
# Conv args: (dequantized inputs, dequantized weights[, bias, stride, padding, dilation, groups])
if func_node.op != "call_function" or func_node.target not in LOWER_FUNCTIONAL_MAP:
if func_node.op != "call_function" or func_node.target not in STATIC_LOWER_FUNCTIONAL_MAP:
continue
(input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
if input_dq_node.target != "dequantize" or weight_dq_node.target != "dequantize":
@ -433,7 +580,7 @@ def _lower_weighted_ref_functional(
packed_weight = model.graph.create_node("call_function", prepack_op, tuple(prepack_args), {})
# Step 2: Replace reference pattern with the corresponding quantized op
(q_func, q_relu_func) = LOWER_FUNCTIONAL_MAP[func_node.target]
(q_func, q_relu_func) = STATIC_LOWER_FUNCTIONAL_MAP[func_node.target]
func_node.target = q_relu_func if relu_node is not None else q_func
func_node.args = (input_dq_node.args[0], packed_weight, output_scale_node, output_zp_node)
q_node.replace_all_uses_with(func_node)
@ -450,6 +597,120 @@ def _lower_weighted_ref_functional(
model.graph.erase_node(relu_node)
return model
def _lower_dynamic_weighted_ref_functional(
model: QuantizedGraphModule,
qconfig_map: Dict[str, QConfigAny]
) -> QuantizedGraphModule:
"""
Traverse the graph and replace functional reference patterns with their dynamically
quantized versions.
Examples:
quantize_per_tensor_dynamic - dequantize - functional linear --> linear_dynamic
to(torch.float16) - dequantize - functional linear --> linear_dynamic_fp16
"""
modules = dict(model.named_modules(remove_duplicate=False))
nodes = list(model.graph.nodes)
# we want to search in reserved order so that we can match the larger patterns first
# e.g. we want to match linear - relu before linear.
for n in reversed(model.graph.nodes):
# Step 0: Find nodes that match this pattern
# (quantize_per_tensor_dynamic - dequantize - dynamically quantized op)
# We search for the pattern backwards, starting with the quantize node
# Quantize node args: (func, scale, zp, dtype)
func_node = n
# Handle cases where the functional op is wrapped in a ReLU
if func_node.op == "call_function" and func_node.target == F.relu or \
func_node.op == "call_module" and \
type(modules[str(func_node.target)]) == torch.nn.ReLU:
relu_node = func_node
func_node = relu_node.args[0]
else:
relu_node = None
if should_skip_lowering(func_node, qconfig_map):
continue
# Linear args: (dequantized inputs, dequantized weights[, bias])
# Conv args: (dequantized inputs, dequantized weights[, bias, stride, padding, dilation, groups])
if func_node.op != "call_function" or func_node.target not in DYNAMIC_LOWER_FUNCTIONAL_MAP:
continue
(input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
if input_dq_node.op != "call_method" or input_dq_node.target != "dequantize" or \
weight_dq_node.op != "call_method" or weight_dq_node.target != "dequantize":
continue
input_dynamic_q_node = input_dq_node.args[0]
# don't support lowering the pattern when the result of quantize is used by
# multiple nodes
if len(input_dynamic_q_node.users) > 1:
continue
if input_dynamic_q_node.op != "call_function" or \
input_dynamic_q_node.target != torch.quantize_per_tensor_dynamic:
continue
reduce_range_node = None
(pattern_input, activation_compute_dtype, reduce_range_node) = input_dynamic_q_node.args
is_fp16 = activation_compute_dtype == torch.float16
is_int8 = activation_compute_dtype in [torch.quint8, torch.qint8]
if not is_int8 and not is_fp16:
continue
quantized_weight = weight_dq_node.args[0]
weight_dtype = quantized_weight.args[-1]
# Step 1: Try to select reference pattern with the corresponding quantized op
dynamic_quant_dtype_key = (activation_compute_dtype, weight_dtype)
if dynamic_quant_dtype_key not in DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target]:
print(f"Didn't find dtype combination {dynamic_quant_dtype_key} during "
f"dynamic quantized op lowering for {func_node.target}")
continue
(q_func, q_relu_func) = DYNAMIC_LOWER_FUNCTIONAL_MAP[func_node.target][dynamic_quant_dtype_key]
if q_func is None or q_relu_func is None:
print("Didn't find corresponding quantized function or quantized relu function "
f"for {func_node.target}, {dynamic_quant_dtype_key}")
continue
# Step 2: Replace quantized weights with packed weights, which will be folded later
# Use the right prepack op and prepare the corresponding args
# Linear prepack args: (quantized weights[, bias])
# Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups])
prepack_args = [quantized_weight] + remaining_func_args
if func_node.target == F.linear:
prepack_op = get_linear_prepack_op_for_dtype(weight_dtype)
elif func_node.target in CONV_FUNCTIONAL_OPS:
prepack_op = get_qconv_prepack_op(func_node.target)
# For conv1d, the stride, padding, and dilation args may be ints,
# in which case we need to convert them to tuples
if func_node.target == F.conv1d:
for i in [2, 3, 4]:
if len(prepack_args) > i and isinstance(prepack_args[i], int):
prepack_args[i] = (prepack_args[i],)
else:
raise ValueError("Lowering is not supported for op '%s'" % func_node.target)
with model.graph.inserting_before(func_node):
packed_weight = model.graph.create_node("call_function", prepack_op, tuple(prepack_args), {})
# Step 3: Replace reference pattern with the corresponding quantized op
func_node.target = q_relu_func if relu_node is not None else q_func
if is_int8:
func_node.args = (pattern_input, packed_weight, reduce_range_node)
else:
func_node.args = (pattern_input, packed_weight)
if relu_node is not None:
relu_node.replace_all_uses_with(func_node)
# Step 4: Remove dequantize and quantize nodes and the old func node
for dqn in [input_dq_node, weight_dq_node]:
dqn_input = dqn.args[0]
dqn.replace_all_uses_with(dqn_input)
model.graph.erase_node(dqn)
model.graph.erase_node(input_dynamic_q_node)
if relu_node is not None:
model.graph.erase_node(relu_node)
return model
def _lower_quantized_binary_op(
model: QuantizedGraphModule,
qconfig_map: Dict[str, QConfigAny]
@ -683,6 +944,22 @@ def special_pattern_replacement(model: QuantizedGraphModule) -> QuantizedGraphMo
return model
def _lower_getattr_tensor_metadta_op(
model: QuantizedGraphModule
) -> None:
""" Modified the graph of the model inplace, to skip extra dequantize op before
the general tensor shape ops when possible
"""
for n in model.graph.nodes:
if is_getattr_tensor_metadata_node(n):
maybe_dq = n.args[0]
if maybe_dq.op != "call_method" or maybe_dq.target != "dequantize":
continue
# skip the dequantize node
args = list(n.args)
args[0] = n.args[0].args[0]
n.args = tuple(args)
def _lower_to_native_backend(
model: QuantizedGraphModule,
qconfig_map: Dict[str, QConfigAny],
@ -692,13 +969,21 @@ def _lower_to_native_backend(
to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same
operator signature so they can be lowered with the same function
"""
model = _lower_weighted_ref_module(model)
model = _lower_weighted_ref_functional(model, qconfig_map)
# TODO: these transformations are just inplace modification of graphs, we don't
# need to return a model
model = _lower_static_weighted_ref_module(model)
model = _lower_dynamic_weighted_ref_module(model)
model = _lower_weight_only_weighted_ref_module(model)
model = _lower_static_weighted_ref_functional(model, qconfig_map)
model = _lower_dynamic_weighted_ref_functional(model, qconfig_map)
# TODO: remove this
for pattern, replacement in get_fbgemm_patterns_and_replacements():
subgraph_rewriter_FORKED_DO_NOT_USE.replace_pattern(model, pattern, replacement)
_lower_quantized_binary_op(model, qconfig_map)
_lower_getattr_tensor_metadta_op(model)
special_pattern_replacement(model)
model = fold_weight(model, node_name_to_scope)
model.graph.eliminate_dead_code()
model.recompile()
model.graph.lint()
return model

View File

@ -18,7 +18,7 @@ class FusedGraphModule(GraphModule):
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return FusedGraphModule(fake_mod, self.graph, self.preserved_attr_names)
return FusedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
class ObservedGraphModule(GraphModule):
@ -45,7 +45,7 @@ class ObservedGraphModule(GraphModule):
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedGraphModule(fake_mod, self.graph, self.preserved_attr_names)
return ObservedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
def is_observed_module(module: Any) -> bool:
return isinstance(module, ObservedGraphModule)
@ -60,7 +60,7 @@ class ObservedStandaloneGraphModule(ObservedGraphModule):
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedStandaloneGraphModule(fake_mod, self.graph, self.preserved_attr_names)
return ObservedStandaloneGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
def is_observed_standalone_module(module: Any) -> bool:
return isinstance(module, ObservedStandaloneGraphModule)
@ -104,4 +104,4 @@ class QuantizedGraphModule(GraphModule):
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return QuantizedGraphModule(fake_mod, self.graph, self.preserved_attr_names)
return QuantizedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))

View File

@ -337,6 +337,7 @@ def get_target_activation_dtype_for_node(
else torch.float
return {
"input_activation_dtype": act_dtype,
"input_activation_compute_dtype": act_compute_dtype,
"weight_dtype": weight_dtype,
"bias_dtype": bias_dtype,
"output_activation_dtype": act_dtype,
@ -410,6 +411,24 @@ def get_arg_target_dtype_as_input_to_node(
else:
return node_name_to_target_dtype[node.name]["bias_dtype"]
def get_arg_target_compute_dtype_as_input_to_node(
arg: Node,
node: Node,
modules: Dict[str, torch.nn.Module],
node_name_to_target_dtype: Dict[str, Dict[str, Optional[torch.dtype]]],
) -> Optional[torch.dtype]:
""" Get the target argument dtype for the argument `arg`, as input
to node `node`
"""
assert isinstance(arg, Node)
is_weight = node_arg_is_weight(node, arg)
is_bias = node_arg_is_bias(node, arg)
is_activation = not is_weight and not is_bias
if is_activation and \
"input_activation_compute_dtype" in node_name_to_target_dtype[node.name]:
return node_name_to_target_dtype[node.name]["input_activation_compute_dtype"]
else:
return None
def maybe_insert_input_observer_for_arg_or_kwarg(
node: Union[Node, Any],
@ -461,6 +480,9 @@ def maybe_insert_input_observer_for_arg_or_kwarg(
arg_as_output_target_dtype = get_arg_target_dtype_as_output(arg, modules, node_name_to_target_dtype)
arg_as_input_target_dtype = get_arg_target_dtype_as_input_to_node(arg, node, modules, node_name_to_target_dtype)
arg_as_input_target_compute_dtype = \
get_arg_target_compute_dtype_as_input_to_node(
arg, node, modules, node_name_to_target_dtype)
needs_obs = (
# if the dtypes are different, we need an observer
(arg_as_output_target_dtype != arg_as_input_target_dtype) and
@ -472,7 +494,13 @@ def maybe_insert_input_observer_for_arg_or_kwarg(
# if arg is a bool tensor or not a tensor, do not insert observer
(arg_as_output_target_dtype not in (torch.bool, None)) and
# if qconfig is reuse_input qconfig, we won't insert extra observer for input
not is_reuse_input_qconfig_
not is_reuse_input_qconfig_ or
# need to add input observer for dynamic quantization
# only add observer for first input for now, we may need to extend
# qconfig_dict and backend_config_dict to support more general configurations
# of dynamic quantization, e.g. dynamically quantizing second input, third
# input etc.
(arg_as_input_target_compute_dtype in [torch.quint8, torch.int8, torch.float16]) and arg is node.args[0]
)
else:
@ -1128,6 +1156,11 @@ def insert_observers_for_model(
if user != node and is_user_quantized:
is_quantized_branch = True
# TODO: this only works for sequential fusion right now, extend it
# it to automatically detect all input nodes based on the pattern
# need to change find_matches function to return this information
is_input_node_of_the_pattern = matched_nodes[-1] is node
if is_input_node_of_the_pattern:
# this modifies node inplace
maybe_insert_input_observers_for_node(
node, qconfig, model, modules, graph,

View File

@ -1248,9 +1248,6 @@ class RNNDynamicQuantizeHandler(QuantizeHandler):
modules: Dict[str, torch.nn.Module]):
super().__init__(node, modules)
def input_output_observed(self) -> bool:
return False
def convert(self,
node: Node,
qconfig: QConfigAny,

View File

@ -13,6 +13,7 @@ from torch.fx.graph import (
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
import operator
import warnings
# A dictionary for querying the weight index for a given op
WEIGHT_INDEX_DICT = {
@ -111,7 +112,7 @@ def get_per_tensor_qparams(activation_post_process):
dtype = activation_post_process.dtype
return scale, zero_point, dtype
def get_quantize_node_info(activation_post_process: Callable) -> Tuple[str, Union[Callable, str], Dict[str, Any]]:
def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[str, Union[Callable, str], Dict[str, Any]]]:
''' Given an activation_post_process module,
return node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary
of extracted qparams from the module
@ -137,14 +138,17 @@ def get_quantize_node_info(activation_post_process: Callable) -> Tuple[str, Unio
node_type = "call_method"
quantize_op = "to"
qparams = {"_dtype_": dtype}
elif dtype == torch.float32 and compute_dtype in [torch.quint8, torch.qint8]:
elif dtype == torch.float32 and compute_dtype in [torch.quint8, torch.qint8, torch.float16]:
# dynamic quantization
node_type = "call_function"
quantize_op = torch.quantize_per_tensor_dynamic
# TODO: get reduce range from observer
# reduce_range = activation_post_process.reduce_range
reduce_range = torch.backends.quantized.engine == "fbgemm"
qparams = {"_dtype_": compute_dtype, "_reduce_range_": reduce_range}
else:
raise Exception("Unsupported dtype in get_quantize_node_info:" + str(dtype))
assert quantize_op is not None
warnings.warn(f"Unsupported activation_post_process in get_quantize_node_info: {activation_post_process}")
return None
return node_type, quantize_op, qparams
def quantize_node(
@ -193,7 +197,10 @@ def quantize_node(
module_path = ""
root_module = modules['']
graph = quantized_graph
node_type, quantize_op, qparams = get_quantize_node_info(obs_module)
maybe_quantize_node_info = get_quantize_node_info(obs_module)
assert maybe_quantize_node_info is not None, \
f"Expecting quantize node info not to be None, observer: {obs_module}"
node_type, quantize_op, qparams = maybe_quantize_node_info
inputs = [in_node]
for key, value in qparams.items():

View File

@ -113,7 +113,7 @@ default_dynamic_qconfig = QConfig(activation=default_dynamic_quant_observer,
Default dynamic qconfig.
"""
float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float32),
float16_dynamic_qconfig = QConfig(activation=PlaceholderObserver.with_args(dtype=torch.float32, compute_dtype=torch.float16),
weight=PlaceholderObserver.with_args(dtype=torch.float16))
"""
Dynamic qconfig with weights quantized to `torch.float16`.

View File

@ -35,6 +35,12 @@ DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
nn.ConvTranspose1d: nnqr.ConvTranspose1d,
nn.ConvTranspose2d: nnqr.ConvTranspose2d,
nn.ConvTranspose3d: nnqr.ConvTranspose3d,
nn.Embedding: nnqr.Embedding,
nn.EmbeddingBag: nnqr.EmbeddingBag,
nn.GRUCell: nnqr.GRUCell,
nn.LSTMCell: nnqr.LSTMCell,
nn.RNNCell: nnqr.RNNCell,
nn.LSTM: nnqr.LSTM,
}
# Default map for swapping float module to quantized ones

View File

@ -6,7 +6,8 @@ from torch.fx._symbolic_trace import Tracer
from torch.fx.node import Target, Node, Argument
from torch.nn.intrinsic import _FusedModule
from .fx import fuse # noqa: F401
from .fx import prepare, convert # noqa: F401
from .fx import prepare # noqa: F401
from .fx._convert_do_not_use import _convert_do_not_use as convert
from .fx import get_tensorrt_backend_config_dict # noqa: F401
from .fx.graph_module import ObservedGraphModule
from .fx.qconfig_utils import (
@ -577,6 +578,7 @@ def _convert_fx(
is_standalone_module: bool = False,
_remove_qconfig: bool = True,
qconfig_dict: Dict[str, Any] = None,
backend_config_dict: Dict[str, Any] = None,
) -> torch.nn.Module:
""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
"""
@ -593,6 +595,7 @@ def _convert_fx(
is_standalone_module,
_remove_qconfig_flag=_remove_qconfig,
convert_qconfig_dict=qconfig_dict,
backend_config_dict=backend_config_dict,
)
preserved_attributes = convert_custom_config_dict.get("preserved_attributes", [])
@ -607,6 +610,7 @@ def convert_fx(
convert_custom_config_dict: Optional[Dict[str, Any]] = None,
_remove_qconfig: bool = True,
qconfig_dict: Dict[str, Any] = None,
backend_config_dict: Dict[str, Any] = None,
) -> torch.nn.Module:
r""" Convert a calibrated or trained model to a quantized model
@ -677,6 +681,11 @@ def convert_fx(
],
}
* `backend_config_dict`: A configuration for the backend which describes how
operators should be quantized in the backend, this includes quantization
mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.),
observer placement for each operators and fused operators. Detailed
documentation can be found in torch/ao/quantization/fx/backend_config/README.md
Return:
A quantized model (GraphModule)
@ -694,6 +703,7 @@ def convert_fx(
convert_custom_config_dict,
_remove_qconfig=_remove_qconfig,
qconfig_dict=qconfig_dict,
backend_config_dict=backend_config_dict,
)

View File

@ -184,6 +184,16 @@ def activation_is_statically_quantized(qconfig):
"""
return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
def activation_is_dynamically_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
dynamically quantized or not, this includes dynamically quantizing to
quint8, qint8 and float16
"""
activation_dtype, _, activation_compute_dtype = \
get_qconfig_dtypes(qconfig)
return activation_dtype == torch.float and \
activation_compute_dtype in [torch.quint8, torch.qint8, torch.float16]
def activation_is_int8_quantized(qconfig):
""" Given a qconfig, decide if the activation needs to be
quantized to int8 or not, this includes quantizing to quint8, qint8
@ -200,7 +210,7 @@ def weight_is_quantized(qconfig):
""" Given a qconfig, decide if the weight needs to be
quantized or not
"""
return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16]
return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16, torch.quint4x2]
def weight_is_statically_quantized(qconfig):
""" Given a qconfig, decide if the weight needs to be statically
@ -235,7 +245,7 @@ def get_quant_type(qconfig):
assert qconfig is not None
activation = qconfig.activation()
weight = qconfig.weight()
static_dtypes = [torch.quint8, torch.qint8]
static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2]
if weight.dtype in static_dtypes:
if activation.dtype in static_dtypes:
return QuantType.STATIC

View File

@ -44,3 +44,7 @@ class LinearReLU(nnqd.Linear):
@classmethod
def from_float(cls, mod):
return super(LinearReLU, cls).from_float(mod)
@classmethod
def from_reference(cls, ref_qlinear_relu):
return super().from_reference(ref_qlinear_relu[0])

View File

@ -17,8 +17,8 @@ class BNReLU2d(nnq.BatchNorm2d):
"""
_FLOAT_MODULE = torch.nn.intrinsic.BNReLU2d
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(BNReLU2d, self).__init__(num_features, eps=eps, momentum=momentum)
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
super(BNReLU2d, self).__init__(num_features, eps=eps, momentum=momentum, device=device, dtype=dtype)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
@ -37,6 +37,9 @@ class BNReLU2d(nnq.BatchNorm2d):
# TODO: Add qat support for BNReLU2d
return super(BNReLU2d, cls).from_float(mod)
@classmethod
def from_reference(cls, bn_relu, output_scale, output_zero_point):
return super().from_reference(bn_relu[0], output_scale, output_zero_point)
class BNReLU3d(nnq.BatchNorm3d):
r"""
@ -50,8 +53,8 @@ class BNReLU3d(nnq.BatchNorm3d):
"""
_FLOAT_MODULE = torch.nn.intrinsic.BNReLU3d
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(BNReLU3d, self).__init__(num_features, eps=eps, momentum=momentum)
def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
super(BNReLU3d, self).__init__(num_features, eps=eps, momentum=momentum, device=device, dtype=dtype)
def forward(self, input):
# Temporarily using len(shape) instead of ndim due to JIT issue
@ -69,3 +72,7 @@ class BNReLU3d(nnq.BatchNorm3d):
def from_float(cls, mod):
# TODO: Add qat support for BNReLU3d
return super(BNReLU3d, cls).from_float(mod)
@classmethod
def from_reference(cls, bn_relu, output_scale, output_zero_point):
return super().from_reference(bn_relu[0], output_scale, output_zero_point)

View File

@ -2,6 +2,7 @@ import torch
import torch.nn as nn
from torch import Tensor
from .utils import _quantize_and_dequantize_weight
from .utils import _quantize_weight
from typing import Optional, Dict, Any, Tuple
from torch import _VF
from torch.nn.utils.rnn import PackedSequence
@ -9,6 +10,31 @@ from torch.nn.utils.rnn import PackedSequence
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
return tensor.index_select(dim, permutation)
def get_weight_and_quantization_params(module, wn):
weight = getattr(module, wn)
params = [weight]
for param_name in [wn + n for n in ["_qscheme", "_dtype", "_scale", "_zero_point", "_axis"]]:
if hasattr(module, param_name):
param = getattr(module, param_name)
else:
param = None
params.append(param)
return params
def get_quantized_weight(module, wn):
if not hasattr(module, wn):
return None
params = get_weight_and_quantization_params(module, wn)
weight = _quantize_weight(*params)
return weight
def get_quantize_and_dequantized_weight(module, wn):
if not hasattr(module, wn):
return None
params = get_weight_and_quantization_params(module, wn)
weight = _quantize_and_dequantize_weight(*params)
return weight
class RNNCellBase(nn.RNNCellBase):
def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int,
device=None, dtype=None, weight_qparams_dict=None) -> None:
@ -56,27 +82,17 @@ class RNNCellBase(nn.RNNCellBase):
def _get_name(self):
return "QuantizedRNNCellBase(Reference)"
def get_quantized_weight_ih(self):
return get_quantized_weight(self, "weight_ih")
def get_quantized_weight_hh(self):
return get_quantized_weight(self, "weight_hh")
def get_weight_ih(self):
wn = "weight_ih"
weight = self.weight_ih
weight_qscheme = getattr(self, wn + "_qscheme")
weight_dtype = getattr(self, wn + "_dtype")
weight_scale = getattr(self, wn + "_scale")
weight_zero_point = getattr(self, wn + "_zero_point")
weight_axis = getattr(self, wn + "_axis")
weight = _quantize_and_dequantize_weight(weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis)
return weight
return get_quantize_and_dequantized_weight(self, "weight_ih")
def get_weight_hh(self):
wn = "weight_hh"
weight = self.weight_hh
weight_qscheme = getattr(self, wn + "_qscheme")
weight_dtype = getattr(self, wn + "_dtype")
weight_scale = getattr(self, wn + "_scale")
weight_zero_point = getattr(self, wn + "_zero_point")
weight_axis = getattr(self, wn + "_axis")
weight = _quantize_and_dequantize_weight(weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis)
return weight
return get_quantize_and_dequantized_weight(self, "weight_hh")
class RNNCell(RNNCellBase):
"""
@ -129,6 +145,21 @@ class RNNCell(RNNCellBase):
return ret
@classmethod
def from_float(cls, mod, weight_qparams_dict):
ref_mod = cls(
mod.input_size,
mod.hidden_size,
mod.bias,
mod.nonlinearity,
mod.weight_ih.device,
mod.weight_ih.dtype,
weight_qparams_dict)
ref_mod.weight_ih = mod.weight_ih
ref_mod.weight_hh = mod.weight_hh
ref_mod.bias_ih = mod.bias_ih
ref_mod.bias_hh = mod.bias_hh
return ref_mod
class LSTMCell(RNNCellBase):
"""
@ -136,11 +167,10 @@ class LSTMCell(RNNCellBase):
we need to pass in a `weight_qparams_dict` that maps from weight name,
e.g. weight_ih, to the weight_qparams for that weight
"""
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh",
def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Dict[str, Any]]] = None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
self.nonlinearity = nonlinearity
def _get_name(self):
return "QuantizedLSTMCell(Reference)"
@ -168,7 +198,20 @@ class LSTMCell(RNNCellBase):
ret = (ret[0].squeeze(0), ret[1].squeeze(0))
return ret
@classmethod
def from_float(cls, mod, weight_qparams_dict):
ref_mod = cls(
mod.input_size,
mod.hidden_size,
mod.bias,
mod.weight_ih.device,
mod.weight_ih.dtype,
weight_qparams_dict)
ref_mod.weight_ih = mod.weight_ih
ref_mod.weight_hh = mod.weight_hh
ref_mod.bias_ih = mod.bias_ih
ref_mod.bias_hh = mod.bias_hh
return ref_mod
class GRUCell(RNNCellBase):
"""
@ -176,11 +219,10 @@ class GRUCell(RNNCellBase):
we need to pass in a `weight_qparams_dict` that maps from weight name,
e.g. weight_ih, to the weight_qparams for that weight
"""
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh",
def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
device=None, dtype=None, weight_qparams_dict: Optional[Dict[str, Dict[str, Any]]] = None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype, 'weight_qparams_dict': weight_qparams_dict}
super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
self.nonlinearity = nonlinearity
def _get_name(self):
return "QuantizedGRUCell(Reference)"
@ -208,6 +250,20 @@ class GRUCell(RNNCellBase):
return ret
@classmethod
def from_float(cls, mod, weight_qparams_dict):
ref_mod = cls(
mod.input_size,
mod.hidden_size,
mod.bias,
mod.weight_ih.device,
mod.weight_ih.dtype,
weight_qparams_dict)
ref_mod.weight_ih = mod.weight_ih
ref_mod.weight_hh = mod.weight_hh
ref_mod.bias_ih = mod.bias_ih
ref_mod.bias_hh = mod.bias_hh
return ref_mod
class RNNBase(nn.RNNBase):
def __init__(self, mode: str, input_size: int, hidden_size: int,
@ -228,8 +284,8 @@ class RNNBase(nn.RNNBase):
}
weight_qparams_dict = dict()
for wn in self._flat_weights_names:
if wn.startswith("weight"):
weight_qparams_dict[wn] = weight_qparams
self._init_weight_qparams_dict(weight_qparams_dict, device)
def _init_weight_qparams_dict(self, weight_qparams_dict, device):
@ -263,8 +319,7 @@ class LSTM(RNNBase):
to the weight_qparams for that weight
"""
def __init__(self, *args, **kwargs):
assert "weight_qparams_dict" in kwargs
super(LSTM, self).__init__('LSTM', *args, **kwargs)
super().__init__('LSTM', *args, **kwargs)
# Same as above, see torch/nn/modules/module.py::_forward_unimplemented
def permute_hidden(self, # type: ignore[override]
@ -298,19 +353,35 @@ class LSTM(RNNBase):
self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),
'Expected hidden[1] size {}, got {}')
def get_quantized_weight_bias_dict(self):
""" dictionary from flat_weight_name to quantized weight or (unquantized) bias
e.g.
{
"weight_ih_l0": quantized_weight,
"bias_ih_l0": unquantized_bias,
...
}
"""
quantized_weight_bias_dict = {}
for wn in self._flat_weights_names:
if hasattr(self, wn):
if wn.startswith("weight"):
weight_or_bias = get_quantized_weight(self, wn)
else:
weight_or_bias = getattr(self, wn)
else:
weight_or_bias = None
quantized_weight_bias_dict[wn] = weight_or_bias
return quantized_weight_bias_dict
def get_flat_weights(self):
flat_weights = []
for wn in self._flat_weights_names:
if hasattr(self, wn):
weight = getattr(self, wn)
weight_qscheme = getattr(self, wn + "_qscheme")
weight_dtype = getattr(self, wn + "_dtype")
weight_scale = getattr(self, wn + "_scale")
weight_zero_point = getattr(self, wn + "_zero_point")
weight_axis = getattr(self, wn + "_axis")
weight = _quantize_and_dequantize_weight(
weight, weight_qscheme, weight_dtype, weight_scale,
weight_zero_point, weight_axis)
if wn.startswith("weight"):
params = get_weight_and_quantization_params(self, wn)
weight = _quantize_and_dequantize_weight(*params)
else:
weight = None
flat_weights.append(weight)
@ -383,3 +454,18 @@ class LSTM(RNNBase):
def _get_name(self):
return "QuantizedLSTM(Reference)"
@classmethod
def from_float(cls, mod, weight_qparams_dict):
ref_mod = cls(
mod.input_size,
mod.hidden_size,
mod.num_layers,
mod.bias,
mod.batch_first,
mod.dropout,
mod.bidirectional,
weight_qparams_dict=weight_qparams_dict)
for wn in mod._flat_weights_names:
setattr(ref_mod, wn, getattr(mod, wn))
return ref_mod

View File

@ -29,6 +29,21 @@ class Embedding(nn.Embedding, ReferenceQuantizedModule):
input, weight_quant_dequant, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
@classmethod
def from_float(cls, mod, weight_qparams):
return cls(
mod.num_embeddings,
mod.embedding_dim,
mod.padding_idx,
mod.max_norm,
mod.norm_type,
mod.scale_grad_by_freq,
mod.sparse,
mod.weight,
mod.weight.device,
mod.weight.dtype,
weight_qparams)
class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule):
""" A reference quantized EmbeddingBag module that fits into the
FX Graph Mode Quantization workflow, activation will be floating point Tensor,
@ -57,3 +72,21 @@ class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule):
self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset,
self.padding_idx)
@classmethod
def from_float(cls, mod, weight_qparams):
return cls(
mod.num_embeddings,
mod.embedding_dim,
mod.max_norm,
mod.norm_type,
mod.scale_grad_by_freq,
mod.mode,
mod.sparse,
mod.weight,
mod.include_last_offset,
mod.padding_idx,
mod.weight.device,
mod.weight.dtype,
weight_qparams
)

View File

@ -16,13 +16,16 @@ class ReferenceQuantizedModule(torch.nn.Module):
None, torch.per_tensor_affine, torch.per_channel_affine,
torch.per_channel_affine_float_qparams], \
Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}")
if self.weight_qscheme is not None:
if self.weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2]:
zero_point_dtype = weight_qparams["zero_point"].dtype if \
isinstance(weight_qparams["zero_point"], torch.Tensor) else \
torch.int
self.register_buffer(
"weight_scale",
torch.tensor(weight_qparams["scale"], dtype=torch.float, device=device))
self.register_buffer(
"weight_zero_point",
torch.tensor(weight_qparams["zero_point"], dtype=torch.int, device=device))
torch.tensor(weight_qparams["zero_point"], dtype=zero_point_dtype, device=device))
if self.weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
self.register_buffer(
"weight_axis",
@ -31,6 +34,13 @@ class ReferenceQuantizedModule(torch.nn.Module):
# added for TorchScriptability, not used
self.register_buffer(
"weight_axis", torch.tensor(0, dtype=torch.int, device=device))
else:
# added for TorchScriptability, not used
self.register_buffer("weight_scale", torch.tensor(1.0, dtype=torch.float, device=device))
self.register_buffer("weight_zero_point", torch.tensor(0, dtype=torch.int, device=device))
self.register_buffer(
"weight_axis", torch.tensor(0, dtype=torch.int, device=device))
def get_weight(self):
"""
@ -83,24 +93,21 @@ def _quantize_weight(
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
weight_axis: torch.Tensor):
if weight_dtype == torch.float16:
weight = weight.to(weight_dtype)
return weight
if weight_qscheme == torch.per_tensor_affine:
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype)
elif weight_dtype == torch.float16:
weight = weight.to(weight_dtype)
else:
raise Exception(f"Unsupported dtype: {weight_dtype} for {weight_qscheme}")
return weight
elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
if weight_dtype in [torch.quint8, torch.qint8]:
if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2]:
weight = torch.quantize_per_channel(
weight, weight_scale,
weight_zero_point, weight_axis.item(), weight_dtype) # type: ignore[arg-type]
else:
raise Exception(f"Unsupported dtype: {weight_dtype} for {weight_qscheme}")
else:
raise Exception(f"Unsupported qscheme: {weight_qscheme}")
return weight
raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
def _quantize_and_dequantize_weight(
weight: torch.Tensor,

View File

@ -110,3 +110,17 @@ class Linear(nnq.Linear):
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
qlinear.set_weight_bias(qweight, mod.bias)
return qlinear
@classmethod
def from_reference(cls, ref_qlinear):
""" Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized
module
Args:
ref_qlinear (Module): a reference quantized module, either produced by
torch.ao.quantization functions or provided by the user
"""
qlinear = cls(ref_qlinear.in_features, ref_qlinear.out_features, dtype=ref_qlinear.weight_dtype)
qweight = ref_qlinear.get_quantized_weight()
bias = ref_qlinear.bias
qlinear.set_weight_bias(qweight, bias)
return qlinear

View File

@ -11,6 +11,27 @@ from torch.nn.quantized.modules.utils import _quantize_weight
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
return tensor.index_select(dim, permutation)
def pack_weight_bias(qweight, bias, dtype):
if dtype == torch.qint8:
# for each layer, for each direction we need to quantize and pack
# weights and pack parameters in this order:
#
# w_ih, w_hh
packed_weight = \
torch.ops.quantized.linear_prepack(qweight, bias)
return packed_weight
else:
# for each layer, for each direction we need to quantize and pack
# weights and pack parameters in this order:
#
# packed_ih, packed_hh, b_ih, b_hh
packed_weight = torch.ops.quantized.linear_prepack_fp16(
qweight, bias)
return packed_weight
class PackedParameter(torch.nn.Module):
def __init__(self, param):
super(PackedParameter, self).__init__()
@ -92,9 +113,7 @@ class RNNBase(torch.nn.Module):
else:
cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
packed_ih, packed_hh, b_ih, b_hh, True)
else:
packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih)
packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh)
cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
@ -197,6 +216,43 @@ class RNNBase(torch.nn.Module):
super(RNNBase, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
def set_weight_bias(self, weight_bias_dict):
def weight_bias_name(ihhh, layer, suffix):
weight_name = "weight_{}_l{}{}".format(ihhh, layer, suffix)
bias_name = "bias_{}_l{}{}".format(ihhh, layer, suffix)
return weight_name, bias_name
num_directions = 2 if self.bidirectional else 1
# TODO: dedup with __init__ of RNNBase
_all_weight_values = []
for layer in range(self.num_layers):
for direction in range(num_directions):
suffix = "_reverse" if direction == 1 else ""
w_ih_name, b_ih_name = weight_bias_name("ih", layer, suffix)
w_hh_name, b_hh_name = weight_bias_name("hh", layer, suffix)
w_ih = weight_bias_dict[w_ih_name]
b_ih = weight_bias_dict[b_ih_name]
w_hh = weight_bias_dict[w_hh_name]
b_hh = weight_bias_dict[b_hh_name]
if w_ih.dtype == torch.qint8:
packed_ih = torch.ops.quantized.linear_prepack(w_ih, b_ih)
packed_hh = torch.ops.quantized.linear_prepack(w_hh, b_hh)
if self.version is None or self.version < 2:
cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
packed_ih, packed_hh, b_ih, b_hh)
else:
cell_params = torch.ops.quantized.make_quantized_cell_params_dynamic(
packed_ih, packed_hh, b_ih, b_hh, True)
else:
packed_ih = torch.ops.quantized.linear_prepack_fp16(w_ih, b_ih)
packed_hh = torch.ops.quantized.linear_prepack_fp16(w_hh, b_hh)
cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
packed_ih, packed_hh)
_all_weight_values.append(PackedParameter(cell_params))
self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
@classmethod
def from_float(cls, mod):
assert type(mod) in set(
@ -429,6 +485,24 @@ class LSTM(RNNBase):
def from_float(cls, mod):
return super(LSTM, cls).from_float(mod)
@classmethod
def from_reference(cls, ref_mod):
assert hasattr(ref_mod, "weight_ih_l0_dtype"), "We are assuming weight_ih_l0 "
"exists in LSTM, may need to relax the assumption to support the use case"
qmod = cls(
ref_mod.input_size,
ref_mod.hidden_size,
ref_mod.num_layers,
ref_mod.bias,
ref_mod.batch_first,
ref_mod.dropout,
ref_mod.bidirectional,
# assuming there is layer 0, which should be OK
ref_mod.weight_ih_l0_dtype,
)
qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict())
return qmod
class GRU(RNNBase):
r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
@ -652,6 +726,7 @@ class RNNCellBase(torch.nn.Module):
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.weight_dtype = dtype
if bias:
self.bias_ih = torch.randn(num_chunks * hidden_size).to(dtype=torch.float)
self.bias_hh = torch.randn(num_chunks * hidden_size).to(dtype=torch.float)
@ -750,42 +825,60 @@ class RNNCellBase(torch.nn.Module):
raise NotImplementedError('Only LSTMCell, GRUCell and RNNCell \
are supported for QuantizedRNN for now')
assert mod.bias
def process_weights(weight, bias, dtype):
def _observe_and_quantize_weight(weight):
if dtype == torch.qint8:
# for each layer, for each direction we need to quantize and pack
# weights and pack parameters in this order:
#
# w_ih, w_hh
weight_observer = weight_observer_method()
weight_observer(weight)
qweight = _quantize_weight(weight.float(), weight_observer)
packed_weight = \
torch.ops.quantized.linear_prepack(qweight, bias)
return packed_weight
return qweight
else:
# for each layer, for each direction we need to quantize and pack
# weights and pack parameters in this order:
#
# packed_ih, packed_hh, b_ih, b_hh
packed_weight = torch.ops.quantized.linear_prepack_fp16(
weight.float(), bias)
return weight.float()
return packed_weight
qRNNCellBase._packed_weight_ih = process_weights(mod.weight_ih, mod.bias_ih, dtype)
qRNNCellBase._packed_weight_hh = process_weights(mod.weight_hh, mod.bias_hh, dtype)
qRNNCellBase._packed_weight_ih = pack_weight_bias(_observe_and_quantize_weight(mod.weight_ih), mod.bias_ih, dtype)
qRNNCellBase._packed_weight_hh = pack_weight_bias(_observe_and_quantize_weight(mod.weight_hh), mod.bias_hh, dtype)
return qRNNCellBase
@classmethod
def from_reference(cls, ref_mod):
assert hasattr(ref_mod, "weight_ih_dtype"), "We are assuming weight_ih "
"exists in reference module, may need to relax the assumption to support the use case"
if hasattr(ref_mod, "nonlinearity"):
qmod = cls(
ref_mod.input_size,
ref_mod.hidden_size,
ref_mod.bias,
ref_mod.nonlinearity,
dtype=ref_mod.weight_ih_dtype
)
else:
qmod = cls(
ref_mod.input_size,
ref_mod.hidden_size,
ref_mod.bias,
dtype=ref_mod.weight_ih_dtype
)
weight_bias_dict = {
"weight": {
"weight_ih": ref_mod.get_quantized_weight_ih(),
"weight_hh": ref_mod.get_quantized_weight_hh(),
},
"bias": {
"bias_ih": ref_mod.bias_ih,
"bias_hh": ref_mod.bias_hh,
}
}
qmod.set_weight_bias(weight_bias_dict)
return qmod
def _weight_bias(self):
# Returns a dict of weights and biases
weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
w1, b1 = self._packed_weight_ih.__getstate__()[0]
w2, b2 = self._packed_weight_hh.__getstate__()[0]
# TODO: these can be simplified to one level? e.g. using weight_ih as key
# directly
weight_bias_dict['weight']['weight_ih'] = w1
weight_bias_dict['weight']['weight_hh'] = w2
weight_bias_dict['bias']['bias_ih'] = b1
@ -798,12 +891,23 @@ class RNNCellBase(torch.nn.Module):
def get_bias(self):
return self._weight_bias()['bias']
def set_weight_bias(self, weight_bias_dict):
# TODO: these can be simplified to one level? e.g. using weight_ih as key
# directly
self._packed_weight_ih = pack_weight_bias(
weight_bias_dict["weight"]["weight_ih"],
weight_bias_dict["bias"]["bias_ih"],
self.weight_dtype)
self._packed_weight_hh = pack_weight_bias(
weight_bias_dict["weight"]["weight_hh"],
weight_bias_dict["bias"]["bias_hh"],
self.weight_dtype)
def _save_to_state_dict(self, destination, prefix, keep_vars):
super(RNNCellBase, self)._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + '_packed_weight_ih'] = self._packed_weight_ih
destination[prefix + '_packed_weight_hh'] = self._packed_weight_hh
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
self._packed_weight_ih = state_dict.pop(prefix + '_packed_weight_ih')

View File

@ -34,6 +34,10 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
device=bn.weight.device,
dtype=bn.weight.dtype
)
qbn.weight = bn.weight
qbn.bias = bn.bias
qbn.running_mean = bn.running_mean
qbn.running_var = bn.running_var
qbn.scale = output_scale
qbn.zero_point = output_zero_point
return qbn

View File

@ -19,7 +19,7 @@ class EmbeddingPackedParams(torch.nn.Module):
axis=0, dtype=self.dtype)
self.set_weight(wq)
else:
raise NotImplementedError('Unsupported dtype on quantized embedding! Supports quint8 and quint4x2.')
raise NotImplementedError(f'Unsupported dtype on quantized embedding! Supports quint8 and quint4x2. Got dtype: {dtype}')
@torch.jit.export
def set_weight(self, weight: torch.Tensor) -> None:
@ -174,6 +174,20 @@ class Embedding(torch.nn.Module):
qembedding.set_weight(qweight)
return qembedding
@classmethod
def from_reference(cls, ref_embedding):
qembedding = cls(
ref_embedding.num_embeddings,
ref_embedding.embedding_dim,
ref_embedding.padding_idx,
ref_embedding.max_norm,
ref_embedding.norm_type,
ref_embedding.scale_grad_by_freq,
ref_embedding.sparse,
ref_embedding.get_quantized_weight(),
ref_embedding.weight_dtype,
)
return qembedding
class EmbeddingBag(Embedding):
r"""
@ -260,3 +274,19 @@ class EmbeddingBag(Embedding):
qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim, dtype=dtype)
qembedding_bag.set_weight(qweight)
return qembedding_bag
@classmethod
def from_reference(cls, ref_embedding_bag):
qembedding_bag = cls(
ref_embedding_bag.num_embeddings,
ref_embedding_bag.embedding_dim,
ref_embedding_bag.max_norm,
ref_embedding_bag.norm_type,
ref_embedding_bag.scale_grad_by_freq,
ref_embedding_bag.mode,
ref_embedding_bag.sparse,
ref_embedding_bag.get_quantized_weight(),
ref_embedding_bag.include_last_offset,
ref_embedding_bag.weight_dtype,
)
return qembedding_bag

View File

@ -279,7 +279,7 @@ class Linear(WeightedQuantizedModule):
r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
Args:
ref_module (Module): a reference quantized module, either produced by torch.ao.quantization
ref_qlinear (Module): a reference quantized linear module, either produced by torch.ao.quantization
utilities or provided by the user
output_scale (float): scale for output Tensor
zero_point (int): zero point for output Tensor

View File

@ -872,8 +872,8 @@ class QuantizationTestCase(TestCase):
prepare_expected_node_occurrence, prepare_expected_node_list)
prepared_copy = copy.deepcopy(prepared)
qgraph = convert_fx(prepared)
qgraph_reference = convert_fx(prepared_copy, is_reference=True)
qgraph = convert_fx(copy.deepcopy(prepared))
qgraph_reference = convert_fx(copy.deepcopy(prepared), is_reference=True)
result = qgraph(*inputs)
result_reference = qgraph_reference(*inputs)
qgraph_copy = copy.deepcopy(qgraph)