mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50418 previously we were storing the quantized weight as a module attribute, whcih was resulting in the weight getting stored as part of the model. We don't need this since we already store the unpacked weights as part of the model. Test Plan: Before ``` Archive: tmp.pt Length Method Size Cmpr Date Time CRC-32 Name -------- ------ ------- ---- ---------- ----- -------- ---- 586 Stored 586 0% 00-00-1980 00:00 5fefdda0 tmp/extra/producer_info.json 1588700 Stored 1588700 0% 00-00-1980 00:00 04e0da4c tmp/data/0 63548 Stored 63548 0% 00-00-1980 00:00 0ceb1f45 tmp/data/1 63548 Stored 63548 0% 00-00-1980 00:00 517bc3ab tmp/data/2 1588700 Stored 1588700 0% 00-00-1980 00:00 dbe88c73 tmp/data/3 63548 Stored 63548 0% 00-00-1980 00:00 d8dc47c4 tmp/data/4 63548 Stored 63548 0% 00-00-1980 00:00 b9e0c20f tmp/data/5 1071 Stored 1071 0% 00-00-1980 00:00 10dc9350 tmp/data.pkl 327 Defl:N 203 38% 00-00-1980 00:00 dfddb661 tmp/code/__torch__/___torch_mangle_0.py 185 Stored 185 0% 00-00-1980 00:00 308f580b tmp/code/__torch__/___torch_mangle_0.py.debug_pkl 1730 Defl:N 515 70% 00-00-1980 00:00 aa11f799 tmp/code/__torch__/torch/nn/quantized/modules/embedding_ops.py 1468 Defl:N 636 57% 00-00-1980 00:00 779609a6 tmp/code/__torch__/torch/nn/quantized/modules/embedding_ops.py.debug_pkl 0 Stored 0 0% 00-00-1980 00:00 00000000 tmp/code/__torch__/torch/classes/quantized.py 6 Stored 6 0% 00-00-1980 00:00 816d0907 tmp/code/__torch__/torch/classes/quantized.py.debug_pkl 4 Stored 4 0% 00-00-1980 00:00 57092f6d tmp/constants.pkl 2 Stored 2 0% 00-00-1980 00:00 55679ed1 tmp/version -------- ------- --- ------- 3436971 3434800 0% 16 files ``` After ``` Archive: tmp.pt Length Method Size Cmpr Date Time CRC-32 Name -------- ------ ------- ---- ---------- ----- -------- ---- 1588700 Stored 1588700 0% 00-00-1980 00:00 a4da6981 tmp/data/0 63548 Stored 63548 0% 00-00-1980 00:00 74d9b607 tmp/data/1 63548 Stored 63548 0% 00-00-1980 00:00 e346a0c2 tmp/data/2 952 Stored 952 0% 00-00-1980 00:00 eff8706e tmp/data.pkl 375 Defl:N 227 40% 00-00-1980 00:00 96c77b68 tmp/code/__torch__/quantization/test_quantize/___torch_mangle_23.py 228 Defl:N 162 29% 00-00-1980 00:00 6a378113 tmp/code/__torch__/quantization/test_quantize/___torch_mangle_23.py.debug_pkl 1711 Defl:N 509 70% 00-00-1980 00:00 66d8fd61 tmp/code/__torch__/torch/nn/quantized/modules/embedding_ops.py 1473 Defl:N 634 57% 00-00-1980 00:00 beb2323b tmp/code/__torch__/torch/nn/quantized/modules/embedding_ops.py.debug_pkl 0 Stored 0 0% 00-00-1980 00:00 00000000 tmp/code/__torch__/torch/classes/quantized.py 6 Stored 6 0% 00-00-1980 00:00 816d0907 tmp/code/__torch__/torch/classes/quantized.py.debug_pkl 4 Stored 4 0% 00-00-1980 00:00 57092f6d tmp/constants.pkl 2 Stored 2 0% 00-00-1980 00:00 55679ed1 tmp/version -------- ------- --- ------- 1720547 1718292 0% 12 files ``` Imported from OSS Reviewed By: jerryzh168 Differential Revision: D25879879 fbshipit-source-id: e09427a60d4c44dd1a190575e75f3ed9cde6358f
243 lines
11 KiB
Python
243 lines
11 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor # noqa: F401
|
|
from torch._jit_internal import Optional, List # noqa: F401
|
|
from torch.nn.quantized.modules.utils import hide_packed_params_repr
|
|
from torch.nn.quantized.modules.utils import _quantize_weight
|
|
|
|
class EmbeddingPackedParams(torch.nn.Module):
|
|
_version = 1
|
|
|
|
def __init__(self, num_embeddings, embedding_dim, dtype=torch.quint8):
|
|
super(EmbeddingPackedParams, self).__init__()
|
|
self.dtype = dtype
|
|
if self.dtype in [torch.quint8, torch.quint4x2]:
|
|
scales = torch.ones(num_embeddings, dtype=torch.float)
|
|
zero_points = torch.zeros(num_embeddings, dtype=torch.float)
|
|
wq = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim], scales=scales,
|
|
zero_points=zero_points,
|
|
axis=0, dtype=self.dtype)
|
|
self.set_weight(wq)
|
|
else:
|
|
raise NotImplementedError('Unsupported dtype on quantized embedding! Supports quint8 and quint4x2.')
|
|
|
|
@torch.jit.export
|
|
def set_weight(self, weight: torch.Tensor) -> None:
|
|
if self.dtype in [torch.quint8, torch.quint4x2]:
|
|
self._packed_weight = torch.ops.quantized.embedding_bag_prepack(weight)
|
|
else:
|
|
raise NotImplementedError('Unsupported dtype for quantized embedding prepack! Supports quint8 and quint4x2.')
|
|
|
|
|
|
@torch.jit.export
|
|
def _weight(self):
|
|
if self.dtype in [torch.quint8, torch.quint4x2]:
|
|
return torch.ops.quantized.embedding_bag_unpack(self._packed_weight)
|
|
else:
|
|
raise NotImplementedError('Unsupported dtype for quantized embedding unpack! Supports quint8 and quint4x2.')
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
# Version 1
|
|
# self
|
|
# |--- _packed_weight : Tensor representing weight of EmbeddingPackedParamsBase
|
|
# |--- dtype : torch.dtype
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
super(EmbeddingPackedParams, self)._save_to_state_dict(destination, prefix, keep_vars)
|
|
destination[prefix + 'dtype'] = self.dtype
|
|
destination[prefix + '_packed_weight'] = self._weight()
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
self.dtype = state_dict[prefix + 'dtype']
|
|
state_dict.pop(prefix + 'dtype')
|
|
|
|
weight = state_dict[prefix + '_packed_weight']
|
|
state_dict.pop(prefix + '_packed_weight')
|
|
self.set_weight(weight)
|
|
|
|
super(EmbeddingPackedParams, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
def __repr__(self):
|
|
return self._weight().__repr__()
|
|
|
|
class Embedding(torch.nn.Module):
|
|
r"""
|
|
A quantized Embedding module with quantized packed weights as inputs.
|
|
We adopt the same interface as `torch.nn.Embedding`, please see
|
|
https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding for documentation.
|
|
|
|
Similar to :class:`~torch.nn.Embedding`, attributes will be randomly
|
|
initialized at module creation time and will be overwritten later
|
|
|
|
Attributes:
|
|
weight (Tensor): the non-learnable quantized weights of the module of
|
|
shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
|
|
|
|
Examples::
|
|
>>> m = nn.quantized.Embedding(num_embeddings=10, embedding_dim=12)
|
|
>>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8])
|
|
>>> output = m(indices)
|
|
>>> print(output.size())
|
|
torch.Size([9, 12]
|
|
|
|
"""
|
|
_version = 1
|
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
|
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
|
|
sparse: bool = False, _weight: Optional[Tensor] = None, dtype=torch.quint8) -> None:
|
|
super(Embedding, self).__init__()
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
|
|
if _weight is None:
|
|
scales = torch.ones(num_embeddings, dtype=torch.float)
|
|
zero_points = torch.zeros(num_embeddings, dtype=torch.float)
|
|
qweight = torch._empty_per_channel_affine_quantized([num_embeddings, embedding_dim],
|
|
scales=scales, zero_points=zero_points,
|
|
axis=0, dtype=torch.quint8)
|
|
else:
|
|
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
|
|
'Shape of weight does not match num_embeddings and embedding_dim'
|
|
qweight = _weight
|
|
|
|
self._packed_params = EmbeddingPackedParams(num_embeddings, embedding_dim, dtype)
|
|
self._packed_params.set_weight(qweight)
|
|
|
|
def forward(self, indices: Tensor) -> Tensor:
|
|
return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizedEmbedding'
|
|
|
|
def __repr__(self):
|
|
return hide_packed_params_repr(self, EmbeddingPackedParams)
|
|
|
|
def extra_repr(self):
|
|
extra_repr_str = 'num_embeddings={}, embedding_dim={}, dtype={}, qscheme={}'.format(
|
|
self.num_embeddings, self.embedding_dim, self._packed_params.dtype, self.weight().qscheme()
|
|
)
|
|
|
|
return extra_repr_str
|
|
|
|
def set_weight(self, w: torch.Tensor) -> None:
|
|
self._packed_params.set_weight(w)
|
|
|
|
def weight(self):
|
|
return self._packed_params._weight()
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
r"""Create a quantized embedding module from a float module
|
|
|
|
Args:
|
|
mod (Module): a float module, either produced by torch.quantization
|
|
utilities or provided by user
|
|
"""
|
|
assert type(mod) == nn.Embedding, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
|
|
nn.Embedding.__name__
|
|
assert hasattr(mod, 'qconfig'), 'Embedding input float module must have qconfig defined'
|
|
from torch.quantization import float_qparams_weight_only_qconfig
|
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
weight_observer = mod.qconfig.weight()
|
|
else:
|
|
weight_observer = float_qparams_weight_only_qconfig.weight()
|
|
|
|
dtype = weight_observer.dtype
|
|
|
|
assert dtype == torch.quint8, 'The only supported dtype for nnq.Embedding is torch.quint8'
|
|
|
|
# Run the observer to calculate qparams.
|
|
weight_observer(mod.weight)
|
|
qweight = _quantize_weight(mod.weight.float(), weight_observer)
|
|
|
|
# Create quantized Embedding module and pass in the quantized weight
|
|
qembedding = Embedding(mod.num_embeddings, mod.embedding_dim)
|
|
qembedding.set_weight(qweight)
|
|
return qembedding
|
|
|
|
|
|
class EmbeddingBag(Embedding):
|
|
r"""
|
|
A quantized EmbeddingBag module with quantized packed weights as inputs.
|
|
We adopt the same interface as `torch.nn.EmbeddingBag`, please see
|
|
https://pytorch.org/docs/stable/nn.html#torch.nn.EmbeddingBag for documentation.
|
|
|
|
Similar to :class:`~torch.nn.EmbeddingBag`, attributes will be randomly
|
|
initialized at module creation time and will be overwritten later
|
|
|
|
Attributes:
|
|
weight (Tensor): the non-learnable quantized weights of the module of
|
|
shape :math:`(\text{num\_embeddings}, \text{embedding\_dim})`.
|
|
|
|
Examples::
|
|
>>> m = nn.quantized.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True, mode='sum')
|
|
>>> indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
|
|
>>> offsets = torch.tensor([0, 19, 20, 28, 28, 32])
|
|
>>> output = m(indices, offsets)
|
|
>>> print(output.size())
|
|
torch.Size([5, 12]
|
|
|
|
"""
|
|
_version = 1
|
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int,
|
|
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
|
|
mode: str = 'sum', sparse: bool = False, _weight: Optional[Tensor] = None,
|
|
include_last_offset: bool = False, dtype=torch.quint8) -> None:
|
|
super(EmbeddingBag, self).__init__(num_embeddings, embedding_dim, _weight=_weight, dtype=dtype)
|
|
|
|
self.mode = mode
|
|
self.pruned_weights = False
|
|
self.include_last_offset = include_last_offset
|
|
self.dtype = dtype
|
|
|
|
def forward(self, indices: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None,
|
|
compressed_indices_mapping: Optional[Tensor] = None) -> Tensor:
|
|
if self.dtype == torch.quint4x2:
|
|
return torch.ops.quantized.embedding_bag_4bit(self._packed_params._packed_weight, indices, offsets, False, 0,
|
|
self.pruned_weights, per_sample_weights, compressed_indices_mapping,
|
|
self.include_last_offset)
|
|
else:
|
|
return torch.ops.quantized.embedding_bag_byte(self._packed_params._packed_weight, indices, offsets, False, 0,
|
|
self.pruned_weights, per_sample_weights, compressed_indices_mapping,
|
|
self.include_last_offset)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizedEmbeddingBag'
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
r"""Create a quantized embedding_bag module from a float module
|
|
|
|
Args:
|
|
mod (Module): a float module, either produced by torch.quantization
|
|
utilities or provided by user
|
|
"""
|
|
assert type(mod) == nn.EmbeddingBag, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
|
|
nn.EmbeddingBag.__name__
|
|
assert hasattr(mod, 'qconfig'), 'EmbeddingBag input float module must have qconfig defined'
|
|
from torch.quantization.qconfig import float_qparams_weight_only_qconfig
|
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
weight_observer = mod.qconfig.weight()
|
|
else:
|
|
weight_observer = float_qparams_weight_only_qconfig.weight()
|
|
|
|
dtype = weight_observer.dtype
|
|
|
|
assert dtype == torch.quint8 or dtype == torch.quint4x2, \
|
|
'The only supported dtype for nnq.EmbeddingBag is torch.quint8 and torch.quint4x2'
|
|
|
|
# Run the observer to calculate qparams.
|
|
weight_observer(mod.weight)
|
|
qweight = _quantize_weight(mod.weight.float(), weight_observer)
|
|
|
|
# Create quantized EmbeddingBag module and pass in the quantized weight
|
|
qembedding_bag = EmbeddingBag(mod.num_embeddings, mod.embedding_dim, dtype=dtype)
|
|
qembedding_bag.set_weight(qweight)
|
|
return qembedding_bag
|