mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Quant] Added 4 bit support for embedding quantized module (#69769)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69769 Added 4 bit support and the correpsonding test in the module api. Restructured the test_quantized_module for both 4 & 8 bit support. Test Plan: In pytorch main dir, execute ``` python test/test_quantization.py TestStaticQuantizedModule.test_embedding_api ``` Imported from OSS Reviewed By: jbschlosser Differential Revision: D33152674 fbshipit-source-id: 73e63383cf60994ab34cc7b4eedd8f32a806cf7f
This commit is contained in:
parent
b331752314
commit
9f512e129b
|
|
@ -819,23 +819,27 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
|||
obs = default_float_qparams_observer()
|
||||
obs(weights)
|
||||
qparams = obs.calculate_qparams()
|
||||
# Quantize the weights to 8bits
|
||||
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
|
||||
qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
||||
qemb.set_weight(qweight)
|
||||
qemb(indices)
|
||||
|
||||
# Ensure the module has the correct weights
|
||||
self.assertEqual(qweight, qemb.weight())
|
||||
dtypes = [torch.quint4x2, torch.quint8]
|
||||
embedding_funcs = [torch.ops.quantized.embedding_4bit, torch.ops.quantized.embedding_byte]
|
||||
|
||||
w_packed = qemb._packed_params._packed_weight
|
||||
module_out = qemb(indices)
|
||||
for dtype, embedding_func in zip(dtypes, embedding_funcs):
|
||||
# Quantize the weights
|
||||
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=dtype)
|
||||
qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype)
|
||||
qemb.set_weight(qweight)
|
||||
qemb(indices)
|
||||
|
||||
# Call the qembedding operator directly
|
||||
ref = torch.ops.quantized.embedding_byte(w_packed, indices, pruned_weights=False)
|
||||
self.assertEqual(module_out, ref)
|
||||
self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False, is_emb_bag=False)
|
||||
# Ensure the module has the correct weights
|
||||
self.assertEqual(qweight, qemb.weight())
|
||||
w_packed = qemb._packed_params._packed_weight
|
||||
module_out = qemb(indices)
|
||||
|
||||
# Call the bit qembedding operator directly
|
||||
ref = embedding_func(w_packed, indices, pruned_weights=False)
|
||||
self.assertEqual(module_out, ref)
|
||||
self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False,
|
||||
is_emb_bag=False, dtype=dtype)
|
||||
|
||||
@given(
|
||||
num_embeddings=st.integers(10, 50),
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ class Embedding(torch.nn.Module):
|
|||
super(Embedding, self).__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.dtype = dtype
|
||||
|
||||
if _weight is None:
|
||||
scales = torch.ones(num_embeddings, dtype=torch.float)
|
||||
|
|
@ -109,7 +110,10 @@ class Embedding(torch.nn.Module):
|
|||
self._packed_params.set_weight(qweight)
|
||||
|
||||
def forward(self, indices: Tensor) -> Tensor:
|
||||
return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices)
|
||||
if self.dtype == torch.quint4x2:
|
||||
return torch.ops.quantized.embedding_4bit(self._packed_params._packed_weight, indices)
|
||||
else:
|
||||
return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizedEmbedding'
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user