mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix type check for torch.quantization._numeric_suite (#46330)
Summary: fix https://github.com/pytorch/pytorch/issues/42977 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46330 Reviewed By: malfet Differential Revision: D24320449 Pulled By: walterddr fbshipit-source-id: f892b5c83cb932aee53245d6c825568b3e05f3c6
This commit is contained in:
parent
92921c82bb
commit
d1745c36dc
4
mypy.ini
4
mypy.ini
|
|
@ -71,10 +71,6 @@ ignore_errors = True
|
|||
[mypy-torch.quantization.stubs]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.quantization._numeric_suite]
|
||||
ignore_errors = True
|
||||
|
||||
|
||||
[mypy-torch.quantization.quantize_fx]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import torch.nn as nn
|
|||
import torch.nn.quantized as nnq
|
||||
import torch.nn.quantized.dynamic as nnqd
|
||||
from torch.quantization import prepare
|
||||
from typing import Dict
|
||||
|
||||
from .quantization_mappings import (
|
||||
get_compare_output_module_list,
|
||||
|
|
@ -66,7 +67,7 @@ def compare_weights(float_dict, quantized_dict):
|
|||
a dictionary with two keys 'float' and 'quantized', containing the float and
|
||||
quantized weights
|
||||
"""
|
||||
weight_dict = {}
|
||||
weight_dict: Dict[str, Dict] = {}
|
||||
for key in quantized_dict:
|
||||
match_key = _find_match(float_dict, key, "weight")
|
||||
if match_key is not None:
|
||||
|
|
@ -142,7 +143,7 @@ def get_logger_dict(mod, prefix=""):
|
|||
target_dict: the dictionary used to save all logger stats
|
||||
"""
|
||||
|
||||
target_dict = {}
|
||||
target_dict: Dict[str, Dict] = {}
|
||||
_get_logger_dict_helper(mod, target_dict, prefix)
|
||||
return target_dict
|
||||
|
||||
|
|
@ -379,7 +380,7 @@ def get_matching_activations(float_module, q_module):
|
|||
"""
|
||||
float_dict = get_logger_dict(float_module)
|
||||
quantized_dict = get_logger_dict(q_module)
|
||||
act_dict = {}
|
||||
act_dict: Dict[str, Dict] = {}
|
||||
for key in quantized_dict:
|
||||
match_key = _find_match(sorted(float_dict, reverse=True), key, "stats")
|
||||
if match_key is not None:
|
||||
|
|
|
|||
|
|
@ -1037,9 +1037,11 @@ class TestCase(expecttest.TestCase):
|
|||
rtol, atol = self._getDefaultRtolAndAtol(torch.float32, torch.float32)
|
||||
else:
|
||||
rtol, atol = 0, 0
|
||||
rtol = cast(float, rtol)
|
||||
atol = cast(float, atol)
|
||||
atol = max(atol, self.precision)
|
||||
|
||||
return _compare_scalars_internal(a, b, rtol=cast(float, rtol), atol=cast(float, atol), equal_nan=equal_nan)
|
||||
return _compare_scalars_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
||||
|
||||
def assertEqualIgnoreType(self, *args, **kwargs) -> None:
|
||||
# If you are seeing this function used, that means test is written wrongly
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user