Revert "Revert D29135358: [quant] Input-Weight Equaliaztion - convert modifications" (#60646)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60646

This reverts commit e60f9cfc58.

Test Plan: Imported from OSS

Reviewed By: supriyar

Differential Revision: D29361191

Pulled By: angelayi

fbshipit-source-id: 275d8691d8e47da4ab80bb21b51d77ec25a0f714
This commit is contained in:
angelayi 2021-06-25 15:35:44 -07:00 committed by Facebook GitHub Bot
parent 7188d84ccf
commit e13a9587b4
4 changed files with 378 additions and 27 deletions

View File

@ -2,15 +2,16 @@ import torch
import torch.nn as nn
from torch.quantization import default_qconfig
from torch.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
from torch.quantization.quantize_fx import prepare_fx
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.quantization.fx._equalize import (
_InputEqualizationObserver,
_WeightEqualizationObserver,
calculate_equalization_scale,
default_equalization_qconfig,
_convert_equalization_ref
)
from torch.testing._internal.common_quantization import NodeSpec as ns
from torch.testing._internal.common_quantization import NodeSpec as ns, skipIfNoFBGEMM
from torch.testing._internal.common_quantization import QuantizationTestCase
# Standard Libraries
@ -238,3 +239,59 @@ class TestEqualizeFx(QuantizationTestCase):
m = M().eval()
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict)
self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
@skipIfNoFBGEMM
def test_input_weight_equalization_convert(self):
"""
"""
qconfig_dict = {"": None,
"object_type": [(nn.Linear, default_qconfig),
(nn.functional.linear, default_qconfig)]}
default_equalization_qconfig_dict = {
"": None,
"object_type": [(nn.Linear, default_equalization_qconfig),
(nn.functional.linear, default_equalization_qconfig)]
}
# Basic test with one linear layer
class LinearModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2)
def forward(self, x):
return self.linear(x)
# Test with two linear layer with a fp32 operation between
class Linear2FP32Module(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 2)
self.linear2 = nn.Linear(2, 2)
def forward(self, x):
x = self.linear1(x)
x = torch.add(x, torch.tensor([1, 2]))
x = self.linear2(x)
return x
tests = [(LinearModule, default_equalization_qconfig_dict),
(Linear2FP32Module, default_equalization_qconfig_dict)]
for (M, equalization_qconfig_dict) in tests:
m = M().eval()
x = torch.tensor([[1.0, 2.0], [2.0, 2.5], [4.5, 6.0]])
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict)
output = prepared(x)
convert_ref = _convert_equalization_ref(prepared)
convert_ref_output = convert_ref(x)
m = M().eval()
prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict)
prepared(x)
convert_fx(prepared) # Check if compile?
self.assertEqual(output, convert_ref_output)

View File

@ -1,15 +1,18 @@
import torch
import torch.nn as nn
from torch.fx import GraphModule
from torch.fx.graph import Node
from .utils import get_new_attr_name_with_prefix, maybe_get_next_module
from ..observer import (
PerChannelMinMaxObserver,
_with_args,
ObserverBase,
)
from ..utils import check_min_max_valid
from collections import namedtuple
from typing import Dict, Any, Tuple, Optional
import warnings
@ -28,14 +31,8 @@ class _InputEqualizationObserver(nn.Module):
The running minimum/maximum :math:`x_\text{min/max}` are computed in the
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`,
with the difference that the running min/max values are stored per column.
The qparams are calculated by multiplying the min/max input column values
with the equalization scale, reducing to find the global min/max input
values, and then calculating in the same way as in
:class:`~torch.quantization.observer.MinMaxObserver`
.. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0.
This observer is intended to be used along with a WeightEqualizationObserver
to calculate the equalization scale.
"""
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
@ -70,8 +67,7 @@ class _InputEqualizationObserver(nn.Module):
self.equalization_scale = equalization_scale
def calculate_scaled_minmax(self):
r"""
Returns the scaled min/max inputs
r""" Returns the scaled min/max inputs
"""
if self.equalization_scale.nelement() == 0:
warnings.warn(
@ -104,21 +100,13 @@ class _WeightEqualizationObserver(nn.Module):
quant_max: Maximum quantization value. If unspecified, it will
follow the 8-bit setup.
This observer is made up of 2 PerChannelMinMaxObservers
- weight_col_obs: Used to record the running minimum and maximum of
columns of incoming weight tensors
- weight_row_obs: Used to record the running minimum and maximum of
rows of incoming weight tensors
This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used
to record the running minimum and maximum of columns of incoming weight
tensors. This observer is intended to be used along with an
InputEqualizationObserver to calculate the equalization scale.
The running minimum/maximum :math:`w_\text{min/max}` are computed in the
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`.
The qparams are calculated by multiplying the min/max weight row values
with the inverse of the equalization scale, and then calculating in the same
way as in :class:`~torch.quantization.observer.PerChannelMinMaxObserver`
.. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0.
"""
def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None,
@ -127,7 +115,7 @@ class _WeightEqualizationObserver(nn.Module):
self.dtype = dtype
self.qscheme = qscheme
self.ch_axis = 0
self.ch_axis = 1
self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
qscheme=qscheme,
@ -224,3 +212,280 @@ def node_supports_equalization(node: Node, modules) -> bool:
def is_equalization_observer(observer: nn.Module) -> bool:
return (isinstance(observer, _InputEqualizationObserver) or
isinstance(observer, _WeightEqualizationObserver))
def get_op_node_and_weight_eq_obs(
input_eq_obs_node: Node,
model: GraphModule,
modules: Dict[str, nn.Module]
) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]:
""" Gets the following weight equalization observer. There should always
exist a weight equalization observer after an input equalization observer.
Returns the node containing the weight equalization observer, and the weight
equalization observer if it has been newly created
"""
# Find the op node that comes directly after the input equaliation observer
op_node = None
for user in input_eq_obs_node.users.keys():
if node_supports_equalization(user, modules):
op_node = user
break
assert(op_node is not None)
if op_node.op == 'call_module':
# If the op_node is a nn.Linear layer, then it must have a
# WeightEqualizationObserver configuration
equalization_qconfig_map: Dict[str, Any] = model._equalization_qconfig_map # type: ignore[assignment]
assert(equalization_qconfig_map.get(op_node.name, None) is not None)
weight_eq_obs = equalization_qconfig_map.get(op_node.name, None).weight()
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
return op_node, weight_eq_obs
elif op_node.op == 'call_function':
# TODO
return None, None
return None, None
def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Optional[_InputEqualizationObserver]:
""" Gets the following input equalization observer if it exists.
For example, in the case of connecting linear layers:
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
If the node being passed in is the linear1 node, then we want to return eq_obs2,
the following equalization observer for linear2.
However, if there are no connecting layers:
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add
Then we want to return None.
"""
assert(node_supports_equalization(node, modules))
# Locate the following output observer if it exists
maybe_obs_node = maybe_get_next_module(node, modules, ObserverBase)
if maybe_obs_node is None:
return None
maybe_eq_obs_node = maybe_get_next_module(maybe_obs_node, modules, _InputEqualizationObserver)
if maybe_eq_obs_node is None:
return None
maybe_eq_obs = modules[str(maybe_eq_obs_node)]
assert(isinstance(maybe_eq_obs, _InputEqualizationObserver))
return maybe_eq_obs
def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]:
""" If the next next node is an InputEqualizationObserver then we want to
return its equalization scale, else we return 1
This is used in the case where there are two connecting linear layers:
linear1 -> LinearOutObs -> InputEqObs -> linear2
In this case, the node given is linear1 and we want to locate the InputEqObs.
"""
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
if next_inp_eq_obs:
return next_inp_eq_obs.equalization_scale
return None
def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
""" Scales the following input quantization observer's min/max values by
updating the values with the scaled min/max values calculated by the input
equalization observer
"""
input_eq_obs = modules[str(node.target)]
assert(isinstance(input_eq_obs, _InputEqualizationObserver))
input_quant_obs_node = node.args[0]
assert(isinstance(input_quant_obs_node, Node))
input_quant_obs = modules[str(input_quant_obs_node.target)]
if not isinstance(input_quant_obs, ObserverBase):
return
min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
input_quant_obs.min_val = min_input_scaled
input_quant_obs.max_val = max_input_scaled
def scale_weight_node(
node: Node,
modules: Dict[str, nn.Module],
equalization_scale: torch.Tensor,
next_equalization_scale: Optional[torch.Tensor],
) -> None:
""" Scale the weights for input-weight equalization by multiplying the
weight by 1/equalization_scale and next_equalization_scale
Args:
node: Current node whose weights we want to scale
equalization_scale: Current node's calculated equalization scale
next_equalization_scale: Next node's calculated equalization scale if
the following node needs to be equalized, 1 otherwise
"""
assert(isinstance(node.target, str))
# Scale the weights for input-weight equalization
# If the following layer needs to be equalized then we will multiply its scale
weight = modules[node.target].weight
assert(isinstance(weight, torch.Tensor))
scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale))
if next_equalization_scale is None:
modules[node.target].weight = nn.Parameter(scaled_weight)
return
scaled_weight = torch.mul(scaled_weight, next_equalization_scale)
modules[node.target].weight = nn.Parameter(scaled_weight)
# TODO: The bias may need to be scaled for connecting linear layers
bias = modules[node.target].bias
assert(isinstance(bias, torch.Tensor))
scaled_bias = torch.mul(bias, next_equalization_scale)
modules[node.target].bias = nn.Parameter(scaled_bias)
def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module]) -> Dict[str, _WeightEqualizationObserver]:
""" Update all of the observer's equalization scale. For each
InputEqualizationObserver, we will find the location of the next
WeightEqualizationObserver, create it, and calculate the equalization scale
based on the two observers.
We will then return a dictionary mapping operation node names to
the corresponding WeightEqualizationObservers for that operation.
"""
weight_eq_obs_dict = {}
for node in model.graph.nodes:
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
input_eq_obs = modules[node.target]
assert(isinstance(input_eq_obs, _InputEqualizationObserver))
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
if op_node is None or weight_eq_obs is None:
continue
weight_eq_obs(modules[str(op_node.target)].weight)
# Calculate and set the equalization scale values
equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs)
input_eq_obs.set_equalization_scale(equalization_scale)
weight_eq_obs.set_equalization_scale(equalization_scale)
weight_eq_obs_dict[op_node.name] = weight_eq_obs
return weight_eq_obs_dict
def convert_eq_obs(
model: GraphModule,
modules: Dict[str, nn.Module],
weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver],
) -> None:
""" Converts the equalization operations and updates the other nodes in the
following way:
- Removes the input equalization observers and inserts a mul operator
along with an equalization scale node wherever applicable (we do not
want to insert a mul operator between connecting linear layers).
- Updates the input quantization observers with the scaled input min/max
values.
- Scales the weights by the current and next equalization scales.
- Removes the weight equalization observer node if it exists.
Before (after prepare):
weight values
|
WeightQuantObs
|
WeightEqObs
|
x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs
After this function:
scaled weight values
|
equalization scale WeightQuantObs
| |
x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs
After convert:
equalization scale scaled weight values
| |
x -> mul -> quantize_per_tensor -> quantized::linear
Note that although the equalization observer appeared after the quantization
observer after prepare_fx, the mul node appears before the quantization node
after convert_fx. This is because placing the equalization observer after
the quantization observer in prepare_fx would allow us to keep the invariant
that the graph before the current node inserts its observers is not
modified.
Having the equalization observer before the quantization observer would also
cause some inconsistences between the ordering of the quantization and
equalization observers.
For example, a single linear layer would look like:
x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1
But between two connected linear layers, it would look like:
linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2
"""
for node in model.graph.nodes:
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
inp_quant_obs_node = node.args[0]
prev_node = inp_quant_obs_node.args[0]
# TODO: Possible special handling for connected linear layers
# Update the following input quantization observer's min/max values
scale_input_observer(node, modules)
# Remove the InputEqualization node and add a mul operator before
# the quantization observer node that appears before the equalization node
# Before: x -> input_quant_obs -> input_eq_obs -> linear
# After: x -> mul -> input_quant_obs -> linear
# Create a node containing the equalization scale
with model.graph.inserting_before(inp_quant_obs_node):
get_new_eq_scale_name = get_new_attr_name_with_prefix(prev_node.name + '_equalization_scale')
name = get_new_eq_scale_name(modules)
setattr(model, name, modules[node.target].equalization_scale)
eq_scale_node = model.graph.create_node('get_attr', name)
# Create a node multiplying the input with the equalization scale
with model.graph.inserting_after(eq_scale_node):
inputs = (prev_node, eq_scale_node)
mul_node = model.graph.create_node("call_function", torch.mul, inputs)
# Set the mul nod to be the input_quant_obs_node's input instead of
# the previous node
inp_quant_obs_node.replace_input_with(prev_node, mul_node)
# For all of the current node's users, replace the current node with
# the input quantization observer node
orig_users = list(node.users.keys())
for user_node in orig_users:
user_node.replace_input_with(node, inp_quant_obs_node)
# Erase the InputEqualizationObserver node
model.graph.erase_node(node)
elif weight_eq_obs_dict.get(node.name, None) is not None:
weight_eq_obs = weight_eq_obs_dict.get(node.name)
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
equalization_scale = weight_eq_obs.equalization_scale
# Scales the weights and runs the weight quantization observers
maybe_next_equalization_scale = maybe_get_next_equalization_scale(node, modules)
scale_weight_node(node, modules, equalization_scale, maybe_next_equalization_scale)
def _convert_equalization_ref(model: GraphModule):
""" Reference function which applies changes needed for equalization, but
does not quantize the nodes
"""
modules = dict(model.named_modules(remove_duplicate=False))
# Calculate the equalization scale, update the observers with the scaled
# inputs, and scale the weight
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
convert_eq_obs(model, modules, weight_eq_obs_dict)
return GraphModule(model, model.graph)

View File

@ -24,6 +24,7 @@ from .graph_module import (
from .quantization_patterns import (
QuantizeHandler,
)
from ._equalize import update_obs_for_equalization, convert_eq_obs
from .utils import (
is_get_tensor_info_node,
node_return_type_is_int,
@ -180,6 +181,13 @@ def convert(model: GraphModule, is_reference: bool = False,
qconfig_map,
custom_module_classes=custom_module_classes)
if model._equalization_qconfig_map is not None:
# If we want to do equalization then do the following:
# Calculate the equalization scale, update the observers with the scaled
# inputs, and scale the weight
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
convert_eq_obs(model, modules, weight_eq_obs_dict)
quantized_graph = Graph()
env: Dict[str, Dict[Optional[torch.dtype], Node]] = defaultdict(lambda: defaultdict(Node)) # type: ignore[arg-type]

View File

@ -1,5 +1,6 @@
import re
import torch
import torch.nn as nn
from ..utils import is_per_tensor, is_per_channel
from ..quantize import is_activation_post_process
@ -10,7 +11,7 @@ from torch.fx.graph import (
Node,
)
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
import operator
# A dictionary for querying the weight index for a given op
@ -481,3 +482,23 @@ def is_get_tensor_info_node(node: Node) -> bool:
result: bool = \
node.op == "call_function" and node.target == getattr and node.args[1] == "shape" # type: ignore[assignment]
return result
def maybe_get_next_module(
node: Node,
modules: Dict[str, nn.Module],
target_module_type: Type[nn.Module],
) -> Optional[Node]:
""" Gets the next module that matches what is needed in
is_target_module_type if it exists
Args:
node: The node whose users we want to look at
is_target_module_type: Function that returns true if the given module
matches the type specified in the function.
"""
for user, _ in node.users.items():
if user.op == 'call_module' and isinstance(modules[str(user.target)], target_module_type):
return user
return None