mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support min/max carry over for eager mode from_float method (#127309)
Summary: After QAT is completed or given pre-tuned weight observer via tunable PTQ algorithm, it should not over-write again with a given weight, at least for static QAT never. Dynamic QAT also does not require to re-run weight observer again by design. This is a fix Test Plan: Signals Differential Revision: D57747749 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127309 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
82a370ae3a
commit
c404b2968c
|
|
@ -2,62 +2,63 @@
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.ao.nn.intrinsic.qat as nniqat
|
||||||
import torch.backends.mkldnn
|
import torch.ao.nn.qat as nnqat
|
||||||
from torch.nn import Conv2d, BatchNorm2d, ReLU, init
|
import torch.ao.nn.qat.dynamic as nnqatd
|
||||||
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
|
|
||||||
from torch.nn.modules.utils import _pair
|
|
||||||
import torch.ao.nn.quantized as nnq
|
import torch.ao.nn.quantized as nnq
|
||||||
import torch.ao.nn.quantized.dynamic as nnqd
|
import torch.ao.nn.quantized.dynamic as nnqd
|
||||||
import torch.ao.nn.qat as nnqat
|
import torch.backends.mkldnn
|
||||||
import torch.ao.nn.intrinsic.qat as nniqat
|
import torch.nn as nn
|
||||||
import torch.ao.nn.qat.dynamic as nnqatd
|
import torch.testing._internal.hypothesis_utils as hu
|
||||||
|
|
||||||
|
from hypothesis import given, strategies as st
|
||||||
|
from torch.ao.nn.intrinsic.qat import ConvBn2d, ConvBnReLU2d
|
||||||
from torch.ao.quantization import (
|
from torch.ao.quantization import (
|
||||||
prepare,
|
|
||||||
convert,
|
convert,
|
||||||
prepare_qat,
|
|
||||||
quantize_qat,
|
|
||||||
QuantStub,
|
|
||||||
DeQuantStub,
|
|
||||||
default_qconfig,
|
|
||||||
default_qat_qconfig,
|
|
||||||
default_embedding_qat_qconfig,
|
default_embedding_qat_qconfig,
|
||||||
|
default_qat_qconfig,
|
||||||
|
default_qconfig,
|
||||||
default_symmetric_qnnpack_qat_qconfig,
|
default_symmetric_qnnpack_qat_qconfig,
|
||||||
get_default_qat_qconfig,
|
DeQuantStub,
|
||||||
FixedQParamsFakeQuantize,
|
FixedQParamsFakeQuantize,
|
||||||
FusedMovingAvgObsFakeQuantize,
|
FusedMovingAvgObsFakeQuantize,
|
||||||
|
get_default_qat_qconfig,
|
||||||
get_embedding_qat_module_mappings,
|
get_embedding_qat_module_mappings,
|
||||||
get_embedding_static_quant_module_mappings,
|
get_embedding_static_quant_module_mappings,
|
||||||
NoopObserver,
|
NoopObserver,
|
||||||
|
prepare,
|
||||||
|
prepare_qat,
|
||||||
|
quantize_qat,
|
||||||
|
QuantStub,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization.qconfig import qconfig_equals
|
from torch.ao.quantization.qconfig import qconfig_equals
|
||||||
|
from torch.nn import BatchNorm2d, Conv2d, init, ReLU
|
||||||
|
from torch.nn.modules.utils import _pair
|
||||||
from torch.testing._internal.common_quantization import (
|
from torch.testing._internal.common_quantization import (
|
||||||
DeFusedEmbeddingBagLinear,
|
DeFusedEmbeddingBagLinear,
|
||||||
QuantizationTestCase,
|
|
||||||
QuantStubModel,
|
|
||||||
ManualLinearQATModel,
|
|
||||||
ManualDropoutQATModel,
|
|
||||||
ManualLinearDynamicQATModel,
|
|
||||||
ManualConvLinearQATModel,
|
ManualConvLinearQATModel,
|
||||||
ManualConvLinearSymmQATModel,
|
ManualConvLinearSymmQATModel,
|
||||||
|
ManualDropoutQATModel,
|
||||||
ManualEmbeddingBagLinear,
|
ManualEmbeddingBagLinear,
|
||||||
TwoLayerLinearModel,
|
ManualLinearDynamicQATModel,
|
||||||
|
ManualLinearQATModel,
|
||||||
|
QuantizationTestCase,
|
||||||
|
QuantStubModel,
|
||||||
test_only_eval_fn,
|
test_only_eval_fn,
|
||||||
test_only_train_fn,
|
test_only_train_fn,
|
||||||
|
TwoLayerLinearModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.testing._internal.common_quantized import (
|
from torch.testing._internal.common_quantized import (
|
||||||
|
override_qengines,
|
||||||
override_quantized_engine,
|
override_quantized_engine,
|
||||||
supported_qengines,
|
supported_qengines,
|
||||||
override_qengines,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import skipIfNoXNNPACK
|
from torch.testing._internal.common_utils import skipIfNoXNNPACK
|
||||||
|
|
||||||
from hypothesis import given
|
|
||||||
from hypothesis import strategies as st
|
|
||||||
import torch.testing._internal.hypothesis_utils as hu
|
|
||||||
hu.assert_deadline_disabled()
|
hu.assert_deadline_disabled()
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
|
|
@ -1099,6 +1100,33 @@ class TestQuantizeEagerQATNumerics(QuantizationTestCase):
|
||||||
self.assertTrue(type(mq[1]) == nnq.Linear)
|
self.assertTrue(type(mq[1]) == nnq.Linear)
|
||||||
self.assertTrue(type(mq[2]) == nn.Identity)
|
self.assertTrue(type(mq[2]) == nn.Identity)
|
||||||
|
|
||||||
|
|
||||||
|
@skipIfNoXNNPACK
|
||||||
|
@override_qengines
|
||||||
|
def test_linear_precomputed_fake_quant(self):
|
||||||
|
qengine = torch.backends.quantized.engine
|
||||||
|
if qengine != "qnnpack":
|
||||||
|
return # Only qnnpack support symmetric quantization
|
||||||
|
m_ref = nn.Linear(4, 4)
|
||||||
|
|
||||||
|
m_ref_copy = copy.deepcopy(m_ref)
|
||||||
|
qconfig = default_qconfig
|
||||||
|
m_ref_copy.qconfig = qconfig
|
||||||
|
weight_post_process = copy.deepcopy(qconfig.weight())
|
||||||
|
activation = copy.deepcopy(qconfig.activation())
|
||||||
|
activation(torch.randn(4, 4))
|
||||||
|
m_ref_copy.activation_post_process = activation
|
||||||
|
m_ref_copy = nnq.Linear.from_float(m_ref_copy)
|
||||||
|
weight_post_process = qconfig.weight()
|
||||||
|
weight_post_process.min_val = torch.tensor(-1)
|
||||||
|
weight_post_process.max_val = torch.tensor(1)
|
||||||
|
m_ref.weight_post_process = weight_post_process
|
||||||
|
m_ref.activation_post_process = activation
|
||||||
|
m_ref.qconfig = qconfig
|
||||||
|
m_ref = nnq.Linear.from_float(m_ref, use_precomputed_fake_quant=True)
|
||||||
|
self.assertTrue(m_ref._weight_bias()[0].q_scale != m_ref_copy._weight_bias()[0].q_scale)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||||
"\tpython test/test_quantization.py TESTNAME\n\n"
|
"\tpython test/test_quantization.py TESTNAME\n\n"
|
||||||
|
|
|
||||||
|
|
@ -289,7 +289,7 @@ class _ConvBnNd(nn.modules.conv._ConvNd, nni._FusedModule):
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Create a qat module from a float module or qparams_dict
|
r"""Create a qat module from a float module or qparams_dict
|
||||||
|
|
||||||
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
||||||
|
|
@ -453,8 +453,8 @@ class ConvBnReLU1d(ConvBn1d):
|
||||||
return F.relu(ConvBn1d._forward(self, input))
|
return F.relu(ConvBn1d._forward(self, input))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant)
|
||||||
|
|
||||||
class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
|
class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
|
||||||
r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
|
r"""A ConvReLU1d module is a fused module of Conv1d and ReLU, attached with
|
||||||
|
|
@ -490,8 +490,8 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
|
||||||
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
|
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
class ConvBn2d(_ConvBnNd, nn.Conv2d):
|
class ConvBn2d(_ConvBnNd, nn.Conv2d):
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -585,8 +585,8 @@ class ConvBnReLU2d(ConvBn2d):
|
||||||
return F.relu(ConvBn2d._forward(self, input))
|
return F.relu(ConvBn2d._forward(self, input))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant)
|
||||||
|
|
||||||
class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
|
class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
|
||||||
r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
|
r"""A ConvReLU2d module is a fused module of Conv2d and ReLU, attached with
|
||||||
|
|
@ -622,8 +622,8 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
|
||||||
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
|
self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
class ConvBn3d(_ConvBnNd, nn.Conv3d):
|
class ConvBn3d(_ConvBnNd, nn.Conv3d):
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -758,8 +758,8 @@ class ConvBnReLU3d(ConvBn3d):
|
||||||
return F.relu(ConvBn3d._forward(self, input))
|
return F.relu(ConvBn3d._forward(self, input))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
|
class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
|
||||||
r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
|
r"""A ConvReLU3d module is a fused module of Conv3d and ReLU, attached with
|
||||||
|
|
@ -813,8 +813,8 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
def update_bn_stats(mod):
|
def update_bn_stats(mod):
|
||||||
if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}:
|
if type(mod) in {ConvBnReLU1d, ConvBnReLU2d, ConvBnReLU3d, ConvBn1d, ConvBn2d, ConvBn3d}:
|
||||||
|
|
|
||||||
|
|
@ -133,7 +133,7 @@ class LinearBn1d(nn.modules.linear.Linear, nni._FusedModule):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Create a qat module from a float module or qparams_dict
|
r"""Create a qat module from a float module or qparams_dict
|
||||||
|
|
||||||
Args: `mod' a float module, either produced by torch.ao.quantization
|
Args: `mod' a float module, either produced by torch.ao.quantization
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,8 @@ class LinearReLU(nnqat.Linear, nni._FusedModule):
|
||||||
return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
|
return F.relu(F.linear(input, self.weight_fake_quant(self.weight), self.bias))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant)
|
||||||
|
|
||||||
def to_float(self):
|
def to_float(self):
|
||||||
linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None)
|
linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None)
|
||||||
|
|
|
||||||
|
|
@ -47,8 +47,8 @@ class LinearReLU(nnqd.Linear):
|
||||||
return 'DynamicQuantizedLinearReLU'
|
return 'DynamicQuantizedLinearReLU'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_reference(cls, ref_qlinear_relu):
|
def from_reference(cls, ref_qlinear_relu):
|
||||||
|
|
|
||||||
|
|
@ -37,9 +37,9 @@ class BNReLU2d(nnq.BatchNorm2d):
|
||||||
return 'QuantizedBNReLU2d'
|
return 'QuantizedBNReLU2d'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
# TODO: Add qat support for BNReLU2d
|
# TODO: Add qat support for BNReLU2d
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_reference(cls, bn_relu, output_scale, output_zero_point):
|
def from_reference(cls, bn_relu, output_scale, output_zero_point):
|
||||||
|
|
@ -73,9 +73,9 @@ class BNReLU3d(nnq.BatchNorm3d):
|
||||||
return 'QuantizedBNReLU3d'
|
return 'QuantizedBNReLU3d'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
# TODO: Add qat support for BNReLU3d
|
# TODO: Add qat support for BNReLU3d
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_reference(cls, bn_relu, output_scale, output_zero_point):
|
def from_reference(cls, bn_relu, output_scale, output_zero_point):
|
||||||
|
|
|
||||||
|
|
@ -42,8 +42,8 @@ class ConvAdd2d(nnq.Conv2d):
|
||||||
return 'QuantizedConvAdd2d'
|
return 'QuantizedConvAdd2d'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
||||||
|
|
@ -85,8 +85,8 @@ class ConvAddReLU2d(nnq.Conv2d):
|
||||||
return 'QuantizedConvAddReLU2d'
|
return 'QuantizedConvAddReLU2d'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
||||||
|
|
|
||||||
|
|
@ -53,13 +53,13 @@ class ConvReLU1d(nnq.Conv1d):
|
||||||
return 'QuantizedConvReLU1d'
|
return 'QuantizedConvReLU1d'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
|
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU1d:
|
||||||
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
||||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||||
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
|
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
|
||||||
mod.bn.eps, mod.bn.weight, mod.bn.bias)
|
mod.bn.eps, mod.bn.weight, mod.bn.bias)
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
||||||
|
|
@ -103,13 +103,13 @@ class ConvReLU2d(nnq.Conv2d):
|
||||||
return 'QuantizedConvReLU2d'
|
return 'QuantizedConvReLU2d'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
|
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU2d:
|
||||||
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
||||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||||
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
|
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
|
||||||
mod.bn.eps, mod.bn.weight, mod.bn.bias)
|
mod.bn.eps, mod.bn.weight, mod.bn.bias)
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
||||||
|
|
@ -154,7 +154,7 @@ class ConvReLU3d(nnq.Conv3d):
|
||||||
return 'QuantizedConvReLU3d'
|
return 'QuantizedConvReLU3d'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
|
if type(mod) == torch.ao.nn.intrinsic.qat.ConvBnReLU3d:
|
||||||
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
assert mod.bn.running_var is not None and mod.bn.running_mean is not None
|
||||||
mod.weight, mod.bias = fuse_conv_bn_weights(
|
mod.weight, mod.bias = fuse_conv_bn_weights(
|
||||||
|
|
@ -166,7 +166,7 @@ class ConvReLU3d(nnq.Conv3d):
|
||||||
mod.bn.weight,
|
mod.bn.weight,
|
||||||
mod.bn.bias,
|
mod.bn.bias,
|
||||||
)
|
)
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
def from_reference(cls, ref_qconv, output_scale, output_zero_point):
|
||||||
|
|
|
||||||
|
|
@ -40,8 +40,8 @@ class LinearReLU(nnq.Linear):
|
||||||
return 'QuantizedLinearReLU'
|
return 'QuantizedLinearReLU'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
|
def from_reference(cls, ref_linear_relu, output_scale, output_zero_point):
|
||||||
|
|
@ -77,7 +77,7 @@ class LinearLeakyReLU(nnq.Linear):
|
||||||
return 'QuantizedLinearLeakyReLU'
|
return 'QuantizedLinearLeakyReLU'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
assert type(mod) == nni.LinearLeakyReLU, 'Input float module should be LinearLeakyReLU'
|
assert type(mod) == nni.LinearLeakyReLU, 'Input float module should be LinearLeakyReLU'
|
||||||
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
||||||
activation_post_process = mod.activation_post_process
|
activation_post_process = mod.activation_post_process
|
||||||
|
|
@ -144,7 +144,7 @@ class LinearTanh(nnq.Linear):
|
||||||
return 'QuantizedLinearTanh'
|
return 'QuantizedLinearTanh'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
assert type(mod) == nni.LinearTanh, 'Input float module should be LinearTanh'
|
assert type(mod) == nni.LinearTanh, 'Input float module should be LinearTanh'
|
||||||
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
||||||
activation_post_process = mod.activation_post_process
|
activation_post_process = mod.activation_post_process
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class _ConvNd(nn.modules.conv._ConvNd):
|
||||||
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Create a qat module from a float module
|
r"""Create a qat module from a float module
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -150,8 +150,8 @@ class Conv1d(_ConvNd, nn.Conv1d):
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(cls, mod)
|
return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
class Conv2d(_ConvNd, nn.Conv2d):
|
class Conv2d(_ConvNd, nn.Conv2d):
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -208,8 +208,8 @@ class Conv2d(_ConvNd, nn.Conv2d):
|
||||||
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(cls, mod)
|
return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
class Conv3d(_ConvNd, nn.Conv3d):
|
class Conv3d(_ConvNd, nn.Conv3d):
|
||||||
r"""
|
r"""
|
||||||
|
|
@ -266,5 +266,5 @@ class Conv3d(_ConvNd, nn.Conv3d):
|
||||||
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(cls, mod)
|
return super().from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class Embedding(nn.Embedding):
|
||||||
self.sparse)
|
self.sparse)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Create a qat module from a float module
|
r"""Create a qat module from a float module
|
||||||
|
|
||||||
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
||||||
|
|
@ -112,7 +112,7 @@ class EmbeddingBag(nn.EmbeddingBag):
|
||||||
self.padding_idx)
|
self.padding_idx)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Create a qat module from a float module
|
r"""Create a qat module from a float module
|
||||||
|
|
||||||
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ class Linear(nn.Linear):
|
||||||
return F.linear(input, self.weight_fake_quant(self.weight), self.bias)
|
return F.linear(input, self.weight_fake_quant(self.weight), self.bias)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Create a qat module from a float module or qparams_dict
|
r"""Create a qat module from a float module or qparams_dict
|
||||||
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
||||||
or directly from user
|
or directly from user
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,7 @@ class LSTMCell(torch.nn.Module):
|
||||||
return cell
|
return cell
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, other):
|
def from_float(cls, other, use_precomputed_fake_quant=False):
|
||||||
assert type(other) == cls._FLOAT_MODULE
|
assert type(other) == cls._FLOAT_MODULE
|
||||||
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
|
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
|
||||||
observed = cls.from_params(other.weight_ih, other.weight_hh,
|
observed = cls.from_params(other.weight_ih, other.weight_hh,
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ class Linear(nnq.Linear):
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Create a dynamic quantized module from a float module or qparams_dict
|
r"""Create a dynamic quantized module from a float module or qparams_dict
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -268,7 +268,7 @@ class RNNBase(torch.nn.Module):
|
||||||
self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
|
self._all_weight_values = torch.nn.ModuleList(_all_weight_values)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
assert type(mod) in {torch.nn.LSTM,
|
assert type(mod) in {torch.nn.LSTM,
|
||||||
torch.nn.GRU}, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU'
|
torch.nn.GRU}, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU'
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
|
|
@ -495,8 +495,8 @@ class LSTM(RNNBase):
|
||||||
return self.forward_tensor(input, hx)
|
return self.forward_tensor(input, hx)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_reference(cls, ref_mod):
|
def from_reference(cls, ref_mod):
|
||||||
|
|
@ -747,8 +747,8 @@ class GRU(RNNBase):
|
||||||
return self.forward_tensor(input, hx)
|
return self.forward_tensor(input, hx)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_reference(cls, ref_mod):
|
def from_reference(cls, ref_mod):
|
||||||
|
|
@ -839,7 +839,7 @@ class RNNCellBase(torch.nn.Module):
|
||||||
f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}")
|
f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
assert type(mod) in {torch.nn.LSTMCell,
|
assert type(mod) in {torch.nn.LSTMCell,
|
||||||
torch.nn.GRUCell,
|
torch.nn.GRUCell,
|
||||||
torch.nn.RNNCell}, 'nn.quantized.dynamic.RNNCellBase.from_float \
|
torch.nn.RNNCell}, 'nn.quantized.dynamic.RNNCellBase.from_float \
|
||||||
|
|
@ -1012,8 +1012,8 @@ class RNNCell(RNNCellBase):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
|
|
||||||
class LSTMCell(RNNCellBase):
|
class LSTMCell(RNNCellBase):
|
||||||
|
|
@ -1055,8 +1055,8 @@ class LSTMCell(RNNCellBase):
|
||||||
self.bias_ih, self.bias_hh)
|
self.bias_ih, self.bias_hh)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
|
|
||||||
class GRUCell(RNNCellBase):
|
class GRUCell(RNNCellBase):
|
||||||
|
|
@ -1096,5 +1096,5 @@ class GRUCell(RNNCellBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return super().from_float(mod)
|
return super().from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ class Quantize(torch.nn.Module):
|
||||||
int(self.zero_point), self.dtype)
|
int(self.zero_point), self.dtype)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_float(mod):
|
def from_float(mod, use_precomputed_fake_quant=False):
|
||||||
assert hasattr(mod, 'activation_post_process')
|
assert hasattr(mod, 'activation_post_process')
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||||
return Quantize(scale.float().item(), zero_point.long().item(), mod.activation_post_process.dtype)
|
return Quantize(scale.float().item(), zero_point.long().item(), mod.activation_post_process.dtype)
|
||||||
|
|
@ -127,5 +127,5 @@ class DeQuantize(torch.nn.Module):
|
||||||
return Xq.dequantize()
|
return Xq.dequantize()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_float(mod):
|
def from_float(mod, use_precomputed_fake_quant=False):
|
||||||
return DeQuantize()
|
return DeQuantize()
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ class ReLU6(torch.nn.ReLU):
|
||||||
return 'QuantizedReLU6'
|
return 'QuantizedReLU6'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_float(mod):
|
def from_float(mod, use_precomputed_fake_quant=False):
|
||||||
return ReLU6(mod.inplace)
|
return ReLU6(mod.inplace)
|
||||||
|
|
||||||
class Hardswish(torch.nn.Hardswish):
|
class Hardswish(torch.nn.Hardswish):
|
||||||
|
|
@ -69,7 +69,7 @@ class Hardswish(torch.nn.Hardswish):
|
||||||
return 'QuantizedHardswish'
|
return 'QuantizedHardswish'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_float(mod):
|
def from_float(mod, use_precomputed_fake_quant=False):
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||||
return Hardswish(float(scale), int(zero_point))
|
return Hardswish(float(scale), int(zero_point))
|
||||||
|
|
||||||
|
|
@ -98,7 +98,7 @@ class ELU(torch.nn.ELU):
|
||||||
return 'QuantizedELU'
|
return 'QuantizedELU'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_float(mod):
|
def from_float(mod, use_precomputed_fake_quant=False):
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||||
return ELU(float(scale), int(zero_point), mod.alpha)
|
return ELU(float(scale), int(zero_point), mod.alpha)
|
||||||
|
|
||||||
|
|
@ -129,7 +129,7 @@ class LeakyReLU(torch.nn.LeakyReLU):
|
||||||
return 'QuantizedLeakyReLU'
|
return 'QuantizedLeakyReLU'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||||
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
|
return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
|
||||||
|
|
||||||
|
|
@ -154,7 +154,7 @@ class Sigmoid(torch.nn.Sigmoid):
|
||||||
return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point)
|
return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
output_scale, output_zero_point = mod.activation_post_process.calculate_qparams()
|
output_scale, output_zero_point = mod.activation_post_process.calculate_qparams()
|
||||||
return cls(float(output_scale), int(output_zero_point))
|
return cls(float(output_scale), int(output_zero_point))
|
||||||
|
|
||||||
|
|
@ -187,7 +187,7 @@ class Softmax(torch.nn.Softmax):
|
||||||
return 'QuantizedSoftmax'
|
return 'QuantizedSoftmax'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_float(mod):
|
def from_float(mod, use_precomputed_fake_quant=False):
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||||
return Softmax(mod.dim, float(scale), int(zero_point))
|
return Softmax(mod.dim, float(scale), int(zero_point))
|
||||||
|
|
||||||
|
|
@ -269,7 +269,7 @@ class PReLU(torch.nn.Module):
|
||||||
return 'QuantizedPReLU'
|
return 'QuantizedPReLU'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||||
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
|
qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
|
||||||
float_wt = mod.weight.float()
|
float_wt = mod.weight.float()
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
|
||||||
self.register_buffer('zero_point', torch.tensor(0, **factory_kwargs))
|
self.register_buffer('zero_point', torch.tensor(0, **factory_kwargs))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
activation_post_process = mod.activation_post_process
|
activation_post_process = mod.activation_post_process
|
||||||
if type(mod) == cls._NNI_BN_RELU_MODULE:
|
if type(mod) == cls._NNI_BN_RELU_MODULE:
|
||||||
mod = mod[0]
|
mod = mod[0]
|
||||||
|
|
@ -72,8 +72,8 @@ class BatchNorm2d(_BatchNorm):
|
||||||
self.running_var, self.eps, self.scale, self.zero_point)
|
self.running_var, self.eps, self.scale, self.zero_point)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return _BatchNorm.from_float(cls, mod)
|
return _BatchNorm.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
class BatchNorm3d(_BatchNorm):
|
class BatchNorm3d(_BatchNorm):
|
||||||
r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.
|
r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.
|
||||||
|
|
@ -102,5 +102,5 @@ class BatchNorm3d(_BatchNorm):
|
||||||
self.running_var, self.eps, self.scale, self.zero_point)
|
self.running_var, self.eps, self.scale, self.zero_point)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return _BatchNorm.from_float(cls, mod)
|
return _BatchNorm.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
|
||||||
|
|
@ -215,7 +215,7 @@ class _ConvNd(WeightedQuantizedModule):
|
||||||
return qconv
|
return qconv
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
if hasattr(mod, "weight_fake_quant"):
|
if hasattr(mod, "weight_fake_quant"):
|
||||||
# assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
|
# assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
|
||||||
# ".from_float only works for " + cls.__QAT_MODULE.__name__
|
# ".from_float only works for " + cls.__QAT_MODULE.__name__
|
||||||
|
|
@ -368,14 +368,14 @@ class Conv1d(_ConvNd):
|
||||||
return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)
|
return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Creates a quantized module from a float module or qparams_dict.
|
r"""Creates a quantized module from a float module or qparams_dict.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mod (Module): a float module, either produced by torch.ao.quantization
|
mod (Module): a float module, either produced by torch.ao.quantization
|
||||||
utilities or provided by the user
|
utilities or provided by the user
|
||||||
"""
|
"""
|
||||||
return _ConvNd.from_float(cls, mod)
|
return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
|
|
||||||
class Conv2d(_ConvNd):
|
class Conv2d(_ConvNd):
|
||||||
|
|
@ -469,14 +469,14 @@ class Conv2d(_ConvNd):
|
||||||
input, self._packed_params, self.scale, self.zero_point)
|
input, self._packed_params, self.scale, self.zero_point)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Creates a quantized module from a float module or qparams_dict.
|
r"""Creates a quantized module from a float module or qparams_dict.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mod (Module): a float module, either produced by torch.ao.quantization
|
mod (Module): a float module, either produced by torch.ao.quantization
|
||||||
utilities or provided by the user
|
utilities or provided by the user
|
||||||
"""
|
"""
|
||||||
return _ConvNd.from_float(cls, mod)
|
return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
|
|
||||||
class Conv3d(_ConvNd):
|
class Conv3d(_ConvNd):
|
||||||
|
|
@ -571,14 +571,14 @@ class Conv3d(_ConvNd):
|
||||||
input, self._packed_params, self.scale, self.zero_point)
|
input, self._packed_params, self.scale, self.zero_point)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Creates a quantized module from a float module or qparams_dict.
|
r"""Creates a quantized module from a float module or qparams_dict.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mod (Module): a float module, either produced by torch.ao.quantization
|
mod (Module): a float module, either produced by torch.ao.quantization
|
||||||
utilities or provided by the user
|
utilities or provided by the user
|
||||||
"""
|
"""
|
||||||
return _ConvNd.from_float(cls, mod)
|
return _ConvNd.from_float(cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
|
||||||
# === Transposed Convolutions ===
|
# === Transposed Convolutions ===
|
||||||
MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
|
MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
|
||||||
|
|
@ -609,7 +609,7 @@ class _ConvTransposeNd(_ConvNd):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Creates a quantized module from a float module or qparams_dict.
|
r"""Creates a quantized module from a float module or qparams_dict.
|
||||||
Args:
|
Args:
|
||||||
mod (Module): a float module, either produced by torch.ao.quantization
|
mod (Module): a float module, either produced by torch.ao.quantization
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ class Dropout(torch.nn.Dropout):
|
||||||
return 'QuantizedDropout'
|
return 'QuantizedDropout'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
return cls(mod.p, mod.inplace)
|
return cls(mod.p, mod.inplace)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,7 @@ class Embedding(torch.nn.Module):
|
||||||
return self._packed_params._weight()
|
return self._packed_params._weight()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Create a quantized embedding module from a float module
|
r"""Create a quantized embedding module from a float module
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -241,7 +241,7 @@ class EmbeddingBag(Embedding):
|
||||||
return 'QuantizedEmbeddingBag'
|
return 'QuantizedEmbeddingBag'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Create a quantized embedding_bag module from a float module
|
r"""Create a quantized embedding_bag module from a float module
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -239,7 +239,7 @@ class QFunctional(torch.nn.Module):
|
||||||
return r
|
return r
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
assert type(mod) == FloatFunctional, \
|
assert type(mod) == FloatFunctional, \
|
||||||
"QFunctional.from_float expects an instance of FloatFunctional"
|
"QFunctional.from_float expects an instance of FloatFunctional"
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator]
|
scale, zero_point = mod.activation_post_process.calculate_qparams() # type: ignore[operator]
|
||||||
|
|
|
||||||
|
|
@ -240,12 +240,14 @@ class Linear(WeightedQuantizedModule):
|
||||||
self._packed_params.set_weight_bias(w, b)
|
self._packed_params.set_weight_bias(w, b)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Create a quantized module from an observed float module
|
r"""Create a quantized module from an observed float module
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mod (Module): a float module, either produced by torch.ao.quantization
|
mod (Module): a float module, either produced by torch.ao.quantization
|
||||||
utilities or provided by the user
|
utilities or provided by the user
|
||||||
|
use_precomputed_fake_quant (bool): if True, the module will reuse min/max
|
||||||
|
values from the precomputed fake quant module.
|
||||||
"""
|
"""
|
||||||
if hasattr(mod, 'weight_fake_quant'):
|
if hasattr(mod, 'weight_fake_quant'):
|
||||||
if type_before_parametrizations(mod) == nniqat.LinearBn1d:
|
if type_before_parametrizations(mod) == nniqat.LinearBn1d:
|
||||||
|
|
@ -267,8 +269,12 @@ class Linear(WeightedQuantizedModule):
|
||||||
activation_post_process = mod.activation_post_process
|
activation_post_process = mod.activation_post_process
|
||||||
if type_before_parametrizations(mod) == nni.LinearReLU:
|
if type_before_parametrizations(mod) == nni.LinearReLU:
|
||||||
mod = mod[0]
|
mod = mod[0]
|
||||||
weight_post_process = mod.qconfig.weight()
|
weight_post_process = mod.qconfig.weight() if not hasattr(mod, "weight_fake_quant") else mod.weight_fake_quant
|
||||||
weight_post_process(mod.weight)
|
|
||||||
|
if not use_precomputed_fake_quant:
|
||||||
|
# Observer may not have been called yet
|
||||||
|
# Observer might have been called in the previous stage via PTQ algorithm e.g. AdaRound
|
||||||
|
weight_post_process(mod.weight)
|
||||||
dtype = weight_post_process.dtype
|
dtype = weight_post_process.dtype
|
||||||
act_scale, act_zp = activation_post_process.calculate_qparams()
|
act_scale, act_zp = activation_post_process.calculate_qparams()
|
||||||
assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
|
assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ class LayerNorm(torch.nn.LayerNorm):
|
||||||
return 'QuantizedLayerNorm'
|
return 'QuantizedLayerNorm'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||||
new_mod = cls(
|
new_mod = cls(
|
||||||
mod.normalized_shape, mod.weight, mod.bias, float(scale),
|
mod.normalized_shape, mod.weight, mod.bias, float(scale),
|
||||||
|
|
@ -71,7 +71,7 @@ class GroupNorm(torch.nn.GroupNorm):
|
||||||
return 'QuantizedGroupNorm'
|
return 'QuantizedGroupNorm'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||||
new_mod = cls(
|
new_mod = cls(
|
||||||
mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point),
|
mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point),
|
||||||
|
|
@ -105,7 +105,7 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d):
|
||||||
return 'QuantizedInstanceNorm1d'
|
return 'QuantizedInstanceNorm1d'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||||
new_mod = cls(
|
new_mod = cls(
|
||||||
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
|
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
|
||||||
|
|
@ -145,7 +145,7 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d):
|
||||||
return 'QuantizedInstanceNorm2d'
|
return 'QuantizedInstanceNorm2d'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||||
new_mod = cls(
|
new_mod = cls(
|
||||||
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
|
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
|
||||||
|
|
@ -185,7 +185,7 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d):
|
||||||
return 'QuantizedInstanceNorm3d'
|
return 'QuantizedInstanceNorm3d'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||||
new_mod = cls(
|
new_mod = cls(
|
||||||
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
|
mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
|
||||||
|
|
|
||||||
|
|
@ -213,7 +213,7 @@ class LSTMCell(RNNCellBase):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod, weight_qparams_dict):
|
def from_float(cls, mod, weight_qparams_dict, use_precomputed_fake_quant=False):
|
||||||
ref_mod = cls(
|
ref_mod = cls(
|
||||||
mod.input_size,
|
mod.input_size,
|
||||||
mod.hidden_size,
|
mod.hidden_size,
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ class EmbeddingBag(nn.EmbeddingBag, ReferenceQuantizedModule):
|
||||||
self.padding_idx)
|
self.padding_idx)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod, weight_qparams):
|
def from_float(cls, mod, weight_qparams, use_precomputed_fake_quant=False):
|
||||||
return cls(
|
return cls(
|
||||||
mod.num_embeddings,
|
mod.num_embeddings,
|
||||||
mod.embedding_dim,
|
mod.embedding_dim,
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ class Linear(torch.nn.Module):
|
||||||
self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
|
self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Create a quantized sparse dynamic module from a float module.
|
r"""Create a quantized sparse dynamic module from a float module.
|
||||||
|
|
||||||
We only care about the convert at this stage, no need for observers just yet.
|
We only care about the convert at this stage, no need for observers just yet.
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,7 @@ class Linear(torch.nn.Module):
|
||||||
self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
|
self._packed_params.set_weight_bias(w, b, row_block_size, col_block_size)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, mod):
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
||||||
r"""Create a quantized sparse module from a float module.
|
r"""Create a quantized sparse module from a float module.
|
||||||
|
|
||||||
We only care about the convert at this stage, no need for observers just yet.
|
We only care about the convert at this stage, no need for observers just yet.
|
||||||
|
|
|
||||||
|
|
@ -235,6 +235,13 @@ def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=N
|
||||||
if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \
|
if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \
|
||||||
and type_before_parametrizations(module) in qconfig_propagation_list:
|
and type_before_parametrizations(module) in qconfig_propagation_list:
|
||||||
insert_activation_post_process(module)
|
insert_activation_post_process(module)
|
||||||
|
# This is a special case for AdaRound eager mode
|
||||||
|
# AdaRound contains weight_fake_quant to be propagated from API to convert
|
||||||
|
# leaf node check with a number of children looks naive assumption that blocks
|
||||||
|
# Adding an exception case for AdaRound
|
||||||
|
if hasattr(module, "weight_fake_quant") and not isinstance(module, torch.nn.Sequential) \
|
||||||
|
and type_before_parametrizations(module) in qconfig_propagation_list:
|
||||||
|
insert_activation_post_process(module)
|
||||||
|
|
||||||
def _get_unique_devices_(module):
|
def _get_unique_devices_(module):
|
||||||
return {p.device for p in module.parameters()} | \
|
return {p.device for p in module.parameters()} | \
|
||||||
|
|
@ -520,7 +527,8 @@ def quantize_qat(model, run_fn, run_args, inplace=False):
|
||||||
|
|
||||||
def convert(
|
def convert(
|
||||||
module, mapping=None, inplace=False, remove_qconfig=True,
|
module, mapping=None, inplace=False, remove_qconfig=True,
|
||||||
is_reference=False, convert_custom_config_dict=None):
|
is_reference=False, convert_custom_config_dict=None,
|
||||||
|
use_precomputed_fake_quant=False):
|
||||||
r"""Converts submodules in input module to a different module according to `mapping`
|
r"""Converts submodules in input module to a different module according to `mapping`
|
||||||
by calling `from_float` method on the target module class. And remove qconfig at the
|
by calling `from_float` method on the target module class. And remove qconfig at the
|
||||||
end if remove_qconfig is set to True.
|
end if remove_qconfig is set to True.
|
||||||
|
|
@ -533,6 +541,7 @@ def convert(
|
||||||
`inplace`: carry out model transformations in-place, the original module
|
`inplace`: carry out model transformations in-place, the original module
|
||||||
is mutated
|
is mutated
|
||||||
`convert_custom_config_dict`: custom configuration dictionary for convert function
|
`convert_custom_config_dict`: custom configuration dictionary for convert function
|
||||||
|
`use_precomputed_fake_quant`: a flag to enable use of precomputed fake quant
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
|
@ -552,14 +561,16 @@ def convert(
|
||||||
module = copy.deepcopy(module)
|
module = copy.deepcopy(module)
|
||||||
_convert(
|
_convert(
|
||||||
module, mapping, inplace=True, is_reference=is_reference,
|
module, mapping, inplace=True, is_reference=is_reference,
|
||||||
convert_custom_config_dict=convert_custom_config_dict)
|
convert_custom_config_dict=convert_custom_config_dict,
|
||||||
|
use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
if remove_qconfig:
|
if remove_qconfig:
|
||||||
_remove_qconfig(module)
|
_remove_qconfig(module)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
def _convert(
|
def _convert(
|
||||||
module, mapping=None, inplace=False,
|
module, mapping=None, inplace=False,
|
||||||
is_reference=False, convert_custom_config_dict=None):
|
is_reference=False, convert_custom_config_dict=None,
|
||||||
|
use_precomputed_fake_quant=False):
|
||||||
r"""Converts submodules in input module to a different module according to `mapping`
|
r"""Converts submodules in input module to a different module according to `mapping`
|
||||||
by calling `from_float` method on the target module class
|
by calling `from_float` method on the target module class
|
||||||
|
|
||||||
|
|
@ -571,6 +582,7 @@ def _convert(
|
||||||
inplace: carry out model transformations in-place, the original module
|
inplace: carry out model transformations in-place, the original module
|
||||||
is mutated
|
is mutated
|
||||||
is_reference: a flag to enable quantized reference module
|
is_reference: a flag to enable quantized reference module
|
||||||
|
use_precomputed_fake_quant: a flag to enable use of precomputed fake quant
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if mapping is None:
|
if mapping is None:
|
||||||
|
|
@ -589,15 +601,16 @@ def _convert(
|
||||||
if not isinstance(mod, _FusedModule) and \
|
if not isinstance(mod, _FusedModule) and \
|
||||||
type_before_parametrizations(mod) not in custom_module_class_mapping:
|
type_before_parametrizations(mod) not in custom_module_class_mapping:
|
||||||
_convert(mod, mapping, True, # inplace
|
_convert(mod, mapping, True, # inplace
|
||||||
is_reference, convert_custom_config_dict)
|
is_reference, convert_custom_config_dict,
|
||||||
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
|
use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
|
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant)
|
||||||
|
|
||||||
for key, value in reassign.items():
|
for key, value in reassign.items():
|
||||||
module._modules[key] = value
|
module._modules[key] = value
|
||||||
|
|
||||||
return module
|
return module
|
||||||
|
|
||||||
def swap_module(mod, mapping, custom_module_class_mapping):
|
def swap_module(mod, mapping, custom_module_class_mapping, use_precomputed_fake_quant=False):
|
||||||
r"""Swaps the module if it has a quantized counterpart and it has an
|
r"""Swaps the module if it has a quantized counterpart and it has an
|
||||||
`observer` attached.
|
`observer` attached.
|
||||||
|
|
||||||
|
|
@ -623,7 +636,7 @@ def swap_module(mod, mapping, custom_module_class_mapping):
|
||||||
weight_qparams = get_qparam_dict(weight_post_process)
|
weight_qparams = get_qparam_dict(weight_post_process)
|
||||||
new_mod = qmod.from_float(mod, weight_qparams)
|
new_mod = qmod.from_float(mod, weight_qparams)
|
||||||
else:
|
else:
|
||||||
new_mod = qmod.from_float(mod)
|
new_mod = qmod.from_float(mod, use_precomputed_fake_quant=use_precomputed_fake_quant)
|
||||||
swapped = True
|
swapped = True
|
||||||
|
|
||||||
if swapped:
|
if swapped:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user