[quant][embedding qat] Add basic EmbeddingBag QAT fakeQuant workflow (#65443)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65443

Test Plan: Imported from OSS

Reviewed By: dagitses, supriyar

Differential Revision: D31456445

Pulled By: b-koopman

fbshipit-source-id: 0edda6e272d9005fce65f2ba6a5e6abc831836de
This commit is contained in:
Ben Koopman 2021-10-07 20:15:46 -07:00 committed by Facebook GitHub Bot
parent 64caee1356
commit a58ff186e8
9 changed files with 138 additions and 9 deletions

View File

@ -1118,6 +1118,40 @@ class TestFusedObsFakeQuantModule(TestCase):
self.assertEqual(obs.quant_min, 0)
self.assertEqual(obs.quant_max, 127)
def test_embedding_bag_qat_config(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.emb1 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
include_last_offset=True, scale_grad_by_freq=False, mode='sum')
self.emb2 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
include_last_offset=True, scale_grad_by_freq=False, mode='sum')
def forward(self, indices):
return torch.cat((self.emb1(indices), self.emb2(indices)))
model = Model()
indices = torch.randint(0, 10, (5, 12))
model.qconfig = torch.ao.quantization.default_embedding_qat_qconfig
quant_model = torch.quantization.prepare_qat(model)
count_fake_quant = 0
for name, mod in quant_model.named_modules():
if name.endswith('weight_fake_quant'):
count_fake_quant += 1
self.assertEqual(type(mod), FakeQuantize)
self.assertEqual(count_fake_quant, 2)
quant_model(indices)
inference_gm = torch.quantization.convert(quant_model.eval().cpu())
# Ensure that EmbeddingBags are now quantized
self.assertEqual(type(inference_gm.emb1), torch.nn.quantized.EmbeddingBag)
self.assertEqual(type(inference_gm.emb2), torch.nn.quantized.EmbeddingBag)
def test_default_fused_qat_config(self):
class Model(nn.Module):
def __init__(self):

View File

@ -7,7 +7,6 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_utils import TestCase, run_tests
# Returns a database of args & kwargs that can be used to construct each module.
# Each entry is in class -> (args, kwargs) format.
# Example: torch.nn.Linear -> ([10, 5], {})
@ -174,6 +173,9 @@ def build_constructor_arg_db():
torch.nn.qat.Linear: ((5, 2), {
'qconfig': torch.ao.quantization.default_qconfig,
}),
torch.nn.qat.EmbeddingBag: ((10, 12), {
'qconfig': torch.ao.quantization.float_qparams_weight_only_qconfig,
}),
torch.nn.quantizable.LSTM: ((5, 6), {}),
torch.nn.quantizable.LSTMCell: ((5, 6), {}),
torch.nn.quantizable.MultiheadAttention: ((10, 2), {}),

View File

@ -194,6 +194,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
set([
nn.EmbeddingBag,
nnq.EmbeddingBag,
nnqat.EmbeddingBag,
]),
# GroupNorm
set([
@ -494,6 +495,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
nn.Conv3d,
nnqat.Conv2d,
nnqat.Conv3d,
nnqat.EmbeddingBag,
nn.LSTM,
# note: nnqd.Linear is an instance of nnq.Linear, so this
# check has to happen before the int8 module check
@ -553,6 +555,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
nnq.Hardswish,
nnq.LeakyReLU,
nnq.ReLU6,
nnq.EmbeddingBag,
nniq.BNReLU2d,
nniq.BNReLU3d,
nniq.ConvReLU1d,

View File

@ -4,6 +4,7 @@ from torch.ao.quantization.observer import (
MovingAverageMinMaxObserver,
HistogramObserver,
MovingAveragePerChannelMinMaxObserver,
PerChannelMinMaxObserver,
_with_args,
)
import re
@ -11,7 +12,7 @@ from abc import ABC, abstractmethod
from typing import Any, Tuple
def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]
return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams]
def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
@ -344,6 +345,12 @@ default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAv
qscheme=torch.per_channel_symmetric,
reduce_range=False,
ch_axis=0)
default_embedding_fake_quant = FakeQuantize.with_args(observer=PerChannelMinMaxObserver,
qscheme=torch.per_channel_affine_float_qparams,
ch_axis=0,
memoryless=True)
default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
quant_min=0,
quant_max=255,

View File

@ -12,11 +12,13 @@ from torch.ao.quantization.fake_quantize import (
default_fused_wt_fake_quant,
FusedMovingAvgObsFakeQuantize,
default_fused_per_channel_wt_fake_quant,
default_embedding_fake_quant,
)
from .observer import (
HistogramObserver,
MovingAverageMinMaxObserver,
NoopObserver,
PlaceholderObserver,
default_debug_observer,
default_dynamic_quant_observer,
@ -127,6 +129,9 @@ def get_default_qconfig(backend='fbgemm'):
qconfig = default_qconfig
return qconfig
default_embedding_qat_qconfig = QConfig(activation=NoopObserver,
weight=default_embedding_fake_quant)
def get_default_qat_qconfig(backend='fbgemm', version=1):
# Histogram observer is too slow for quantization aware training
if version is None:

View File

@ -76,6 +76,7 @@ DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
nnqat.Linear: nnq.Linear,
nnqat.Conv2d: nnq.Conv2d,
nnqat.Conv3d: nnq.Conv3d,
nnqat.EmbeddingBag: nnq.EmbeddingBag,
}
# Default map for swapping float module to qat modules
@ -83,6 +84,7 @@ DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = {
nn.Conv2d: nnqat.Conv2d,
nn.Conv3d: nnqat.Conv3d,
nn.Linear: nnqat.Linear,
nn.EmbeddingBag: nnqat.EmbeddingBag,
nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear,
# Intrinsic modules:
nni.ConvBn1d: nniqat.ConvBn1d,

View File

@ -1,9 +1,11 @@
from .linear import Linear
from .conv import Conv2d
from .conv import Conv3d
from .embedding_ops import EmbeddingBag
__all__ = [
"Linear",
"Conv2d",
"Conv3d",
"EmbeddingBag",
]

View File

@ -0,0 +1,70 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class EmbeddingBag(nn.EmbeddingBag):
r"""
An embedding bag module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.EmbeddingBag`, please see
https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag
for documentation.
Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
"""
_FLOAT_MODULE = nn.EmbeddingBag
def __init__(self, num_embeddings, embedding_dim, max_norm=None, norm_type=2.0,
scale_grad_by_freq=False, mode='mean', sparse=False, _weight=None,
include_last_offset=False, padding_idx=None, qconfig=None, device=None,
dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
scale_grad_by_freq, mode, sparse, _weight,
include_last_offset, padding_idx, **factory_kwargs)
assert qconfig, 'qconfig must be provided for QAT module'
assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
str(qconfig.weight().qscheme)
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input):
return F.embedding_bag(input, self.weight_fake_quant(self.weight))
@classmethod
def from_float(cls, mod):
r"""Create a qat module from a float module
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
assert mod.qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
mod.qconfig.weight().qscheme.__name__
qconfig = mod.qconfig
qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.max_norm, mod.norm_type,
mod.scale_grad_by_freq, mod.mode, mod.sparse, mod.weight,
mod.include_last_offset, mod.padding_idx, qconfig=qconfig)
qat_embedding_bag.weight = mod.weight
return qat_embedding_bag
def to_float(self):
embedding_bag = torch.nn.EmbeddingBag(self.num_embeddings, self.embedding_dim, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
None, self.include_last_offset, self.padding_idx,
self.device, self.dtype)
embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
embedding_bag.train(self.training)
return embedding_bag

View File

@ -222,14 +222,18 @@ class EmbeddingBag(Embedding):
mod (Module): a float module, either produced by torch.quantization
utilities or provided by user
"""
assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
nn.EmbeddingBag.__name__
assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
from torch.quantization.qconfig import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
if hasattr(mod, 'weight_fake_quant'):
weight_observer = mod.weight_fake_quant
activation_post_process = mod.activation_post_process
else:
weight_observer = float_qparams_weight_only_qconfig.weight()
assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
nn.EmbeddingBag.__name__
assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
from torch.quantization.qconfig import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
weight_observer = float_qparams_weight_only_qconfig.weight()
dtype = weight_observer.dtype
is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams