mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
add type annotations to torch.nn.quantized.modules.conv (#49702)
Summary: closes gh-49700 No mypy issues were found in the first three entries deleted from `mypy.ini`: ``` [mypy-torch.nn.qat.modules.activations] ignore_errors = True [mypy-torch.nn.qat.modules.conv] ignore_errors = True [mypy-torch.nn.quantized.dynamic.modules.linear] ignore_errors = True ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/49702 Reviewed By: walterddr, zou3519 Differential Revision: D25767119 Pulled By: ezyang fbshipit-source-id: cb83e53549a299538e1b154cf8b79e3280f7392a
This commit is contained in:
parent
54ce171f16
commit
55919a4758
11
mypy.ini
11
mypy.ini
|
|
@ -91,16 +91,7 @@ ignore_errors = True
|
||||||
[mypy-torch.nn.modules.pooling]
|
[mypy-torch.nn.modules.pooling]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-torch.nn.qat.modules.activations]
|
[mypy-torch.nn.parallel._functions]
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.nn.qat.modules.conv]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.nn.quantized.dynamic.modules.linear]
|
|
||||||
ignore_errors = True
|
|
||||||
|
|
||||||
[mypy-torch.nn.quantized.modules.conv]
|
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-torch._appdirs]
|
[mypy-torch._appdirs]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
r"""Quantized convolution modules."""
|
r"""Quantized convolution modules."""
|
||||||
|
|
||||||
from typing import Optional, List
|
from typing import Optional, List, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -16,11 +16,17 @@ from torch.nn.utils import fuse_conv_bn_weights
|
||||||
|
|
||||||
class _ConvNd(nn.Module):
|
class _ConvNd(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||||
padding, dilation,
|
padding=0, dilation=1, groups=1, bias=True,
|
||||||
transposed, output_padding,
|
|
||||||
groups, bias,
|
|
||||||
padding_mode='zeros'):
|
padding_mode='zeros'):
|
||||||
|
# All subclasses have this signature - See PR #49702s
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _init(self, in_channels, out_channels, kernel_size, stride,
|
||||||
|
padding, dilation,
|
||||||
|
transposed, output_padding,
|
||||||
|
groups, bias,
|
||||||
|
padding_mode='zeros'):
|
||||||
super(_ConvNd, self).__init__()
|
super(_ConvNd, self).__init__()
|
||||||
if padding_mode != 'zeros':
|
if padding_mode != 'zeros':
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|
@ -54,6 +60,15 @@ class _ConvNd(nn.Module):
|
||||||
self.scale = 1.0
|
self.scale = 1.0
|
||||||
self.zero_point = 0
|
self.zero_point = 0
|
||||||
|
|
||||||
|
def set_weight_bias(self, qweight, bias_float):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def bias(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _weight_bias(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
|
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
|
||||||
', stride={stride}, scale={scale}, zero_point={zero_point}')
|
', stride={stride}, scale={scale}, zero_point={zero_point}')
|
||||||
|
|
@ -155,7 +170,8 @@ class _ConvNd(nn.Module):
|
||||||
assert weight_post_process.dtype == torch.qint8, \
|
assert weight_post_process.dtype == torch.qint8, \
|
||||||
'Weight observer must have a dtype of qint8'
|
'Weight observer must have a dtype of qint8'
|
||||||
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
|
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
|
||||||
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
|
# the __init__ call used is the one from derived classes and not the one from _ConvNd
|
||||||
|
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
|
||||||
mod.stride, mod.padding, mod.dilation, mod.groups,
|
mod.stride, mod.padding, mod.dilation, mod.groups,
|
||||||
mod.bias is not None, mod.padding_mode)
|
mod.bias is not None, mod.padding_mode)
|
||||||
qconv.set_weight_bias(qweight, mod.bias)
|
qconv.set_weight_bias(qweight, mod.bias)
|
||||||
|
|
@ -233,7 +249,9 @@ class Conv1d(_ConvNd):
|
||||||
padding = _pair_from_first(padding)
|
padding = _pair_from_first(padding)
|
||||||
dilation = _pair_from_first(dilation)
|
dilation = _pair_from_first(dilation)
|
||||||
|
|
||||||
super(Conv1d, self).__init__(
|
# Subclasses of _ConvNd needs to call _init rather than __init__. See
|
||||||
|
# discussion on PR #49702
|
||||||
|
super(Conv1d, self)._init(
|
||||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||||
False, _single(0), groups, bias, padding_mode)
|
False, _single(0), groups, bias, padding_mode)
|
||||||
|
|
||||||
|
|
@ -319,7 +337,9 @@ class Conv2d(_ConvNd):
|
||||||
stride = _pair(stride)
|
stride = _pair(stride)
|
||||||
padding = _pair(padding)
|
padding = _pair(padding)
|
||||||
dilation = _pair(dilation)
|
dilation = _pair(dilation)
|
||||||
super(Conv2d, self).__init__(
|
# Subclasses of _ConvNd need to call _init rather than __init__. See
|
||||||
|
# discussion on PR #49702
|
||||||
|
super(Conv2d, self)._init(
|
||||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||||
False, _pair(0), groups, bias, padding_mode)
|
False, _pair(0), groups, bias, padding_mode)
|
||||||
|
|
||||||
|
|
@ -403,7 +423,9 @@ class Conv3d(_ConvNd):
|
||||||
stride = _triple(stride)
|
stride = _triple(stride)
|
||||||
padding = _triple(padding)
|
padding = _triple(padding)
|
||||||
dilation = _triple(dilation)
|
dilation = _triple(dilation)
|
||||||
super(Conv3d, self).__init__(
|
# Subclasses of _ConvNd need to call _init rather than __init__. See
|
||||||
|
# discussion on PR #49702
|
||||||
|
super(Conv3d, self)._init(
|
||||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||||
False, _triple(0), groups, bias, padding_mode)
|
False, _triple(0), groups, bias, padding_mode)
|
||||||
|
|
||||||
|
|
@ -450,15 +472,20 @@ class Conv3d(_ConvNd):
|
||||||
return cls.get_qconv(mod, activation_post_process)
|
return cls.get_qconv(mod, activation_post_process)
|
||||||
|
|
||||||
# === Transposed Convolutions ===
|
# === Transposed Convolutions ===
|
||||||
|
MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
|
||||||
|
|
||||||
class _ConvTransposeNd(_ConvNd):
|
class _ConvTransposeNd(_ConvNd):
|
||||||
|
|
||||||
|
_FLOAT_MODULE = MOD
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
def __init__(self, in_channels, out_channels, kernel_size, stride,
|
||||||
padding, dilation, transposed, output_padding,
|
padding, dilation, transposed, output_padding,
|
||||||
groups, bias, padding_mode):
|
groups, bias, padding_mode):
|
||||||
if padding_mode != 'zeros':
|
if padding_mode != 'zeros':
|
||||||
raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__))
|
raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__))
|
||||||
|
# Subclasses of _ConvNd need to call _init rather than __init__. See
|
||||||
super(_ConvTransposeNd, self).__init__(
|
# discussion on PR #49702
|
||||||
|
super(_ConvTransposeNd, self)._init(
|
||||||
in_channels, out_channels, kernel_size, stride,
|
in_channels, out_channels, kernel_size, stride,
|
||||||
padding, dilation, transposed, output_padding,
|
padding, dilation, transposed, output_padding,
|
||||||
groups, bias, padding_mode)
|
groups, bias, padding_mode)
|
||||||
|
|
@ -477,9 +504,10 @@ class _ConvTransposeNd(_ConvNd):
|
||||||
mod (Module): a float module, either produced by torch.quantization
|
mod (Module): a float module, either produced by torch.quantization
|
||||||
utilities or provided by the user
|
utilities or provided by the user
|
||||||
"""
|
"""
|
||||||
assert type(mod) == cls._FLOAT_MODULE, \
|
# derived classes override cls._FLOAT_MODULE attribute
|
||||||
' nnq.' + cls.__name__ + '.from_float only works for ' + \
|
msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \
|
||||||
cls._FLOAT_MODULE.__name__
|
cls._FLOAT_MODULE.__name__
|
||||||
|
assert type(mod) == cls._FLOAT_MODULE, msg
|
||||||
assert hasattr(mod, 'qconfig'), \
|
assert hasattr(mod, 'qconfig'), \
|
||||||
'Input float module must have qconfig defined.'
|
'Input float module must have qconfig defined.'
|
||||||
weight_post_process = mod.qconfig.weight()
|
weight_post_process = mod.qconfig.weight()
|
||||||
|
|
@ -488,7 +516,8 @@ class _ConvTransposeNd(_ConvNd):
|
||||||
assert weight_post_process.dtype == torch.qint8, \
|
assert weight_post_process.dtype == torch.qint8, \
|
||||||
'Weight observer must have a dtype of qint8'
|
'Weight observer must have a dtype of qint8'
|
||||||
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
|
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
|
||||||
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
|
# the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
|
||||||
|
qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
|
||||||
mod.stride, mod.padding, mod.output_padding, mod.groups,
|
mod.stride, mod.padding, mod.output_padding, mod.groups,
|
||||||
mod.bias is not None, mod.dilation, mod.padding_mode)
|
mod.bias is not None, mod.dilation, mod.padding_mode)
|
||||||
qconv.set_weight_bias(qweight, mod.bias)
|
qconv.set_weight_bias(qweight, mod.bias)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user