pytorch/torch/nn/qat/modules/embedding_ops.py
Ben Koopman 0036e41143 [quant][embedding qat] Add eager QAT test for EmbeddingBag+Linear model (#66334)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66334

Test Plan: Imported from OSS

Reviewed By: HDCharles

Differential Revision: D31618283

Pulled By: b-koopman

fbshipit-source-id: bb824a341f1aa9d7e83f8e66d320a9dfd348a1d7
2021-10-19 07:03:36 -07:00

76 lines
3.7 KiB
Python

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
class EmbeddingBag(nn.EmbeddingBag):
r"""
An embedding bag module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.EmbeddingBag`, please see
https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag
for documentation.
Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
"""
_FLOAT_MODULE = nn.EmbeddingBag
def __init__(self, num_embeddings, embedding_dim, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, mode='mean',
sparse=False, _weight=None, include_last_offset=False,
padding_idx=None, qconfig=None, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
scale_grad_by_freq, mode, sparse, _weight,
include_last_offset, padding_idx, **factory_kwargs)
assert qconfig, 'qconfig must be provided for QAT module'
assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
str(qconfig.weight().qscheme)
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor:
return F.embedding_bag(input, self.weight_fake_quant(self.weight), offsets,
self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset,
self.padding_idx)
@classmethod
def from_float(cls, mod):
r"""Create a qat module from a float module
Args: `mod` a float module, either produced by torch.ao.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
assert mod.qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
mod.qconfig.weight().qscheme.__name__
qconfig = mod.qconfig
qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.max_norm, mod.norm_type,
mod.scale_grad_by_freq, mod.mode, mod.sparse, mod.weight,
mod.include_last_offset, mod.padding_idx, qconfig=qconfig)
qat_embedding_bag.weight = mod.weight
return qat_embedding_bag
def to_float(self):
embedding_bag = torch.nn.EmbeddingBag(self.num_embeddings, self.embedding_dim, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
None, self.include_last_offset, self.padding_idx,
self.device, self.dtype)
embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
embedding_bag.train(self.training)
return embedding_bag