mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133492 Approved by: https://github.com/albanD
64 lines
1.3 KiB
Python
64 lines
1.3 KiB
Python
# mypy: allow-untyped-defs
|
|
r"""
|
|
This file contains utility functions to convert values
|
|
using APoT nonuniform quantization methods.
|
|
"""
|
|
|
|
import math
|
|
|
|
|
|
r"""Converts floating point input into APoT number
|
|
based on quantization levels
|
|
"""
|
|
|
|
|
|
def float_to_apot(x, levels, indices, alpha):
|
|
# clip values based on alpha
|
|
if x < -alpha:
|
|
return -alpha
|
|
elif x > alpha:
|
|
return alpha
|
|
|
|
levels_lst = list(levels)
|
|
indices_lst = list(indices)
|
|
|
|
min_delta = math.inf
|
|
best_idx = 0
|
|
|
|
for level, idx in zip(levels_lst, indices_lst):
|
|
cur_delta = abs(level - x)
|
|
if cur_delta < min_delta:
|
|
min_delta = cur_delta
|
|
best_idx = idx
|
|
|
|
return best_idx
|
|
|
|
|
|
r"""Converts floating point input into
|
|
reduced precision floating point value
|
|
based on quantization levels
|
|
"""
|
|
|
|
|
|
def quant_dequant_util(x, levels, indices):
|
|
min_delta = math.inf
|
|
best_fp = 0.0
|
|
|
|
for level in levels:
|
|
cur_delta = abs(level - x)
|
|
if cur_delta < min_delta:
|
|
min_delta = cur_delta
|
|
best_fp = level
|
|
|
|
return best_fp
|
|
|
|
|
|
r"""Converts APoT input into floating point number
|
|
based on quantization levels
|
|
"""
|
|
|
|
|
|
def apot_to_float(x_apot, levels, indices):
|
|
idx = list(indices).index(x_apot)
|
|
return levels[idx]
|