add type annotations to torch.nn.quantized.modules.conv (#49702)

Summary:
closes gh-49700

No mypy issues were found in the first three entries deleted from `mypy.ini`:
```
[mypy-torch.nn.qat.modules.activations]
ignore_errors = True

[mypy-torch.nn.qat.modules.conv]
ignore_errors = True

[mypy-torch.nn.quantized.dynamic.modules.linear]
ignore_errors = True
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49702

Reviewed By: walterddr, zou3519

Differential Revision: D25767119

Pulled By: ezyang

fbshipit-source-id: cb83e53549a299538e1b154cf8b79e3280f7392a
This commit is contained in:
Guilherme Leobas 2021-01-08 07:29:43 -08:00 committed by Facebook GitHub Bot
parent 54ce171f16
commit 55919a4758
2 changed files with 45 additions and 25 deletions

View File

@ -91,16 +91,7 @@ ignore_errors = True
[mypy-torch.nn.modules.pooling] [mypy-torch.nn.modules.pooling]
ignore_errors = True ignore_errors = True
[mypy-torch.nn.qat.modules.activations] [mypy-torch.nn.parallel._functions]
ignore_errors = True
[mypy-torch.nn.qat.modules.conv]
ignore_errors = True
[mypy-torch.nn.quantized.dynamic.modules.linear]
ignore_errors = True
[mypy-torch.nn.quantized.modules.conv]
ignore_errors = True ignore_errors = True
[mypy-torch._appdirs] [mypy-torch._appdirs]

View File

@ -1,7 +1,7 @@
# coding=utf-8 # coding=utf-8
r"""Quantized convolution modules.""" r"""Quantized convolution modules."""
from typing import Optional, List from typing import Optional, List, TypeVar
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -16,11 +16,17 @@ from torch.nn.utils import fuse_conv_bn_weights
class _ConvNd(nn.Module): class _ConvNd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding, dilation, padding=0, dilation=1, groups=1, bias=True,
transposed, output_padding,
groups, bias,
padding_mode='zeros'): padding_mode='zeros'):
# All subclasses have this signature - See PR #49702s
raise NotImplementedError
def _init(self, in_channels, out_channels, kernel_size, stride,
padding, dilation,
transposed, output_padding,
groups, bias,
padding_mode='zeros'):
super(_ConvNd, self).__init__() super(_ConvNd, self).__init__()
if padding_mode != 'zeros': if padding_mode != 'zeros':
raise NotImplementedError( raise NotImplementedError(
@ -54,6 +60,15 @@ class _ConvNd(nn.Module):
self.scale = 1.0 self.scale = 1.0
self.zero_point = 0 self.zero_point = 0
def set_weight_bias(self, qweight, bias_float):
raise NotImplementedError
def bias(self):
raise NotImplementedError
def _weight_bias(self):
raise NotImplementedError
def extra_repr(self): def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}, scale={scale}, zero_point={zero_point}') ', stride={stride}, scale={scale}, zero_point={zero_point}')
@ -155,7 +170,8 @@ class _ConvNd(nn.Module):
assert weight_post_process.dtype == torch.qint8, \ assert weight_post_process.dtype == torch.qint8, \
'Weight observer must have a dtype of qint8' 'Weight observer must have a dtype of qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process) qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # the __init__ call used is the one from derived classes and not the one from _ConvNd
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
mod.stride, mod.padding, mod.dilation, mod.groups, mod.stride, mod.padding, mod.dilation, mod.groups,
mod.bias is not None, mod.padding_mode) mod.bias is not None, mod.padding_mode)
qconv.set_weight_bias(qweight, mod.bias) qconv.set_weight_bias(qweight, mod.bias)
@ -233,7 +249,9 @@ class Conv1d(_ConvNd):
padding = _pair_from_first(padding) padding = _pair_from_first(padding)
dilation = _pair_from_first(dilation) dilation = _pair_from_first(dilation)
super(Conv1d, self).__init__( # Subclasses of _ConvNd needs to call _init rather than __init__. See
# discussion on PR #49702
super(Conv1d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _single(0), groups, bias, padding_mode) False, _single(0), groups, bias, padding_mode)
@ -319,7 +337,9 @@ class Conv2d(_ConvNd):
stride = _pair(stride) stride = _pair(stride)
padding = _pair(padding) padding = _pair(padding)
dilation = _pair(dilation) dilation = _pair(dilation)
super(Conv2d, self).__init__( # Subclasses of _ConvNd need to call _init rather than __init__. See
# discussion on PR #49702
super(Conv2d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _pair(0), groups, bias, padding_mode) False, _pair(0), groups, bias, padding_mode)
@ -403,7 +423,9 @@ class Conv3d(_ConvNd):
stride = _triple(stride) stride = _triple(stride)
padding = _triple(padding) padding = _triple(padding)
dilation = _triple(dilation) dilation = _triple(dilation)
super(Conv3d, self).__init__( # Subclasses of _ConvNd need to call _init rather than __init__. See
# discussion on PR #49702
super(Conv3d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _triple(0), groups, bias, padding_mode) False, _triple(0), groups, bias, padding_mode)
@ -450,15 +472,20 @@ class Conv3d(_ConvNd):
return cls.get_qconv(mod, activation_post_process) return cls.get_qconv(mod, activation_post_process)
# === Transposed Convolutions === # === Transposed Convolutions ===
MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
class _ConvTransposeNd(_ConvNd): class _ConvTransposeNd(_ConvNd):
_FLOAT_MODULE = MOD
def __init__(self, in_channels, out_channels, kernel_size, stride, def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding, padding, dilation, transposed, output_padding,
groups, bias, padding_mode): groups, bias, padding_mode):
if padding_mode != 'zeros': if padding_mode != 'zeros':
raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__)) raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__))
# Subclasses of _ConvNd need to call _init rather than __init__. See
super(_ConvTransposeNd, self).__init__( # discussion on PR #49702
super(_ConvTransposeNd, self)._init(
in_channels, out_channels, kernel_size, stride, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding, padding, dilation, transposed, output_padding,
groups, bias, padding_mode) groups, bias, padding_mode)
@ -477,9 +504,10 @@ class _ConvTransposeNd(_ConvNd):
mod (Module): a float module, either produced by torch.quantization mod (Module): a float module, either produced by torch.quantization
utilities or provided by the user utilities or provided by the user
""" """
assert type(mod) == cls._FLOAT_MODULE, \ # derived classes override cls._FLOAT_MODULE attribute
' nnq.' + cls.__name__ + '.from_float only works for ' + \ msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__ cls._FLOAT_MODULE.__name__
assert type(mod) == cls._FLOAT_MODULE, msg
assert hasattr(mod, 'qconfig'), \ assert hasattr(mod, 'qconfig'), \
'Input float module must have qconfig defined.' 'Input float module must have qconfig defined.'
weight_post_process = mod.qconfig.weight() weight_post_process = mod.qconfig.weight()
@ -488,7 +516,8 @@ class _ConvTransposeNd(_ConvNd):
assert weight_post_process.dtype == torch.qint8, \ assert weight_post_process.dtype == torch.qint8, \
'Weight observer must have a dtype of qint8' 'Weight observer must have a dtype of qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process) qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
mod.stride, mod.padding, mod.output_padding, mod.groups, mod.stride, mod.padding, mod.output_padding, mod.groups,
mod.bias is not None, mod.dilation, mod.padding_mode) mod.bias is not None, mod.dilation, mod.padding_mode)
qconv.set_weight_bias(qweight, mod.bias) qconv.set_weight_bias(qweight, mod.bias)