mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59747 Modifies prepare_fx for input-weight equalization. If a current node is being equalized (there exists a EqualizationQConfig), then the EqualizationObserver will be inserted before its quantization observer. For a singular linear layer, the general flow looks like: Original graph: `x0 -> linear -> x1`, `w -> linear` After prepare: `x0 -> InpEqObs -> MinMaxObs -> linear1 -> MinMaxObs -> x1` `w -> WeightEqObs -> MinMaxObs -> linear1` For two connected linear layers, the general flow looks like: Original graph: `x0 -> linear1 -> linear2 -> x1`, `w1 -> linear1`, `w2 -> linear2` After prepare: `x0 -> InpEqObs -> MinMaxObs -> linear1 -> MinMaxObs -> InpEqObs -> linear2 -> MinMaxObs -> x1` `w1 -> WeightEqObs -> MinMaxObs -> linear1`, `w2 -> WeightEqObs -> MinMaxObs -> linear2 Test Plan: `python test/test_quantization.py TestEqualizeFx.test_input_equalization_prepare` Original model with one `nn.Linear` layer ``` LinearModule( (linear): Linear(in_features=1, out_features=1, bias=True) ) ``` Graph after `prepare_fx`: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_equalization_process_0 : [#users=1] = call_module[target=x_equalization_process_0](args = (%x,), kwargs = {}) %x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_00](args = (%x_equalization_process_0,), kwargs = {}) %linear : [#users=1] = call_module[target=linear](args = (%x_activation_post_process_0,), kwargs = {}) %linear_activation_post_process_0 : [#users=1] = call_module[target=linear_activation_post_process_0](args = (%linear,), kwargs = {}) return linear_activation_post_process_0 ``` -------------------------------------- Original model with two connected functional linear layers ``` FunctionalLinearModule( (linear1): Linear() (linear2): Linear() ) ``` Graph after `prepare_fx`: ``` graph(): %x : [#users=1] = placeholder[target=x] %x_equalization_process_0 : [#users=1] = call_module[target=x_equalization_process_0](args = (%x,), kwargs = {}) %x_activation_post_process_0 : [#users=1] = call_module[target=x_activation_post_process_00](args = (%x_equalization_process_0,), kwargs = {}) %linear1_w : [#users=1] = get_attr[target=linear1.w] %linear1_w_equalization_process_0 : [#users=1] = call_module[target=linear1_w_equalization_process_0](args = (%linear1_w,), kwargs = {}) %linear1_w_activation_post_process_0 : [#users=1] = call_module[target=linear1_w_activation_post_process_00](args = (%linear1_w_equalization_process_0,), kwargs = {}) %linear1_b : [#users=1] = get_attr[target=linear1.b] %linear : [#users=1] = call_function[target=torch.nn.functional.linear](args = (%x_activation_post_process_0, %linear1_w_activation_post_process_0), kwargs = {bias: %linear1_b}) %linear_activation_post_process_0 : [#users=1] = call_module[target=linear_activation_post_process_0](args = (%linear,), kwargs = {}) %linear_activation_post_process_0_equalization_process_0 : [#users=1] = call_module[target=linear_activation_post_process_0_equalization_process_0](args = (%linear_activation_post_process_0,), kwargs = {}) %linear2_w : [#users=1] = get_attr[target=linear2.w] %linear2_w_equalization_process_0 : [#users=1] = call_module[target=linear2_w_equalization_process_0](args = (%linear2_w,), kwargs = {}) %linear2_w_activation_post_process_0 : [#users=1] = call_module[target=linear2_w_activation_post_process_00](args = (%linear2_w_equalization_process_0,), kwargs = {}) %linear2_b : [#users=1] = get_attr[target=linear2.b] %linear_1 : [#users=1] = call_function[target=torch.nn.functional.linear](args = (%linear_activation_post_process_0_equalization_process_0, %linear2_w_activation_post_process_0), kwargs = {bias: %linear2_b}) %linear_1_activation_post_process_0 : [#users=1] = call_module[target=linear_1_activation_post_process_0](args = (%linear_1,), kwargs = {}) return linear_1_activation_post_process_0 ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D29135316 fbshipit-source-id: 91697e805ede254dbb2a42ee4c23eb1c1c64590e
106 lines
5.0 KiB
Python
106 lines
5.0 KiB
Python
import torch
|
|
import copy
|
|
from torch.fx import GraphModule
|
|
from torch.fx.graph import Graph
|
|
from typing import Union, Dict, Any, Set
|
|
|
|
class FusedGraphModule(GraphModule):
|
|
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
|
|
self.preserved_attr_names = preserved_attr_names
|
|
preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
|
|
super().__init__(root, graph)
|
|
for attr in preserved_attrs:
|
|
setattr(self, attr, preserved_attrs[attr])
|
|
|
|
# GraphModule does not copy attributes which are not in the __dict__
|
|
# of vanilla nn.Module. So, we override __deepcopy__ in order
|
|
# to copy the quantization specific attributes correctly.
|
|
def __deepcopy__(self, memo):
|
|
fake_mod = torch.nn.Module()
|
|
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
|
|
return FusedGraphModule(fake_mod, self.graph, self.preserved_attr_names)
|
|
|
|
class ObservedGraphModule(GraphModule):
|
|
|
|
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
|
|
self.preserved_attr_names = set([
|
|
'_activation_post_process_map',
|
|
'_activation_post_process_indexes',
|
|
'_patterns',
|
|
'_qconfig_map',
|
|
'_prepare_custom_config_dict',
|
|
'_equalization_qconfig_map',
|
|
'_node_name_to_scope']).union(preserved_attr_names)
|
|
preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
|
|
super().__init__(root, graph)
|
|
for attr in preserved_attrs:
|
|
setattr(self, attr, preserved_attrs[attr])
|
|
|
|
# GraphModule does not copy attributes which are not in the __dict__
|
|
# of vanilla nn.Module. So, we override __deepcopy__ in order
|
|
# to copy the quantization specific attributes correctly.
|
|
def __deepcopy__(self, memo):
|
|
fake_mod = torch.nn.Module()
|
|
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
|
|
return ObservedGraphModule(fake_mod, self.graph, self.preserved_attr_names)
|
|
|
|
def is_observed_module(module: Any) -> bool:
|
|
return isinstance(module, ObservedGraphModule)
|
|
|
|
class ObservedStandaloneGraphModule(ObservedGraphModule):
|
|
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
|
|
preserved_attr_names = preserved_attr_names.union(set([
|
|
"_standalone_module_input_quantized_idxs",
|
|
"_standalone_module_output_quantized_idxs"]))
|
|
super().__init__(root, graph, preserved_attr_names)
|
|
|
|
def __deepcopy__(self, memo):
|
|
fake_mod = torch.nn.Module()
|
|
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
|
|
return ObservedStandaloneGraphModule(fake_mod, self.graph, self.preserved_attr_names)
|
|
|
|
def is_observed_standalone_module(module: Any) -> bool:
|
|
return isinstance(module, ObservedStandaloneGraphModule)
|
|
|
|
|
|
def _save_packed_weight(self, destination, prefix, keep_vars):
|
|
for attr_name in dir(self):
|
|
if "_packed_weight" in attr_name and \
|
|
isinstance(getattr(self, attr_name), torch._C.ScriptObject): # type: ignore[attr-defined]
|
|
packed_weight = getattr(self, attr_name)
|
|
destination[prefix + attr_name] = packed_weight
|
|
|
|
class QuantizedGraphModule(GraphModule):
|
|
""" This class is created to make sure PackedParams
|
|
(e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict
|
|
so that we can serialize and deserialize quantized graph module with
|
|
torch.save(m.state_dict()) and m.load_state_dict(state_dict)
|
|
"""
|
|
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
|
|
self.preserved_attr_names = preserved_attr_names
|
|
preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
|
|
super().__init__(root, graph)
|
|
for attr in preserved_attrs:
|
|
setattr(self, attr, preserved_attrs[attr])
|
|
self._register_state_dict_hook(_save_packed_weight)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
attrs_to_pop = []
|
|
for attr_name in state_dict:
|
|
if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject): # type: ignore[attr-defined] # noqa: B950
|
|
setattr(self, attr_name, state_dict[attr_name])
|
|
attrs_to_pop.append(attr_name)
|
|
|
|
# pop the packed param attributesn
|
|
for attr_name in attrs_to_pop:
|
|
state_dict.pop(attr_name)
|
|
|
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
|
|
|
|
def __deepcopy__(self, memo):
|
|
fake_mod = torch.nn.Module()
|
|
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
|
|
return ObservedStandaloneGraphModule(fake_mod, self.graph, self.preserved_attr_names)
|