mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant] Add ConvTranspose reference module (#73031)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73031
Add ConvTranspose reference module
Test Plan:
python3 test/test_quantization.py TestQuantizeEagerOps.test_conv_transpose_2d
Imported from OSS
Reviewed By: jerryzh168
Differential Revision: D34313425
fbshipit-source-id: 3eeec1b24a51c7951c4d4b0c7dca43a012468b85
(cherry picked from commit 0ee7c1cc39)
This commit is contained in:
parent
51b04f27c7
commit
710f12f58e
|
|
@ -3,6 +3,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.quantized as nnq
|
||||
import torch.nn.quantized._reference as nnqr
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.ao.quantization import (
|
||||
quantize,
|
||||
|
|
@ -74,6 +75,130 @@ import unittest
|
|||
import numpy as np
|
||||
|
||||
class TestQuantizeEagerOps(QuantizationTestCase):
|
||||
def _test_reference_module_impl(self,
|
||||
float_module_class,
|
||||
quantized_module_class,
|
||||
extra_module_kwargs,
|
||||
input_size):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = float_module_class(**extra_module_kwargs)
|
||||
self.quant = QuantStub()
|
||||
self.dequant = DeQuantStub()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.quant(x)
|
||||
x = self.conv(x)
|
||||
x = self.dequant(x)
|
||||
return x
|
||||
|
||||
class RefM(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = float_module_class(**extra_module_kwargs)
|
||||
self.quant1 = QuantStub()
|
||||
self.dequant1 = DeQuantStub()
|
||||
self.quant2 = QuantStub()
|
||||
self.dequant2 = DeQuantStub()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.quant1(x)
|
||||
x = self.dequant1(x)
|
||||
x = self.conv(x)
|
||||
x = self.quant2(x)
|
||||
x = self.dequant2(x)
|
||||
return x
|
||||
|
||||
qengine = 'fbgemm'
|
||||
with override_quantized_engine(qengine):
|
||||
data = torch.randn(*input_size, dtype=torch.float)
|
||||
original_m = M()
|
||||
original_ref_m = RefM()
|
||||
|
||||
original_ref_m.conv.weight = torch.nn.Parameter(original_m.conv.weight.detach())
|
||||
original_ref_m.conv.bias = torch.nn.Parameter(original_m.conv.bias.detach())
|
||||
|
||||
original_m.qconfig = torch.quantization.default_qconfig
|
||||
|
||||
m = prepare(original_m)
|
||||
# calibration
|
||||
m(data)
|
||||
m = convert(m)
|
||||
# check if the module is properly quantized
|
||||
self.assertEqual(type(m.quant), nnq.Quantize)
|
||||
self.assertEqual(type(m.conv), quantized_module_class)
|
||||
self.assertEqual(type(m.dequant), nnq.DeQuantize)
|
||||
res = m(data)
|
||||
|
||||
# quantize the reference model
|
||||
original_ref_m.eval()
|
||||
original_ref_m.qconfig = torch.quantization.default_qconfig
|
||||
|
||||
ref_m = prepare(original_ref_m)
|
||||
ref_m(data)
|
||||
reference_module_mapping = {
|
||||
QuantStub: nnq.Quantize,
|
||||
DeQuantStub: nnq.DeQuantize,
|
||||
nn.Conv1d: nnqr.Conv1d,
|
||||
nn.Conv2d: nnqr.Conv2d,
|
||||
nn.Conv3d: nnqr.Conv3d,
|
||||
nn.ConvTranspose1d: nnqr.ConvTranspose1d,
|
||||
nn.ConvTranspose2d: nnqr.ConvTranspose2d,
|
||||
nn.ConvTranspose3d: nnqr.ConvTranspose3d,
|
||||
}
|
||||
ref_m = convert(ref_m, mapping=reference_module_mapping)
|
||||
ref_res = ref_m(data)
|
||||
self.assertEqual(res, ref_res)
|
||||
|
||||
def test_conv_1d(self):
|
||||
self._test_reference_module_impl(
|
||||
nn.Conv1d,
|
||||
nnq.Conv1d,
|
||||
{'in_channels': 1, 'out_channels': 1, 'kernel_size': 1},
|
||||
(16, 1, 1)
|
||||
)
|
||||
|
||||
def test_conv_2d(self):
|
||||
self._test_reference_module_impl(
|
||||
nn.Conv2d,
|
||||
nnq.Conv2d,
|
||||
{'in_channels': 1, 'out_channels': 1, 'kernel_size': 1},
|
||||
(16, 1, 10, 10)
|
||||
)
|
||||
|
||||
def test_conv_3d(self):
|
||||
self._test_reference_module_impl(
|
||||
nn.Conv3d,
|
||||
nnq.Conv3d,
|
||||
{'in_channels': 1, 'out_channels': 1, 'kernel_size': 1},
|
||||
(16, 1, 10, 10, 10)
|
||||
)
|
||||
|
||||
def test_conv_transpose_1d(self):
|
||||
self._test_reference_module_impl(
|
||||
nn.ConvTranspose1d,
|
||||
nnq.ConvTranspose1d,
|
||||
{'in_channels': 1, 'out_channels': 1, 'kernel_size': 1},
|
||||
(16, 1, 1)
|
||||
)
|
||||
|
||||
def test_conv_transpose_2d(self):
|
||||
self._test_reference_module_impl(
|
||||
nn.ConvTranspose2d,
|
||||
nnq.ConvTranspose2d,
|
||||
{'in_channels': 1, 'out_channels': 1, 'kernel_size': 1},
|
||||
(16, 1, 10, 10)
|
||||
)
|
||||
|
||||
def test_conv_transpose_3d(self):
|
||||
self._test_reference_module_impl(
|
||||
nn.ConvTranspose3d,
|
||||
nnq.ConvTranspose3d,
|
||||
{'in_channels': 1, 'out_channels': 1, 'kernel_size': 1},
|
||||
(16, 1, 10, 10, 10)
|
||||
)
|
||||
|
||||
def _test_activation_op_impl(
|
||||
self, float_module_class, quantized_module_class, extra_module_kwargs):
|
||||
""" Implementation for testing common activation ops like leaky relu
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from torch.ao.quantization.quantization_mappings import (
|
|||
_has_special_act_post_process,
|
||||
_get_special_act_post_process,
|
||||
)
|
||||
|
||||
from .utils import get_qparam_dict
|
||||
from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
|
||||
from torch.ao.quantization.qconfig import (
|
||||
add_module_to_qconfig_obs_ctr,
|
||||
|
|
@ -565,7 +565,15 @@ def swap_module(mod, mapping, custom_module_class_mapping):
|
|||
new_mod = custom_module_class_mapping[type(mod)].from_observed(mod)
|
||||
swapped = True
|
||||
elif type(mod) in mapping:
|
||||
new_mod = mapping[type(mod)].from_float(mod)
|
||||
qmod = mapping[type(mod)]
|
||||
if hasattr(qmod, '_IS_REFERENCE') and qmod._IS_REFERENCE:
|
||||
assert mod.qconfig is not None
|
||||
weight_post_process = mod.qconfig.weight()
|
||||
weight_post_process(mod.weight)
|
||||
weight_qparams = get_qparam_dict(weight_post_process)
|
||||
new_mod = qmod.from_float(mod, weight_qparams)
|
||||
else:
|
||||
new_mod = qmod.from_float(mod)
|
||||
swapped = True
|
||||
|
||||
if swapped:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
from .linear import Linear
|
||||
from .conv import Conv1d, Conv2d, Conv3d
|
||||
from .conv import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
|
||||
|
||||
__all__ = [
|
||||
'Linear',
|
||||
'Conv1d',
|
||||
'Conv2d',
|
||||
'Conv3d',
|
||||
'ConvTranspose1d',
|
||||
'ConvTranspose2d',
|
||||
'ConvTranspose3d',
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, List
|
||||
from torch.nn.common_types import _size_1_t
|
||||
from .utils import _quantize_weight, _quantize_and_dequantize_weight
|
||||
from .utils import _save_weight_qparams
|
||||
|
|
@ -14,6 +14,7 @@ class _ConvNd(torch.nn.modules.conv._ConvNd):
|
|||
this is useful when user want to use this module in other backends like Glow.
|
||||
"""
|
||||
__annotations__ = {"bias": Optional[torch.Tensor]}
|
||||
_IS_REFERENCE = True
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
|
@ -217,3 +218,169 @@ class Conv3d(_ConvNd, nn.Conv3d):
|
|||
@classmethod
|
||||
def from_float(cls, float_conv, weight_qparams):
|
||||
return _ConvNd.from_float(cls, float_conv, weight_qparams)
|
||||
|
||||
class _ConvTransposeNd(_ConvNd, torch.nn.modules.conv._ConvTransposeNd):
|
||||
""" A reference version of nn.quantized.ConvTranspose2d
|
||||
we will not pack the parameters in this module, since weight packing is an
|
||||
optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
|
||||
this is useful when user want to use this module in other backends like Glow.
|
||||
"""
|
||||
@staticmethod
|
||||
def from_float(cls, float_conv, weight_qparams):
|
||||
qref_conv = cls(
|
||||
float_conv.in_channels,
|
||||
float_conv.out_channels,
|
||||
float_conv.kernel_size, # type: ignore[arg-type]
|
||||
float_conv.stride, # type: ignore[arg-type]
|
||||
float_conv.padding, # type: ignore[arg-type]
|
||||
float_conv.output_padding, # type: ignore[arg-type]
|
||||
float_conv.groups,
|
||||
float_conv.bias is not None, # type: ignore[arg-type]
|
||||
float_conv.dilation, # type: ignore[arg-type]
|
||||
float_conv.padding_mode,
|
||||
device=float_conv.weight.device,
|
||||
dtype=float_conv.weight.dtype,
|
||||
weight_qparams=weight_qparams)
|
||||
qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
|
||||
if float_conv.bias is not None:
|
||||
qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
|
||||
return qref_conv
|
||||
|
||||
|
||||
class ConvTranspose1d(_ConvTransposeNd, nn.ConvTranspose1d):
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: _size_1_t,
|
||||
stride: _size_1_t = 1,
|
||||
padding: _size_1_t = 0,
|
||||
output_padding: _size_1_t = 0,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
dilation: _size_1_t = 1,
|
||||
padding_mode: str = "zeros",
|
||||
device=None,
|
||||
dtype=None,
|
||||
weight_qparams: Optional[Dict[str, Any]] = None):
|
||||
nn.ConvTranspose1d.__init__(
|
||||
self, in_channels, out_channels, kernel_size, stride, padding, output_padding,
|
||||
groups, bias, dilation, padding_mode, device, dtype)
|
||||
self._init_weight_qparams(weight_qparams, device)
|
||||
|
||||
def forward(self, x: torch.Tensor, output_size: Optional[List[int]] = None) -> torch.Tensor:
|
||||
"""
|
||||
we have:
|
||||
w(float) -- quant - dequant \
|
||||
x(float) ------------- F.convTranspose1d ---
|
||||
In the full model, we will see
|
||||
w(float) -- quant - *dequant \
|
||||
x -- quant --- *dequant -- *F.convTranspose1d --- *quant - dequant
|
||||
and the backend should be able to fuse the ops with `*` into a quantized conv1d
|
||||
"""
|
||||
|
||||
assert isinstance(self.padding, tuple)
|
||||
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
||||
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
||||
output_padding = self._output_padding(
|
||||
input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
|
||||
|
||||
weight_dequant = self.get_weight()
|
||||
result = F.conv_transpose1d(
|
||||
x, weight_dequant, self.bias, self.stride,
|
||||
self.padding, output_padding, self.groups, self.dilation)
|
||||
return result
|
||||
|
||||
def _get_name(self):
|
||||
return "QuantizedConvTranspose1d(Reference)"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, float_conv, weight_qparams):
|
||||
return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
|
||||
|
||||
class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, output_padding=0,
|
||||
groups=1, bias=True, dilation=1,
|
||||
padding_mode='zeros',
|
||||
device=None,
|
||||
dtype=None,
|
||||
weight_qparams: Optional[Dict[str, Any]] = None):
|
||||
|
||||
nn.ConvTranspose2d.__init__(
|
||||
self, in_channels, out_channels, kernel_size, stride, padding, output_padding,
|
||||
groups, bias, dilation, padding_mode, device, dtype)
|
||||
self._init_weight_qparams(weight_qparams, device)
|
||||
|
||||
def forward(self, x: torch.Tensor, output_size: Optional[List[int]] = None) -> torch.Tensor:
|
||||
"""
|
||||
we have:
|
||||
w(float) -- quant - dequant \
|
||||
x(float) ------------- F.convTranspose2d ---
|
||||
In the full model, we will see
|
||||
w(float) -- quant - *dequant \
|
||||
x -- quant --- *dequant -- *F.convTranspose2d --- *quant - dequant
|
||||
and the backend should be able to fuse the ops with `*` into a quantized conv2d
|
||||
"""
|
||||
assert isinstance(self.padding, tuple)
|
||||
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
||||
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
||||
|
||||
output_padding = self._output_padding(
|
||||
input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
|
||||
|
||||
weight_dequant = self.get_weight()
|
||||
result = F.conv_transpose2d(
|
||||
x, weight_dequant, self.bias, self.stride,
|
||||
self.padding, output_padding, self.groups, self.dilation)
|
||||
|
||||
return result
|
||||
|
||||
def _get_name(self):
|
||||
return "QuantizedConvTranspose2d(Reference)"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, float_conv, weight_qparams):
|
||||
return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
|
||||
|
||||
class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, output_padding=0,
|
||||
groups=1, bias=True, dilation=1,
|
||||
padding_mode="zeros",
|
||||
device=None,
|
||||
dtype=None,
|
||||
weight_qparams: Optional[Dict[str, Any]] = None):
|
||||
nn.ConvTranspose3d.__init__(
|
||||
self, in_channels, out_channels, kernel_size, stride, padding, output_padding,
|
||||
groups, bias, dilation, padding_mode, device, dtype)
|
||||
self._init_weight_qparams(weight_qparams, device)
|
||||
|
||||
def forward(self, x: torch.Tensor, output_size: Optional[List[int]] = None) -> torch.Tensor:
|
||||
"""
|
||||
we have:
|
||||
w(float) -- quant - dequant \
|
||||
x(float) ------------- F.convTranspose3d ---
|
||||
In the full model, we will see
|
||||
w(float) -- quant - *dequant \
|
||||
x -- quant --- *dequant -- *F.convTranspose3d --- *quant - dequant
|
||||
and the backend should be able to fuse the ops with `*` into a quantized conv3d
|
||||
"""
|
||||
|
||||
assert isinstance(self.padding, tuple)
|
||||
# One cannot replace List by Tuple or Sequence in "_output_padding" because
|
||||
# TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
|
||||
output_padding = self._output_padding(
|
||||
input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
|
||||
|
||||
weight_dequant = self.get_weight()
|
||||
result = F.conv_transpose3d(
|
||||
x, weight_dequant, self.bias, self.stride,
|
||||
self.padding, output_padding, self.groups, self.dilation)
|
||||
return result
|
||||
|
||||
def _get_name(self):
|
||||
return "QuantizedConvTranspose3d(Reference)"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, float_conv, weight_qparams):
|
||||
return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user