ns for fx: fix bug for user function in weight extraction (#62333)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62333

We incorrectly ignored any custom relationships the user specified
in the `extract_weights` API.  Fixing this and adding a test case.

Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_user_defined_function
```

Imported from OSS

Reviewed By: hx89

Differential Revision: D29963502

fbshipit-source-id: 33ce3d4df1acb6298b6c7dcb6674015c8d14bdf4
This commit is contained in:
Vasiliy Kuznetsov 2021-07-28 16:02:51 -07:00 committed by Facebook GitHub Bot
parent d98b1c400d
commit 72c943a2ac
2 changed files with 4 additions and 2 deletions

View File

@ -1631,7 +1631,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
torch.quantization.ns.weight_utils.get_linear_fun_weight
# test compare weights
results = _extract_weights_impl(
results = extract_weights(
'a', m1, 'b', m2,
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
op_to_type_to_weight_extraction_fn=op_to_type_to_weight_extraction_fn)

View File

@ -191,7 +191,9 @@ def extract_weights(
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
) -> NSResultsType:
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
if base_name_to_sets_of_related_ops is None:
base_name_to_sets_of_related_ops = \
get_base_name_to_sets_of_related_ops()
type_a_related_to_b = \
get_type_a_related_to_b(base_name_to_sets_of_related_ops)