pytorch/torch/quantization/_numeric_suite.py
Haixin Liu 455d4aab64 [PyTorch Numeric Suite] Add weight compare API (#36186)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36186

Start PyTorch Numeric Suite under PyTorch quantization and add weight compare API to it.
ghstack-source-id: 102062165

Test Plan: buck test mode/dev caffe2/test:quantization -- 'test_compare_weights'

Differential Revision: D20903395

fbshipit-source-id: 125d84569837142626a0e2119b3b7657a32dbf4e
2020-04-13 19:02:00 -07:00

42 lines
1.5 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
def _find_match(str_list, key_str, postfix):
split_str = key_str.split(".")
if split_str[-1] == postfix:
match_string = "".join(key_str.split(".")[0:-1])
for s2 in str_list:
pattern1 = "".join(s2.split(".")[0:-1])
pattern2 = "".join(s2.split(".")[0:-2])
if match_string == pattern1:
return s2
if match_string == pattern2:
return s2
else:
return None
def compare_weights(float_dict, quantized_dict):
r"""Returns a dict with key corresponding to module names and each entry being
a dictionary with two keys 'float' and 'quantized', containing the float and
quantized weights. This dict can be used to compare and compute the quantization
error of the weights of float and quantized models .
Args:
float_dict: state dict of the float model
quantized_dict: state dict of the quantized model
Return:
weight_dict: dict with key corresponding to module names and each entry being
a dictionary with two keys 'float' and 'quantized', containing the float and
quantized weights
"""
weight_dict = {}
for key in quantized_dict:
match_key = _find_match(float_dict, key, "weight")
if match_key is not None:
weight_dict[key] = {}
weight_dict[key]["float"] = float_dict[match_key]
weight_dict[key]["quantized"] = quantized_dict[key]
return weight_dict