mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Previously we are still relying on the registration mechnism and get the default quantize handlers that are registered, now we have moved all registration to backend_config_dict we can get all quant patterns just from backend_config_dict now. This PR enables using native backend_config_dict everywhere in prepare when the backend_config_dict is None, we'll also do similar changes in convert as well Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps python test/test_quantization.py TestFXNumericSuiteCoreAPIs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75469 Approved by: https://github.com/vkuzo
143 lines
6.9 KiB
Python
143 lines
6.9 KiB
Python
import torch
|
|
from torch import Tensor
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class Embedding(nn.Embedding):
|
|
r"""
|
|
An embedding bag module attached with FakeQuantize modules for weight,
|
|
used for quantization aware training.
|
|
|
|
We adopt the same interface as `torch.nn.Embedding`, please see
|
|
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding
|
|
for documentation.
|
|
|
|
Similar to `torch.nn.Embedding`, with FakeQuantize modules initialized to
|
|
default.
|
|
|
|
Attributes:
|
|
weight: fake quant module for weight
|
|
"""
|
|
_FLOAT_MODULE = nn.Embedding
|
|
|
|
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
|
|
max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
|
|
sparse=False, _weight=None, device=None, dtype=None, qconfig=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
|
|
norm_type, scale_grad_by_freq, sparse, _weight,
|
|
**factory_kwargs)
|
|
assert qconfig, 'qconfig must be provided for QAT module'
|
|
assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
|
|
'Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
|
|
str(qconfig.weight().qscheme)
|
|
self.qconfig = qconfig
|
|
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
|
|
|
|
def forward(self, input) -> Tensor:
|
|
return F.embedding(input, self.weight_fake_quant(self.weight), self.padding_idx,
|
|
self.max_norm, self.norm_type, self.scale_grad_by_freq,
|
|
self.sparse)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
r"""Create a qat module from a float module
|
|
|
|
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
|
or directly from user
|
|
"""
|
|
assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
|
|
cls._FLOAT_MODULE.__name__
|
|
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
|
assert mod.qconfig, 'Input float module must have a valid qconfig'
|
|
weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator]
|
|
assert weight_qscheme == torch.per_channel_affine_float_qparams, \
|
|
'Embedding weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
|
|
str(weight_qscheme)
|
|
|
|
qconfig = mod.qconfig
|
|
qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.padding_idx,
|
|
mod.max_norm, mod.norm_type, mod.scale_grad_by_freq,
|
|
mod.sparse, mod.weight, qconfig=qconfig)
|
|
|
|
return qat_embedding_bag
|
|
|
|
def to_float(self):
|
|
embedding_bag = torch.nn.Embedding(self.num_embeddings, self.embedding_dim, self.padding_idx,
|
|
self.max_norm, self.norm_type, self.scale_grad_by_freq,
|
|
self.sparse, None)
|
|
embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
|
|
embedding_bag.train(self.training)
|
|
return embedding_bag
|
|
|
|
class EmbeddingBag(nn.EmbeddingBag):
|
|
r"""
|
|
An embedding bag module attached with FakeQuantize modules for weight,
|
|
used for quantization aware training.
|
|
|
|
We adopt the same interface as `torch.nn.EmbeddingBag`, please see
|
|
https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html#torch.nn.EmbeddingBag
|
|
for documentation.
|
|
|
|
Similar to `torch.nn.EmbeddingBag`, with FakeQuantize modules initialized to
|
|
default.
|
|
|
|
Attributes:
|
|
weight: fake quant module for weight
|
|
"""
|
|
_FLOAT_MODULE = nn.EmbeddingBag
|
|
|
|
def __init__(self, num_embeddings, embedding_dim, max_norm=None,
|
|
norm_type=2.0, scale_grad_by_freq=False, mode='mean',
|
|
sparse=False, _weight=None, include_last_offset=False,
|
|
padding_idx=None, qconfig=None, device=None, dtype=None) -> None:
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
super().__init__(num_embeddings, embedding_dim, max_norm, norm_type,
|
|
scale_grad_by_freq, mode, sparse, _weight,
|
|
include_last_offset, padding_idx, **factory_kwargs)
|
|
assert qconfig, 'qconfig must be provided for QAT module'
|
|
assert qconfig.weight().qscheme == torch.per_channel_affine_float_qparams, \
|
|
'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
|
|
str(qconfig.weight().qscheme)
|
|
self.qconfig = qconfig
|
|
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
|
|
|
|
def forward(self, input, offsets=None, per_sample_weights=None) -> Tensor:
|
|
return F.embedding_bag(input, self.weight_fake_quant(self.weight), offsets,
|
|
self.max_norm, self.norm_type,
|
|
self.scale_grad_by_freq, self.mode, self.sparse,
|
|
per_sample_weights, self.include_last_offset,
|
|
self.padding_idx)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
r"""Create a qat module from a float module
|
|
|
|
Args: `mod` a float module, either produced by torch.ao.quantization utilities
|
|
or directly from user
|
|
"""
|
|
assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
|
|
cls._FLOAT_MODULE.__name__
|
|
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
|
|
assert mod.qconfig, 'Input float module must have a valid qconfig'
|
|
weight_qscheme = mod.qconfig.weight().qscheme # type: ignore[union-attr, operator]
|
|
assert weight_qscheme == torch.per_channel_affine_float_qparams, \
|
|
'Embedding Bag weights requires a qscheme of torch.per_channel_affine_float_qparams Got ' + \
|
|
str(weight_qscheme)
|
|
|
|
qconfig = mod.qconfig
|
|
qat_embedding_bag = cls(mod.num_embeddings, mod.embedding_dim, mod.max_norm, mod.norm_type,
|
|
mod.scale_grad_by_freq, mod.mode, mod.sparse, mod.weight,
|
|
mod.include_last_offset, mod.padding_idx, qconfig=qconfig)
|
|
|
|
return qat_embedding_bag
|
|
|
|
def to_float(self):
|
|
embedding_bag = torch.nn.EmbeddingBag(self.num_embeddings, self.embedding_dim, self.max_norm,
|
|
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
|
None, self.include_last_offset, self.padding_idx)
|
|
embedding_bag.weight = torch.nn.Parameter(self.weight.detach())
|
|
embedding_bag.train(self.training)
|
|
return embedding_bag
|