mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable typechecks for torch.nn.quantized.modules.linear (#44154)
Summary: Also import `Optional` directly from `typing` rather than from `_jit_internal` Pull Request resolved: https://github.com/pytorch/pytorch/pull/44154 Reviewed By: seemethere Differential Revision: D23511833 Pulled By: malfet fbshipit-source-id: f78c5fd679c002b218e4d287a9e56fa198171981
This commit is contained in:
parent
538d3bd364
commit
b60ffcdfdd
3
mypy.ini
3
mypy.ini
|
|
@ -260,9 +260,6 @@ ignore_errors = True
|
||||||
[mypy-torch.nn.quantized.modules.batchnorm]
|
[mypy-torch.nn.quantized.modules.batchnorm]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-torch.nn.quantized.modules.linear]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.nn.intrinsic.quantized.modules.conv_relu]
|
[mypy-torch.nn.intrinsic.quantized.modules.conv_relu]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch._jit_internal import Tuple, Optional, List # noqa: F401
|
from typing import Tuple, Optional, List
|
||||||
|
|
||||||
from torch import Tensor, _VF # noqa: F401
|
from torch import Tensor, _VF # noqa: F401
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,15 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor # noqa: F401
|
from torch import Tensor
|
||||||
from torch._jit_internal import Optional, List # noqa: F401
|
|
||||||
from torch.nn.quantized.modules.utils import hide_packed_params_repr
|
from torch.nn.quantized.modules.utils import hide_packed_params_repr
|
||||||
from torch.nn.quantized.modules.utils import _quantize_weight
|
from torch.nn.quantized.modules.utils import _quantize_weight
|
||||||
from torch.quantization.qconfig import float_qparams_dynamic_qconfig
|
from torch.quantization.qconfig import float_qparams_dynamic_qconfig
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
class EmbeddingPackedParams(torch.nn.Module):
|
class EmbeddingPackedParams(torch.nn.Module):
|
||||||
_version = 1
|
_version = 1
|
||||||
|
|
||||||
def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
|
def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8) -> None:
|
||||||
super(EmbeddingPackedParams, self).__init__()
|
super(EmbeddingPackedParams, self).__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
if self.dtype == torch.quint8:
|
if self.dtype == torch.quint8:
|
||||||
|
|
@ -23,8 +23,7 @@ class EmbeddingPackedParams(torch.nn.Module):
|
||||||
raise RuntimeError('Unsupported dtype on dynamic quantized embedding_bag!')
|
raise RuntimeError('Unsupported dtype on dynamic quantized embedding_bag!')
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def set_weight(self, weight):
|
def set_weight(self, weight: Tensor) -> None:
|
||||||
# type: (torch.Tensor) -> None
|
|
||||||
if self.dtype == torch.quint8:
|
if self.dtype == torch.quint8:
|
||||||
self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
|
self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
|
||||||
else:
|
else:
|
||||||
|
|
@ -136,8 +135,7 @@ class EmbeddingBag(torch.nn.Module):
|
||||||
|
|
||||||
return extra_repr_str
|
return extra_repr_str
|
||||||
|
|
||||||
def set_weight(self, w):
|
def set_weight(self, w: Tensor) -> None:
|
||||||
# type: (torch.Tensor) -> None
|
|
||||||
self._packed_params.set_weight(w)
|
self._packed_params.set_weight(w)
|
||||||
|
|
||||||
def weight(self):
|
def weight(self):
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch._jit_internal import Optional # noqa: F401
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.intrinsic as nni
|
import torch.nn.intrinsic as nni
|
||||||
from torch.nn.quantized.modules.utils import _quantize_weight, hide_packed_params_repr
|
from torch.nn.quantized.modules.utils import _quantize_weight, hide_packed_params_repr
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
class LinearPackedParams(torch.nn.Module):
|
class LinearPackedParams(torch.nn.Module):
|
||||||
_version = 3
|
_version = 3
|
||||||
|
|
@ -18,8 +18,7 @@ class LinearPackedParams(torch.nn.Module):
|
||||||
self.set_weight_bias(wq, None)
|
self.set_weight_bias(wq, None)
|
||||||
|
|
||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def set_weight_bias(self, weight, bias):
|
def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:
|
||||||
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
|
|
||||||
if self.dtype == torch.qint8:
|
if self.dtype == torch.qint8:
|
||||||
self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
|
self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
|
||||||
elif self.dtype == torch.float16:
|
elif self.dtype == torch.float16:
|
||||||
|
|
@ -234,8 +233,7 @@ class Linear(torch.nn.Module):
|
||||||
def bias(self):
|
def bias(self):
|
||||||
return self._weight_bias()[1]
|
return self._weight_bias()[1]
|
||||||
|
|
||||||
def set_weight_bias(self, w, b):
|
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
|
||||||
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
|
|
||||||
self._packed_params.set_weight_bias(w, b)
|
self._packed_params.set_weight_bias(w, b)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user