diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index 4a08de59790..25ab934226c 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -1118,6 +1118,40 @@ class TestFusedObsFakeQuantModule(TestCase): self.assertEqual(obs.quant_min, 0) self.assertEqual(obs.quant_max, 127) + def test_embedding_bag_qat_config(self): + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.emb1 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, + include_last_offset=True, scale_grad_by_freq=False, mode='sum') + self.emb2 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, + include_last_offset=True, scale_grad_by_freq=False, mode='sum') + + def forward(self, indices): + return torch.cat((self.emb1(indices), self.emb2(indices))) + + model = Model() + indices = torch.randint(0, 10, (5, 12)) + + model.qconfig = torch.ao.quantization.default_embedding_qat_qconfig + + quant_model = torch.quantization.prepare_qat(model) + + count_fake_quant = 0 + for name, mod in quant_model.named_modules(): + if name.endswith('weight_fake_quant'): + count_fake_quant += 1 + self.assertEqual(type(mod), FakeQuantize) + self.assertEqual(count_fake_quant, 2) + + quant_model(indices) + inference_gm = torch.quantization.convert(quant_model.eval().cpu()) + + # Ensure that EmbeddingBags are now quantized + self.assertEqual(type(inference_gm.emb1), torch.nn.quantized.EmbeddingBag) + self.assertEqual(type(inference_gm.emb2), torch.nn.quantized.EmbeddingBag) + + def test_default_fused_qat_config(self): class Model(nn.Module): def __init__(self): diff --git a/test/test_module_init.py b/test/test_module_init.py index ed92fa86824..6f0c1263a18 100644 --- a/test/test_module_init.py +++ b/test/test_module_init.py @@ -7,7 +7,6 @@ from torch.testing._internal.common_device_type import instantiate_device_type_t from torch.testing._internal.common_quantization import skipIfNoFBGEMM from torch.testing._internal.common_utils import TestCase, run_tests - # Returns a database of args & kwargs that can be used to construct each module. # Each entry is in class -> (args, kwargs) format. # Example: torch.nn.Linear -> ([10, 5], {}) @@ -174,6 +173,9 @@ def build_constructor_arg_db(): torch.nn.qat.Linear: ((5, 2), { 'qconfig': torch.ao.quantization.default_qconfig, }), + torch.nn.qat.EmbeddingBag: ((10, 12), { + 'qconfig': torch.ao.quantization.float_qparams_weight_only_qconfig, + }), torch.nn.quantizable.LSTM: ((5, 6), {}), torch.nn.quantizable.LSTMCell: ((5, 6), {}), torch.nn.quantizable.MultiheadAttention: ((10, 2), {}), diff --git a/torch/ao/ns/fx/mappings.py b/torch/ao/ns/fx/mappings.py index e97d77119d0..01d9d9f8166 100644 --- a/torch/ao/ns/fx/mappings.py +++ b/torch/ao/ns/fx/mappings.py @@ -194,6 +194,7 @@ def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]: set([ nn.EmbeddingBag, nnq.EmbeddingBag, + nnqat.EmbeddingBag, ]), # GroupNorm set([ @@ -494,6 +495,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: nn.Conv3d, nnqat.Conv2d, nnqat.Conv3d, + nnqat.EmbeddingBag, nn.LSTM, # note: nnqd.Linear is an instance of nnq.Linear, so this # check has to happen before the int8 module check @@ -553,6 +555,7 @@ def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]: nnq.Hardswish, nnq.LeakyReLU, nnq.ReLU6, + nnq.EmbeddingBag, nniq.BNReLU2d, nniq.BNReLU3d, nniq.ConvReLU1d, diff --git a/torch/ao/quantization/fake_quantize.py b/torch/ao/quantization/fake_quantize.py index 5d5dea07d1c..5327e278375 100644 --- a/torch/ao/quantization/fake_quantize.py +++ b/torch/ao/quantization/fake_quantize.py @@ -4,6 +4,7 @@ from torch.ao.quantization.observer import ( MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, + PerChannelMinMaxObserver, _with_args, ) import re @@ -11,7 +12,7 @@ from abc import ABC, abstractmethod from typing import Any, Tuple def _is_per_channel(qscheme: 'torch.qscheme') -> bool: - return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine] + return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams] def _is_per_tensor(qscheme: 'torch.qscheme') -> bool: return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] @@ -344,6 +345,12 @@ default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAv qscheme=torch.per_channel_symmetric, reduce_range=False, ch_axis=0) + +default_embedding_fake_quant = FakeQuantize.with_args(observer=PerChannelMinMaxObserver, + qscheme=torch.per_channel_affine_float_qparams, + ch_axis=0, + memoryless=True) + default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver, quant_min=0, quant_max=255, diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 9dce628178e..869724bd03e 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -12,11 +12,13 @@ from torch.ao.quantization.fake_quantize import ( default_fused_wt_fake_quant, FusedMovingAvgObsFakeQuantize, default_fused_per_channel_wt_fake_quant, + default_embedding_fake_quant, ) from .observer import ( HistogramObserver, MovingAverageMinMaxObserver, + NoopObserver, PlaceholderObserver, default_debug_observer, default_dynamic_quant_observer, @@ -127,6 +129,9 @@ def get_default_qconfig(backend='fbgemm'): qconfig = default_qconfig return qconfig +default_embedding_qat_qconfig = QConfig(activation=NoopObserver, + weight=default_embedding_fake_quant) + def get_default_qat_qconfig(backend='fbgemm', version=1): # Histogram observer is too slow for quantization aware training if version is None: diff --git a/torch/ao/quantization/quantization_mappings.py b/torch/ao/quantization/quantization_mappings.py index 09832d14ad0..f8e42c2ccd7 100644 --- a/torch/ao/quantization/quantization_mappings.py +++ b/torch/ao/quantization/quantization_mappings.py @@ -76,6 +76,7 @@ DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = { nnqat.Linear: nnq.Linear, nnqat.Conv2d: nnq.Conv2d, nnqat.Conv3d: nnq.Conv3d, + nnqat.EmbeddingBag: nnq.EmbeddingBag, } # Default map for swapping float module to qat modules @@ -83,6 +84,7 @@ DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = { nn.Conv2d: nnqat.Conv2d, nn.Conv3d: nnqat.Conv3d, nn.Linear: nnqat.Linear, + nn.EmbeddingBag: nnqat.EmbeddingBag, nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear, # Intrinsic modules: nni.ConvBn1d: nniqat.ConvBn1d, diff --git a/torch/nn/qat/modules/__init__.py b/torch/nn/qat/modules/__init__.py index 585ccb7ad58..1a0707c05d8 100644 --- a/torch/nn/qat/modules/__init__.py +++ b/torch/nn/qat/modules/__init__.py @@ -1,9 +1,11 @@ from .linear import Linear from .conv import Conv2d from .conv import Conv3d +from .embedding_ops import EmbeddingBag __all__ = [ "Linear", "Conv2d", "Conv3d", + "EmbeddingBag", ] diff --git a/torch/nn/qat/modules/embedding_ops.py b/torch/nn/qat/modules/embedding_ops.py new file mode 100644 index 00000000000..aaea288bc6b --- /dev/null +++ b/torch/nn/qat/modules/embedding_ops.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class EmbeddingBag(nn.EmbeddingBag): + r""" + An embedding bag module attached with FakeQuantize modules for weight, + used for quantization aware training. + + We adopt the same interface as `torch.nn.EmbeddingBag`, please see + https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag + for documentation. + + Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to + default. + + Attributes: + weight: fake quant module for weight + """ + _FLOAT_MODULE = nn.EmbeddingBag + + def __init__(self, num_embeddings, embedding_dim, max_norm=None, norm_type=2.0, + scale_grad_by_freq=False, mode='mean', sparse=False, _weight=None, + include_last_offset=False, padding_idx=None, qconfig=None, device=None, + dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__(num_embeddings, embedding_dim, max_norm, norm_type, + scale_grad_by_freq, mode, sparse, _weight, + include_last_offset, padding_idx, **factory_kwargs) + assert qconfig, 'qconfig must be provided for QAT module' + assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \ + 'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \ + str(qconfig.weight().qscheme) + self.qconfig = qconfig + self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs) + + def forward(self, input): + return F.embedding_bag(input, self.weight_fake_quant(self.weight)) + + @classmethod + def from_float(cls, mod): + r"""Create a qat module from a float module + + Args: `mod` a float module, either produced by torch.quantization utilities + or directly from user + """ + assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \ + cls._FLOAT_MODULE.__name__ + assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' + assert mod.qconfig, 'Input float module must have a valid qconfig' + assert mod.qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \ + 'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \ + mod.qconfig.weight().qscheme.__name__ + + qconfig = mod.qconfig + qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.max_norm, mod.norm_type, + mod.scale_grad_by_freq, mod.mode, mod.sparse, mod.weight, + mod.include_last_offset, mod.padding_idx, qconfig=qconfig) + qat_embedding_bag.weight = mod.weight + + return qat_embedding_bag + + def to_float(self): + embedding_bag = torch.nn.EmbeddingBag(self.num_embeddings, self.embedding_dim, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + None, self.include_last_offset, self.padding_idx, + self.device, self.dtype) + embedding_bag.weight = torch.nn.Parameter(self.weight.detach()) + embedding_bag.train(self.training) + return embedding_bag diff --git a/torch/nn/quantized/modules/embedding_ops.py b/torch/nn/quantized/modules/embedding_ops.py index 4ab7a6e826b..cda7afa9957 100644 --- a/torch/nn/quantized/modules/embedding_ops.py +++ b/torch/nn/quantized/modules/embedding_ops.py @@ -222,14 +222,18 @@ class EmbeddingBag(Embedding): mod (Module): a float module, either produced by torch.quantization utilities or provided by user """ - assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \ - nn.EmbeddingBag.__name__ - assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined' - from torch.quantization.qconfig import float_qparams_weight_only_qconfig - if mod.qconfig is not None and mod.qconfig.weight is not None: - weight_observer = mod.qconfig.weight() + if hasattr(mod, 'weight_fake_quant'): + weight_observer = mod.weight_fake_quant + activation_post_process = mod.activation_post_process else: - weight_observer = float_qparams_weight_only_qconfig.weight() + assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \ + nn.EmbeddingBag.__name__ + assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined' + from torch.quantization.qconfig import float_qparams_weight_only_qconfig + if mod.qconfig is not None and mod.qconfig.weight is not None: + weight_observer = mod.qconfig.weight() + else: + weight_observer = float_qparams_weight_only_qconfig.weight() dtype = weight_observer.dtype is_float_qparams_qconfig = weight_observer.qscheme == torch.per_channel_affine_float_qparams