pytorch/torch/quantization/_numeric_suite.py
Alban Desmaison 35b9c89dc1 Revert D21045393: [PyTorch Numeric Suite] Add module level comparison
Test Plan: revert-hammer

Differential Revision:
D21045393

Original commit changeset: 4303805f732c

fbshipit-source-id: 06d8a234eda800eb14bc3aa58ff14b0d3cf86d86
2020-04-24 07:03:04 -07:00

42 lines
1.5 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
def _find_match(str_list, key_str, postfix):
split_str = key_str.split(".")
if split_str[-1] == postfix:
match_string = "".join(key_str.split(".")[0:-1])
for s2 in str_list:
pattern1 = "".join(s2.split(".")[0:-1])
pattern2 = "".join(s2.split(".")[0:-2])
if match_string == pattern1:
return s2
if match_string == pattern2:
return s2
else:
return None
def compare_weights(float_dict, quantized_dict):
r"""Returns a dict with key corresponding to module names and each entry being
a dictionary with two keys 'float' and 'quantized', containing the float and
quantized weights. This dict can be used to compare and compute the quantization
error of the weights of float and quantized models .
Args:
float_dict: state dict of the float model
quantized_dict: state dict of the quantized model
Return:
weight_dict: dict with key corresponding to module names and each entry being
a dictionary with two keys 'float' and 'quantized', containing the float and
quantized weights
"""
weight_dict = {}
for key in quantized_dict:
match_key = _find_match(float_dict, key, "weight")
if match_key is not None:
weight_dict[key] = {}
weight_dict[key]["float"] = float_dict[match_key]
weight_dict[key]["quantized"] = quantized_dict[key]
return weight_dict