mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Do the following renames: `torch.quantization` -> `torch.ao.quantization` `torch.nn.quantized` -> `torch.ao.nn.quantized` `torch.nn.quantizable` -> `torch.ao.nn.quantizable` `torch.nn.qat` -> `torch.ao.nn.qat` `torch.nn.intrinsic` -> `torch.ao.nn.intrinsic` And then, do `torch.ao.nn.quantized._reference` -> `torch.ao.nn.quantized.reference` to clean up the aftermath of https://github.com/pytorch/pytorch/pull/84974 Then, manually update `test/test_module_init.py` to fix hanging whitespace due to the replace. Run this script to do the replacements: https://gist.github.com/vkuzo/7f7afebf8c31b9ba48306223e68a1c82 This is for https://github.com/pytorch/pytorch/issues/81667 Test plan: CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/94170 Approved by: https://github.com/jerryzh168
156 lines
5.3 KiB
Python
156 lines
5.3 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
|
|
from torch.ao.quantization.experimental.observer import APoTObserver
|
|
from torch.ao.quantization.experimental.quantizer import quantize_APoT
|
|
|
|
class LinearAPoT(WeightedQuantizedModule):
|
|
r"""
|
|
A quantized linear module with quantized tensor as inputs and outputs
|
|
to support APoT quantization.
|
|
We adopt the same interface as `torch.nn.Linear`, see
|
|
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
|
|
|
|
Similar to :class:`~torch.nn.Linear`, attributes will be randomly
|
|
initialized at module creation time and will be overwritten later
|
|
|
|
Attributes:
|
|
alpha: `alpha` qparam of output Quantized Tensor, type: Tensor
|
|
gamma: `gamma` qparam of output Quantized Tensor, type: Tensor
|
|
quantization_levels: `quantization_levels` qparam of output Quantized Tensor, type: Tensor
|
|
level_indices: `level_indices` qparam of output Quantized Tensor, type: Tensor
|
|
weight: APoT quantized tensor from weight2quantize
|
|
weight_transposed: transposed weight tensor, used in linear transformation calculation (y = x * A^T + b)
|
|
"""
|
|
|
|
def __init__(self, weight2quantize: torch.Tensor, b: int, k: int):
|
|
assert weight2quantize.dim() == 2
|
|
assert b % k == 0
|
|
|
|
super().__init__()
|
|
|
|
self.b = b
|
|
self.k = k
|
|
self.n = self.b // self.k
|
|
|
|
observer = APoTObserver(b=self.b, k=self.k)
|
|
|
|
observer(weight2quantize)
|
|
|
|
self.alpha, self.gamma, self.quantization_levels, self.level_indices = observer.calculate_qparams(signed=False)
|
|
|
|
quantized_weight = quantize_APoT(weight2quantize, self.alpha, self.gamma, self.quantization_levels, self.level_indices)
|
|
self.weight = quantized_weight.data
|
|
self.weight_transposed = torch.transpose(self.weight, 0, 1)
|
|
|
|
def decompose_APoT(self, x):
|
|
r"""
|
|
Decompose binary representation of APoT values into list of k-sized blocks
|
|
Args:
|
|
x (Tensor): binary representation of APoT quantized tensor
|
|
"""
|
|
# remove "0b" prefix from binary representation
|
|
x = x[2:]
|
|
|
|
# initialize list of blocks
|
|
blocks = []
|
|
|
|
while x:
|
|
blocks.append(x[0:self.k])
|
|
x = x[self.k:]
|
|
|
|
return blocks
|
|
|
|
def bitshift_mul(self, weight_val, r):
|
|
r"""
|
|
Compute multiplication of weight_val * r using bitshifting
|
|
method discussed in APoT paper: https://arxiv.org/pdf/1909.13144.pdf
|
|
Args:
|
|
weight_val: list of binary digits representing APoT quantized weight value
|
|
r: int representing uniformly quantized activation value
|
|
"""
|
|
product = 0
|
|
|
|
idx = len(weight_val) - 1
|
|
place = 0
|
|
|
|
while idx >= 0:
|
|
block = weight_val[idx]
|
|
|
|
# reverse digits in block
|
|
block = block[::-1]
|
|
|
|
curr_block_result = 0
|
|
|
|
for ele in block:
|
|
if int(ele):
|
|
curr_block_result += r << place
|
|
place += 1
|
|
|
|
idx -= 1
|
|
product += curr_block_result
|
|
|
|
return product
|
|
|
|
|
|
def matmul(self, decomposed_weight, activation):
|
|
r"""
|
|
Perform matrix multiplication between decomposed_weight and
|
|
activation by calling bitshift_mul function for each value
|
|
Args:
|
|
decomposed_weight (Tensor): APoT quantized weight decomposed into binary
|
|
activation (Tensor): uniformly quantized activation
|
|
"""
|
|
rows1 = activation.size(dim=0)
|
|
cols1 = activation.size(dim=1)
|
|
|
|
rows2 = decomposed_weight.shape[0]
|
|
cols2 = decomposed_weight.shape[1]
|
|
|
|
result = torch.zeros(rows1, cols2)
|
|
|
|
# compute matrix multiplication with bitshifts
|
|
for i in range(rows1):
|
|
for j in range(cols2):
|
|
for k in range(rows2):
|
|
weight_val = decomposed_weight[k][j]
|
|
r = int(activation[i][k])
|
|
|
|
product = self.bitshift_mul(weight_val, r)
|
|
|
|
result[i][j] += product
|
|
|
|
return result
|
|
|
|
def forward(self, activation: torch.Tensor) -> torch.FloatTensor:
|
|
r"""
|
|
Multiply APoT quantized weight and uniformly quantized activation (dtype: quint8)
|
|
with bitshifting instead of matrix multiplication.
|
|
Result has dtype torch.float32
|
|
Args:
|
|
activation (Tensor): uniformly quantized activation tensor
|
|
"""
|
|
assert activation.dim() == 2
|
|
|
|
weight_rows = self.weight_transposed.size()[0]
|
|
weight_cols = self.weight_transposed.size()[1]
|
|
|
|
decomposed_weight = np.empty(shape=(weight_rows, weight_cols), dtype=object)
|
|
for row in range(weight_rows):
|
|
for col in range(weight_cols):
|
|
decomposed_weight[row][col] = self.decompose_APoT(bin(self.weight_transposed[row][col]))
|
|
|
|
result = self.matmul(decomposed_weight, activation).type(torch.FloatTensor)
|
|
|
|
return result
|
|
|
|
@classmethod
|
|
def from_reference(cls, # type: ignore[override]
|
|
ref_qlinear,
|
|
alpha: torch.Tensor,
|
|
gamma: torch.Tensor,
|
|
quantization_levels: torch.Tensor,
|
|
level_indices: torch.Tensor):
|
|
raise NotImplementedError
|