mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][embedding qat] Add basic EmbeddingBag QAT fakeQuant workflow (#65443)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65443 Test Plan: Imported from OSS Reviewed By: dagitses, supriyar Differential Revision: D31456445 Pulled By: b-koopman fbshipit-source-id: 0edda6e272d9005fce65f2ba6a5e6abc831836de
This commit is contained in:
parent
64caee1356
commit
a58ff186e8
|
|
@ -1118,6 +1118,40 @@ class TestFusedObsFakeQuantModule(TestCase):
|
||||||
self.assertEqual(obs.quant_min, 0)
|
self.assertEqual(obs.quant_min, 0)
|
||||||
self.assertEqual(obs.quant_max, 127)
|
self.assertEqual(obs.quant_max, 127)
|
||||||
|
|
||||||
|
def test_embedding_bag_qat_config(self):
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.emb1 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
|
||||||
|
include_last_offset=True, scale_grad_by_freq=False, mode='sum')
|
||||||
|
self.emb2 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
|
||||||
|
include_last_offset=True, scale_grad_by_freq=False, mode='sum')
|
||||||
|
|
||||||
|
def forward(self, indices):
|
||||||
|
return torch.cat((self.emb1(indices), self.emb2(indices)))
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
indices = torch.randint(0, 10, (5, 12))
|
||||||
|
|
||||||
|
model.qconfig = torch.ao.quantization.default_embedding_qat_qconfig
|
||||||
|
|
||||||
|
quant_model = torch.quantization.prepare_qat(model)
|
||||||
|
|
||||||
|
count_fake_quant = 0
|
||||||
|
for name, mod in quant_model.named_modules():
|
||||||
|
if name.endswith('weight_fake_quant'):
|
||||||
|
count_fake_quant += 1
|
||||||
|
self.assertEqual(type(mod), FakeQuantize)
|
||||||
|
self.assertEqual(count_fake_quant, 2)
|
||||||
|
|
||||||
|
quant_model(indices)
|
||||||
|
inference_gm = torch.quantization.convert(quant_model.eval().cpu())
|
||||||
|
|
||||||
|
# Ensure that EmbeddingBags are now quantized
|
||||||
|
self.assertEqual(type(inference_gm.emb1), torch.nn.quantized.EmbeddingBag)
|
||||||
|
self.assertEqual(type(inference_gm.emb2), torch.nn.quantized.EmbeddingBag)
|
||||||
|
|
||||||
|
|
||||||
def test_default_fused_qat_config(self):
|
def test_default_fused_qat_config(self):
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
|
||||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||||
|
|
||||||
|
|
||||||
# Returns a database of args & kwargs that can be used to construct each module.
|
# Returns a database of args & kwargs that can be used to construct each module.
|
||||||
# Each entry is in class -> (args, kwargs) format.
|
# Each entry is in class -> (args, kwargs) format.
|
||||||
# Example: torch.nn.Linear -> ([10, 5], {})
|
# Example: torch.nn.Linear -> ([10, 5], {})
|
||||||
|
|
@ -174,6 +173,9 @@ def build_constructor_arg_db():
|
||||||
torch.nn.qat.Linear: ((5, 2), {
|
torch.nn.qat.Linear: ((5, 2), {
|
||||||
'qconfig': torch.ao.quantization.default_qconfig,
|
'qconfig': torch.ao.quantization.default_qconfig,
|
||||||
}),
|
}),
|
||||||
|
torch.nn.qat.EmbeddingBag: ((10, 12), {
|
||||||
|
'qconfig': torch.ao.quantization.float_qparams_weight_only_qconfig,
|
||||||
|
}),
|
||||||
torch.nn.quantizable.LSTM: ((5, 6), {}),
|
torch.nn.quantizable.LSTM: ((5, 6), {}),
|
||||||
torch.nn.quantizable.LSTMCell: ((5, 6), {}),
|
torch.nn.quantizable.LSTMCell: ((5, 6), {}),
|
||||||
torch.nn.quantizable.MultiheadAttention: ((10, 2), {}),
|
torch.nn.quantizable.MultiheadAttention: ((10, 2), {}),
|
||||||
|
|
|
||||||
|
|
@ -194,6 +194,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
|
||||||
set([
|
set([
|
||||||
nn.EmbeddingBag,
|
nn.EmbeddingBag,
|
||||||
nnq.EmbeddingBag,
|
nnq.EmbeddingBag,
|
||||||
|
nnqat.EmbeddingBag,
|
||||||
]),
|
]),
|
||||||
# GroupNorm
|
# GroupNorm
|
||||||
set([
|
set([
|
||||||
|
|
@ -494,6 +495,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
||||||
nn.Conv3d,
|
nn.Conv3d,
|
||||||
nnqat.Conv2d,
|
nnqat.Conv2d,
|
||||||
nnqat.Conv3d,
|
nnqat.Conv3d,
|
||||||
|
nnqat.EmbeddingBag,
|
||||||
nn.LSTM,
|
nn.LSTM,
|
||||||
# note: nnqd.Linear is an instance of nnq.Linear, so this
|
# note: nnqd.Linear is an instance of nnq.Linear, so this
|
||||||
# check has to happen before the int8 module check
|
# check has to happen before the int8 module check
|
||||||
|
|
@ -553,6 +555,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
||||||
nnq.Hardswish,
|
nnq.Hardswish,
|
||||||
nnq.LeakyReLU,
|
nnq.LeakyReLU,
|
||||||
nnq.ReLU6,
|
nnq.ReLU6,
|
||||||
|
nnq.EmbeddingBag,
|
||||||
nniq.BNReLU2d,
|
nniq.BNReLU2d,
|
||||||
nniq.BNReLU3d,
|
nniq.BNReLU3d,
|
||||||
nniq.ConvReLU1d,
|
nniq.ConvReLU1d,
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from torch.ao.quantization.observer import (
|
||||||
MovingAverageMinMaxObserver,
|
MovingAverageMinMaxObserver,
|
||||||
HistogramObserver,
|
HistogramObserver,
|
||||||
MovingAveragePerChannelMinMaxObserver,
|
MovingAveragePerChannelMinMaxObserver,
|
||||||
|
PerChannelMinMaxObserver,
|
||||||
_with_args,
|
_with_args,
|
||||||
)
|
)
|
||||||
import re
|
import re
|
||||||
|
|
@ -11,7 +12,7 @@ from abc import ABC, abstractmethod
|
||||||
from typing import Any, Tuple
|
from typing import Any, Tuple
|
||||||
|
|
||||||
def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
|
def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
|
||||||
return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]
|
return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams]
|
||||||
|
|
||||||
def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
|
def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
|
||||||
return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
|
return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
|
||||||
|
|
@ -344,6 +345,12 @@ default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAv
|
||||||
qscheme=torch.per_channel_symmetric,
|
qscheme=torch.per_channel_symmetric,
|
||||||
reduce_range=False,
|
reduce_range=False,
|
||||||
ch_axis=0)
|
ch_axis=0)
|
||||||
|
|
||||||
|
default_embedding_fake_quant = FakeQuantize.with_args(observer=PerChannelMinMaxObserver,
|
||||||
|
qscheme=torch.per_channel_affine_float_qparams,
|
||||||
|
ch_axis=0,
|
||||||
|
memoryless=True)
|
||||||
|
|
||||||
default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
|
default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
|
||||||
quant_min=0,
|
quant_min=0,
|
||||||
quant_max=255,
|
quant_max=255,
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,13 @@ from torch.ao.quantization.fake_quantize import (
|
||||||
default_fused_wt_fake_quant,
|
default_fused_wt_fake_quant,
|
||||||
FusedMovingAvgObsFakeQuantize,
|
FusedMovingAvgObsFakeQuantize,
|
||||||
default_fused_per_channel_wt_fake_quant,
|
default_fused_per_channel_wt_fake_quant,
|
||||||
|
default_embedding_fake_quant,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .observer import (
|
from .observer import (
|
||||||
HistogramObserver,
|
HistogramObserver,
|
||||||
MovingAverageMinMaxObserver,
|
MovingAverageMinMaxObserver,
|
||||||
|
NoopObserver,
|
||||||
PlaceholderObserver,
|
PlaceholderObserver,
|
||||||
default_debug_observer,
|
default_debug_observer,
|
||||||
default_dynamic_quant_observer,
|
default_dynamic_quant_observer,
|
||||||
|
|
@ -127,6 +129,9 @@ def get_default_qconfig(backend='fbgemm'):
|
||||||
qconfig = default_qconfig
|
qconfig = default_qconfig
|
||||||
return qconfig
|
return qconfig
|
||||||
|
|
||||||
|
default_embedding_qat_qconfig = QConfig(activation=NoopObserver,
|
||||||
|
weight=default_embedding_fake_quant)
|
||||||
|
|
||||||
def get_default_qat_qconfig(backend='fbgemm', version=1):
|
def get_default_qat_qconfig(backend='fbgemm', version=1):
|
||||||
# Histogram observer is too slow for quantization aware training
|
# Histogram observer is too slow for quantization aware training
|
||||||
if version is None:
|
if version is None:
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,7 @@ DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
||||||
nnqat.Linear: nnq.Linear,
|
nnqat.Linear: nnq.Linear,
|
||||||
nnqat.Conv2d: nnq.Conv2d,
|
nnqat.Conv2d: nnq.Conv2d,
|
||||||
nnqat.Conv3d: nnq.Conv3d,
|
nnqat.Conv3d: nnq.Conv3d,
|
||||||
|
nnqat.EmbeddingBag: nnq.EmbeddingBag,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Default map for swapping float module to qat modules
|
# Default map for swapping float module to qat modules
|
||||||
|
|
@ -83,6 +84,7 @@ DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
||||||
nn.Conv2d: nnqat.Conv2d,
|
nn.Conv2d: nnqat.Conv2d,
|
||||||
nn.Conv3d: nnqat.Conv3d,
|
nn.Conv3d: nnqat.Conv3d,
|
||||||
nn.Linear: nnqat.Linear,
|
nn.Linear: nnqat.Linear,
|
||||||
|
nn.EmbeddingBag: nnqat.EmbeddingBag,
|
||||||
nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear,
|
nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear,
|
||||||
# Intrinsic modules:
|
# Intrinsic modules:
|
||||||
nni.ConvBn1d: nniqat.ConvBn1d,
|
nni.ConvBn1d: nniqat.ConvBn1d,
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
from .linear import Linear
|
from .linear import Linear
|
||||||
from .conv import Conv2d
|
from .conv import Conv2d
|
||||||
from .conv import Conv3d
|
from .conv import Conv3d
|
||||||
|
from .embedding_ops import EmbeddingBag
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Linear",
|
"Linear",
|
||||||
"Conv2d",
|
"Conv2d",
|
||||||
"Conv3d",
|
"Conv3d",
|
||||||
|
"EmbeddingBag",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
70
torch/nn/qat/modules/embedding_ops.py
Normal file
70
torch/nn/qat/modules/embedding_ops.py
Normal file
|
|
@ -0,0 +1,70 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class EmbeddingBag(nn.EmbeddingBag):
|
||||||
|
r"""
|
||||||
|
An embedding bag module attached with FakeQuantize modules for weight,
|
||||||
|
used for quantization aware training.
|
||||||
|
|
||||||
|
We adopt the same interface as `torch.nn.EmbeddingBag`, please see
|
||||||
|
https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag
|
||||||
|
for documentation.
|
||||||
|
|
||||||
|
Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to
|
||||||
|
default.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
weight: fake quant module for weight
|
||||||
|
"""
|
||||||
|
_FLOAT_MODULE = nn.EmbeddingBag
|
||||||
|
|
||||||
|
def __init__(self, num_embeddings, embedding_dim, max_norm=None, norm_type=2.0,
|
||||||
|
scale_grad_by_freq=False, mode='mean', sparse=False, _weight=None,
|
||||||
|
include_last_offset=False, padding_idx=None, qconfig=None, device=None,
|
||||||
|
dtype=None) -> None:
|
||||||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
|
||||||
|
scale_grad_by_freq, mode, sparse, _weight,
|
||||||
|
include_last_offset, padding_idx, **factory_kwargs)
|
||||||
|
assert qconfig, 'qconfig must be provided for QAT module'
|
||||||
|
assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
|
||||||
|
'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
|
||||||
|
str(qconfig.weight().qscheme)
|
||||||
|
self.qconfig = qconfig
|
||||||
|
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return F.embedding_bag(input, self.weight_fake_quant(self.weight))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_float(cls, mod):
|
||||||
|
r"""Create a qat module from a float module
|
||||||
|
|
||||||
|
Args: `mod` a float module, either produced by torch.quantization utilities
|
||||||
|
or directly from user
|
||||||
|
"""
|
||||||
|
assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
|
||||||
|
cls._FLOAT_MODULE.__name__
|
||||||
|
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
||||||
|
assert mod.qconfig, 'Input float module must have a valid qconfig'
|
||||||
|
assert mod.qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
|
||||||
|
'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
|
||||||
|
mod.qconfig.weight().qscheme.__name__
|
||||||
|
|
||||||
|
qconfig = mod.qconfig
|
||||||
|
qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.max_norm, mod.norm_type,
|
||||||
|
mod.scale_grad_by_freq, mod.mode, mod.sparse, mod.weight,
|
||||||
|
mod.include_last_offset, mod.padding_idx, qconfig=qconfig)
|
||||||
|
qat_embedding_bag.weight = mod.weight
|
||||||
|
|
||||||
|
return qat_embedding_bag
|
||||||
|
|
||||||
|
def to_float(self):
|
||||||
|
embedding_bag = torch.nn.EmbeddingBag(self.num_embeddings, self.embedding_dim, self.max_norm,
|
||||||
|
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||||
|
None, self.include_last_offset, self.padding_idx,
|
||||||
|
self.device, self.dtype)
|
||||||
|
embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
|
||||||
|
embedding_bag.train(self.training)
|
||||||
|
return embedding_bag
|
||||||
|
|
@ -222,14 +222,18 @@ class EmbeddingBag(Embedding):
|
||||||
mod (Module): a float module, either produced by torch.quantization
|
mod (Module): a float module, either produced by torch.quantization
|
||||||
utilities or provided by user
|
utilities or provided by user
|
||||||
"""
|
"""
|
||||||
assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
|
if hasattr(mod, 'weight_fake_quant'):
|
||||||
nn.EmbeddingBag.__name__
|
weight_observer = mod.weight_fake_quant
|
||||||
assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
|
activation_post_process = mod.activation_post_process
|
||||||
from torch.quantization.qconfig import float_qparams_weight_only_qconfig
|
|
||||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
||||||
weight_observer = mod.qconfig.weight()
|
|
||||||
else:
|
else:
|
||||||
weight_observer = float_qparams_weight_only_qconfig.weight()
|
assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
|
||||||
|
nn.EmbeddingBag.__name__
|
||||||
|
assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
|
||||||
|
from torch.quantization.qconfig import float_qparams_weight_only_qconfig
|
||||||
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||||
|
weight_observer = mod.qconfig.weight()
|
||||||
|
else:
|
||||||
|
weight_observer = float_qparams_weight_only_qconfig.weight()
|
||||||
|
|
||||||
dtype = weight_observer.dtype
|
dtype = weight_observer.dtype
|
||||||
is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams
|
is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user