pytorch/torch/ao/quantization/experimental/apot_utils.py
asl3 13ad4739a6 [quant] Implement PTQ for APoT FakeQuant (#81040)
### Summary:
This PR implements PTQ for APoT FakeQuant. It runs models (Resnet-18 pre-trained model, ImageNet dataset) to compare accuracy metrics for different qconfig settings of uniform vs. APoT quantized activation and weight.

According to the collected accuracy stats, model #2 (uniform activation and APoT weight) appears to have a slight improvement in accuracy compared to model #1 (uniform activation and uniform weight) for 8-bit and significant improvement for 4-bit (see "Accuracy Stats" section below).

### Test Plan:
Run models with: `python test/quantization/core/experimental/fx_graph_mode_apot.py`

### Accuracy Stats:
8-bit (Uniform int8, APoT b = 8 k = 2)

**Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.43% (Top-1), 85.62% (Top-5)

**Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.51% (Top-1), 85.78% (Top-5)

**Model #3:** APoT activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.32% (Top-1), 85.78% (Top-5)

4-bit (Uniform int4, APoT b = 4 k = 2)

**Model #1:** Uniform activation, uniform weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 45.63% (Top-1), 71.96% (Top-5)

**Model #2:** Uniform activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 64.24% (Top-1), 85.56% (Top-5)

**Model #3:** APoT activation, APoT weight (FX Graph Mode quantized)
Evaluation accuracy on test dataset: 45.40% (Top-1), 76.21% (Top-5)

**Full Precision model (FX Graph Mode quantized)**
Evaluation accuracy on test dataset: 69.76% (Top-1), 89.08% (Top-5)

**Eager mode quantized model**
Evaluation accuracy on test dataset: 69.49% (Top-1), 88.90% (Top-5)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81040
Approved by: https://github.com/jerryzh168
2022-07-28 07:21:31 +00:00

57 lines
1.3 KiB
Python

r"""
This file contains utility functions to convert values
using APoT nonuniform quantization methods.
"""
import math
r"""Converts floating point input into APoT number
based on quantization levels
"""
def float_to_apot(x, levels, indices, alpha):
# clip values based on alpha
if x < -alpha:
return -alpha
elif x > alpha:
return alpha
levels_lst = list(levels)
indices_lst = list(indices)
min_delta = math.inf
best_idx = 0
for level, idx in zip(levels_lst, indices_lst):
cur_delta = abs(level - x)
if cur_delta < min_delta:
min_delta = cur_delta
best_idx = idx
return best_idx
r"""Converts floating point input into
reduced precision floating point value
based on quantization levels
"""
def quant_dequant_util(x, levels, indices):
levels_lst = list(levels)
indices_lst = list(indices)
min_delta = math.inf
best_fp = 0.0
for level, idx in zip(levels_lst, indices_lst):
cur_delta = abs(level - x)
if cur_delta < min_delta:
min_delta = cur_delta
best_fp = level
return best_fp
r"""Converts APoT input into floating point number
based on quantization levels
"""
def apot_to_float(x_apot, levels, indices):
idx = list(indices).index(x_apot)
return levels[idx]