mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][embedding qat] Add basic EmbeddingBag QAT fakeQuant workflow (#65443)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65443 Test Plan: Imported from OSS Reviewed By: dagitses, supriyar Differential Revision: D31456445 Pulled By: b-koopman fbshipit-source-id: 0edda6e272d9005fce65f2ba6a5e6abc831836de
This commit is contained in:
parent
64caee1356
commit
a58ff186e8
|
|
@ -1118,6 +1118,40 @@ class TestFusedObsFakeQuantModule(TestCase):
|
|||
self.assertEqual(obs.quant_min, 0)
|
||||
self.assertEqual(obs.quant_max, 127)
|
||||
|
||||
def test_embedding_bag_qat_config(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.emb1 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
|
||||
include_last_offset=True, scale_grad_by_freq=False, mode='sum')
|
||||
self.emb2 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
|
||||
include_last_offset=True, scale_grad_by_freq=False, mode='sum')
|
||||
|
||||
def forward(self, indices):
|
||||
return torch.cat((self.emb1(indices), self.emb2(indices)))
|
||||
|
||||
model = Model()
|
||||
indices = torch.randint(0, 10, (5, 12))
|
||||
|
||||
model.qconfig = torch.ao.quantization.default_embedding_qat_qconfig
|
||||
|
||||
quant_model = torch.quantization.prepare_qat(model)
|
||||
|
||||
count_fake_quant = 0
|
||||
for name, mod in quant_model.named_modules():
|
||||
if name.endswith('weight_fake_quant'):
|
||||
count_fake_quant += 1
|
||||
self.assertEqual(type(mod), FakeQuantize)
|
||||
self.assertEqual(count_fake_quant, 2)
|
||||
|
||||
quant_model(indices)
|
||||
inference_gm = torch.quantization.convert(quant_model.eval().cpu())
|
||||
|
||||
# Ensure that EmbeddingBags are now quantized
|
||||
self.assertEqual(type(inference_gm.emb1), torch.nn.quantized.EmbeddingBag)
|
||||
self.assertEqual(type(inference_gm.emb2), torch.nn.quantized.EmbeddingBag)
|
||||
|
||||
|
||||
def test_default_fused_qat_config(self):
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
|
|||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
|
||||
# Returns a database of args & kwargs that can be used to construct each module.
|
||||
# Each entry is in class -> (args, kwargs) format.
|
||||
# Example: torch.nn.Linear -> ([10, 5], {})
|
||||
|
|
@ -174,6 +173,9 @@ def build_constructor_arg_db():
|
|||
torch.nn.qat.Linear: ((5, 2), {
|
||||
'qconfig': torch.ao.quantization.default_qconfig,
|
||||
}),
|
||||
torch.nn.qat.EmbeddingBag: ((10, 12), {
|
||||
'qconfig': torch.ao.quantization.float_qparams_weight_only_qconfig,
|
||||
}),
|
||||
torch.nn.quantizable.LSTM: ((5, 6), {}),
|
||||
torch.nn.quantizable.LSTMCell: ((5, 6), {}),
|
||||
torch.nn.quantizable.MultiheadAttention: ((10, 2), {}),
|
||||
|
|
|
|||
|
|
@ -194,6 +194,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
set([
|
||||
nn.EmbeddingBag,
|
||||
nnq.EmbeddingBag,
|
||||
nnqat.EmbeddingBag,
|
||||
]),
|
||||
# GroupNorm
|
||||
set([
|
||||
|
|
@ -494,6 +495,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
nn.Conv3d,
|
||||
nnqat.Conv2d,
|
||||
nnqat.Conv3d,
|
||||
nnqat.EmbeddingBag,
|
||||
nn.LSTM,
|
||||
# note: nnqd.Linear is an instance of nnq.Linear, so this
|
||||
# check has to happen before the int8 module check
|
||||
|
|
@ -553,6 +555,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
|
|||
nnq.Hardswish,
|
||||
nnq.LeakyReLU,
|
||||
nnq.ReLU6,
|
||||
nnq.EmbeddingBag,
|
||||
nniq.BNReLU2d,
|
||||
nniq.BNReLU3d,
|
||||
nniq.ConvReLU1d,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from torch.ao.quantization.observer import (
|
|||
MovingAverageMinMaxObserver,
|
||||
HistogramObserver,
|
||||
MovingAveragePerChannelMinMaxObserver,
|
||||
PerChannelMinMaxObserver,
|
||||
_with_args,
|
||||
)
|
||||
import re
|
||||
|
|
@ -11,7 +12,7 @@ from abc import ABC, abstractmethod
|
|||
from typing import Any, Tuple
|
||||
|
||||
def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
|
||||
return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]
|
||||
return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams]
|
||||
|
||||
def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
|
||||
return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
|
||||
|
|
@ -344,6 +345,12 @@ default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAv
|
|||
qscheme=torch.per_channel_symmetric,
|
||||
reduce_range=False,
|
||||
ch_axis=0)
|
||||
|
||||
default_embedding_fake_quant = FakeQuantize.with_args(observer=PerChannelMinMaxObserver,
|
||||
qscheme=torch.per_channel_affine_float_qparams,
|
||||
ch_axis=0,
|
||||
memoryless=True)
|
||||
|
||||
default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
|
|
|
|||
|
|
@ -12,11 +12,13 @@ from torch.ao.quantization.fake_quantize import (
|
|||
default_fused_wt_fake_quant,
|
||||
FusedMovingAvgObsFakeQuantize,
|
||||
default_fused_per_channel_wt_fake_quant,
|
||||
default_embedding_fake_quant,
|
||||
)
|
||||
|
||||
from .observer import (
|
||||
HistogramObserver,
|
||||
MovingAverageMinMaxObserver,
|
||||
NoopObserver,
|
||||
PlaceholderObserver,
|
||||
default_debug_observer,
|
||||
default_dynamic_quant_observer,
|
||||
|
|
@ -127,6 +129,9 @@ def get_default_qconfig(backend='fbgemm'):
|
|||
qconfig = default_qconfig
|
||||
return qconfig
|
||||
|
||||
default_embedding_qat_qconfig = QConfig(activation=NoopObserver,
|
||||
weight=default_embedding_fake_quant)
|
||||
|
||||
def get_default_qat_qconfig(backend='fbgemm', version=1):
|
||||
# Histogram observer is too slow for quantization aware training
|
||||
if version is None:
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
|||
nnqat.Linear: nnq.Linear,
|
||||
nnqat.Conv2d: nnq.Conv2d,
|
||||
nnqat.Conv3d: nnq.Conv3d,
|
||||
nnqat.EmbeddingBag: nnq.EmbeddingBag,
|
||||
}
|
||||
|
||||
# Default map for swapping float module to qat modules
|
||||
|
|
@ -83,6 +84,7 @@ DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = {
|
|||
nn.Conv2d: nnqat.Conv2d,
|
||||
nn.Conv3d: nnqat.Conv3d,
|
||||
nn.Linear: nnqat.Linear,
|
||||
nn.EmbeddingBag: nnqat.EmbeddingBag,
|
||||
nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear,
|
||||
# Intrinsic modules:
|
||||
nni.ConvBn1d: nniqat.ConvBn1d,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from .linear import Linear
|
||||
from .conv import Conv2d
|
||||
from .conv import Conv3d
|
||||
from .embedding_ops import EmbeddingBag
|
||||
|
||||
__all__ = [
|
||||
"Linear",
|
||||
"Conv2d",
|
||||
"Conv3d",
|
||||
"EmbeddingBag",
|
||||
]
|
||||
|
|
|
|||
70
torch/nn/qat/modules/embedding_ops.py
Normal file
70
torch/nn/qat/modules/embedding_ops.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class EmbeddingBag(nn.EmbeddingBag):
|
||||
r"""
|
||||
An embedding bag module attached with FakeQuantize modules for weight,
|
||||
used for quantization aware training.
|
||||
|
||||
We adopt the same interface as `torch.nn.EmbeddingBag`, please see
|
||||
https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag
|
||||
for documentation.
|
||||
|
||||
Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to
|
||||
default.
|
||||
|
||||
Attributes:
|
||||
weight: fake quant module for weight
|
||||
"""
|
||||
_FLOAT_MODULE = nn.EmbeddingBag
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim, max_norm=None, norm_type=2.0,
|
||||
scale_grad_by_freq=False, mode='mean', sparse=False, _weight=None,
|
||||
include_last_offset=False, padding_idx=None, qconfig=None, device=None,
|
||||
dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
|
||||
scale_grad_by_freq, mode, sparse, _weight,
|
||||
include_last_offset, padding_idx, **factory_kwargs)
|
||||
assert qconfig, 'qconfig must be provided for QAT module'
|
||||
assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
|
||||
'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
|
||||
str(qconfig.weight().qscheme)
|
||||
self.qconfig = qconfig
|
||||
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
return F.embedding_bag(input, self.weight_fake_quant(self.weight))
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod):
|
||||
r"""Create a qat module from a float module
|
||||
|
||||
Args: `mod` a float module, either produced by torch.quantization utilities
|
||||
or directly from user
|
||||
"""
|
||||
assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
|
||||
cls._FLOAT_MODULE.__name__
|
||||
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
||||
assert mod.qconfig, 'Input float module must have a valid qconfig'
|
||||
assert mod.qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
|
||||
'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
|
||||
mod.qconfig.weight().qscheme.__name__
|
||||
|
||||
qconfig = mod.qconfig
|
||||
qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.max_norm, mod.norm_type,
|
||||
mod.scale_grad_by_freq, mod.mode, mod.sparse, mod.weight,
|
||||
mod.include_last_offset, mod.padding_idx, qconfig=qconfig)
|
||||
qat_embedding_bag.weight = mod.weight
|
||||
|
||||
return qat_embedding_bag
|
||||
|
||||
def to_float(self):
|
||||
embedding_bag = torch.nn.EmbeddingBag(self.num_embeddings, self.embedding_dim, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
None, self.include_last_offset, self.padding_idx,
|
||||
self.device, self.dtype)
|
||||
embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
|
||||
embedding_bag.train(self.training)
|
||||
return embedding_bag
|
||||
|
|
@ -222,6 +222,10 @@ class EmbeddingBag(Embedding):
|
|||
mod (Module): a float module, either produced by torch.quantization
|
||||
utilities or provided by user
|
||||
"""
|
||||
if hasattr(mod, 'weight_fake_quant'):
|
||||
weight_observer = mod.weight_fake_quant
|
||||
activation_post_process = mod.activation_post_process
|
||||
else:
|
||||
assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
|
||||
nn.EmbeddingBag.__name__
|
||||
assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user