Enable typechecks for torch.nn.quantized.modules.linear (#44154)

Summary:
Also import `Optional` directly from `typing` rather than from `_jit_internal`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44154

Reviewed By: seemethere

Differential Revision: D23511833

Pulled By: malfet

fbshipit-source-id: f78c5fd679c002b218e4d287a9e56fa198171981
This commit is contained in:
Nikita Shulga 2020-09-03 19:50:40 -07:00 committed by Facebook GitHub Bot
parent 538d3bd364
commit b60ffcdfdd
4 changed files with 9 additions and 16 deletions

View File

@ -260,9 +260,6 @@ ignore_errors = True
[mypy-torch.nn.quantized.modules.batchnorm]
ignore_errors = True
[mypy-torch.nn.quantized.modules.linear]
ignore_errors = True
[mypy-torch.nn.intrinsic.quantized.modules.conv_relu]
ignore_errors = True

View File

@ -1,6 +1,6 @@
import torch
from torch._jit_internal import Tuple, Optional, List # noqa: F401
from typing import Tuple, Optional, List
from torch import Tensor, _VF # noqa: F401

View File

@ -1,15 +1,15 @@
import torch
import torch.nn as nn
from torch import Tensor # noqa: F401
from torch._jit_internal import Optional, List # noqa: F401
from torch import Tensor
from torch.nn.quantized.modules.utils import hide_packed_params_repr
from torch.nn.quantized.modules.utils import _quantize_weight
from torch.quantization.qconfig import float_qparams_dynamic_qconfig
from typing import Optional
class EmbeddingPackedParams(torch.nn.Module):
_version = 1
def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8) -> None:
super(EmbeddingPackedParams, self).__init__()
self.dtype = dtype
if self.dtype == torch.quint8:
@ -23,8 +23,7 @@ class EmbeddingPackedParams(torch.nn.Module):
raise RuntimeError('Unsupported dtype on dynamic quantized embedding_bag!')
@torch.jit.export
def set_weight(self, weight):
# type: (torch.Tensor) -> None
def set_weight(self, weight: Tensor) -> None:
if self.dtype == torch.quint8:
self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
else:
@ -136,8 +135,7 @@ class EmbeddingBag(torch.nn.Module):
return extra_repr_str
def set_weight(self, w):
# type: (torch.Tensor) -> None
def set_weight(self, w: Tensor) -> None:
self._packed_params.set_weight(w)
def weight(self):

View File

@ -1,9 +1,9 @@
import torch
from torch._jit_internal import Optional # noqa: F401
import torch.nn as nn
import torch.nn.intrinsic as nni
from torch.nn.quantized.modules.utils import _quantize_weight, hide_packed_params_repr
from typing import Optional
class LinearPackedParams(torch.nn.Module):
_version = 3
@ -18,8 +18,7 @@ class LinearPackedParams(torch.nn.Module):
self.set_weight_bias(wq, None)
@torch.jit.export
def set_weight_bias(self, weight, bias):
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:
if self.dtype == torch.qint8:
self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
elif self.dtype == torch.float16:
@ -234,8 +233,7 @@ class Linear(torch.nn.Module):
def bias(self):
return self._weight_bias()[1]
def set_weight_bias(self, w, b):
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params.set_weight_bias(w, b)
@classmethod