From e13a9587b47bb5f82b9762ef0bb38ab3d3fd5234 Mon Sep 17 00:00:00 2001 From: angelayi Date: Fri, 25 Jun 2021 15:35:44 -0700 Subject: [PATCH] Revert "Revert D29135358: [quant] Input-Weight Equaliaztion - convert modifications" (#60646) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60646 This reverts commit e60f9cfc58fb2fe3e2e7f65fcdbbf350e5b55a75. Test Plan: Imported from OSS Reviewed By: supriyar Differential Revision: D29361191 Pulled By: angelayi fbshipit-source-id: 275d8691d8e47da4ab80bb21b51d77ec25a0f714 --- test/quantization/fx/test_equalize_fx.py | 61 ++++- torch/quantization/fx/_equalize.py | 313 +++++++++++++++++++++-- torch/quantization/fx/convert.py | 8 + torch/quantization/fx/utils.py | 23 +- 4 files changed, 378 insertions(+), 27 deletions(-) diff --git a/test/quantization/fx/test_equalize_fx.py b/test/quantization/fx/test_equalize_fx.py index 99b74379c9d..25a76518f5a 100644 --- a/test/quantization/fx/test_equalize_fx.py +++ b/test/quantization/fx/test_equalize_fx.py @@ -2,15 +2,16 @@ import torch import torch.nn as nn from torch.quantization import default_qconfig from torch.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver -from torch.quantization.quantize_fx import prepare_fx +from torch.quantization.quantize_fx import prepare_fx, convert_fx from torch.quantization.fx._equalize import ( _InputEqualizationObserver, _WeightEqualizationObserver, calculate_equalization_scale, default_equalization_qconfig, + _convert_equalization_ref ) -from torch.testing._internal.common_quantization import NodeSpec as ns +from torch.testing._internal.common_quantization import NodeSpec as ns, skipIfNoFBGEMM from torch.testing._internal.common_quantization import QuantizationTestCase # Standard Libraries @@ -238,3 +239,59 @@ class TestEqualizeFx(QuantizationTestCase): m = M().eval() prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict) self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence) + + @skipIfNoFBGEMM + def test_input_weight_equalization_convert(self): + """ + """ + qconfig_dict = {"": None, + "object_type": [(nn.Linear, default_qconfig), + (nn.functional.linear, default_qconfig)]} + + default_equalization_qconfig_dict = { + "": None, + "object_type": [(nn.Linear, default_equalization_qconfig), + (nn.functional.linear, default_equalization_qconfig)] + } + + # Basic test with one linear layer + class LinearModule(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + # Test with two linear layer with a fp32 operation between + class Linear2FP32Module(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(2, 2) + self.linear2 = nn.Linear(2, 2) + + def forward(self, x): + x = self.linear1(x) + x = torch.add(x, torch.tensor([1, 2])) + x = self.linear2(x) + return x + + + tests = [(LinearModule, default_equalization_qconfig_dict), + (Linear2FP32Module, default_equalization_qconfig_dict)] + + for (M, equalization_qconfig_dict) in tests: + m = M().eval() + x = torch.tensor([[1.0, 2.0], [2.0, 2.5], [4.5, 6.0]]) + prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict) + output = prepared(x) + + convert_ref = _convert_equalization_ref(prepared) + convert_ref_output = convert_ref(x) + + m = M().eval() + prepared = prepare_fx(m, qconfig_dict, equalization_qconfig_dict=equalization_qconfig_dict) + prepared(x) + convert_fx(prepared) # Check if compile? + + self.assertEqual(output, convert_ref_output) diff --git a/torch/quantization/fx/_equalize.py b/torch/quantization/fx/_equalize.py index 9fbe1fc9bcb..db3c95ef48b 100644 --- a/torch/quantization/fx/_equalize.py +++ b/torch/quantization/fx/_equalize.py @@ -1,15 +1,18 @@ import torch import torch.nn as nn - +from torch.fx import GraphModule from torch.fx.graph import Node +from .utils import get_new_attr_name_with_prefix, maybe_get_next_module from ..observer import ( PerChannelMinMaxObserver, _with_args, + ObserverBase, ) from ..utils import check_min_max_valid from collections import namedtuple +from typing import Dict, Any, Tuple, Optional import warnings @@ -28,14 +31,8 @@ class _InputEqualizationObserver(nn.Module): The running minimum/maximum :math:`x_\text{min/max}` are computed in the same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`, with the difference that the running min/max values are stored per column. - - The qparams are calculated by multiplying the min/max input column values - with the equalization scale, reducing to find the global min/max input - values, and then calculating in the same way as in - :class:`~torch.quantization.observer.MinMaxObserver` - - .. note:: If the running minimum equals to the running maximum, the scales - and zero_points are set to 1.0 and 0. + This observer is intended to be used along with a WeightEqualizationObserver + to calculate the equalization scale. """ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, @@ -70,8 +67,7 @@ class _InputEqualizationObserver(nn.Module): self.equalization_scale = equalization_scale def calculate_scaled_minmax(self): - r""" - Returns the scaled min/max inputs + r""" Returns the scaled min/max inputs """ if self.equalization_scale.nelement() == 0: warnings.warn( @@ -104,21 +100,13 @@ class _WeightEqualizationObserver(nn.Module): quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup. - This observer is made up of 2 PerChannelMinMaxObservers - - weight_col_obs: Used to record the running minimum and maximum of - columns of incoming weight tensors - - weight_row_obs: Used to record the running minimum and maximum of - rows of incoming weight tensors + This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used + to record the running minimum and maximum of columns of incoming weight + tensors. This observer is intended to be used along with an + InputEqualizationObserver to calculate the equalization scale. The running minimum/maximum :math:`w_\text{min/max}` are computed in the same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`. - - The qparams are calculated by multiplying the min/max weight row values - with the inverse of the equalization scale, and then calculating in the same - way as in :class:`~torch.quantization.observer.PerChannelMinMaxObserver` - - .. note:: If the running minimum equals to the running maximum, the scales - and zero_points are set to 1.0 and 0. """ def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None, @@ -127,7 +115,7 @@ class _WeightEqualizationObserver(nn.Module): self.dtype = dtype self.qscheme = qscheme - self.ch_axis = 0 + self.ch_axis = 1 self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype, qscheme=qscheme, @@ -224,3 +212,280 @@ def node_supports_equalization(node: Node, modules) -> bool: def is_equalization_observer(observer: nn.Module) -> bool: return (isinstance(observer, _InputEqualizationObserver) or isinstance(observer, _WeightEqualizationObserver)) + +def get_op_node_and_weight_eq_obs( + input_eq_obs_node: Node, + model: GraphModule, + modules: Dict[str, nn.Module] +) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]: + """ Gets the following weight equalization observer. There should always + exist a weight equalization observer after an input equalization observer. + + Returns the node containing the weight equalization observer, and the weight + equalization observer if it has been newly created + """ + + # Find the op node that comes directly after the input equaliation observer + op_node = None + for user in input_eq_obs_node.users.keys(): + if node_supports_equalization(user, modules): + op_node = user + break + + assert(op_node is not None) + if op_node.op == 'call_module': + # If the op_node is a nn.Linear layer, then it must have a + # WeightEqualizationObserver configuration + equalization_qconfig_map: Dict[str, Any] = model._equalization_qconfig_map # type: ignore[assignment] + assert(equalization_qconfig_map.get(op_node.name, None) is not None) + weight_eq_obs = equalization_qconfig_map.get(op_node.name, None).weight() + + assert(isinstance(weight_eq_obs, _WeightEqualizationObserver)) + return op_node, weight_eq_obs + + elif op_node.op == 'call_function': + # TODO + return None, None + + return None, None + +def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Optional[_InputEqualizationObserver]: + """ Gets the following input equalization observer if it exists. + + For example, in the case of connecting linear layers: + x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2 + If the node being passed in is the linear1 node, then we want to return eq_obs2, + the following equalization observer for linear2. + + However, if there are no connecting layers: + x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add + Then we want to return None. + """ + + assert(node_supports_equalization(node, modules)) + + # Locate the following output observer if it exists + maybe_obs_node = maybe_get_next_module(node, modules, ObserverBase) + if maybe_obs_node is None: + return None + + maybe_eq_obs_node = maybe_get_next_module(maybe_obs_node, modules, _InputEqualizationObserver) + if maybe_eq_obs_node is None: + return None + + maybe_eq_obs = modules[str(maybe_eq_obs_node)] + assert(isinstance(maybe_eq_obs, _InputEqualizationObserver)) + return maybe_eq_obs + +def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]: + """ If the next next node is an InputEqualizationObserver then we want to + return its equalization scale, else we return 1 + + This is used in the case where there are two connecting linear layers: + linear1 -> LinearOutObs -> InputEqObs -> linear2 + In this case, the node given is linear1 and we want to locate the InputEqObs. + """ + next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules) + if next_inp_eq_obs: + return next_inp_eq_obs.equalization_scale + return None + +def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None: + """ Scales the following input quantization observer's min/max values by + updating the values with the scaled min/max values calculated by the input + equalization observer + """ + input_eq_obs = modules[str(node.target)] + assert(isinstance(input_eq_obs, _InputEqualizationObserver)) + + input_quant_obs_node = node.args[0] + assert(isinstance(input_quant_obs_node, Node)) + + input_quant_obs = modules[str(input_quant_obs_node.target)] + if not isinstance(input_quant_obs, ObserverBase): + return + + min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax() + input_quant_obs.min_val = min_input_scaled + input_quant_obs.max_val = max_input_scaled + +def scale_weight_node( + node: Node, + modules: Dict[str, nn.Module], + equalization_scale: torch.Tensor, + next_equalization_scale: Optional[torch.Tensor], +) -> None: + """ Scale the weights for input-weight equalization by multiplying the + weight by 1/equalization_scale and next_equalization_scale + + Args: + node: Current node whose weights we want to scale + equalization_scale: Current node's calculated equalization scale + next_equalization_scale: Next node's calculated equalization scale if + the following node needs to be equalized, 1 otherwise + """ + assert(isinstance(node.target, str)) + + # Scale the weights for input-weight equalization + # If the following layer needs to be equalized then we will multiply its scale + weight = modules[node.target].weight + assert(isinstance(weight, torch.Tensor)) + + scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale)) + + if next_equalization_scale is None: + modules[node.target].weight = nn.Parameter(scaled_weight) + return + + scaled_weight = torch.mul(scaled_weight, next_equalization_scale) + modules[node.target].weight = nn.Parameter(scaled_weight) + + # TODO: The bias may need to be scaled for connecting linear layers + bias = modules[node.target].bias + assert(isinstance(bias, torch.Tensor)) + + scaled_bias = torch.mul(bias, next_equalization_scale) + modules[node.target].bias = nn.Parameter(scaled_bias) + +def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module]) -> Dict[str, _WeightEqualizationObserver]: + """ Update all of the observer's equalization scale. For each + InputEqualizationObserver, we will find the location of the next + WeightEqualizationObserver, create it, and calculate the equalization scale + based on the two observers. + + We will then return a dictionary mapping operation node names to + the corresponding WeightEqualizationObservers for that operation. + """ + weight_eq_obs_dict = {} + for node in model.graph.nodes: + if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver): + input_eq_obs = modules[node.target] + assert(isinstance(input_eq_obs, _InputEqualizationObserver)) + op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules) + + if op_node is None or weight_eq_obs is None: + continue + + weight_eq_obs(modules[str(op_node.target)].weight) + + # Calculate and set the equalization scale values + equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs) + input_eq_obs.set_equalization_scale(equalization_scale) + weight_eq_obs.set_equalization_scale(equalization_scale) + + weight_eq_obs_dict[op_node.name] = weight_eq_obs + + return weight_eq_obs_dict + +def convert_eq_obs( + model: GraphModule, + modules: Dict[str, nn.Module], + weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver], +) -> None: + """ Converts the equalization operations and updates the other nodes in the + following way: + - Removes the input equalization observers and inserts a mul operator + along with an equalization scale node wherever applicable (we do not + want to insert a mul operator between connecting linear layers). + - Updates the input quantization observers with the scaled input min/max + values. + - Scales the weights by the current and next equalization scales. + - Removes the weight equalization observer node if it exists. + + Before (after prepare): + weight values + | + WeightQuantObs + | + WeightEqObs + | + x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs + + After this function: + scaled weight values + | + equalization scale WeightQuantObs + | | + x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs + + After convert: + equalization scale scaled weight values + | | + x -> mul -> quantize_per_tensor -> quantized::linear + + Note that although the equalization observer appeared after the quantization + observer after prepare_fx, the mul node appears before the quantization node + after convert_fx. This is because placing the equalization observer after + the quantization observer in prepare_fx would allow us to keep the invariant + that the graph before the current node inserts its observers is not + modified. + + Having the equalization observer before the quantization observer would also + cause some inconsistences between the ordering of the quantization and + equalization observers. + For example, a single linear layer would look like: + x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1 + But between two connected linear layers, it would look like: + linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2 + """ + for node in model.graph.nodes: + if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver): + inp_quant_obs_node = node.args[0] + prev_node = inp_quant_obs_node.args[0] + + # TODO: Possible special handling for connected linear layers + # Update the following input quantization observer's min/max values + scale_input_observer(node, modules) + + # Remove the InputEqualization node and add a mul operator before + # the quantization observer node that appears before the equalization node + # Before: x -> input_quant_obs -> input_eq_obs -> linear + # After: x -> mul -> input_quant_obs -> linear + + # Create a node containing the equalization scale + with model.graph.inserting_before(inp_quant_obs_node): + get_new_eq_scale_name = get_new_attr_name_with_prefix(prev_node.name + '_equalization_scale') + name = get_new_eq_scale_name(modules) + setattr(model, name, modules[node.target].equalization_scale) + eq_scale_node = model.graph.create_node('get_attr', name) + + # Create a node multiplying the input with the equalization scale + with model.graph.inserting_after(eq_scale_node): + inputs = (prev_node, eq_scale_node) + mul_node = model.graph.create_node("call_function", torch.mul, inputs) + + # Set the mul nod to be the input_quant_obs_node's input instead of + # the previous node + inp_quant_obs_node.replace_input_with(prev_node, mul_node) + + # For all of the current node's users, replace the current node with + # the input quantization observer node + orig_users = list(node.users.keys()) + for user_node in orig_users: + user_node.replace_input_with(node, inp_quant_obs_node) + + # Erase the InputEqualizationObserver node + model.graph.erase_node(node) + + elif weight_eq_obs_dict.get(node.name, None) is not None: + weight_eq_obs = weight_eq_obs_dict.get(node.name) + assert(isinstance(weight_eq_obs, _WeightEqualizationObserver)) + + equalization_scale = weight_eq_obs.equalization_scale + + # Scales the weights and runs the weight quantization observers + maybe_next_equalization_scale = maybe_get_next_equalization_scale(node, modules) + scale_weight_node(node, modules, equalization_scale, maybe_next_equalization_scale) + +def _convert_equalization_ref(model: GraphModule): + """ Reference function which applies changes needed for equalization, but + does not quantize the nodes + """ + modules = dict(model.named_modules(remove_duplicate=False)) + + # Calculate the equalization scale, update the observers with the scaled + # inputs, and scale the weight + weight_eq_obs_dict = update_obs_for_equalization(model, modules) + convert_eq_obs(model, modules, weight_eq_obs_dict) + + return GraphModule(model, model.graph) diff --git a/torch/quantization/fx/convert.py b/torch/quantization/fx/convert.py index 641e738e807..e911471be8a 100644 --- a/torch/quantization/fx/convert.py +++ b/torch/quantization/fx/convert.py @@ -24,6 +24,7 @@ from .graph_module import ( from .quantization_patterns import ( QuantizeHandler, ) +from ._equalize import update_obs_for_equalization, convert_eq_obs from .utils import ( is_get_tensor_info_node, node_return_type_is_int, @@ -180,6 +181,13 @@ def convert(model: GraphModule, is_reference: bool = False, qconfig_map, custom_module_classes=custom_module_classes) + if model._equalization_qconfig_map is not None: + # If we want to do equalization then do the following: + # Calculate the equalization scale, update the observers with the scaled + # inputs, and scale the weight + weight_eq_obs_dict = update_obs_for_equalization(model, modules) + convert_eq_obs(model, modules, weight_eq_obs_dict) + quantized_graph = Graph() env: Dict[str, Dict[Optional[torch.dtype], Node]] = defaultdict(lambda: defaultdict(Node)) # type: ignore[arg-type] diff --git a/torch/quantization/fx/utils.py b/torch/quantization/fx/utils.py index 76b481f8b38..9ce36bacad5 100644 --- a/torch/quantization/fx/utils.py +++ b/torch/quantization/fx/utils.py @@ -1,5 +1,6 @@ import re import torch +import torch.nn as nn from ..utils import is_per_tensor, is_per_channel from ..quantize import is_activation_post_process @@ -10,7 +11,7 @@ from torch.fx.graph import ( Node, ) -from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union +from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type import operator # A dictionary for querying the weight index for a given op @@ -481,3 +482,23 @@ def is_get_tensor_info_node(node: Node) -> bool: result: bool = \ node.op == "call_function" and node.target == getattr and node.args[1] == "shape" # type: ignore[assignment] return result + +def maybe_get_next_module( + node: Node, + modules: Dict[str, nn.Module], + target_module_type: Type[nn.Module], +) -> Optional[Node]: + """ Gets the next module that matches what is needed in + is_target_module_type if it exists + + Args: + node: The node whose users we want to look at + is_target_module_type: Function that returns true if the given module + matches the type specified in the function. + """ + + for user, _ in node.users.items(): + if user.op == 'call_module' and isinstance(modules[str(user.target)], target_module_type): + return user + + return None