Support min/max carry over for eager mode from_float method (#127309)

Summary:
After QAT is completed or given pre-tuned weight observer via tunable PTQ algorithm, it should not over-write again with a given weight, at least for static QAT never.

Dynamic QAT also does not require to re-run weight observer again by design.

This is a fix

Test Plan: Signals

Differential Revision: D57747749

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127309
Approved by: https://github.com/jerryzh168
This commit is contained in:
Kwanghoon An 2024-05-29 19:33:26 +00:00 committed by PyTorch MergeBot
parent 82a370ae3a
commit c404b2968c
29 changed files with 178 additions and 131 deletions

View File

@ -2,62 +2,63 @@
import copy import copy
import math import math
import torch import torch
import torch.nn as nn import torch.ao.nn.intrinsic.qat as nniqat
import torch.backends.mkldnn import torch.ao.nn.qat as nnqat
from torch.nn import Conv2d, BatchNorm2d, ReLU, init import torch.ao.nn.qat.dynamic as nnqatd
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.nn.modules.utils import _pair
import torch.ao.nn.quantized as nnq import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd import torch.ao.nn.quantized.dynamic as nnqd
import torch.ao.nn.qat as nnqat import torch.backends.mkldnn
import torch.ao.nn.intrinsic.qat as nniqat import torch.nn as nn
import torch.ao.nn.qat.dynamic as nnqatd import torch.testing._internal.hypothesis_utils as hu
from hypothesis import given, strategies as st
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.ao.quantization import ( from torch.ao.quantization import (
prepare,
convert, convert,
prepare_qat,
quantize_qat,
QuantStub,
DeQuantStub,
default_qconfig,
default_qat_qconfig,
default_embedding_qat_qconfig, default_embedding_qat_qconfig,
default_qat_qconfig,
default_qconfig,
default_symmetric_qnnpack_qat_qconfig, default_symmetric_qnnpack_qat_qconfig,
get_default_qat_qconfig, DeQuantStub,
FixedQParamsFakeQuantize, FixedQParamsFakeQuantize,
FusedMovingAvgObsFakeQuantize, FusedMovingAvgObsFakeQuantize,
get_default_qat_qconfig,
get_embedding_qat_module_mappings, get_embedding_qat_module_mappings,
get_embedding_static_quant_module_mappings, get_embedding_static_quant_module_mappings,
NoopObserver, NoopObserver,
prepare,
prepare_qat,
quantize_qat,
QuantStub,
) )
from torch.ao.quantization.qconfig import qconfig_equals from torch.ao.quantization.qconfig import qconfig_equals
from torch.nn import BatchNorm2d, Conv2d, init, ReLU
from torch.nn.modules.utils import _pair
from torch.testing._internal.common_quantization import ( from torch.testing._internal.common_quantization import (
DeFusedEmbeddingBagLinear, DeFusedEmbeddingBagLinear,
QuantizationTestCase,
QuantStubModel,
ManualLinearQATModel,
ManualDropoutQATModel,
ManualLinearDynamicQATModel,
ManualConvLinearQATModel, ManualConvLinearQATModel,
ManualConvLinearSymmQATModel, ManualConvLinearSymmQATModel,
ManualDropoutQATModel,
ManualEmbeddingBagLinear, ManualEmbeddingBagLinear,
TwoLayerLinearModel, ManualLinearDynamicQATModel,
ManualLinearQATModel,
QuantizationTestCase,
QuantStubModel,
test_only_eval_fn, test_only_eval_fn,
test_only_train_fn, test_only_train_fn,
TwoLayerLinearModel,
) )
from torch.testing._internal.common_quantized import ( from torch.testing._internal.common_quantized import (
override_qengines,
override_quantized_engine, override_quantized_engine,
supported_qengines, supported_qengines,
override_qengines,
) )
from torch.testing._internal.common_utils import skipIfNoXNNPACK from torch.testing._internal.common_utils import skipIfNoXNNPACK
from hypothesis import given
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled() hu.assert_deadline_disabled()
from functools import reduce from functools import reduce
@ -1099,6 +1100,33 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
self.assertTrue(type(mq[1]) == nnq.Linear) self.assertTrue(type(mq[1]) == nnq.Linear)
self.assertTrue(type(mq[2]) == nn.Identity) self.assertTrue(type(mq[2]) == nn.Identity)
@skipIfNoXNNPACK
@override_qengines
def test_linear_precomputed_fake_quant(self):
qengine = torch.backends.quantized.engine
if qengine != "qnnpack":
return # Only qnnpack support symmetric quantization
m_ref = nn.Linear(4, 4)
m_ref_copy = copy.deepcopy(m_ref)
qconfig = default_qconfig
m_ref_copy.qconfig = qconfig
weight_post_process = copy.deepcopy(qconfig.weight())
activation = copy.deepcopy(qconfig.activation())
activation(torch.randn(4, 4))
m_ref_copy.activation_post_process = activation
m_ref_copy = nnq.Linear.from_float(m_ref_copy)
weight_post_process = qconfig.weight()
weight_post_process.min_val = torch.tensor(-1)
weight_post_process.max_val = torch.tensor(1)
m_ref.weight_post_process = weight_post_process
m_ref.activation_post_process = activation
m_ref.qconfig = qconfig
m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True)
self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale)
if __name__ == '__main__': if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n" "\tpython test/test_quantization.py TESTNAME\n\n"

View File

@ -289,7 +289,7 @@ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
missing_keys, unexpected_keys, error_msgs) missing_keys, unexpected_keys, error_msgs)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module or qparams_dict r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities Args: `mod` a float module, either produced by torch.ao.quantization utilities
@ -453,8 +453,8 @@ class ConvBnReLU1d(ConvBn1d):
return F.relu(ConvBn1d._forward(self, input)) return F.relu(ConvBn1d._forward(self, input))
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant)
class ConvReLU1d(nnqat.Conv1d, nni._FusedModule): class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
@ -490,8 +490,8 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)) self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class ConvBn2d(_ConvBnNd, nn.Conv2d): class ConvBn2d(_ConvBnNd, nn.Conv2d):
r""" r"""
@ -585,8 +585,8 @@ class ConvBnReLU2d(ConvBn2d):
return F.relu(ConvBn2d._forward(self, input)) return F.relu(ConvBn2d._forward(self, input))
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant)
class ConvReLU2d(nnqat.Conv2d, nni._FusedModule): class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
@ -622,8 +622,8 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)) self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class ConvBn3d(_ConvBnNd, nn.Conv3d): class ConvBn3d(_ConvBnNd, nn.Conv3d):
r""" r"""
@ -758,8 +758,8 @@ class ConvBnReLU3d(ConvBn3d):
return F.relu(ConvBn3d._forward(self, input)) return F.relu(ConvBn3d._forward(self, input))
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class ConvReLU3d(nnqat.Conv3d, nni._FusedModule): class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
@ -813,8 +813,8 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
) )
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
def update_bn_stats(mod): def update_bn_stats(mod):
if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}: if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}:

View File

@ -133,7 +133,7 @@ class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
return self return self
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module or qparams_dict r"""Create a qat module from a float module or qparams_dict
Args: `mod' a float module, either produced by torch.ao.quantization Args: `mod' a float module, either produced by torch.ao.quantization

View File

@ -36,8 +36,8 @@ class LinearReLU(nnqat.Linear, nni._FusedModule):
return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias)) return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant)
def to_float(self): def to_float(self):
linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None) linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None)

View File

@ -47,8 +47,8 @@ class LinearReLU(nnqd.Linear):
return 'DynamicQuantizedLinearReLU' return 'DynamicQuantizedLinearReLU'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod @classmethod
def from_reference(cls, ref_qlinear_relu): def from_reference(cls, ref_qlinear_relu):

View File

@ -37,9 +37,9 @@ class BNReLU2d(nnq.BatchNorm2d):
return 'QuantizedBNReLU2d' return 'QuantizedBNReLU2d'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
# TODO: Add qat support for BNReLU2d # TODO: Add qat support for BNReLU2d
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod @classmethod
def from_reference(cls, bn_relu, output_scale, output_zero_point): def from_reference(cls, bn_relu, output_scale, output_zero_point):
@ -73,9 +73,9 @@ class BNReLU3d(nnq.BatchNorm3d):
return 'QuantizedBNReLU3d' return 'QuantizedBNReLU3d'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
# TODO: Add qat support for BNReLU3d # TODO: Add qat support for BNReLU3d
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod @classmethod
def from_reference(cls, bn_relu, output_scale, output_zero_point): def from_reference(cls, bn_relu, output_scale, output_zero_point):

View File

@ -42,8 +42,8 @@ class ConvAdd2d(nnq.Conv2d):
return 'QuantizedConvAdd2d' return 'QuantizedConvAdd2d'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod @classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point): def from_reference(cls, ref_qconv, output_scale, output_zero_point):
@ -85,8 +85,8 @@ class ConvAddReLU2d(nnq.Conv2d):
return 'QuantizedConvAddReLU2d' return 'QuantizedConvAddReLU2d'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod @classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point): def from_reference(cls, ref_qconv, output_scale, output_zero_point):

View File

@ -53,13 +53,13 @@ class ConvReLU1d(nnq.Conv1d):
return 'QuantizedConvReLU1d' return 'QuantizedConvReLU1d'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d: if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
mod.bn.eps, mod.bn.weight, mod.bn.bias) mod.bn.eps, mod.bn.weight, mod.bn.bias)
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant)
@classmethod @classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point): def from_reference(cls, ref_qconv, output_scale, output_zero_point):
@ -103,13 +103,13 @@ class ConvReLU2d(nnq.Conv2d):
return 'QuantizedConvReLU2d' return 'QuantizedConvReLU2d'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d: if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
mod.bn.eps, mod.bn.weight, mod.bn.bias) mod.bn.eps, mod.bn.weight, mod.bn.bias)
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod @classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point): def from_reference(cls, ref_qconv, output_scale, output_zero_point):
@ -154,7 +154,7 @@ class ConvReLU3d(nnq.Conv3d):
return 'QuantizedConvReLU3d' return 'QuantizedConvReLU3d'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d: if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights( mod.weight, mod.bias = fuse_conv_bn_weights(
@ -166,7 +166,7 @@ class ConvReLU3d(nnq.Conv3d):
mod.bn.weight, mod.bn.weight,
mod.bn.bias, mod.bn.bias,
) )
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod @classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point): def from_reference(cls, ref_qconv, output_scale, output_zero_point):

View File

@ -40,8 +40,8 @@ class LinearReLU(nnq.Linear):
return 'QuantizedLinearReLU' return 'QuantizedLinearReLU'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant)
@classmethod @classmethod
def from_reference(cls, ref_linear_relu, output_scale, output_zero_point): def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
@ -77,7 +77,7 @@ class LinearLeakyReLU(nnq.Linear):
return 'QuantizedLinearLeakyReLU' return 'QuantizedLinearLeakyReLU'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) == nni.LinearLeakyReLU, 'Input float module should be LinearLeakyReLU' assert type(mod) == nni.LinearLeakyReLU, 'Input float module should be LinearLeakyReLU'
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
activation_post_process = mod.activation_post_process activation_post_process = mod.activation_post_process
@ -144,7 +144,7 @@ class LinearTanh(nnq.Linear):
return 'QuantizedLinearTanh' return 'QuantizedLinearTanh'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) == nni.LinearTanh, 'Input float module should be LinearTanh' assert type(mod) == nni.LinearTanh, 'Input float module should be LinearTanh'
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
activation_post_process = mod.activation_post_process activation_post_process = mod.activation_post_process

View File

@ -44,7 +44,7 @@ class _ConvNd(nn.modules.conv._ConvNd):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@staticmethod @staticmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module r"""Create a qat module from a float module
Args: Args:
@ -150,8 +150,8 @@ class Conv1d(_ConvNd, nn.Conv1d):
dtype=dtype) dtype=dtype)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(cls, mod) return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class Conv2d(_ConvNd, nn.Conv2d): class Conv2d(_ConvNd, nn.Conv2d):
r""" r"""
@ -208,8 +208,8 @@ class Conv2d(_ConvNd, nn.Conv2d):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(cls, mod) return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class Conv3d(_ConvNd, nn.Conv3d): class Conv3d(_ConvNd, nn.Conv3d):
r""" r"""
@ -266,5 +266,5 @@ class Conv3d(_ConvNd, nn.Conv3d):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias) return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(cls, mod) return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)

View File

@ -42,7 +42,7 @@ class Embedding(nn.Embedding):
self.sparse) self.sparse)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module r"""Create a qat module from a float module
Args: `mod` a float module, either produced by torch.ao.quantization utilities Args: `mod` a float module, either produced by torch.ao.quantization utilities
@ -112,7 +112,7 @@ class EmbeddingBag(nn.EmbeddingBag):
self.padding_idx) self.padding_idx)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module r"""Create a qat module from a float module
Args: `mod` a float module, either produced by torch.ao.quantization utilities Args: `mod` a float module, either produced by torch.ao.quantization utilities

View File

@ -41,7 +41,7 @@ class Linear(nn.Linear):
return F.linear(input, self.weight_fake_quant(self.weight), self.bias) return F.linear(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module or qparams_dict r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user or directly from user

View File

@ -122,7 +122,7 @@ class LSTMCell(torch.nn.Module):
return cell return cell
@classmethod @classmethod
def from_float(cls, other): def from_float(cls, other, use_precomputed_fake_quant=False):
assert type(other) == cls._FLOAT_MODULE assert type(other) == cls._FLOAT_MODULE
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'" assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
observed = cls.from_params(other.weight_ih, other.weight_hh, observed = cls.from_params(other.weight_ih, other.weight_hh,

View File

@ -77,7 +77,7 @@ class Linear(nnq.Linear):
missing_keys, unexpected_keys, error_msgs) missing_keys, unexpected_keys, error_msgs)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a dynamic quantized module from a float module or qparams_dict r"""Create a dynamic quantized module from a float module or qparams_dict
Args: Args:

View File

@ -268,7 +268,7 @@ class RNNBase(torch.nn.Module):
self._all_weight_values = torch.nn.ModuleList(_all_weight_values) self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) in {torch.nn.LSTM, assert type(mod) in {torch.nn.LSTM,
torch.nn.GRU}, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU' torch.nn.GRU}, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU'
assert hasattr( assert hasattr(
@ -495,8 +495,8 @@ class LSTM(RNNBase):
return self.forward_tensor(input, hx) return self.forward_tensor(input, hx)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod @classmethod
def from_reference(cls, ref_mod): def from_reference(cls, ref_mod):
@ -747,8 +747,8 @@ class GRU(RNNBase):
return self.forward_tensor(input, hx) return self.forward_tensor(input, hx)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod @classmethod
def from_reference(cls, ref_mod): def from_reference(cls, ref_mod):
@ -839,7 +839,7 @@ class RNNCellBase(torch.nn.Module):
f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}") f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}")
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) in {torch.nn.LSTMCell, assert type(mod) in {torch.nn.LSTMCell,
torch.nn.GRUCell, torch.nn.GRUCell,
torch.nn.RNNCell}, 'nn.quantized.dynamic.RNNCellBase.from_float \ torch.nn.RNNCell}, 'nn.quantized.dynamic.RNNCellBase.from_float \
@ -1012,8 +1012,8 @@ class RNNCell(RNNCellBase):
return ret return ret
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class LSTMCell(RNNCellBase): class LSTMCell(RNNCellBase):
@ -1055,8 +1055,8 @@ class LSTMCell(RNNCellBase):
self.bias_ih, self.bias_hh) self.bias_ih, self.bias_hh)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class GRUCell(RNNCellBase): class GRUCell(RNNCellBase):
@ -1096,5 +1096,5 @@ class GRUCell(RNNCellBase):
) )
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod) return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)

View File

@ -98,7 +98,7 @@ class Quantize(torch.nn.Module):
int(self.zero_point), self.dtype) int(self.zero_point), self.dtype)
@staticmethod @staticmethod
def from_float(mod): def from_float(mod, use_precomputed_fake_quant=False):
assert hasattr(mod, 'activation_post_process') assert hasattr(mod, 'activation_post_process')
scale, zero_point = mod.activation_post_process.calculate_qparams() scale, zero_point = mod.activation_post_process.calculate_qparams()
return Quantize(scale.float().item(), zero_point.long().item(), mod.activation_post_process.dtype) return Quantize(scale.float().item(), zero_point.long().item(), mod.activation_post_process.dtype)
@ -127,5 +127,5 @@ class DeQuantize(torch.nn.Module):
return Xq.dequantize() return Xq.dequantize()
@staticmethod @staticmethod
def from_float(mod): def from_float(mod, use_precomputed_fake_quant=False):
return DeQuantize() return DeQuantize()

View File

@ -46,7 +46,7 @@ class ReLU6(torch.nn.ReLU):
return 'QuantizedReLU6' return 'QuantizedReLU6'
@staticmethod @staticmethod
def from_float(mod): def from_float(mod, use_precomputed_fake_quant=False):
return ReLU6(mod.inplace) return ReLU6(mod.inplace)
class Hardswish(torch.nn.Hardswish): class Hardswish(torch.nn.Hardswish):
@ -69,7 +69,7 @@ class Hardswish(torch.nn.Hardswish):
return 'QuantizedHardswish' return 'QuantizedHardswish'
@staticmethod @staticmethod
def from_float(mod): def from_float(mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams() scale, zero_point = mod.activation_post_process.calculate_qparams()
return Hardswish(float(scale), int(zero_point)) return Hardswish(float(scale), int(zero_point))
@ -98,7 +98,7 @@ class ELU(torch.nn.ELU):
return 'QuantizedELU' return 'QuantizedELU'
@staticmethod @staticmethod
def from_float(mod): def from_float(mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams() scale, zero_point = mod.activation_post_process.calculate_qparams()
return ELU(float(scale), int(zero_point), mod.alpha) return ELU(float(scale), int(zero_point), mod.alpha)
@ -129,7 +129,7 @@ class LeakyReLU(torch.nn.LeakyReLU):
return 'QuantizedLeakyReLU' return 'QuantizedLeakyReLU'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams() scale, zero_point = mod.activation_post_process.calculate_qparams()
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace) return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
@ -154,7 +154,7 @@ class Sigmoid(torch.nn.Sigmoid):
return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point) return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
output_scale, output_zero_point = mod.activation_post_process.calculate_qparams() output_scale, output_zero_point = mod.activation_post_process.calculate_qparams()
return cls(float(output_scale), int(output_zero_point)) return cls(float(output_scale), int(output_zero_point))
@ -187,7 +187,7 @@ class Softmax(torch.nn.Softmax):
return 'QuantizedSoftmax' return 'QuantizedSoftmax'
@staticmethod @staticmethod
def from_float(mod): def from_float(mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams() scale, zero_point = mod.activation_post_process.calculate_qparams()
return Softmax(mod.dim, float(scale), int(zero_point)) return Softmax(mod.dim, float(scale), int(zero_point))
@ -269,7 +269,7 @@ class PReLU(torch.nn.Module):
return 'QuantizedPReLU' return 'QuantizedPReLU'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams() scale, zero_point = mod.activation_post_process.calculate_qparams()
qprelu = cls(float(scale), int(zero_point), mod.num_parameters) qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
float_wt = mod.weight.float() float_wt = mod.weight.float()

View File

@ -14,7 +14,7 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
self.register_buffer('zero_point', torch.tensor(0, **factory_kwargs)) self.register_buffer('zero_point', torch.tensor(0, **factory_kwargs))
@staticmethod @staticmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
activation_post_process = mod.activation_post_process activation_post_process = mod.activation_post_process
if type(mod) == cls._NNI_BN_RELU_MODULE: if type(mod) == cls._NNI_BN_RELU_MODULE:
mod = mod[0] mod = mod[0]
@ -72,8 +72,8 @@ class BatchNorm2d(_BatchNorm):
self.running_var, self.eps, self.scale, self.zero_point) self.running_var, self.eps, self.scale, self.zero_point)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return _BatchNorm.from_float(cls, mod) return _BatchNorm.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class BatchNorm3d(_BatchNorm): class BatchNorm3d(_BatchNorm):
r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`. r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.
@ -102,5 +102,5 @@ class BatchNorm3d(_BatchNorm):
self.running_var, self.eps, self.scale, self.zero_point) self.running_var, self.eps, self.scale, self.zero_point)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return _BatchNorm.from_float(cls, mod) return _BatchNorm.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)

View File

@ -215,7 +215,7 @@ class _ConvNd(WeightedQuantizedModule):
return qconv return qconv
@staticmethod @staticmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
if hasattr(mod, "weight_fake_quant"): if hasattr(mod, "weight_fake_quant"):
# assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \ # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
# ".from_float only works for " + cls.__QAT_MODULE.__name__ # ".from_float only works for " + cls.__QAT_MODULE.__name__
@ -368,14 +368,14 @@ class Conv1d(_ConvNd):
return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point) return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Creates a quantized module from a float module or qparams_dict. r"""Creates a quantized module from a float module or qparams_dict.
Args: Args:
mod (Module): a float module, either produced by torch.ao.quantization mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user utilities or provided by the user
""" """
return _ConvNd.from_float(cls, mod) return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class Conv2d(_ConvNd): class Conv2d(_ConvNd):
@ -469,14 +469,14 @@ class Conv2d(_ConvNd):
input, self._packed_params, self.scale, self.zero_point) input, self._packed_params, self.scale, self.zero_point)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Creates a quantized module from a float module or qparams_dict. r"""Creates a quantized module from a float module or qparams_dict.
Args: Args:
mod (Module): a float module, either produced by torch.ao.quantization mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user utilities or provided by the user
""" """
return _ConvNd.from_float(cls, mod) return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class Conv3d(_ConvNd): class Conv3d(_ConvNd):
@ -571,14 +571,14 @@ class Conv3d(_ConvNd):
input, self._packed_params, self.scale, self.zero_point) input, self._packed_params, self.scale, self.zero_point)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Creates a quantized module from a float module or qparams_dict. r"""Creates a quantized module from a float module or qparams_dict.
Args: Args:
mod (Module): a float module, either produced by torch.ao.quantization mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user utilities or provided by the user
""" """
return _ConvNd.from_float(cls, mod) return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
# === Transposed Convolutions === # === Transposed Convolutions ===
MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd) MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
@ -609,7 +609,7 @@ class _ConvTransposeNd(_ConvNd):
return res return res
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Creates a quantized module from a float module or qparams_dict. r"""Creates a quantized module from a float module or qparams_dict.
Args: Args:
mod (Module): a float module, either produced by torch.ao.quantization mod (Module): a float module, either produced by torch.ao.quantization

View File

@ -19,7 +19,7 @@ class Dropout(torch.nn.Dropout):
return 'QuantizedDropout' return 'QuantizedDropout'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
return cls(mod.p, mod.inplace) return cls(mod.p, mod.inplace)
@classmethod @classmethod

View File

@ -137,7 +137,7 @@ class Embedding(torch.nn.Module):
return self._packed_params._weight() return self._packed_params._weight()
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized embedding module from a float module r"""Create a quantized embedding module from a float module
Args: Args:
@ -241,7 +241,7 @@ class EmbeddingBag(Embedding):
return 'QuantizedEmbeddingBag' return 'QuantizedEmbeddingBag'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized embedding_bag module from a float module r"""Create a quantized embedding_bag module from a float module
Args: Args:

View File

@ -239,7 +239,7 @@ class QFunctional(torch.nn.Module):
return r return r
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) == FloatFunctional, \ assert type(mod) == FloatFunctional, \
"QFunctional.from_float expects an instance of FloatFunctional" "QFunctional.from_float expects an instance of FloatFunctional"
scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator] scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator]

View File

@ -240,12 +240,14 @@ class Linear(WeightedQuantizedModule):
self._packed_params.set_weight_bias(w, b) self._packed_params.set_weight_bias(w, b)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized module from an observed float module r"""Create a quantized module from an observed float module
Args: Args:
mod (Module): a float module, either produced by torch.ao.quantization mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user utilities or provided by the user
use_precomputed_fake_quant (bool): if True, the module will reuse min/max
values from the precomputed fake quant module.
""" """
if hasattr(mod, 'weight_fake_quant'): if hasattr(mod, 'weight_fake_quant'):
if type_before_parametrizations(mod) == nniqat.LinearBn1d: if type_before_parametrizations(mod) == nniqat.LinearBn1d:
@ -267,8 +269,12 @@ class Linear(WeightedQuantizedModule):
activation_post_process = mod.activation_post_process activation_post_process = mod.activation_post_process
if type_before_parametrizations(mod) == nni.LinearReLU: if type_before_parametrizations(mod) == nni.LinearReLU:
mod = mod[0] mod = mod[0]
weight_post_process = mod.qconfig.weight() weight_post_process = mod.qconfig.weight() if not hasattr(mod, "weight_fake_quant") else mod.weight_fake_quant
weight_post_process(mod.weight)
if not use_precomputed_fake_quant:
# Observer may not have been called yet
# Observer might have been called in the previous stage via PTQ algorithm e.g. AdaRound
weight_post_process(mod.weight)
dtype = weight_post_process.dtype dtype = weight_post_process.dtype
act_scale, act_zp = activation_post_process.calculate_qparams() act_scale, act_zp = activation_post_process.calculate_qparams()
assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'

View File

@ -30,7 +30,7 @@ class LayerNorm(torch.nn.LayerNorm):
return 'QuantizedLayerNorm' return 'QuantizedLayerNorm'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams() scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls( new_mod = cls(
mod.normalized_shape, mod.weight, mod.bias, float(scale), mod.normalized_shape, mod.weight, mod.bias, float(scale),
@ -71,7 +71,7 @@ class GroupNorm(torch.nn.GroupNorm):
return 'QuantizedGroupNorm' return 'QuantizedGroupNorm'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams() scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls( new_mod = cls(
mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point), mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point),
@ -105,7 +105,7 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d):
return 'QuantizedInstanceNorm1d' return 'QuantizedInstanceNorm1d'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams() scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls( new_mod = cls(
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
@ -145,7 +145,7 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d):
return 'QuantizedInstanceNorm2d' return 'QuantizedInstanceNorm2d'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams() scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls( new_mod = cls(
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
@ -185,7 +185,7 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d):
return 'QuantizedInstanceNorm3d' return 'QuantizedInstanceNorm3d'
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams() scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls( new_mod = cls(
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point), mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),

View File

@ -213,7 +213,7 @@ class LSTMCell(RNNCellBase):
return ret return ret
@classmethod @classmethod
def from_float(cls, mod, weight_qparams_dict): def from_float(cls, mod, weight_qparams_dict, use_precomputed_fake_quant=False):
ref_mod = cls( ref_mod = cls(
mod.input_size, mod.input_size,
mod.hidden_size, mod.hidden_size,

View File

@ -76,7 +76,7 @@ class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule):
self.padding_idx) self.padding_idx)
@classmethod @classmethod
def from_float(cls, mod, weight_qparams): def from_float(cls, mod, weight_qparams, use_precomputed_fake_quant=False):
return cls( return cls(
mod.num_embeddings, mod.num_embeddings,
mod.embedding_dim, mod.embedding_dim,

View File

@ -92,7 +92,7 @@ class Linear(torch.nn.Module):
self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized sparse dynamic module from a float module. r"""Create a quantized sparse dynamic module from a float module.
We only care about the convert at this stage, no need for observers just yet. We only care about the convert at this stage, no need for observers just yet.

View File

@ -146,7 +146,7 @@ class Linear(torch.nn.Module):
self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size) self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
@classmethod @classmethod
def from_float(cls, mod): def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized sparse module from a float module. r"""Create a quantized sparse module from a float module.
We only care about the convert at this stage, no need for observers just yet. We only care about the convert at this stage, no need for observers just yet.

View File

@ -235,6 +235,13 @@ def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=N
if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \ if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \
and type_before_parametrizations(module) in qconfig_propagation_list: and type_before_parametrizations(module) in qconfig_propagation_list:
insert_activation_post_process(module) insert_activation_post_process(module)
# This is a special case for AdaRound eager mode
# AdaRound contains weight_fake_quant to be propagated from API to convert
# leaf node check with a number of children looks naive assumption that blocks
# Adding an exception case for AdaRound
if hasattr(module, "weight_fake_quant") and not isinstance(module, torch.nn.Sequential) \
and type_before_parametrizations(module) in qconfig_propagation_list:
insert_activation_post_process(module)
def _get_unique_devices_(module): def _get_unique_devices_(module):
return {p.device for p in module.parameters()} | \ return {p.device for p in module.parameters()} | \
@ -520,7 +527,8 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
def convert( def convert(
module, mapping=None, inplace=False, remove_qconfig=True, module, mapping=None, inplace=False, remove_qconfig=True,
is_reference=False, convert_custom_config_dict=None): is_reference=False, convert_custom_config_dict=None,
use_precomputed_fake_quant=False):
r"""Converts submodules in input module to a different module according to `mapping` r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class. And remove qconfig at the by calling `from_float` method on the target module class. And remove qconfig at the
end if remove_qconfig is set to True. end if remove_qconfig is set to True.
@ -533,6 +541,7 @@ def convert(
`inplace`: carry out model transformations in-place, the original module `inplace`: carry out model transformations in-place, the original module
is mutated is mutated
`convert_custom_config_dict`: custom configuration dictionary for convert function `convert_custom_config_dict`: custom configuration dictionary for convert function
`use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant
.. code-block:: python .. code-block:: python
@ -552,14 +561,16 @@ def convert(
module = copy.deepcopy(module) module = copy.deepcopy(module)
_convert( _convert(
module, mapping, inplace=True, is_reference=is_reference, module, mapping, inplace=True, is_reference=is_reference,
convert_custom_config_dict=convert_custom_config_dict) convert_custom_config_dict=convert_custom_config_dict,
use_precomputed_fake_quant=use_precomputed_fake_quant)
if remove_qconfig: if remove_qconfig:
_remove_qconfig(module) _remove_qconfig(module)
return module return module
def _convert( def _convert(
module, mapping=None, inplace=False, module, mapping=None, inplace=False,
is_reference=False, convert_custom_config_dict=None): is_reference=False, convert_custom_config_dict=None,
use_precomputed_fake_quant=False):
r"""Converts submodules in input module to a different module according to `mapping` r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class by calling `from_float` method on the target module class
@ -571,6 +582,7 @@ def _convert(
inplace: carry out model transformations in-place, the original module inplace: carry out model transformations in-place, the original module
is mutated is mutated
is_reference: a flag to enable quantized reference module is_reference: a flag to enable quantized reference module
use_precomputed_fake_quant: a flag to enable use of precomputed fake quant
""" """
if mapping is None: if mapping is None:
@ -589,15 +601,16 @@ def _convert(
if not isinstance(mod, _FusedModule) and \ if not isinstance(mod, _FusedModule) and \
type_before_parametrizations(mod) not in custom_module_class_mapping: type_before_parametrizations(mod) not in custom_module_class_mapping:
_convert(mod, mapping, True, # inplace _convert(mod, mapping, True, # inplace
is_reference, convert_custom_config_dict) is_reference, convert_custom_config_dict,
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping) use_precomputed_fake_quant=use_precomputed_fake_quant)
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant)
for key, value in reassign.items(): for key, value in reassign.items():
module._modules[key] = value module._modules[key] = value
return module return module
def swap_module(mod, mapping, custom_module_class_mapping): def swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False):
r"""Swaps the module if it has a quantized counterpart and it has an r"""Swaps the module if it has a quantized counterpart and it has an
`observer` attached. `observer` attached.
@ -623,7 +636,7 @@ def swap_module(mod, mapping, custom_module_class_mapping):
weight_qparams = get_qparam_dict(weight_post_process) weight_qparams = get_qparam_dict(weight_post_process)
new_mod = qmod.from_float(mod, weight_qparams) new_mod = qmod.from_float(mod, weight_qparams)
else: else:
new_mod = qmod.from_float(mod) new_mod = qmod.from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
swapped = True swapped = True
if swapped: if swapped: