from __future__ import absolute_import, division, print_function, unicode_literals def _find_match(str_list, key_str, postfix): split_str = key_str.split(".") if split_str[-1] == postfix: match_string = "".join(key_str.split(".")[0:-1]) for s2 in str_list: pattern1 = "".join(s2.split(".")[0:-1]) pattern2 = "".join(s2.split(".")[0:-2]) if match_string == pattern1: return s2 if match_string == pattern2: return s2 else: return None def compare_weights(float_dict, quantized_dict): r"""Returns a dict with key corresponding to module names and each entry being a dictionary with two keys 'float' and 'quantized', containing the float and quantized weights. This dict can be used to compare and compute the quantization error of the weights of float and quantized models . Args: float_dict: state dict of the float model quantized_dict: state dict of the quantized model Return: weight_dict: dict with key corresponding to module names and each entry being a dictionary with two keys 'float' and 'quantized', containing the float and quantized weights """ weight_dict = {} for key in quantized_dict: match_key = _find_match(float_dict, key, "weight") if match_key is not None: weight_dict[key] = {} weight_dict[key]["float"] = float_dict[match_key] weight_dict[key]["quantized"] = quantized_dict[key] return weight_dict