fix type check for torch.quantization._numeric_suite (#46330)

Summary:
fix https://github.com/pytorch/pytorch/issues/42977

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46330

Reviewed By: malfet

Differential Revision: D24320449

Pulled By: walterddr

fbshipit-source-id: f892b5c83cb932aee53245d6c825568b3e05f3c6
This commit is contained in:
Rong Rong 2020-10-15 20:41:22 -07:00 committed by Facebook GitHub Bot
parent 92921c82bb
commit d1745c36dc
3 changed files with 7 additions and 8 deletions

View File

@ -71,10 +71,6 @@ ignore_errors = True
[mypy-torch.quantization.stubs]
ignore_errors = True
[mypy-torch.quantization._numeric_suite]
ignore_errors = True
[mypy-torch.quantization.quantize_fx]
ignore_errors = True

View File

@ -4,6 +4,7 @@ import torch.nn as nn
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
from torch.quantization import prepare
from typing import Dict
from .quantization_mappings import (
get_compare_output_module_list,
@ -66,7 +67,7 @@ def compare_weights(float_dict, quantized_dict):
a dictionary with two keys 'float' and 'quantized', containing the float and
quantized weights
"""
weight_dict = {}
weight_dict: Dict[str, Dict] = {}
for key in quantized_dict:
match_key = _find_match(float_dict, key, "weight")
if match_key is not None:
@ -142,7 +143,7 @@ def get_logger_dict(mod, prefix=""):
target_dict: the dictionary used to save all logger stats
"""
target_dict = {}
target_dict: Dict[str, Dict] = {}
_get_logger_dict_helper(mod, target_dict, prefix)
return target_dict
@ -379,7 +380,7 @@ def get_matching_activations(float_module, q_module):
"""
float_dict = get_logger_dict(float_module)
quantized_dict = get_logger_dict(q_module)
act_dict = {}
act_dict: Dict[str, Dict] = {}
for key in quantized_dict:
match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
if match_key is not None:

View File

@ -1037,9 +1037,11 @@ class TestCase(expecttest.TestCase):
rtol, atol = self._getDefaultRtolAndAtol(torch.float32, torch.float32)
else:
rtol, atol = 0, 0
rtol = cast(float, rtol)
atol = cast(float, atol)
atol = max(atol, self.precision)
return _compare_scalars_internal(a, b, rtol=cast(float, rtol), atol=cast(float, atol), equal_nan=equal_nan)
return _compare_scalars_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
def assertEqualIgnoreType(self, *args, **kwargs) -> None:
# If you are seeing this function used, that means test is written wrongly