Support min/max carry over for eager mode from_float method (#127309)

Summary:
After QAT is completed or given pre-tuned weight observer via tunable PTQ algorithm, it should not over-write again with a given weight, at least for static QAT never.

Dynamic QAT also does not require to re-run weight observer again by design.

This is a fix

Test Plan: Signals

Differential Revision: D57747749

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127309
Approved by: https://github.com/jerryzh168
This commit is contained in:
Kwanghoon An 2024-05-29 19:33:26 +00:00 committed by PyTorch MergeBot
parent 82a370ae3a
commit c404b2968c
29 changed files with 178 additions and 131 deletions

View File

@ -2,62 +2,63 @@
import copy
import math
import torch
import torch.nn as nn
import torch.backends.mkldnn
from torch.nn import Conv2d, BatchNorm2d, ReLU, init
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.nn.modules.utils import _pair
import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.qat as nnqat
import torch.ao.nn.qat.dynamic as nnqatd
import torch.ao.nn.quantized as nnq
import torch.ao.nn.quantized.dynamic as nnqd
import torch.ao.nn.qat as nnqat
import torch.ao.nn.intrinsic.qat as nniqat
import torch.ao.nn.qat.dynamic as nnqatd
import torch.backends.mkldnn
import torch.nn as nn
import torch.testing._internal.hypothesis_utils as hu
from hypothesis import given, strategies as st
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.ao.quantization import (
prepare,
convert,
prepare_qat,
quantize_qat,
QuantStub,
DeQuantStub,
default_qconfig,
default_qat_qconfig,
default_embedding_qat_qconfig,
default_qat_qconfig,
default_qconfig,
default_symmetric_qnnpack_qat_qconfig,
get_default_qat_qconfig,
DeQuantStub,
FixedQParamsFakeQuantize,
FusedMovingAvgObsFakeQuantize,
get_default_qat_qconfig,
get_embedding_qat_module_mappings,
get_embedding_static_quant_module_mappings,
NoopObserver,
prepare,
prepare_qat,
quantize_qat,
QuantStub,
)
from torch.ao.quantization.qconfig import qconfig_equals
from torch.nn import BatchNorm2d, Conv2d, init, ReLU
from torch.nn.modules.utils import _pair
from torch.testing._internal.common_quantization import (
DeFusedEmbeddingBagLinear,
QuantizationTestCase,
QuantStubModel,
ManualLinearQATModel,
ManualDropoutQATModel,
ManualLinearDynamicQATModel,
ManualConvLinearQATModel,
ManualConvLinearSymmQATModel,
ManualDropoutQATModel,
ManualEmbeddingBagLinear,
TwoLayerLinearModel,
ManualLinearDynamicQATModel,
ManualLinearQATModel,
QuantizationTestCase,
QuantStubModel,
test_only_eval_fn,
test_only_train_fn,
TwoLayerLinearModel,
)
from torch.testing._internal.common_quantized import (
override_qengines,
override_quantized_engine,
supported_qengines,
override_qengines,
)
from torch.testing._internal.common_utils import skipIfNoXNNPACK
from hypothesis import given
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()
from functools import reduce
@ -1099,6 +1100,33 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
self.assertTrue(type(mq[1]) == nnq.Linear)
self.assertTrue(type(mq[2]) == nn.Identity)
@skipIfNoXNNPACK
@override_qengines
def test_linear_precomputed_fake_quant(self):
qengine = torch.backends.quantized.engine
if qengine != "qnnpack":
return # Only qnnpack support symmetric quantization
m_ref = nn.Linear(4, 4)
m_ref_copy = copy.deepcopy(m_ref)
qconfig = default_qconfig
m_ref_copy.qconfig = qconfig
weight_post_process = copy.deepcopy(qconfig.weight())
activation = copy.deepcopy(qconfig.activation())
activation(torch.randn(4, 4))
m_ref_copy.activation_post_process = activation
m_ref_copy = nnq.Linear.from_float(m_ref_copy)
weight_post_process = qconfig.weight()
weight_post_process.min_val = torch.tensor(-1)
weight_post_process.max_val = torch.tensor(1)
m_ref.weight_post_process = weight_post_process
m_ref.activation_post_process = activation
m_ref.qconfig = qconfig
m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True)
self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"

View File

@ -289,7 +289,7 @@ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
missing_keys, unexpected_keys, error_msgs)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities
@ -453,8 +453,8 @@ class ConvBnReLU1d(ConvBn1d):
return F.relu(ConvBn1d._forward(self, input))
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant)
class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
@ -490,8 +490,8 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class ConvBn2d(_ConvBnNd, nn.Conv2d):
r"""
@ -585,8 +585,8 @@ class ConvBnReLU2d(ConvBn2d):
return F.relu(ConvBn2d._forward(self, input))
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant)
class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
@ -622,8 +622,8 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class ConvBn3d(_ConvBnNd, nn.Conv3d):
r"""
@ -758,8 +758,8 @@ class ConvBnReLU3d(ConvBn3d):
return F.relu(ConvBn3d._forward(self, input))
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
@ -813,8 +813,8 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
)
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
def update_bn_stats(mod):
if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}:

View File

@ -133,7 +133,7 @@ class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
return self
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module or qparams_dict
Args: `mod' a float module, either produced by torch.ao.quantization

View File

@ -36,8 +36,8 @@ class LinearReLU(nnqat.Linear, nni._FusedModule):
return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant)
def to_float(self):
linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None)

View File

@ -47,8 +47,8 @@ class LinearReLU(nnqd.Linear):
return 'DynamicQuantizedLinearReLU'
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_qlinear_relu):

View File

@ -37,9 +37,9 @@ class BNReLU2d(nnq.BatchNorm2d):
return 'QuantizedBNReLU2d'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
# TODO: Add qat support for BNReLU2d
return super().from_float(mod)
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod
def from_reference(cls, bn_relu, output_scale, output_zero_point):
@ -73,9 +73,9 @@ class BNReLU3d(nnq.BatchNorm3d):
return 'QuantizedBNReLU3d'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
# TODO: Add qat support for BNReLU3d
return super().from_float(mod)
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod
def from_reference(cls, bn_relu, output_scale, output_zero_point):

View File

@ -42,8 +42,8 @@ class ConvAdd2d(nnq.Conv2d):
return 'QuantizedConvAdd2d'
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
@ -85,8 +85,8 @@ class ConvAddReLU2d(nnq.Conv2d):
return 'QuantizedConvAddReLU2d'
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):

View File

@ -53,13 +53,13 @@ class ConvReLU1d(nnq.Conv1d):
return 'QuantizedConvReLU1d'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
mod.bn.eps, mod.bn.weight, mod.bn.bias)
return super().from_float(mod)
return super().from_float(mod, use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
@ -103,13 +103,13 @@ class ConvReLU2d(nnq.Conv2d):
return 'QuantizedConvReLU2d'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
mod.bn.eps, mod.bn.weight, mod.bn.bias)
return super().from_float(mod)
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
@ -154,7 +154,7 @@ class ConvReLU3d(nnq.Conv3d):
return 'QuantizedConvReLU3d'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
mod.weight, mod.bias = fuse_conv_bn_weights(
@ -166,7 +166,7 @@ class ConvReLU3d(nnq.Conv3d):
mod.bn.weight,
mod.bn.bias,
)
return super().from_float(mod)
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_qconv, output_scale, output_zero_point):

View File

@ -40,8 +40,8 @@ class LinearReLU(nnq.Linear):
return 'QuantizedLinearReLU'
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
@ -77,7 +77,7 @@ class LinearLeakyReLU(nnq.Linear):
return 'QuantizedLinearLeakyReLU'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) == nni.LinearLeakyReLU, 'Input float module should be LinearLeakyReLU'
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
activation_post_process = mod.activation_post_process
@ -144,7 +144,7 @@ class LinearTanh(nnq.Linear):
return 'QuantizedLinearTanh'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) == nni.LinearTanh, 'Input float module should be LinearTanh'
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
activation_post_process = mod.activation_post_process

View File

@ -44,7 +44,7 @@ class _ConvNd(nn.modules.conv._ConvNd):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@staticmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module
Args:
@ -150,8 +150,8 @@ class Conv1d(_ConvNd, nn.Conv1d):
dtype=dtype)
@classmethod
def from_float(cls, mod):
return super().from_float(cls, mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class Conv2d(_ConvNd, nn.Conv2d):
r"""
@ -208,8 +208,8 @@ class Conv2d(_ConvNd, nn.Conv2d):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod
def from_float(cls, mod):
return super().from_float(cls, mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class Conv3d(_ConvNd, nn.Conv3d):
r"""
@ -266,5 +266,5 @@ class Conv3d(_ConvNd, nn.Conv3d):
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod
def from_float(cls, mod):
return super().from_float(cls, mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)

View File

@ -42,7 +42,7 @@ class Embedding(nn.Embedding):
self.sparse)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module
Args: `mod` a float module, either produced by torch.ao.quantization utilities
@ -112,7 +112,7 @@ class EmbeddingBag(nn.EmbeddingBag):
self.padding_idx)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module
Args: `mod` a float module, either produced by torch.ao.quantization utilities

View File

@ -41,7 +41,7 @@ class Linear(nn.Linear):
return F.linear(input, self.weight_fake_quant(self.weight), self.bias)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user

View File

@ -122,7 +122,7 @@ class LSTMCell(torch.nn.Module):
return cell
@classmethod
def from_float(cls, other):
def from_float(cls, other, use_precomputed_fake_quant=False):
assert type(other) == cls._FLOAT_MODULE
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
observed = cls.from_params(other.weight_ih, other.weight_hh,

View File

@ -77,7 +77,7 @@ class Linear(nnq.Linear):
missing_keys, unexpected_keys, error_msgs)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a dynamic quantized module from a float module or qparams_dict
Args:

View File

@ -268,7 +268,7 @@ class RNNBase(torch.nn.Module):
self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) in {torch.nn.LSTM,
torch.nn.GRU}, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU'
assert hasattr(
@ -495,8 +495,8 @@ class LSTM(RNNBase):
return self.forward_tensor(input, hx)
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_mod):
@ -747,8 +747,8 @@ class GRU(RNNBase):
return self.forward_tensor(input, hx)
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
@classmethod
def from_reference(cls, ref_mod):
@ -839,7 +839,7 @@ class RNNCellBase(torch.nn.Module):
f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}")
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) in {torch.nn.LSTMCell,
torch.nn.GRUCell,
torch.nn.RNNCell}, 'nn.quantized.dynamic.RNNCellBase.from_float \
@ -1012,8 +1012,8 @@ class RNNCell(RNNCellBase):
return ret
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class LSTMCell(RNNCellBase):
@ -1055,8 +1055,8 @@ class LSTMCell(RNNCellBase):
self.bias_ih, self.bias_hh)
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class GRUCell(RNNCellBase):
@ -1096,5 +1096,5 @@ class GRUCell(RNNCellBase):
)
@classmethod
def from_float(cls, mod):
return super().from_float(mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)

View File

@ -98,7 +98,7 @@ class Quantize(torch.nn.Module):
int(self.zero_point), self.dtype)
@staticmethod
def from_float(mod):
def from_float(mod, use_precomputed_fake_quant=False):
assert hasattr(mod, 'activation_post_process')
scale, zero_point = mod.activation_post_process.calculate_qparams()
return Quantize(scale.float().item(), zero_point.long().item(), mod.activation_post_process.dtype)
@ -127,5 +127,5 @@ class DeQuantize(torch.nn.Module):
return Xq.dequantize()
@staticmethod
def from_float(mod):
def from_float(mod, use_precomputed_fake_quant=False):
return DeQuantize()

View File

@ -46,7 +46,7 @@ class ReLU6(torch.nn.ReLU):
return 'QuantizedReLU6'
@staticmethod
def from_float(mod):
def from_float(mod, use_precomputed_fake_quant=False):
return ReLU6(mod.inplace)
class Hardswish(torch.nn.Hardswish):
@ -69,7 +69,7 @@ class Hardswish(torch.nn.Hardswish):
return 'QuantizedHardswish'
@staticmethod
def from_float(mod):
def from_float(mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return Hardswish(float(scale), int(zero_point))
@ -98,7 +98,7 @@ class ELU(torch.nn.ELU):
return 'QuantizedELU'
@staticmethod
def from_float(mod):
def from_float(mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return ELU(float(scale), int(zero_point), mod.alpha)
@ -129,7 +129,7 @@ class LeakyReLU(torch.nn.LeakyReLU):
return 'QuantizedLeakyReLU'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
@ -154,7 +154,7 @@ class Sigmoid(torch.nn.Sigmoid):
return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
output_scale, output_zero_point = mod.activation_post_process.calculate_qparams()
return cls(float(output_scale), int(output_zero_point))
@ -187,7 +187,7 @@ class Softmax(torch.nn.Softmax):
return 'QuantizedSoftmax'
@staticmethod
def from_float(mod):
def from_float(mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
return Softmax(mod.dim, float(scale), int(zero_point))
@ -269,7 +269,7 @@ class PReLU(torch.nn.Module):
return 'QuantizedPReLU'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
float_wt = mod.weight.float()

View File

@ -14,7 +14,7 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
self.register_buffer('zero_point', torch.tensor(0, **factory_kwargs))
@staticmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
activation_post_process = mod.activation_post_process
if type(mod) == cls._NNI_BN_RELU_MODULE:
mod = mod[0]
@ -72,8 +72,8 @@ class BatchNorm2d(_BatchNorm):
self.running_var, self.eps, self.scale, self.zero_point)
@classmethod
def from_float(cls, mod):
return _BatchNorm.from_float(cls, mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return _BatchNorm.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class BatchNorm3d(_BatchNorm):
r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.
@ -102,5 +102,5 @@ class BatchNorm3d(_BatchNorm):
self.running_var, self.eps, self.scale, self.zero_point)
@classmethod
def from_float(cls, mod):
return _BatchNorm.from_float(cls, mod)
def from_float(cls, mod, use_precomputed_fake_quant=False):
return _BatchNorm.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)

View File

@ -215,7 +215,7 @@ class _ConvNd(WeightedQuantizedModule):
return qconv
@staticmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
if hasattr(mod, "weight_fake_quant"):
# assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
# ".from_float only works for " + cls.__QAT_MODULE.__name__
@ -368,14 +368,14 @@ class Conv1d(_ConvNd):
return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
return _ConvNd.from_float(cls, mod)
return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class Conv2d(_ConvNd):
@ -469,14 +469,14 @@ class Conv2d(_ConvNd):
input, self._packed_params, self.scale, self.zero_point)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
return _ConvNd.from_float(cls, mod)
return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
class Conv3d(_ConvNd):
@ -571,14 +571,14 @@ class Conv3d(_ConvNd):
input, self._packed_params, self.scale, self.zero_point)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
return _ConvNd.from_float(cls, mod)
return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
# === Transposed Convolutions ===
MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
@ -609,7 +609,7 @@ class _ConvTransposeNd(_ConvNd):
return res
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.ao.quantization

View File

@ -19,7 +19,7 @@ class Dropout(torch.nn.Dropout):
return 'QuantizedDropout'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
return cls(mod.p, mod.inplace)
@classmethod

View File

@ -137,7 +137,7 @@ class Embedding(torch.nn.Module):
return self._packed_params._weight()
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized embedding module from a float module
Args:
@ -241,7 +241,7 @@ class EmbeddingBag(Embedding):
return 'QuantizedEmbeddingBag'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized embedding_bag module from a float module
Args:

View File

@ -239,7 +239,7 @@ class QFunctional(torch.nn.Module):
return r
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
assert type(mod) == FloatFunctional, \
"QFunctional.from_float expects an instance of FloatFunctional"
scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator]

View File

@ -240,12 +240,14 @@ class Linear(WeightedQuantizedModule):
self._packed_params.set_weight_bias(w, b)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized module from an observed float module
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
use_precomputed_fake_quant (bool): if True, the module will reuse min/max
values from the precomputed fake quant module.
"""
if hasattr(mod, 'weight_fake_quant'):
if type_before_parametrizations(mod) == nniqat.LinearBn1d:
@ -267,8 +269,12 @@ class Linear(WeightedQuantizedModule):
activation_post_process = mod.activation_post_process
if type_before_parametrizations(mod) == nni.LinearReLU:
mod = mod[0]
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
weight_post_process = mod.qconfig.weight() if not hasattr(mod, "weight_fake_quant") else mod.weight_fake_quant
if not use_precomputed_fake_quant:
# Observer may not have been called yet
# Observer might have been called in the previous stage via PTQ algorithm e.g. AdaRound
weight_post_process(mod.weight)
dtype = weight_post_process.dtype
act_scale, act_zp = activation_post_process.calculate_qparams()
assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'

View File

@ -30,7 +30,7 @@ class LayerNorm(torch.nn.LayerNorm):
return 'QuantizedLayerNorm'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.normalized_shape, mod.weight, mod.bias, float(scale),
@ -71,7 +71,7 @@ class GroupNorm(torch.nn.GroupNorm):
return 'QuantizedGroupNorm'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point),
@ -105,7 +105,7 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d):
return 'QuantizedInstanceNorm1d'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
@ -145,7 +145,7 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d):
return 'QuantizedInstanceNorm2d'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
@ -185,7 +185,7 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d):
return 'QuantizedInstanceNorm3d'
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = cls(
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),

View File

@ -213,7 +213,7 @@ class LSTMCell(RNNCellBase):
return ret
@classmethod
def from_float(cls, mod, weight_qparams_dict):
def from_float(cls, mod, weight_qparams_dict, use_precomputed_fake_quant=False):
ref_mod = cls(
mod.input_size,
mod.hidden_size,

View File

@ -76,7 +76,7 @@ class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule):
self.padding_idx)
@classmethod
def from_float(cls, mod, weight_qparams):
def from_float(cls, mod, weight_qparams, use_precomputed_fake_quant=False):
return cls(
mod.num_embeddings,
mod.embedding_dim,

View File

@ -92,7 +92,7 @@ class Linear(torch.nn.Module):
self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized sparse dynamic module from a float module.
We only care about the convert at this stage, no need for observers just yet.

View File

@ -146,7 +146,7 @@ class Linear(torch.nn.Module):
self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
@classmethod
def from_float(cls, mod):
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a quantized sparse module from a float module.
We only care about the convert at this stage, no need for observers just yet.

View File

@ -235,6 +235,13 @@ def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=N
if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \
and type_before_parametrizations(module) in qconfig_propagation_list:
insert_activation_post_process(module)
# This is a special case for AdaRound eager mode
# AdaRound contains weight_fake_quant to be propagated from API to convert
# leaf node check with a number of children looks naive assumption that blocks
# Adding an exception case for AdaRound
if hasattr(module, "weight_fake_quant") and not isinstance(module, torch.nn.Sequential) \
and type_before_parametrizations(module) in qconfig_propagation_list:
insert_activation_post_process(module)
def _get_unique_devices_(module):
return {p.device for p in module.parameters()} | \
@ -520,7 +527,8 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
def convert(
module, mapping=None, inplace=False, remove_qconfig=True,
is_reference=False, convert_custom_config_dict=None):
is_reference=False, convert_custom_config_dict=None,
use_precomputed_fake_quant=False):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class. And remove qconfig at the
end if remove_qconfig is set to True.
@ -533,6 +541,7 @@ def convert(
`inplace`: carry out model transformations in-place, the original module
is mutated
`convert_custom_config_dict`: custom configuration dictionary for convert function
`use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant
.. code-block:: python
@ -552,14 +561,16 @@ def convert(
module = copy.deepcopy(module)
_convert(
module, mapping, inplace=True, is_reference=is_reference,
convert_custom_config_dict=convert_custom_config_dict)
convert_custom_config_dict=convert_custom_config_dict,
use_precomputed_fake_quant=use_precomputed_fake_quant)
if remove_qconfig:
_remove_qconfig(module)
return module
def _convert(
module, mapping=None, inplace=False,
is_reference=False, convert_custom_config_dict=None):
is_reference=False, convert_custom_config_dict=None,
use_precomputed_fake_quant=False):
r"""Converts submodules in input module to a different module according to `mapping`
by calling `from_float` method on the target module class
@ -571,6 +582,7 @@ def _convert(
inplace: carry out model transformations in-place, the original module
is mutated
is_reference: a flag to enable quantized reference module
use_precomputed_fake_quant: a flag to enable use of precomputed fake quant
"""
if mapping is None:
@ -589,15 +601,16 @@ def _convert(
if not isinstance(mod, _FusedModule) and \
type_before_parametrizations(mod) not in custom_module_class_mapping:
_convert(mod, mapping, True, # inplace
is_reference, convert_custom_config_dict)
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
is_reference, convert_custom_config_dict,
use_precomputed_fake_quant=use_precomputed_fake_quant)
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant)
for key, value in reassign.items():
module._modules[key] = value
return module
def swap_module(mod, mapping, custom_module_class_mapping):
def swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False):
r"""Swaps the module if it has a quantized counterpart and it has an
`observer` attached.
@ -623,7 +636,7 @@ def swap_module(mod, mapping, custom_module_class_mapping):
weight_qparams = get_qparam_dict(weight_post_process)
new_mod = qmod.from_float(mod, weight_qparams)
else:
new_mod = qmod.from_float(mod)
new_mod = qmod.from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
swapped = True
if swapped: