mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
ns for fx: fix bug for user function in weight extraction (#62333)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62333 We incorrectly ignored any custom relationships the user specified in the `extract_weights` API. Fixing this and adding a test case. Test Plan: ``` python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_user_defined_function ``` Imported from OSS Reviewed By: hx89 Differential Revision: D29963502 fbshipit-source-id: 33ce3d4df1acb6298b6c7dcb6674015c8d14bdf4
This commit is contained in:
parent
d98b1c400d
commit
72c943a2ac
|
|
@ -1631,7 +1631,7 @@ class TestFXNumericSuiteCoreAPIs(FXNumericSuiteQuantizationTestCase):
|
|||
torch.quantization.ns.weight_utils.get_linear_fun_weight
|
||||
|
||||
# test compare weights
|
||||
results = _extract_weights_impl(
|
||||
results = extract_weights(
|
||||
'a', m1, 'b', m2,
|
||||
base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
|
||||
op_to_type_to_weight_extraction_fn=op_to_type_to_weight_extraction_fn)
|
||||
|
|
|
|||
|
|
@ -191,7 +191,9 @@ def extract_weights(
|
|||
op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
|
||||
) -> NSResultsType:
|
||||
torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
|
||||
base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
|
||||
if base_name_to_sets_of_related_ops is None:
|
||||
base_name_to_sets_of_related_ops = \
|
||||
get_base_name_to_sets_of_related_ops()
|
||||
type_a_related_to_b = \
|
||||
get_type_a_related_to_b(base_name_to_sets_of_related_ops)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user