[Quant] Added 4 bit support for embedding quantized module (#69769)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69769

Added 4 bit support and the correpsonding test in the module api. Restructured the test_quantized_module for both 4 & 8 bit support.

Test Plan:
In pytorch main dir, execute
```
python test/test_quantization.py TestStaticQuantizedModule.test_embedding_api
```

Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D33152674

fbshipit-source-id: 73e63383cf60994ab34cc7b4eedd8f32a806cf7f
This commit is contained in:
David Dang 2021-12-18 22:22:57 -08:00 committed by Facebook GitHub Bot
parent b331752314
commit 9f512e129b
2 changed files with 22 additions and 14 deletions

View File

@ -819,23 +819,27 @@ class TestStaticQuantizedModule(QuantizationTestCase):
obs = default_float_qparams_observer()
obs(weights)
qparams = obs.calculate_qparams()
# Quantize the weights to 8bits
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
qemb.set_weight(qweight)
qemb(indices)
# Ensure the module has the correct weights
self.assertEqual(qweight, qemb.weight())
dtypes = [torch.quint4x2, torch.quint8]
embedding_funcs = [torch.ops.quantized.embedding_4bit, torch.ops.quantized.embedding_byte]
w_packed = qemb._packed_params._packed_weight
module_out = qemb(indices)
for dtype, embedding_func in zip(dtypes, embedding_funcs):
# Quantize the weights
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=dtype)
qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype)
qemb.set_weight(qweight)
qemb(indices)
# Call the qembedding operator directly
ref = torch.ops.quantized.embedding_byte(w_packed, indices, pruned_weights=False)
self.assertEqual(module_out, ref)
self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False, is_emb_bag=False)
# Ensure the module has the correct weights
self.assertEqual(qweight, qemb.weight())
w_packed = qemb._packed_params._packed_weight
module_out = qemb(indices)
# Call the bit qembedding operator directly
ref = embedding_func(w_packed, indices, pruned_weights=False)
self.assertEqual(module_out, ref)
self.checkEmbeddingSerialization(qemb, num_embeddings, embedding_dim, indices, None, set_qconfig=False,
is_emb_bag=False, dtype=dtype)
@given(
num_embeddings=st.integers(10, 50),

View File

@ -93,6 +93,7 @@ class Embedding(torch.nn.Module):
super(Embedding, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.dtype = dtype
if _weight is None:
scales = torch.ones(num_embeddings, dtype=torch.float)
@ -109,7 +110,10 @@ class Embedding(torch.nn.Module):
self._packed_params.set_weight(qweight)
def forward(self, indices: Tensor) -> Tensor:
return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices)
if self.dtype == torch.quint4x2:
return torch.ops.quantized.embedding_4bit(self._packed_params._packed_weight, indices)
else:
return torch.ops.quantized.embedding_byte(self._packed_params._packed_weight, indices)
def _get_name(self):
return 'QuantizedEmbedding'