diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index a5ed2fdb4c6..a30bdfb710d 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -6640,6 +6640,101 @@ class TestQuantizeFx(QuantizationTestCase): } self.checkGraphModuleNodes(quantized_model, expected_node_occurrence=node_occurrence) + @skipIfNoFBGEMM + def test_keep_original_weights(self): + class SubModule(nn.Module): + """ + A simple submodule containing a linear layer. + """ + + def __init__(self, input_dim, output_dim): + super(__class__, self).__init__() + self.w = nn.Parameter(torch.randn(input_dim, output_dim)) + self.b = nn.Parameter(torch.randn(input_dim)) + + def forward(self, x): + return F.linear(x, self.w, self.b) + + class MainModule(nn.Module): + """ + The main module containing the submodule. + """ + + def __init__(self, input_dim, hidden_dim, output_dim): + super(__class__, self).__init__() + self.submodule_1 = SubModule(hidden_dim, input_dim) + setattr(self, 'submodule|2', SubModule(hidden_dim, hidden_dim)) + setattr(self, 'submodule/3', SubModule(hidden_dim, hidden_dim)) + setattr(self, 'submodule:4', SubModule(hidden_dim, hidden_dim)) + self._w = nn.Parameter(torch.randn(output_dim, hidden_dim)) + + def forward(self, x): + x1 = self.submodule_1(x) + x2 = getattr(self, 'submodule|2')(x1) + x3 = getattr(self, 'submodule/3')(x2) + x4 = getattr(self, 'submodule:4')(x3) + x5 = F.linear(x4, self._w) + return x5 + + input_dim = 10 + hidden_dim = 20 + output_dim = 5 + model = MainModule(input_dim, hidden_dim, output_dim) + model.eval() + example_inputs = torch.randn(1, input_dim) + _ = model(*example_inputs) + qconfig_mapping = QConfigMapping().set_object_type(nn.functional.linear, float16_dynamic_qconfig) + prepared_model = prepare_fx(model, qconfig_mapping, example_inputs) + prepared_model(example_inputs) + quantized_model = convert_fx(prepared_model, keep_original_weights=True) + + self.assertTrue(len(quantized_model.original_weights_lookup) == 5) + self.assertTrue("submodule_1_packed_weight_0" in quantized_model.original_weights_lookup) + torch.testing.assert_close( + quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][0], + model.submodule_1.w + ) + torch.testing.assert_close( + quantized_model.original_weights_lookup["submodule_1_packed_weight_0"][1], + model.submodule_1.b + ) + self.assertTrue("submodule_2_packed_weight_0" in quantized_model.original_weights_lookup) + torch.testing.assert_close( + quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][0], + getattr(model, "submodule|2").w + ) + torch.testing.assert_close( + quantized_model.original_weights_lookup["submodule_2_packed_weight_0"][1], + getattr(model, "submodule|2").b + ) + self.assertTrue("submodule_3_packed_weight_0" in quantized_model.original_weights_lookup) + torch.testing.assert_close( + quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][0], + getattr(model, "submodule/3").w + ) + torch.testing.assert_close( + quantized_model.original_weights_lookup["submodule_3_packed_weight_0"][1], + getattr(model, "submodule/3").b + ) + self.assertTrue("submodule_4_packed_weight_0" in quantized_model.original_weights_lookup) + torch.testing.assert_close( + quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][0], + getattr(model, "submodule:4").w + ) + torch.testing.assert_close( + quantized_model.original_weights_lookup["submodule_4_packed_weight_0"][1], + getattr(model, "submodule:4").b + ) + self.assertTrue("_packed_weight_0" in quantized_model.original_weights_lookup) + torch.testing.assert_close( + quantized_model.original_weights_lookup["_packed_weight_0"][0], + model._w + ) + torch.testing.assert_close( + quantized_model.original_weights_lookup["_packed_weight_0"][1], + None + ) + @skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): def setUp(self): diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index c38f6146d77..5161f9691b0 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -443,7 +443,9 @@ def _load_packed_weight( def fold_weight( - quantized_model: GraphModule, node_name_to_scope: Dict[str, Tuple[str, type]] + quantized_model: GraphModule, + node_name_to_scope: Dict[str, Tuple[str, type]], + keep_original_weights: bool = False, ) -> GraphModule: """ Trace back from the weight node util we hit getattr, reconstruct the @@ -453,6 +455,8 @@ def fold_weight( packed_weights = {} # map from folded node name to the prepacked weight name folded_nodes = {} + original_weights_lookup: Dict[str, List] = {} + lookup_counter = 0 # get packed weights for node in quantized_model.graph.nodes: if node.op == "call_function" and node.target in WEIGHT_PREPACK_OPS: @@ -466,6 +470,16 @@ def fold_weight( ) packed_weight = prepacking_module() packed_weights[node.name] = packed_weight + if keep_original_weights: + original_weights = list(prepacking_module.state_dict().values()) + original_weights_lookup[str(lookup_counter)] = sorted( + original_weights, key=lambda x: x.numel(), reverse=True + ) + if len(original_weights_lookup[str(lookup_counter)]) == 1: + # bias is None + original_weights_lookup[str(lookup_counter)].append(None) + lookup_counter += 1 + lookup_counter = 0 # remove folded nodes and replace the prepacking node with getattr folded_graph = Graph() @@ -490,6 +504,18 @@ def fold_weight( env[node.name] = folded_graph.create_node( "get_attr", packed_weight_name, (), {} ) + if keep_original_weights: + key_name = ( + packed_weight_name.replace(":", "_") + .replace("/", "_") + .replace("|", "_") + .lower() + ) + original_weights_lookup[key_name] = original_weights_lookup[ + str(lookup_counter) + ] + del original_weights_lookup[str(lookup_counter)] + lookup_counter += 1 elif prepack_node is not None: # remove the foled node continue @@ -500,6 +526,12 @@ def fold_weight( quantized_model = GraphModule(quantized_model, folded_graph) quantized_model._register_state_dict_hook(_save_packed_weight) quantized_model.register_load_state_dict_pre_hook(_load_packed_weight) + + if keep_original_weights: + setattr( # noqa: B010 + quantized_model, "original_weights_lookup", original_weights_lookup + ) + return quantized_model @@ -1296,6 +1328,7 @@ def _lower_to_native_backend( model: GraphModule, qconfig_map: Dict[str, QConfigAny], node_name_to_scope: Dict[str, Tuple[str, type]], + keep_original_weights: bool = False, ) -> GraphModule: """Lower a quantized reference model (with reference quantized operator patterns) to the native backend in PyTorch (fbgemm/qnnpack), both backends shares the same @@ -1312,7 +1345,7 @@ def _lower_to_native_backend( _lower_get_tensor_info_op(model) special_pattern_replacement(model) model.graph.eliminate_dead_code() - model = fold_weight(model, node_name_to_scope) + model = fold_weight(model, node_name_to_scope, keep_original_weights) model.graph.eliminate_dead_code() model.recompile() model.graph.lint() diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 65616e940ad..5f2a3eb17e2 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -992,6 +992,7 @@ def convert( qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, backend_config: Union[BackendConfig, Dict[str, Any], None] = None, is_decomposed: bool = False, + keep_original_weights: bool = False, ) -> GraphModule: """ We will convert an observed model (a module with observer calls) to a reference @@ -1243,7 +1244,9 @@ def convert( # TODO: maybe move this to quantize_fx.py if not is_reference: - model = lower_to_fbgemm(model, node_name_to_qconfig, node_name_to_scope) + model = lower_to_fbgemm( + model, node_name_to_qconfig, node_name_to_scope, keep_original_weights + ) # TODO: this looks hacky, we want to check why we need this and see if we can # remove this diff --git a/torch/ao/quantization/fx/lower_to_fbgemm.py b/torch/ao/quantization/fx/lower_to_fbgemm.py index b40d5cb4cca..4bd24136e16 100644 --- a/torch/ao/quantization/fx/lower_to_fbgemm.py +++ b/torch/ao/quantization/fx/lower_to_fbgemm.py @@ -13,8 +13,11 @@ def lower_to_fbgemm( model: GraphModule, qconfig_map: Dict[str, QConfigAny], node_name_to_scope: Dict[str, Tuple[str, type]], + keep_original_weights: bool = False, ) -> GraphModule: """Lower a quantized reference model (with reference quantized operator patterns) to fbgemm """ - return _lower_to_native_backend(model, qconfig_map, node_name_to_scope) + return _lower_to_native_backend( + model, qconfig_map, node_name_to_scope, keep_original_weights + ) diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index dd8f3e811a3..072dcf41ca7 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -515,6 +515,7 @@ def _convert_fx( qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, backend_config: Union[BackendConfig, Dict[str, Any], None] = None, is_decomposed: bool = False, + keep_original_weights: bool = False, ) -> GraphModule: """`is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`""" if convert_custom_config is None: @@ -546,6 +547,7 @@ def _convert_fx( qconfig_mapping=qconfig_mapping, backend_config=backend_config, is_decomposed=is_decomposed, + keep_original_weights=keep_original_weights, ) attach_preserved_attrs_to_model(quantized, preserved_attrs) @@ -558,6 +560,7 @@ def convert_fx( _remove_qconfig: bool = True, qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, backend_config: Union[BackendConfig, Dict[str, Any], None] = None, + keep_original_weights: bool = False, ) -> GraphModule: r"""Convert a calibrated or trained model to a quantized model @@ -616,6 +619,7 @@ def convert_fx( _remove_qconfig=_remove_qconfig, qconfig_mapping=qconfig_mapping, backend_config=backend_config, + keep_original_weights=keep_original_weights, )