fix type check for torch.quantization._numeric_suite (#46330)

Summary:
fix https://github.com/pytorch/pytorch/issues/42977

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46330

Reviewed By: malfet

Differential Revision: D24320449

Pulled By: walterddr

fbshipit-source-id: f892b5c83cb932aee53245d6c825568b3e05f3c6
This commit is contained in:
Rong Rong 2020-10-15 20:41:22 -07:00 committed by Facebook GitHub Bot
parent 92921c82bb
commit d1745c36dc
3 changed files with 7 additions and 8 deletions

View File

@ -71,10 +71,6 @@ ignore_errors = True
[mypy-torch.quantization.stubs] [mypy-torch.quantization.stubs]
ignore_errors = True ignore_errors = True
[mypy-torch.quantization._numeric_suite]
ignore_errors = True
[mypy-torch.quantization.quantize_fx] [mypy-torch.quantization.quantize_fx]
ignore_errors = True ignore_errors = True

View File

@ -4,6 +4,7 @@ import torch.nn as nn
import torch.nn.quantized as nnq import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd import torch.nn.quantized.dynamic as nnqd
from torch.quantization import prepare from torch.quantization import prepare
from typing import Dict
from .quantization_mappings import ( from .quantization_mappings import (
get_compare_output_module_list, get_compare_output_module_list,
@ -66,7 +67,7 @@ def compare_weights(float_dict, quantized_dict):
a dictionary with two keys 'float' and 'quantized', containing the float and a dictionary with two keys 'float' and 'quantized', containing the float and
quantized weights quantized weights
""" """
weight_dict = {} weight_dict: Dict[str, Dict] = {}
for key in quantized_dict: for key in quantized_dict:
match_key = _find_match(float_dict, key, "weight") match_key = _find_match(float_dict, key, "weight")
if match_key is not None: if match_key is not None:
@ -142,7 +143,7 @@ def get_logger_dict(mod, prefix=""):
target_dict: the dictionary used to save all logger stats target_dict: the dictionary used to save all logger stats
""" """
target_dict = {} target_dict: Dict[str, Dict] = {}
_get_logger_dict_helper(mod, target_dict, prefix) _get_logger_dict_helper(mod, target_dict, prefix)
return target_dict return target_dict
@ -379,7 +380,7 @@ def get_matching_activations(float_module, q_module):
""" """
float_dict = get_logger_dict(float_module) float_dict = get_logger_dict(float_module)
quantized_dict = get_logger_dict(q_module) quantized_dict = get_logger_dict(q_module)
act_dict = {} act_dict: Dict[str, Dict] = {}
for key in quantized_dict: for key in quantized_dict:
match_key = _find_match(sorted(float_dict, reverse=True), key, "stats") match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
if match_key is not None: if match_key is not None:

View File

@ -1037,9 +1037,11 @@ class TestCase(expecttest.TestCase):
rtol, atol = self._getDefaultRtolAndAtol(torch.float32, torch.float32) rtol, atol = self._getDefaultRtolAndAtol(torch.float32, torch.float32)
else: else:
rtol, atol = 0, 0 rtol, atol = 0, 0
rtol = cast(float, rtol)
atol = cast(float, atol)
atol = max(atol, self.precision) atol = max(atol, self.precision)
return _compare_scalars_internal(a, b, rtol=cast(float, rtol), atol=cast(float, atol), equal_nan=equal_nan) return _compare_scalars_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
def assertEqualIgnoreType(self, *args, **kwargs) -> None: def assertEqualIgnoreType(self, *args, **kwargs) -> None:
# If you are seeing this function used, that means test is written wrongly # If you are seeing this function used, that means test is written wrongly