diff --git a/mypy.ini b/mypy.ini index 6c579ee9399..0c99a9c62d1 100644 --- a/mypy.ini +++ b/mypy.ini @@ -91,16 +91,7 @@ ignore_errors = True [mypy-torch.nn.modules.pooling] ignore_errors = True -[mypy-torch.nn.qat.modules.activations] -ignore_errors = True - -[mypy-torch.nn.qat.modules.conv] -ignore_errors = True - -[mypy-torch.nn.quantized.dynamic.modules.linear] -ignore_errors = True - -[mypy-torch.nn.quantized.modules.conv] +[mypy-torch.nn.parallel._functions] ignore_errors = True [mypy-torch._appdirs] diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index 00ceba7ab36..b3bc78ff694 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -1,7 +1,7 @@ # coding=utf-8 r"""Quantized convolution modules.""" -from typing import Optional, List +from typing import Optional, List, TypeVar import torch import torch.nn as nn @@ -16,11 +16,17 @@ from torch.nn.utils import fuse_conv_bn_weights class _ConvNd(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, - padding, dilation, - transposed, output_padding, - groups, bias, + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): + # All subclasses have this signature - See PR #49702s + raise NotImplementedError + + def _init(self, in_channels, out_channels, kernel_size, stride, + padding, dilation, + transposed, output_padding, + groups, bias, + padding_mode='zeros'): super(_ConvNd, self).__init__() if padding_mode != 'zeros': raise NotImplementedError( @@ -54,6 +60,15 @@ class _ConvNd(nn.Module): self.scale = 1.0 self.zero_point = 0 + def set_weight_bias(self, qweight, bias_float): + raise NotImplementedError + + def bias(self): + raise NotImplementedError + + def _weight_bias(self): + raise NotImplementedError + def extra_repr(self): s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}, scale={scale}, zero_point={zero_point}') @@ -155,7 +170,8 @@ class _ConvNd(nn.Module): assert weight_post_process.dtype == torch.qint8, \ 'Weight observer must have a dtype of qint8' qweight = _quantize_weight(mod.weight.float(), weight_post_process) - qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, + # the __init__ call used is the one from derived classes and not the one from _ConvNd + qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg] mod.stride, mod.padding, mod.dilation, mod.groups, mod.bias is not None, mod.padding_mode) qconv.set_weight_bias(qweight, mod.bias) @@ -233,7 +249,9 @@ class Conv1d(_ConvNd): padding = _pair_from_first(padding) dilation = _pair_from_first(dilation) - super(Conv1d, self).__init__( + # Subclasses of _ConvNd needs to call _init rather than __init__. See + # discussion on PR #49702 + super(Conv1d, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _single(0), groups, bias, padding_mode) @@ -319,7 +337,9 @@ class Conv2d(_ConvNd): stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) - super(Conv2d, self).__init__( + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super(Conv2d, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode) @@ -403,7 +423,9 @@ class Conv3d(_ConvNd): stride = _triple(stride) padding = _triple(padding) dilation = _triple(dilation) - super(Conv3d, self).__init__( + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super(Conv3d, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _triple(0), groups, bias, padding_mode) @@ -450,15 +472,20 @@ class Conv3d(_ConvNd): return cls.get_qconv(mod, activation_post_process) # === Transposed Convolutions === +MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd) class _ConvTransposeNd(_ConvNd): + + _FLOAT_MODULE = MOD + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode): if padding_mode != 'zeros': raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__)) - - super(_ConvTransposeNd, self).__init__( + # Subclasses of _ConvNd need to call _init rather than __init__. See + # discussion on PR #49702 + super(_ConvTransposeNd, self)._init( in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode) @@ -477,9 +504,10 @@ class _ConvTransposeNd(_ConvNd): mod (Module): a float module, either produced by torch.quantization utilities or provided by the user """ - assert type(mod) == cls._FLOAT_MODULE, \ - ' nnq.' + cls.__name__ + '.from_float only works for ' + \ - cls._FLOAT_MODULE.__name__ + # derived classes override cls._FLOAT_MODULE attribute + msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \ + cls._FLOAT_MODULE.__name__ + assert type(mod) == cls._FLOAT_MODULE, msg assert hasattr(mod, 'qconfig'), \ 'Input float module must have qconfig defined.' weight_post_process = mod.qconfig.weight() @@ -488,7 +516,8 @@ class _ConvTransposeNd(_ConvNd): assert weight_post_process.dtype == torch.qint8, \ 'Weight observer must have a dtype of qint8' qweight = _quantize_weight(mod.weight.float(), weight_post_process) - qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, + # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd + qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg] mod.stride, mod.padding, mod.output_padding, mod.groups, mod.bias is not None, mod.dilation, mod.padding_mode) qconv.set_weight_bias(qweight, mod.bias)