mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Initial implementation of AdaRound (#126153)
Summary: This is an implementation of AdaRound from a paper https://arxiv.org/abs/2004.10568 This algorithm is going to be used by multiple people, hence we need make it official implementation. Differential Revision: D57227565 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126153 Approved by: https://github.com/jerryzh168, https://github.com/huydhn
This commit is contained in:
parent
875221dedf
commit
eb0b16db92
118
test/quantization/core/experimental/test_adaround_eager.py
Normal file
118
test/quantization/core/experimental/test_adaround_eager.py
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
# Owner(s): ["oncall: speech_infra"]
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.ao.quantization.experimental.adaround_optimization import (
|
||||||
|
AdaptiveRoundingOptimizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.quantization.observer import MinMaxObserver
|
||||||
|
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||||
|
|
||||||
|
|
||||||
|
def forward_wrapper(fetcher):
|
||||||
|
def forward(module, input, output):
|
||||||
|
fetcher.append(input[0].detach())
|
||||||
|
fetcher.append(output.detach())
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdaround(QuantizationTestCase):
|
||||||
|
def feedforawrd_callback(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
data,
|
||||||
|
) -> None:
|
||||||
|
model(data)
|
||||||
|
|
||||||
|
def run_adaround(self, model, img_data):
|
||||||
|
adaround_optimizer = AdaptiveRoundingOptimizer(
|
||||||
|
model,
|
||||||
|
self.feedforawrd_callback,
|
||||||
|
forward_wrapper,
|
||||||
|
img_data,
|
||||||
|
max_iter=100,
|
||||||
|
batch_size=10,
|
||||||
|
)
|
||||||
|
adarounded_model = adaround_optimizer.run_adaround()
|
||||||
|
return adarounded_model
|
||||||
|
|
||||||
|
def get_fake_quant(self, model):
|
||||||
|
hard_fake_quant_model = copy.deepcopy(model)
|
||||||
|
for _, module in hard_fake_quant_model.named_modules():
|
||||||
|
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
|
||||||
|
weight_observer = MinMaxObserver(
|
||||||
|
quant_min=-128,
|
||||||
|
quant_max=127,
|
||||||
|
dtype=torch.qint8,
|
||||||
|
qscheme=torch.per_tensor_symmetric,
|
||||||
|
)
|
||||||
|
weight_observer(module.weight)
|
||||||
|
scale, zero_point = weight_observer.calculate_qparams()
|
||||||
|
fake_quant_module = torch.fake_quantize_per_tensor_affine(
|
||||||
|
module.weight,
|
||||||
|
scale=scale,
|
||||||
|
zero_point=zero_point,
|
||||||
|
quant_min=-128,
|
||||||
|
quant_max=127,
|
||||||
|
)
|
||||||
|
module.weight.data.copy_(fake_quant_module)
|
||||||
|
return hard_fake_quant_model
|
||||||
|
|
||||||
|
def test_linear_chain(self):
|
||||||
|
class LinearChain(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(3, 4)
|
||||||
|
self.linear2 = nn.Linear(4, 5)
|
||||||
|
self.linear3 = nn.Linear(5, 6)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear1(x)
|
||||||
|
x = self.linear2(x)
|
||||||
|
x = self.linear3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
float_model = LinearChain()
|
||||||
|
img_data = [torch.rand(10, 3, dtype=torch.float) for _ in range(50)]
|
||||||
|
adarounded_model = self.run_adaround(float_model, img_data)
|
||||||
|
fq_model = self.get_fake_quant(float_model)
|
||||||
|
rand_input = torch.rand(10, 3)
|
||||||
|
with torch.no_grad():
|
||||||
|
ada_out = adarounded_model(rand_input)
|
||||||
|
fq_out = fq_model(rand_input)
|
||||||
|
float_out = float_model(rand_input)
|
||||||
|
ada_loss = F.mse_loss(ada_out, float_out)
|
||||||
|
fq_loss = F.mse_loss(fq_out, float_out)
|
||||||
|
self.assertTrue(ada_loss.item() < fq_loss.item())
|
||||||
|
|
||||||
|
def test_conv_chain(self):
|
||||||
|
class ConvChain(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv2d1 = nn.Conv2d(3, 4, 5, 5)
|
||||||
|
self.conv2d2 = nn.Conv2d(4, 5, 5, 5)
|
||||||
|
self.conv2d3 = nn.Conv2d(5, 6, 5, 5)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv2d1(x)
|
||||||
|
x = self.conv2d2(x)
|
||||||
|
x = self.conv2d3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
float_model = ConvChain()
|
||||||
|
img_data = [torch.rand(10, 3, 125, 125, dtype=torch.float) for _ in range(50)]
|
||||||
|
adarounded_model = self.run_adaround(float_model, img_data)
|
||||||
|
fq_model = self.get_fake_quant(float_model)
|
||||||
|
rand_input = torch.rand(10, 3, 256, 256)
|
||||||
|
with torch.no_grad():
|
||||||
|
ada_out = adarounded_model(rand_input)
|
||||||
|
fq_out = fq_model(rand_input)
|
||||||
|
float_out = float_model(rand_input)
|
||||||
|
ada_loss = F.mse_loss(ada_out, float_out)
|
||||||
|
fq_loss = F.mse_loss(fq_out, float_out)
|
||||||
|
self.assertTrue(ada_loss.item() < fq_loss.item())
|
||||||
148
torch/ao/quantization/experimental/adaround_fake_quantize.py
Normal file
148
torch/ao/quantization/experimental/adaround_fake_quantize.py
Normal file
|
|
@ -0,0 +1,148 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.ao.quantization.fake_quantize import _is_symmetric_quant
|
||||||
|
from torch.ao.quantization.utils import is_per_tensor
|
||||||
|
from torch.quantization import FakeQuantize
|
||||||
|
from torch.quantization.observer import MinMaxObserver
|
||||||
|
|
||||||
|
|
||||||
|
class AdaroundFakeQuantizer(FakeQuantize):
|
||||||
|
"""
|
||||||
|
This is a FakeQuantizer that enables an adaptive rounding fake quantizer.
|
||||||
|
Adaround is a technique to adaptively round weights, derived from the paper https://arxiv.org/pdf/2004.10568.pdf
|
||||||
|
For HTP compatibility, we are targeting to use symmetric quantization
|
||||||
|
"""
|
||||||
|
|
||||||
|
scale: torch.Tensor
|
||||||
|
zero_point: torch.Tensor
|
||||||
|
V: torch.nn.Parameter
|
||||||
|
|
||||||
|
# pyre-fixme[3]: Return type must be annotated.
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
observer=MinMaxObserver,
|
||||||
|
qscheme=torch.per_tensor_symmetric, # not used, but needed for fakequant
|
||||||
|
quant_min: int = -128,
|
||||||
|
quant_max: int = 127,
|
||||||
|
ch_axis: int = 0,
|
||||||
|
# pyre-fixme[2]: Parameter must be annotated.
|
||||||
|
**observer_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
observer=observer,
|
||||||
|
qscheme=qscheme,
|
||||||
|
quant_min=quant_min,
|
||||||
|
quant_max=quant_max,
|
||||||
|
is_dynamic=False,
|
||||||
|
**observer_kwargs,
|
||||||
|
)
|
||||||
|
# Populate quant_min/quant_max to observer_kwargs if valid
|
||||||
|
if quant_min is not None and quant_max is not None:
|
||||||
|
assert (
|
||||||
|
quant_min <= quant_max
|
||||||
|
), "quant_min must be less than or equal to quant_max"
|
||||||
|
# pyre-fixme[4]: Attribute must be annotated.
|
||||||
|
self.qscheme = qscheme
|
||||||
|
self.is_per_tensor: bool = is_per_tensor(qscheme)
|
||||||
|
self.is_symmetric: bool = _is_symmetric_quant(qscheme)
|
||||||
|
assert self.is_symmetric, "Only symmetric quantization is supported"
|
||||||
|
self.ch_axis: int = ch_axis
|
||||||
|
|
||||||
|
self.scale = torch.tensor([], requires_grad=False)
|
||||||
|
self.zero_point = torch.tensor([], requires_grad=False)
|
||||||
|
self.V = torch.nn.Parameter(torch.tensor([]), requires_grad=True)
|
||||||
|
# Fixed Stretch parameters
|
||||||
|
self.zeta: torch.Tensor = torch.tensor(1.1, requires_grad=False)
|
||||||
|
self.gamma: torch.Tensor = torch.tensor(-0.1, requires_grad=False)
|
||||||
|
self.sigmoid = torch.nn.Sigmoid()
|
||||||
|
self.use_soft_rounding = True
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return self.scale, self.zero_point
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
return (
|
||||||
|
f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
|
||||||
|
f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, "
|
||||||
|
f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, "
|
||||||
|
f"scale={self.scale}, zero_point={self.zero_point}, (self.V >= 0).int().sum()={(self.V >= 0).int().sum()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def enable_weight_fake_quant(self) -> None:
|
||||||
|
self.fake_quant_enabled[0] = 1
|
||||||
|
|
||||||
|
def get_rectified_sigmoid_func(self) -> torch.Tensor:
|
||||||
|
if self.use_soft_rounding:
|
||||||
|
return torch.clamp(
|
||||||
|
self.sigmoid(self.V) * (self.zeta - self.gamma) + self.gamma,
|
||||||
|
min=0,
|
||||||
|
max=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# This will dump a binary solution
|
||||||
|
return (self.V >= 0).int()
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def update_scale(
|
||||||
|
self, X: torch.Tensor, _scale: torch.Tensor, _zero_point: torch.Tensor
|
||||||
|
) -> None:
|
||||||
|
if self.scale.numel() == 0:
|
||||||
|
self.scale.data = _scale.to(X.device)
|
||||||
|
self.zero_point = _zero_point.to(X.device)
|
||||||
|
else:
|
||||||
|
self.scale.data = _scale
|
||||||
|
if not self.is_symmetric:
|
||||||
|
self.zero_point = _zero_point
|
||||||
|
else:
|
||||||
|
self.zero_point = torch.zeros_like(_zero_point)
|
||||||
|
for i in range(X.dim()):
|
||||||
|
if i == self.ch_axis:
|
||||||
|
continue
|
||||||
|
self.zero_point = self.zero_point.unsqueeze(i)
|
||||||
|
X_q = X / self.scale
|
||||||
|
X_q_floor = torch.floor(X_q)
|
||||||
|
residual = X_q - X_q_floor # [0,1)
|
||||||
|
assert torch.all(
|
||||||
|
torch.ge(residual, 0)
|
||||||
|
), "residual should be non-negative [0, 1)"
|
||||||
|
V_init = -torch.log((self.zeta - self.gamma) / (residual - self.gamma) - 1)
|
||||||
|
self.V.data = V_init
|
||||||
|
|
||||||
|
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.observer_enabled[0] == 1:
|
||||||
|
X_detached = X.detach()
|
||||||
|
self.activation_post_process(X_detached)
|
||||||
|
_scale, _zero_point = self.activation_post_process.calculate_qparams()
|
||||||
|
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(
|
||||||
|
self.zero_point.device
|
||||||
|
)
|
||||||
|
dims = list(range(X.dim()))
|
||||||
|
if not self.is_per_tensor:
|
||||||
|
dims.remove(self.ch_axis)
|
||||||
|
if not self.is_per_tensor:
|
||||||
|
for i in range(X.dim()):
|
||||||
|
if i == self.ch_axis:
|
||||||
|
continue
|
||||||
|
_scale = _scale.unsqueeze(i)
|
||||||
|
_zero_point = _zero_point.unsqueeze(i)
|
||||||
|
self.update_scale(X_detached, _scale, _zero_point)
|
||||||
|
|
||||||
|
if self.fake_quant_enabled[0] == 1:
|
||||||
|
# Perform soft quantization
|
||||||
|
# See the equation (23) in Adaround paper
|
||||||
|
h_v = self.get_rectified_sigmoid_func()
|
||||||
|
X_q = X / self.scale
|
||||||
|
# Straight-Through Estimator for floor function
|
||||||
|
X_q_floor = torch.floor(X_q) + self.zero_point
|
||||||
|
# Regardless of rounding, gradient should be able to flow back to self.V from X_q_dq.
|
||||||
|
# With adaround, we don't train weight, but train V only.
|
||||||
|
X_q_dq = (
|
||||||
|
torch.clamp(X_q_floor + h_v, min=self.quant_min, max=self.quant_max)
|
||||||
|
- self.zero_point
|
||||||
|
) * self.scale
|
||||||
|
return X_q_dq
|
||||||
|
else:
|
||||||
|
return X
|
||||||
96
torch/ao/quantization/experimental/adaround_loss.py
Normal file
96
torch/ao/quantization/experimental/adaround_loss.py
Normal file
|
|
@ -0,0 +1,96 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
ADAROUND_ZETA: float = 1.1
|
||||||
|
ADAROUND_GAMMA: float = -0.1
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveRoundingLoss(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Adaptive Rounding Loss functions described in https://arxiv.org/pdf/2004.10568.pdf
|
||||||
|
rounding regularization is eq [24]
|
||||||
|
reconstruction loss is eq [25] except regularization term
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_iter: int,
|
||||||
|
warm_start: float = 0.2,
|
||||||
|
beta_range: Tuple[int, int] = (20, 2),
|
||||||
|
reg_param: float = 0.001,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.max_iter = max_iter
|
||||||
|
self.warm_start = warm_start
|
||||||
|
self.beta_range = beta_range
|
||||||
|
self.reg_param = reg_param
|
||||||
|
|
||||||
|
def rounding_regularization(
|
||||||
|
self,
|
||||||
|
V: torch.Tensor,
|
||||||
|
curr_iter: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Major logics copied from official Adaround Implementation.
|
||||||
|
Apply rounding regularization to the input tensor V.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
curr_iter < self.max_iter
|
||||||
|
), "Current iteration strictly les sthan max iteration"
|
||||||
|
if curr_iter < self.warm_start * self.max_iter:
|
||||||
|
return torch.tensor(0.0)
|
||||||
|
else:
|
||||||
|
start_beta, end_beta = self.beta_range
|
||||||
|
warm_start_end_iter = self.warm_start * self.max_iter
|
||||||
|
|
||||||
|
# compute relative iteration of current iteration
|
||||||
|
rel_iter = (curr_iter - warm_start_end_iter) / (
|
||||||
|
self.max_iter - warm_start_end_iter
|
||||||
|
)
|
||||||
|
beta = end_beta + 0.5 * (start_beta - end_beta) * (
|
||||||
|
1 + np.cos(rel_iter * np.pi)
|
||||||
|
)
|
||||||
|
|
||||||
|
# A rectified sigmoid for soft-quantization as formualted [23] in https://arxiv.org/pdf/2004.10568.pdf
|
||||||
|
h_alpha = torch.clamp(
|
||||||
|
torch.sigmoid(V) * (ADAROUND_ZETA - ADAROUND_GAMMA) + ADAROUND_GAMMA,
|
||||||
|
min=0,
|
||||||
|
max=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply rounding regularization
|
||||||
|
# This regularization term helps out term to converge into binary solution either 0 or 1 at the end of optimization.
|
||||||
|
inner_term = torch.add(2 * h_alpha, -1).abs().pow(beta)
|
||||||
|
regularization_term = torch.add(1, -inner_term).sum()
|
||||||
|
return regularization_term * self.reg_param
|
||||||
|
|
||||||
|
def reconstruction_loss(
|
||||||
|
self,
|
||||||
|
soft_quantized_output: torch.Tensor,
|
||||||
|
original_output: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute the reconstruction loss between the soft quantized output and the original output.
|
||||||
|
"""
|
||||||
|
return F.mse_loss(
|
||||||
|
soft_quantized_output, original_output, reduction="none"
|
||||||
|
).mean()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
soft_quantized_output: torch.Tensor,
|
||||||
|
original_output: torch.Tensor,
|
||||||
|
V: torch.Tensor,
|
||||||
|
curr_iter: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Compute the asymmetric reconstruction formulation as eq [25]
|
||||||
|
"""
|
||||||
|
regularization_term = self.rounding_regularization(V, curr_iter)
|
||||||
|
reconstruction_term = self.reconstruction_loss(
|
||||||
|
soft_quantized_output, original_output
|
||||||
|
)
|
||||||
|
return regularization_term, reconstruction_term
|
||||||
238
torch/ao/quantization/experimental/adaround_optimization.py
Normal file
238
torch/ao/quantization/experimental/adaround_optimization.py
Normal file
|
|
@ -0,0 +1,238 @@
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.ao.quantization.experimental.adaround_fake_quantize import (
|
||||||
|
AdaroundFakeQuantizer,
|
||||||
|
)
|
||||||
|
from torch.ao.quantization.experimental.adaround_loss import AdaptiveRoundingLoss
|
||||||
|
from torch.ao.quantization.observer import MinMaxObserver
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.parallel import DataParallel
|
||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
|
||||||
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveRoundingOptimizer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Union[torch.nn.Module, torch.nn.DataParallel],
|
||||||
|
callback: Callable[[torch.nn.Module, List[Any]], None],
|
||||||
|
forward_hook_wrapper: Callable[[List[torch.Tensor]], Callable],
|
||||||
|
data: List[Any],
|
||||||
|
observer: Type[torch.ao.quantization.observer.ObserverBase] = MinMaxObserver,
|
||||||
|
max_iter=10000,
|
||||||
|
dtype: torch.dtype = torch.qint8,
|
||||||
|
quant_min=-128,
|
||||||
|
quant_max=127,
|
||||||
|
qscheme: torch.qscheme = torch.per_tensor_symmetric,
|
||||||
|
batch_size: int = 256,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.q_model = copy.deepcopy(self.model)
|
||||||
|
self.device = torch.device("cuda") if torch.cuda.is_available() else None
|
||||||
|
self.callback = callback
|
||||||
|
self.forward_hook_wrapper = forward_hook_wrapper
|
||||||
|
# TODO rather than having a data as list type or, we better pass *iterator* instead of list
|
||||||
|
self.data = data
|
||||||
|
self.batch_size = min(batch_size, len(data))
|
||||||
|
self.max_iter = max_iter
|
||||||
|
self.adaptive_round_loss_fn = AdaptiveRoundingLoss(
|
||||||
|
max_iter=self.max_iter, warm_start=0.2
|
||||||
|
)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.observer = observer
|
||||||
|
self.quant_min = quant_min
|
||||||
|
self.quant_max = quant_max
|
||||||
|
self.qscheme = qscheme
|
||||||
|
|
||||||
|
def run_adaround(self) -> torch.nn.Module:
|
||||||
|
layer_list: List[Tuple[str, torch.nn.Module, torch.nn.Module]] = []
|
||||||
|
for (name, module), q_module in zip(
|
||||||
|
self.model.named_modules(), self.q_model.modules()
|
||||||
|
):
|
||||||
|
if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear)):
|
||||||
|
# Knowing activation ahead-of-time would be helpful for asymmetric formulation
|
||||||
|
# But this is challenging in eager mode, but graph module.
|
||||||
|
layer_list.append((name, module, q_module))
|
||||||
|
logger.info(f"Total number of layers : {len(layer_list)}") # noqa: G004
|
||||||
|
|
||||||
|
for name, module, q_module in layer_list:
|
||||||
|
logger.info(
|
||||||
|
f"Kick start adaptive rounding on {name} module {module}" # noqa: G004
|
||||||
|
)
|
||||||
|
self.optimize_adaptive_rounding(
|
||||||
|
module,
|
||||||
|
q_module,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.q_model.module
|
||||||
|
if isinstance(self.q_model, DataParallel)
|
||||||
|
else self.q_model
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_data_inp_out(
|
||||||
|
self, module: torch.nn.Module, q_module: torch.nn.Module, data: List[Any]
|
||||||
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
fp_out: List[torch.Tensor] = []
|
||||||
|
q_input: List[torch.Tensor] = []
|
||||||
|
fp_input: List[torch.Tensor] = []
|
||||||
|
fp32_fetcher: List[torch.Tensor] = []
|
||||||
|
quant_fetcher: List[torch.Tensor] = []
|
||||||
|
handler1 = module.register_forward_hook(self.forward_hook_wrapper(fp32_fetcher))
|
||||||
|
handler2 = q_module.register_forward_hook(
|
||||||
|
self.forward_hook_wrapper(quant_fetcher)
|
||||||
|
)
|
||||||
|
for data_ in data:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.callback(self.model, data_)
|
||||||
|
self.callback(self.q_model, data_)
|
||||||
|
fp32_output = fp32_fetcher[1]
|
||||||
|
quant_input = quant_fetcher[0]
|
||||||
|
fp_out.append(fp32_output)
|
||||||
|
q_input.append(quant_input)
|
||||||
|
fp_input.append(fp32_fetcher[0])
|
||||||
|
handler1.remove()
|
||||||
|
handler2.remove()
|
||||||
|
return q_input, fp_out, fp_input
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def feed_forward(self, x, weight, module):
|
||||||
|
if isinstance(module, torch.nn.Conv1d):
|
||||||
|
out = torch.nn.functional.conv1d(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
stride=module.stride,
|
||||||
|
padding=module.padding,
|
||||||
|
dilation=module.dilation,
|
||||||
|
groups=module.groups,
|
||||||
|
)
|
||||||
|
elif isinstance(module, torch.nn.Linear):
|
||||||
|
out = torch.nn.functional.linear(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
bias=module.bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _compute_and_display_local_losses(
|
||||||
|
self,
|
||||||
|
ada_quantizer: AdaroundFakeQuantizer,
|
||||||
|
q_module: torch.nn.Module,
|
||||||
|
q_inp: torch.Tensor,
|
||||||
|
fp_out: torch.Tensor,
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
ada_quantizer.use_soft_rounding = False
|
||||||
|
q_w_hard_round = ada_quantizer(q_module.weight)
|
||||||
|
out_hard_quant = self.feed_forward(q_inp, q_w_hard_round, q_module)
|
||||||
|
ada_quantizer.use_soft_rounding = True
|
||||||
|
q_w_soft_round = ada_quantizer(q_module.weight)
|
||||||
|
out_soft_quant = self.feed_forward(q_inp, q_w_soft_round, q_module)
|
||||||
|
soft_quant_loss = F.mse_loss(out_soft_quant, fp_out)
|
||||||
|
hard_quant_loss = F.mse_loss(out_hard_quant, fp_out)
|
||||||
|
logger.info(
|
||||||
|
f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}" # noqa: G004
|
||||||
|
)
|
||||||
|
|
||||||
|
def optimize_adaptive_rounding(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
q_module: torch.nn.Module,
|
||||||
|
activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
||||||
|
) -> None:
|
||||||
|
ada_quantizer = AdaroundFakeQuantizer(
|
||||||
|
dtype=self.dtype,
|
||||||
|
observer=self.observer,
|
||||||
|
qscheme=self.qscheme,
|
||||||
|
quant_min=self.quant_min,
|
||||||
|
quant_max=self.quant_max,
|
||||||
|
reduce_range=False,
|
||||||
|
)
|
||||||
|
ada_quantizer.enable_observer()
|
||||||
|
ada_quantizer(q_module.weight)
|
||||||
|
ada_quantizer.disable_observer()
|
||||||
|
ada_quantizer.enable_fake_quant()
|
||||||
|
optimizer = torch.optim.Adam([ada_quantizer.V])
|
||||||
|
inp, out, fp_in = self.get_data_inp_out(module, q_module, self.data)
|
||||||
|
|
||||||
|
logger.info("==================== Before adaround ====================")
|
||||||
|
test_in, test_out, fp_test_in = self.get_data_inp_out(
|
||||||
|
module, q_module, self.data[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
torch.abs(test_out[0] - module(fp_test_in[0])).sum().item() == 0
|
||||||
|
), "In-placed activation is detected, please do not use activation in-placed"
|
||||||
|
# Stack the tensors in each list into a single tensor
|
||||||
|
# Assuming inp and out are your lists of tensors
|
||||||
|
inp_tensor = torch.vstack(inp)
|
||||||
|
out_tensor = torch.vstack(out)
|
||||||
|
dataset = TensorDataset(inp_tensor, out_tensor)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
self._compute_and_display_local_losses(
|
||||||
|
ada_quantizer, q_module, test_in[0], test_out[0]
|
||||||
|
)
|
||||||
|
global_idx = 0
|
||||||
|
one_iter = len(out) // self.batch_size
|
||||||
|
for iteration in range(self.max_iter // one_iter):
|
||||||
|
reconstruction_loss = regularization_loss = torch.tensor(0)
|
||||||
|
for q_inp, fp_out in dataloader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
q_weight = ada_quantizer(q_module.weight)
|
||||||
|
if isinstance(module, torch.nn.Conv1d):
|
||||||
|
q_out = torch.nn.functional.conv1d(
|
||||||
|
q_inp,
|
||||||
|
q_weight,
|
||||||
|
stride=q_module.stride,
|
||||||
|
padding=q_module.padding,
|
||||||
|
dilation=q_module.dilation,
|
||||||
|
groups=q_module.groups,
|
||||||
|
)
|
||||||
|
elif isinstance(q_module, torch.nn.Linear):
|
||||||
|
q_out = torch.nn.functional.linear(
|
||||||
|
q_inp,
|
||||||
|
q_weight,
|
||||||
|
bias=q_module.bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
regularization_loss, reconstruction_loss = self.adaptive_round_loss_fn(
|
||||||
|
fp_out,
|
||||||
|
q_out,
|
||||||
|
ada_quantizer.V,
|
||||||
|
curr_iter=global_idx,
|
||||||
|
)
|
||||||
|
loss = regularization_loss + reconstruction_loss
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
global_idx += 1
|
||||||
|
if global_idx >= self.max_iter:
|
||||||
|
break
|
||||||
|
if global_idx >= self.max_iter:
|
||||||
|
break
|
||||||
|
if iteration % 30 == 0:
|
||||||
|
logger.info(
|
||||||
|
f"glob iter {global_idx} regularization_loss {regularization_loss.item()} " # noqa: G004
|
||||||
|
f"reconstruction_loss {reconstruction_loss.item()}" # noqa: G004
|
||||||
|
)
|
||||||
|
logger.info("==================== After adaround ====================")
|
||||||
|
self._compute_and_display_local_losses(
|
||||||
|
ada_quantizer, q_module, test_in[0], test_out[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
ada_quantizer.use_soft_rounding = True
|
||||||
|
ada_quantizer.V.requires_grad = False
|
||||||
|
ada_quantizer = ada_quantizer.eval()
|
||||||
|
q_weight = ada_quantizer(q_module.weight)
|
||||||
|
# At the end of optimization, we need to copy the adarounded weight back to the original module
|
||||||
|
q_module.weight.data.copy_(q_weight)
|
||||||
|
# Eager mode requires observer to be set as "weight_fake_quant" to be parsed
|
||||||
|
q_module.weight_fake_quant = ada_quantizer.activation_post_process
|
||||||
Loading…
Reference in New Issue
Block a user