mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert "Revert D29135358: [quant] Input-Weight Equaliaztion - convert modifications" (#60646)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60646
This reverts commit e60f9cfc58.
Test Plan: Imported from OSS
Reviewed By: supriyar
Differential Revision: D29361191
Pulled By: angelayi
fbshipit-source-id: 275d8691d8e47da4ab80bb21b51d77ec25a0f714
This commit is contained in:
parent
7188d84ccf
commit
e13a9587b4
|
|
@ -2,15 +2,16 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.quantization import default_qconfig
|
from torch.quantization import default_qconfig
|
||||||
from torch.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
|
from torch.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
|
||||||
from torch.quantization.quantize_fx import prepare_fx
|
from torch.quantization.quantize_fx import prepare_fx, convert_fx
|
||||||
from torch.quantization.fx._equalize import (
|
from torch.quantization.fx._equalize import (
|
||||||
_InputEqualizationObserver,
|
_InputEqualizationObserver,
|
||||||
_WeightEqualizationObserver,
|
_WeightEqualizationObserver,
|
||||||
calculate_equalization_scale,
|
calculate_equalization_scale,
|
||||||
default_equalization_qconfig,
|
default_equalization_qconfig,
|
||||||
|
_convert_equalization_ref
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.testing._internal.common_quantization import NodeSpec as ns
|
from torch.testing._internal.common_quantization import NodeSpec as ns, skipIfNoFBGEMM
|
||||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||||
|
|
||||||
# Standard Libraries
|
# Standard Libraries
|
||||||
|
|
@ -238,3 +239,59 @@ class TestEqualizeFx(QuantizationTestCase):
|
||||||
m = M().eval()
|
m = M().eval()
|
||||||
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict)
|
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict)
|
||||||
self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
|
self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
|
||||||
|
|
||||||
|
@skipIfNoFBGEMM
|
||||||
|
def test_input_weight_equalization_convert(self):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
qconfig_dict = {"": None,
|
||||||
|
"object_type": [(nn.Linear, default_qconfig),
|
||||||
|
(nn.functional.linear, default_qconfig)]}
|
||||||
|
|
||||||
|
default_equalization_qconfig_dict = {
|
||||||
|
"": None,
|
||||||
|
"object_type": [(nn.Linear, default_equalization_qconfig),
|
||||||
|
(nn.functional.linear, default_equalization_qconfig)]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Basic test with one linear layer
|
||||||
|
class LinearModule(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(2, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
# Test with two linear layer with a fp32 operation between
|
||||||
|
class Linear2FP32Module(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(2, 2)
|
||||||
|
self.linear2 = nn.Linear(2, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear1(x)
|
||||||
|
x = torch.add(x, torch.tensor([1, 2]))
|
||||||
|
x = self.linear2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
tests = [(LinearModule, default_equalization_qconfig_dict),
|
||||||
|
(Linear2FP32Module, default_equalization_qconfig_dict)]
|
||||||
|
|
||||||
|
for (M, equalization_qconfig_dict) in tests:
|
||||||
|
m = M().eval()
|
||||||
|
x = torch.tensor([[1.0, 2.0], [2.0, 2.5], [4.5, 6.0]])
|
||||||
|
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict)
|
||||||
|
output = prepared(x)
|
||||||
|
|
||||||
|
convert_ref = _convert_equalization_ref(prepared)
|
||||||
|
convert_ref_output = convert_ref(x)
|
||||||
|
|
||||||
|
m = M().eval()
|
||||||
|
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict)
|
||||||
|
prepared(x)
|
||||||
|
convert_fx(prepared) # Check if compile?
|
||||||
|
|
||||||
|
self.assertEqual(output, convert_ref_output)
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,18 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.fx import GraphModule
|
||||||
from torch.fx.graph import Node
|
from torch.fx.graph import Node
|
||||||
|
|
||||||
|
from .utils import get_new_attr_name_with_prefix, maybe_get_next_module
|
||||||
from ..observer import (
|
from ..observer import (
|
||||||
PerChannelMinMaxObserver,
|
PerChannelMinMaxObserver,
|
||||||
_with_args,
|
_with_args,
|
||||||
|
ObserverBase,
|
||||||
)
|
)
|
||||||
from ..utils import check_min_max_valid
|
from ..utils import check_min_max_valid
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from typing import Dict, Any, Tuple, Optional
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -28,14 +31,8 @@ class _InputEqualizationObserver(nn.Module):
|
||||||
The running minimum/maximum :math:`x_\text{min/max}` are computed in the
|
The running minimum/maximum :math:`x_\text{min/max}` are computed in the
|
||||||
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`,
|
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`,
|
||||||
with the difference that the running min/max values are stored per column.
|
with the difference that the running min/max values are stored per column.
|
||||||
|
This observer is intended to be used along with a WeightEqualizationObserver
|
||||||
The qparams are calculated by multiplying the min/max input column values
|
to calculate the equalization scale.
|
||||||
with the equalization scale, reducing to find the global min/max input
|
|
||||||
values, and then calculating in the same way as in
|
|
||||||
:class:`~torch.quantization.observer.MinMaxObserver`
|
|
||||||
|
|
||||||
.. note:: If the running minimum equals to the running maximum, the scales
|
|
||||||
and zero_points are set to 1.0 and 0.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
|
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
|
||||||
|
|
@ -70,8 +67,7 @@ class _InputEqualizationObserver(nn.Module):
|
||||||
self.equalization_scale = equalization_scale
|
self.equalization_scale = equalization_scale
|
||||||
|
|
||||||
def calculate_scaled_minmax(self):
|
def calculate_scaled_minmax(self):
|
||||||
r"""
|
r""" Returns the scaled min/max inputs
|
||||||
Returns the scaled min/max inputs
|
|
||||||
"""
|
"""
|
||||||
if self.equalization_scale.nelement() == 0:
|
if self.equalization_scale.nelement() == 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
@ -104,21 +100,13 @@ class _WeightEqualizationObserver(nn.Module):
|
||||||
quant_max: Maximum quantization value. If unspecified, it will
|
quant_max: Maximum quantization value. If unspecified, it will
|
||||||
follow the 8-bit setup.
|
follow the 8-bit setup.
|
||||||
|
|
||||||
This observer is made up of 2 PerChannelMinMaxObservers
|
This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used
|
||||||
- weight_col_obs: Used to record the running minimum and maximum of
|
to record the running minimum and maximum of columns of incoming weight
|
||||||
columns of incoming weight tensors
|
tensors. This observer is intended to be used along with an
|
||||||
- weight_row_obs: Used to record the running minimum and maximum of
|
InputEqualizationObserver to calculate the equalization scale.
|
||||||
rows of incoming weight tensors
|
|
||||||
|
|
||||||
The running minimum/maximum :math:`w_\text{min/max}` are computed in the
|
The running minimum/maximum :math:`w_\text{min/max}` are computed in the
|
||||||
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`.
|
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`.
|
||||||
|
|
||||||
The qparams are calculated by multiplying the min/max weight row values
|
|
||||||
with the inverse of the equalization scale, and then calculating in the same
|
|
||||||
way as in :class:`~torch.quantization.observer.PerChannelMinMaxObserver`
|
|
||||||
|
|
||||||
.. note:: If the running minimum equals to the running maximum, the scales
|
|
||||||
and zero_points are set to 1.0 and 0.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None,
|
def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None,
|
||||||
|
|
@ -127,7 +115,7 @@ class _WeightEqualizationObserver(nn.Module):
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.qscheme = qscheme
|
self.qscheme = qscheme
|
||||||
self.ch_axis = 0
|
self.ch_axis = 1
|
||||||
|
|
||||||
self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
|
self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
|
||||||
qscheme=qscheme,
|
qscheme=qscheme,
|
||||||
|
|
@ -224,3 +212,280 @@ def node_supports_equalization(node: Node, modules) -> bool:
|
||||||
def is_equalization_observer(observer: nn.Module) -> bool:
|
def is_equalization_observer(observer: nn.Module) -> bool:
|
||||||
return (isinstance(observer, _InputEqualizationObserver) or
|
return (isinstance(observer, _InputEqualizationObserver) or
|
||||||
isinstance(observer, _WeightEqualizationObserver))
|
isinstance(observer, _WeightEqualizationObserver))
|
||||||
|
|
||||||
|
def get_op_node_and_weight_eq_obs(
|
||||||
|
input_eq_obs_node: Node,
|
||||||
|
model: GraphModule,
|
||||||
|
modules: Dict[str, nn.Module]
|
||||||
|
) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]:
|
||||||
|
""" Gets the following weight equalization observer. There should always
|
||||||
|
exist a weight equalization observer after an input equalization observer.
|
||||||
|
|
||||||
|
Returns the node containing the weight equalization observer, and the weight
|
||||||
|
equalization observer if it has been newly created
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Find the op node that comes directly after the input equaliation observer
|
||||||
|
op_node = None
|
||||||
|
for user in input_eq_obs_node.users.keys():
|
||||||
|
if node_supports_equalization(user, modules):
|
||||||
|
op_node = user
|
||||||
|
break
|
||||||
|
|
||||||
|
assert(op_node is not None)
|
||||||
|
if op_node.op == 'call_module':
|
||||||
|
# If the op_node is a nn.Linear layer, then it must have a
|
||||||
|
# WeightEqualizationObserver configuration
|
||||||
|
equalization_qconfig_map: Dict[str, Any] = model._equalization_qconfig_map # type: ignore[assignment]
|
||||||
|
assert(equalization_qconfig_map.get(op_node.name, None) is not None)
|
||||||
|
weight_eq_obs = equalization_qconfig_map.get(op_node.name, None).weight()
|
||||||
|
|
||||||
|
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
|
||||||
|
return op_node, weight_eq_obs
|
||||||
|
|
||||||
|
elif op_node.op == 'call_function':
|
||||||
|
# TODO
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Optional[_InputEqualizationObserver]:
|
||||||
|
""" Gets the following input equalization observer if it exists.
|
||||||
|
|
||||||
|
For example, in the case of connecting linear layers:
|
||||||
|
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
|
||||||
|
If the node being passed in is the linear1 node, then we want to return eq_obs2,
|
||||||
|
the following equalization observer for linear2.
|
||||||
|
|
||||||
|
However, if there are no connecting layers:
|
||||||
|
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add
|
||||||
|
Then we want to return None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert(node_supports_equalization(node, modules))
|
||||||
|
|
||||||
|
# Locate the following output observer if it exists
|
||||||
|
maybe_obs_node = maybe_get_next_module(node, modules, ObserverBase)
|
||||||
|
if maybe_obs_node is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
maybe_eq_obs_node = maybe_get_next_module(maybe_obs_node, modules, _InputEqualizationObserver)
|
||||||
|
if maybe_eq_obs_node is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
maybe_eq_obs = modules[str(maybe_eq_obs_node)]
|
||||||
|
assert(isinstance(maybe_eq_obs, _InputEqualizationObserver))
|
||||||
|
return maybe_eq_obs
|
||||||
|
|
||||||
|
def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]:
|
||||||
|
""" If the next next node is an InputEqualizationObserver then we want to
|
||||||
|
return its equalization scale, else we return 1
|
||||||
|
|
||||||
|
This is used in the case where there are two connecting linear layers:
|
||||||
|
linear1 -> LinearOutObs -> InputEqObs -> linear2
|
||||||
|
In this case, the node given is linear1 and we want to locate the InputEqObs.
|
||||||
|
"""
|
||||||
|
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
|
||||||
|
if next_inp_eq_obs:
|
||||||
|
return next_inp_eq_obs.equalization_scale
|
||||||
|
return None
|
||||||
|
|
||||||
|
def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
|
||||||
|
""" Scales the following input quantization observer's min/max values by
|
||||||
|
updating the values with the scaled min/max values calculated by the input
|
||||||
|
equalization observer
|
||||||
|
"""
|
||||||
|
input_eq_obs = modules[str(node.target)]
|
||||||
|
assert(isinstance(input_eq_obs, _InputEqualizationObserver))
|
||||||
|
|
||||||
|
input_quant_obs_node = node.args[0]
|
||||||
|
assert(isinstance(input_quant_obs_node, Node))
|
||||||
|
|
||||||
|
input_quant_obs = modules[str(input_quant_obs_node.target)]
|
||||||
|
if not isinstance(input_quant_obs, ObserverBase):
|
||||||
|
return
|
||||||
|
|
||||||
|
min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
|
||||||
|
input_quant_obs.min_val = min_input_scaled
|
||||||
|
input_quant_obs.max_val = max_input_scaled
|
||||||
|
|
||||||
|
def scale_weight_node(
|
||||||
|
node: Node,
|
||||||
|
modules: Dict[str, nn.Module],
|
||||||
|
equalization_scale: torch.Tensor,
|
||||||
|
next_equalization_scale: Optional[torch.Tensor],
|
||||||
|
) -> None:
|
||||||
|
""" Scale the weights for input-weight equalization by multiplying the
|
||||||
|
weight by 1/equalization_scale and next_equalization_scale
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: Current node whose weights we want to scale
|
||||||
|
equalization_scale: Current node's calculated equalization scale
|
||||||
|
next_equalization_scale: Next node's calculated equalization scale if
|
||||||
|
the following node needs to be equalized, 1 otherwise
|
||||||
|
"""
|
||||||
|
assert(isinstance(node.target, str))
|
||||||
|
|
||||||
|
# Scale the weights for input-weight equalization
|
||||||
|
# If the following layer needs to be equalized then we will multiply its scale
|
||||||
|
weight = modules[node.target].weight
|
||||||
|
assert(isinstance(weight, torch.Tensor))
|
||||||
|
|
||||||
|
scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale))
|
||||||
|
|
||||||
|
if next_equalization_scale is None:
|
||||||
|
modules[node.target].weight = nn.Parameter(scaled_weight)
|
||||||
|
return
|
||||||
|
|
||||||
|
scaled_weight = torch.mul(scaled_weight, next_equalization_scale)
|
||||||
|
modules[node.target].weight = nn.Parameter(scaled_weight)
|
||||||
|
|
||||||
|
# TODO: The bias may need to be scaled for connecting linear layers
|
||||||
|
bias = modules[node.target].bias
|
||||||
|
assert(isinstance(bias, torch.Tensor))
|
||||||
|
|
||||||
|
scaled_bias = torch.mul(bias, next_equalization_scale)
|
||||||
|
modules[node.target].bias = nn.Parameter(scaled_bias)
|
||||||
|
|
||||||
|
def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module]) -> Dict[str, _WeightEqualizationObserver]:
|
||||||
|
""" Update all of the observer's equalization scale. For each
|
||||||
|
InputEqualizationObserver, we will find the location of the next
|
||||||
|
WeightEqualizationObserver, create it, and calculate the equalization scale
|
||||||
|
based on the two observers.
|
||||||
|
|
||||||
|
We will then return a dictionary mapping operation node names to
|
||||||
|
the corresponding WeightEqualizationObservers for that operation.
|
||||||
|
"""
|
||||||
|
weight_eq_obs_dict = {}
|
||||||
|
for node in model.graph.nodes:
|
||||||
|
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
|
||||||
|
input_eq_obs = modules[node.target]
|
||||||
|
assert(isinstance(input_eq_obs, _InputEqualizationObserver))
|
||||||
|
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
|
||||||
|
|
||||||
|
if op_node is None or weight_eq_obs is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
weight_eq_obs(modules[str(op_node.target)].weight)
|
||||||
|
|
||||||
|
# Calculate and set the equalization scale values
|
||||||
|
equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs)
|
||||||
|
input_eq_obs.set_equalization_scale(equalization_scale)
|
||||||
|
weight_eq_obs.set_equalization_scale(equalization_scale)
|
||||||
|
|
||||||
|
weight_eq_obs_dict[op_node.name] = weight_eq_obs
|
||||||
|
|
||||||
|
return weight_eq_obs_dict
|
||||||
|
|
||||||
|
def convert_eq_obs(
|
||||||
|
model: GraphModule,
|
||||||
|
modules: Dict[str, nn.Module],
|
||||||
|
weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver],
|
||||||
|
) -> None:
|
||||||
|
""" Converts the equalization operations and updates the other nodes in the
|
||||||
|
following way:
|
||||||
|
- Removes the input equalization observers and inserts a mul operator
|
||||||
|
along with an equalization scale node wherever applicable (we do not
|
||||||
|
want to insert a mul operator between connecting linear layers).
|
||||||
|
- Updates the input quantization observers with the scaled input min/max
|
||||||
|
values.
|
||||||
|
- Scales the weights by the current and next equalization scales.
|
||||||
|
- Removes the weight equalization observer node if it exists.
|
||||||
|
|
||||||
|
Before (after prepare):
|
||||||
|
weight values
|
||||||
|
|
|
||||||
|
WeightQuantObs
|
||||||
|
|
|
||||||
|
WeightEqObs
|
||||||
|
|
|
||||||
|
x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs
|
||||||
|
|
||||||
|
After this function:
|
||||||
|
scaled weight values
|
||||||
|
|
|
||||||
|
equalization scale WeightQuantObs
|
||||||
|
| |
|
||||||
|
x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs
|
||||||
|
|
||||||
|
After convert:
|
||||||
|
equalization scale scaled weight values
|
||||||
|
| |
|
||||||
|
x -> mul -> quantize_per_tensor -> quantized::linear
|
||||||
|
|
||||||
|
Note that although the equalization observer appeared after the quantization
|
||||||
|
observer after prepare_fx, the mul node appears before the quantization node
|
||||||
|
after convert_fx. This is because placing the equalization observer after
|
||||||
|
the quantization observer in prepare_fx would allow us to keep the invariant
|
||||||
|
that the graph before the current node inserts its observers is not
|
||||||
|
modified.
|
||||||
|
|
||||||
|
Having the equalization observer before the quantization observer would also
|
||||||
|
cause some inconsistences between the ordering of the quantization and
|
||||||
|
equalization observers.
|
||||||
|
For example, a single linear layer would look like:
|
||||||
|
x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1
|
||||||
|
But between two connected linear layers, it would look like:
|
||||||
|
linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2
|
||||||
|
"""
|
||||||
|
for node in model.graph.nodes:
|
||||||
|
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
|
||||||
|
inp_quant_obs_node = node.args[0]
|
||||||
|
prev_node = inp_quant_obs_node.args[0]
|
||||||
|
|
||||||
|
# TODO: Possible special handling for connected linear layers
|
||||||
|
# Update the following input quantization observer's min/max values
|
||||||
|
scale_input_observer(node, modules)
|
||||||
|
|
||||||
|
# Remove the InputEqualization node and add a mul operator before
|
||||||
|
# the quantization observer node that appears before the equalization node
|
||||||
|
# Before: x -> input_quant_obs -> input_eq_obs -> linear
|
||||||
|
# After: x -> mul -> input_quant_obs -> linear
|
||||||
|
|
||||||
|
# Create a node containing the equalization scale
|
||||||
|
with model.graph.inserting_before(inp_quant_obs_node):
|
||||||
|
get_new_eq_scale_name = get_new_attr_name_with_prefix(prev_node.name + '_equalization_scale')
|
||||||
|
name = get_new_eq_scale_name(modules)
|
||||||
|
setattr(model, name, modules[node.target].equalization_scale)
|
||||||
|
eq_scale_node = model.graph.create_node('get_attr', name)
|
||||||
|
|
||||||
|
# Create a node multiplying the input with the equalization scale
|
||||||
|
with model.graph.inserting_after(eq_scale_node):
|
||||||
|
inputs = (prev_node, eq_scale_node)
|
||||||
|
mul_node = model.graph.create_node("call_function", torch.mul, inputs)
|
||||||
|
|
||||||
|
# Set the mul nod to be the input_quant_obs_node's input instead of
|
||||||
|
# the previous node
|
||||||
|
inp_quant_obs_node.replace_input_with(prev_node, mul_node)
|
||||||
|
|
||||||
|
# For all of the current node's users, replace the current node with
|
||||||
|
# the input quantization observer node
|
||||||
|
orig_users = list(node.users.keys())
|
||||||
|
for user_node in orig_users:
|
||||||
|
user_node.replace_input_with(node, inp_quant_obs_node)
|
||||||
|
|
||||||
|
# Erase the InputEqualizationObserver node
|
||||||
|
model.graph.erase_node(node)
|
||||||
|
|
||||||
|
elif weight_eq_obs_dict.get(node.name, None) is not None:
|
||||||
|
weight_eq_obs = weight_eq_obs_dict.get(node.name)
|
||||||
|
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
|
||||||
|
|
||||||
|
equalization_scale = weight_eq_obs.equalization_scale
|
||||||
|
|
||||||
|
# Scales the weights and runs the weight quantization observers
|
||||||
|
maybe_next_equalization_scale = maybe_get_next_equalization_scale(node, modules)
|
||||||
|
scale_weight_node(node, modules, equalization_scale, maybe_next_equalization_scale)
|
||||||
|
|
||||||
|
def _convert_equalization_ref(model: GraphModule):
|
||||||
|
""" Reference function which applies changes needed for equalization, but
|
||||||
|
does not quantize the nodes
|
||||||
|
"""
|
||||||
|
modules = dict(model.named_modules(remove_duplicate=False))
|
||||||
|
|
||||||
|
# Calculate the equalization scale, update the observers with the scaled
|
||||||
|
# inputs, and scale the weight
|
||||||
|
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
|
||||||
|
convert_eq_obs(model, modules, weight_eq_obs_dict)
|
||||||
|
|
||||||
|
return GraphModule(model, model.graph)
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from .graph_module import (
|
||||||
from .quantization_patterns import (
|
from .quantization_patterns import (
|
||||||
QuantizeHandler,
|
QuantizeHandler,
|
||||||
)
|
)
|
||||||
|
from ._equalize import update_obs_for_equalization, convert_eq_obs
|
||||||
from .utils import (
|
from .utils import (
|
||||||
is_get_tensor_info_node,
|
is_get_tensor_info_node,
|
||||||
node_return_type_is_int,
|
node_return_type_is_int,
|
||||||
|
|
@ -180,6 +181,13 @@ def convert(model: GraphModule, is_reference: bool = False,
|
||||||
qconfig_map,
|
qconfig_map,
|
||||||
custom_module_classes=custom_module_classes)
|
custom_module_classes=custom_module_classes)
|
||||||
|
|
||||||
|
if model._equalization_qconfig_map is not None:
|
||||||
|
# If we want to do equalization then do the following:
|
||||||
|
# Calculate the equalization scale, update the observers with the scaled
|
||||||
|
# inputs, and scale the weight
|
||||||
|
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
|
||||||
|
convert_eq_obs(model, modules, weight_eq_obs_dict)
|
||||||
|
|
||||||
quantized_graph = Graph()
|
quantized_graph = Graph()
|
||||||
env: Dict[str, Dict[Optional[torch.dtype], Node]] = defaultdict(lambda: defaultdict(Node)) # type: ignore[arg-type]
|
env: Dict[str, Dict[Optional[torch.dtype], Node]] = defaultdict(lambda: defaultdict(Node)) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from ..utils import is_per_tensor, is_per_channel
|
from ..utils import is_per_tensor, is_per_channel
|
||||||
from ..quantize import is_activation_post_process
|
from ..quantize import is_activation_post_process
|
||||||
|
|
||||||
|
|
@ -10,7 +11,7 @@ from torch.fx.graph import (
|
||||||
Node,
|
Node,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union
|
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
|
||||||
import operator
|
import operator
|
||||||
|
|
||||||
# A dictionary for querying the weight index for a given op
|
# A dictionary for querying the weight index for a given op
|
||||||
|
|
@ -481,3 +482,23 @@ def is_get_tensor_info_node(node: Node) -> bool:
|
||||||
result: bool = \
|
result: bool = \
|
||||||
node.op == "call_function" and node.target == getattr and node.args[1] == "shape" # type: ignore[assignment]
|
node.op == "call_function" and node.target == getattr and node.args[1] == "shape" # type: ignore[assignment]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def maybe_get_next_module(
|
||||||
|
node: Node,
|
||||||
|
modules: Dict[str, nn.Module],
|
||||||
|
target_module_type: Type[nn.Module],
|
||||||
|
) -> Optional[Node]:
|
||||||
|
""" Gets the next module that matches what is needed in
|
||||||
|
is_target_module_type if it exists
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: The node whose users we want to look at
|
||||||
|
is_target_module_type: Function that returns true if the given module
|
||||||
|
matches the type specified in the function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for user, _ in node.users.items():
|
||||||
|
if user.op == 'call_module' and isinstance(modules[str(user.target)], target_module_type):
|
||||||
|
return user
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user