Enable typechecks for torch.nn.quantized.modules.linear (#44154)

Summary:
Also import `Optional` directly from `typing` rather than from `_jit_internal`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44154

Reviewed By: seemethere

Differential Revision: D23511833

Pulled By: malfet

fbshipit-source-id: f78c5fd679c002b218e4d287a9e56fa198171981
This commit is contained in:
Nikita Shulga 2020-09-03 19:50:40 -07:00 committed by Facebook GitHub Bot
parent 538d3bd364
commit b60ffcdfdd
4 changed files with 9 additions and 16 deletions

View File

@ -260,9 +260,6 @@ ignore_errors = True
[mypy-torch.nn.quantized.modules.batchnorm] [mypy-torch.nn.quantized.modules.batchnorm]
ignore_errors = True ignore_errors = True
[mypy-torch.nn.quantized.modules.linear]
ignore_errors = True
[mypy-torch.nn.intrinsic.quantized.modules.conv_relu] [mypy-torch.nn.intrinsic.quantized.modules.conv_relu]
ignore_errors = True ignore_errors = True

View File

@ -1,6 +1,6 @@
import torch import torch
from torch._jit_internal import Tuple, Optional, List # noqa: F401 from typing import Tuple, Optional, List
from torch import Tensor, _VF # noqa: F401 from torch import Tensor, _VF # noqa: F401

View File

@ -1,15 +1,15 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor # noqa: F401 from torch import Tensor
from torch._jit_internal import Optional, List # noqa: F401
from torch.nn.quantized.modules.utils import hide_packed_params_repr from torch.nn.quantized.modules.utils import hide_packed_params_repr
from torch.nn.quantized.modules.utils import _quantize_weight from torch.nn.quantized.modules.utils import _quantize_weight
from torch.quantization.qconfig import float_qparams_dynamic_qconfig from torch.quantization.qconfig import float_qparams_dynamic_qconfig
from typing import Optional
class EmbeddingPackedParams(torch.nn.Module): class EmbeddingPackedParams(torch.nn.Module):
_version = 1 _version = 1
def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8): def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8) -> None:
super(EmbeddingPackedParams, self).__init__() super(EmbeddingPackedParams, self).__init__()
self.dtype = dtype self.dtype = dtype
if self.dtype == torch.quint8: if self.dtype == torch.quint8:
@ -23,8 +23,7 @@ class EmbeddingPackedParams(torch.nn.Module):
raise RuntimeError('Unsupported dtype on dynamic quantized embedding_bag!') raise RuntimeError('Unsupported dtype on dynamic quantized embedding_bag!')
@torch.jit.export @torch.jit.export
def set_weight(self, weight): def set_weight(self, weight: Tensor) -> None:
# type: (torch.Tensor) -> None
if self.dtype == torch.quint8: if self.dtype == torch.quint8:
self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight) self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
else: else:
@ -136,8 +135,7 @@ class EmbeddingBag(torch.nn.Module):
return extra_repr_str return extra_repr_str
def set_weight(self, w): def set_weight(self, w: Tensor) -> None:
# type: (torch.Tensor) -> None
self._packed_params.set_weight(w) self._packed_params.set_weight(w)
def weight(self): def weight(self):

View File

@ -1,9 +1,9 @@
import torch import torch
from torch._jit_internal import Optional # noqa: F401
import torch.nn as nn import torch.nn as nn
import torch.nn.intrinsic as nni import torch.nn.intrinsic as nni
from torch.nn.quantized.modules.utils import _quantize_weight, hide_packed_params_repr from torch.nn.quantized.modules.utils import _quantize_weight, hide_packed_params_repr
from typing import Optional
class LinearPackedParams(torch.nn.Module): class LinearPackedParams(torch.nn.Module):
_version = 3 _version = 3
@ -18,8 +18,7 @@ class LinearPackedParams(torch.nn.Module):
self.set_weight_bias(wq, None) self.set_weight_bias(wq, None)
@torch.jit.export @torch.jit.export
def set_weight_bias(self, weight, bias): def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
if self.dtype == torch.qint8: if self.dtype == torch.qint8:
self._packed_params = torch.ops.quantized.linear_prepack(weight, bias) self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
elif self.dtype == torch.float16: elif self.dtype == torch.float16:
@ -234,8 +233,7 @@ class Linear(torch.nn.Module):
def bias(self): def bias(self):
return self._weight_bias()[1] return self._weight_bias()[1]
def set_weight_bias(self, w, b): def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
self._packed_params.set_weight_bias(w, b) self._packed_params.set_weight_bias(w, b)
@classmethod @classmethod