diff --git a/mypy.ini b/mypy.ini index caba585fe29..efaa8b12e96 100644 --- a/mypy.ini +++ b/mypy.ini @@ -260,9 +260,6 @@ ignore_errors = True [mypy-torch.nn.quantized.modules.batchnorm] ignore_errors = True -[mypy-torch.nn.quantized.modules.linear] -ignore_errors = True - [mypy-torch.nn.intrinsic.quantized.modules.conv_relu] ignore_errors = True diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py index 697d739b852..615741f38da 100644 --- a/torch/jit/quantized.py +++ b/torch/jit/quantized.py @@ -1,6 +1,6 @@ import torch -from torch._jit_internal import Tuple, Optional, List # noqa: F401 +from typing import Tuple, Optional, List from torch import Tensor, _VF # noqa: F401 diff --git a/torch/nn/quantized/dynamic/modules/embeddingbag.py b/torch/nn/quantized/dynamic/modules/embeddingbag.py index 43b3dc59287..78a426abb37 100644 --- a/torch/nn/quantized/dynamic/modules/embeddingbag.py +++ b/torch/nn/quantized/dynamic/modules/embeddingbag.py @@ -1,15 +1,15 @@ import torch import torch.nn as nn -from torch import Tensor # noqa: F401 -from torch._jit_internal import Optional, List # noqa: F401 +from torch import Tensor from torch.nn.quantized.modules.utils import hide_packed_params_repr from torch.nn.quantized.modules.utils import _quantize_weight from torch.quantization.qconfig import float_qparams_dynamic_qconfig +from typing import Optional class EmbeddingPackedParams(torch.nn.Module): _version = 1 - def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8): + def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8) -> None: super(EmbeddingPackedParams, self).__init__() self.dtype = dtype if self.dtype == torch.quint8: @@ -23,8 +23,7 @@ class EmbeddingPackedParams(torch.nn.Module): raise RuntimeError('Unsupported dtype on dynamic quantized embedding_bag!') @torch.jit.export - def set_weight(self, weight): - # type: (torch.Tensor) -> None + def set_weight(self, weight: Tensor) -> None: if self.dtype == torch.quint8: self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight) else: @@ -136,8 +135,7 @@ class EmbeddingBag(torch.nn.Module): return extra_repr_str - def set_weight(self, w): - # type: (torch.Tensor) -> None + def set_weight(self, w: Tensor) -> None: self._packed_params.set_weight(w) def weight(self): diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index ba35105c414..4d27dad07bc 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -1,9 +1,9 @@ import torch -from torch._jit_internal import Optional # noqa: F401 import torch.nn as nn import torch.nn.intrinsic as nni from torch.nn.quantized.modules.utils import _quantize_weight, hide_packed_params_repr +from typing import Optional class LinearPackedParams(torch.nn.Module): _version = 3 @@ -18,8 +18,7 @@ class LinearPackedParams(torch.nn.Module): self.set_weight_bias(wq, None) @torch.jit.export - def set_weight_bias(self, weight, bias): - # type: (torch.Tensor, Optional[torch.Tensor]) -> None + def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None: if self.dtype == torch.qint8: self._packed_params = torch.ops.quantized.linear_prepack(weight, bias) elif self.dtype == torch.float16: @@ -234,8 +233,7 @@ class Linear(torch.nn.Module): def bias(self): return self._weight_bias()[1] - def set_weight_bias(self, w, b): - # type: (torch.Tensor, Optional[torch.Tensor]) -> None + def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None: self._packed_params.set_weight_bias(w, b) @classmethod