mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Initial implementation of AdaRound (#126153)"
This reverts commit 175c18af81.
Reverted https://github.com/pytorch/pytorch/pull/126153 on behalf of https://github.com/huydhn due to Sorry for reverting your change but the lint failure is legit because there are more than one lint issues, torch/optim/asgd.py is just the last one ([comment](https://github.com/pytorch/pytorch/pull/126153#issuecomment-2113902522))
This commit is contained in:
parent
e3c5d1b7d7
commit
ae6fdfa539
|
|
@ -1,115 +0,0 @@
|
|||
# Owner(s): ["oncall: speech_infra"]
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.ao.quantization.experimental.adaround_optimization import AdaptiveRoundingOptimizer
|
||||
|
||||
from torch.nn import functional as F
|
||||
from torch.quantization.observer import MinMaxObserver
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
|
||||
|
||||
|
||||
def forward_wrapper(fetcher):
|
||||
def forward(module, input, output):
|
||||
fetcher.append(input[0].detach())
|
||||
fetcher.append(output.detach())
|
||||
|
||||
return forward
|
||||
|
||||
class TestAdaround(QuantizationTestCase):
|
||||
def feedforawrd_callback(
|
||||
self,
|
||||
model,
|
||||
data,
|
||||
) -> None:
|
||||
model(data)
|
||||
|
||||
def run_adaround(self, model, img_data):
|
||||
adaround_optimizer = AdaptiveRoundingOptimizer(
|
||||
model,
|
||||
self.feedforawrd_callback,
|
||||
forward_wrapper,
|
||||
img_data,
|
||||
max_iter=100,
|
||||
batch_size=10,
|
||||
)
|
||||
adarounded_model = adaround_optimizer.run_adaround()
|
||||
return adarounded_model
|
||||
|
||||
def get_fake_quant(self, model):
|
||||
hard_fake_quant_model = copy.deepcopy(model)
|
||||
for _, module in hard_fake_quant_model.named_modules():
|
||||
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
|
||||
weight_observer = MinMaxObserver(
|
||||
quant_min=-128,
|
||||
quant_max=127,
|
||||
dtype=torch.qint8,
|
||||
qscheme=torch.per_tensor_symmetric,
|
||||
)
|
||||
weight_observer(module.weight)
|
||||
scale, zero_point = weight_observer.calculate_qparams()
|
||||
fake_quant_module = torch.fake_quantize_per_tensor_affine(
|
||||
module.weight,
|
||||
scale=scale,
|
||||
zero_point=zero_point,
|
||||
quant_min=-128,
|
||||
quant_max=127,
|
||||
)
|
||||
module.weight.data.copy_(fake_quant_module)
|
||||
return hard_fake_quant_model
|
||||
|
||||
def test_linear_chain(self):
|
||||
class LinearChain(nn.Module):
|
||||
def __init__(self):
|
||||
super(LinearChain, self).__init__()
|
||||
self.linear1 = nn.Linear(3, 4)
|
||||
self.linear2 = nn.Linear(4, 5)
|
||||
self.linear3 = nn.Linear(5, 6)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
x = self.linear3(x)
|
||||
return x
|
||||
float_model = LinearChain()
|
||||
img_data = [torch.rand(10, 3, dtype=torch.float) for _ in range(50)]
|
||||
adarounded_model = self.run_adaround(float_model, img_data)
|
||||
fq_model = self.get_fake_quant(float_model)
|
||||
rand_input = torch.rand(10, 3)
|
||||
with torch.no_grad():
|
||||
ada_out = adarounded_model(rand_input)
|
||||
fq_out = fq_model(rand_input)
|
||||
float_out = float_model(rand_input)
|
||||
ada_loss = F.mse_loss(ada_out, float_out)
|
||||
fq_loss = F.mse_loss(fq_out, float_out)
|
||||
self.assertTrue(ada_loss.item() < fq_loss.item())
|
||||
|
||||
|
||||
def test_conv_chain(self):
|
||||
class ConvChain(nn.Module):
|
||||
def __init__(self):
|
||||
super(ConvChain, self).__init__()
|
||||
self.conv2d1 = nn.Conv2d(3, 4, 5, 5)
|
||||
self.conv2d2 = nn.Conv2d(4, 5, 5, 5)
|
||||
self.conv2d3 = nn.Conv2d(5, 6, 5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv2d1(x)
|
||||
x = self.conv2d2(x)
|
||||
x = self.conv2d3(x)
|
||||
return x
|
||||
float_model = ConvChain()
|
||||
img_data = [torch.rand(10, 3, 125, 125, dtype=torch.float) for _ in range(50)]
|
||||
adarounded_model = self.run_adaround(float_model, img_data)
|
||||
fq_model = self.get_fake_quant(float_model)
|
||||
rand_input = torch.rand(10, 3, 256, 256)
|
||||
with torch.no_grad():
|
||||
ada_out = adarounded_model(rand_input)
|
||||
fq_out = fq_model(rand_input)
|
||||
float_out = float_model(rand_input)
|
||||
ada_loss = F.mse_loss(ada_out, float_out)
|
||||
fq_loss = F.mse_loss(fq_out, float_out)
|
||||
self.assertTrue(ada_loss.item() < fq_loss.item())
|
||||
|
|
@ -1,146 +0,0 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.fake_quantize import _is_symmetric_quant
|
||||
from torch.ao.quantization.utils import is_per_tensor
|
||||
from torch.quantization import FakeQuantize
|
||||
from torch.quantization.observer import MinMaxObserver
|
||||
|
||||
class AdaroundFakeQuantizer(FakeQuantize):
|
||||
"""
|
||||
This is a FakeQuantizer that enables an adaptive rounding fake quantizer.
|
||||
Adaround is a technique to adaptively round weights, derived from the paper https://arxiv.org/pdf/2004.10568.pdf
|
||||
For HTP compatibility, we are targeting to use symmetric quantization
|
||||
"""
|
||||
|
||||
scale: torch.Tensor
|
||||
zero_point: torch.Tensor
|
||||
V: torch.nn.Parameter
|
||||
|
||||
# pyre-fixme[3]: Return type must be annotated.
|
||||
def __init__(
|
||||
self,
|
||||
observer=MinMaxObserver,
|
||||
qscheme=torch.per_tensor_symmetric, # not used, but needed for fakequant
|
||||
quant_min: int = -128,
|
||||
quant_max: int = 127,
|
||||
ch_axis: int = 0,
|
||||
# pyre-fixme[2]: Parameter must be annotated.
|
||||
**observer_kwargs,
|
||||
):
|
||||
super(AdaroundFakeQuantizer, self).__init__(
|
||||
observer=observer,
|
||||
qscheme=qscheme,
|
||||
quant_min=quant_min,
|
||||
quant_max=quant_max,
|
||||
is_dynamic=False,
|
||||
**observer_kwargs,
|
||||
)
|
||||
# Populate quant_min/quant_max to observer_kwargs if valid
|
||||
if quant_min is not None and quant_max is not None:
|
||||
assert (
|
||||
quant_min <= quant_max
|
||||
), "quant_min must be less than or equal to quant_max"
|
||||
# pyre-fixme[4]: Attribute must be annotated.
|
||||
self.qscheme = qscheme
|
||||
self.is_per_tensor: bool = is_per_tensor(qscheme)
|
||||
self.is_symmetric: bool = _is_symmetric_quant(qscheme)
|
||||
assert self.is_symmetric, "Only symmetric quantization is supported"
|
||||
self.ch_axis: int = ch_axis
|
||||
|
||||
self.scale = torch.tensor([], requires_grad=False)
|
||||
self.zero_point = torch.tensor([], requires_grad=False)
|
||||
self.V = torch.nn.Parameter(torch.tensor([]), requires_grad=True)
|
||||
# Fixed Stretch parameters
|
||||
self.zeta: torch.Tensor = torch.tensor(1.1, requires_grad=False)
|
||||
self.gamma: torch.Tensor = torch.tensor(-0.1, requires_grad=False)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
self.use_soft_rounding = True
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.scale, self.zero_point
|
||||
|
||||
@torch.jit.export
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
|
||||
f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, "
|
||||
f"scale={self.scale}, zero_point={self.zero_point}, (self.V >= 0).int().sum()={(self.V >= 0).int().sum()}"
|
||||
)
|
||||
|
||||
def enable_weight_fake_quant(self) -> None:
|
||||
self.fake_quant_enabled[0] = 1
|
||||
|
||||
def get_rectified_sigmoid_func(self) -> torch.Tensor:
|
||||
if self.use_soft_rounding:
|
||||
return torch.clamp(
|
||||
self.sigmoid(self.V) * (self.zeta - self.gamma) + self.gamma,
|
||||
min=0,
|
||||
max=1,
|
||||
)
|
||||
else:
|
||||
# This will dump a binary solution
|
||||
return (self.V >= 0).int()
|
||||
|
||||
@torch.jit.ignore
|
||||
def update_scale(
|
||||
self, X: torch.Tensor, _scale: torch.Tensor, _zero_point: torch.Tensor
|
||||
) -> None:
|
||||
if self.scale.numel() == 0:
|
||||
self.scale.data = _scale.to(X.device)
|
||||
self.zero_point = _zero_point.to(X.device)
|
||||
else:
|
||||
self.scale.data = _scale
|
||||
if not self.is_symmetric:
|
||||
self.zero_point = _zero_point
|
||||
else:
|
||||
self.zero_point = torch.zeros_like(_zero_point)
|
||||
for i in range(X.dim()):
|
||||
if i == self.ch_axis:
|
||||
continue
|
||||
self.zero_point = self.zero_point.unsqueeze(i)
|
||||
X_q = X / self.scale
|
||||
X_q_floor = torch.floor(X_q)
|
||||
residual = X_q - X_q_floor # [0,1)
|
||||
assert torch.all(
|
||||
torch.ge(residual, 0)
|
||||
), "residual should be non-negative [0, 1)"
|
||||
V_init = -torch.log((self.zeta - self.gamma) / (residual - self.gamma) - 1)
|
||||
self.V.data = V_init
|
||||
|
||||
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
||||
if self.observer_enabled[0] == 1:
|
||||
X_detached = X.detach()
|
||||
self.activation_post_process(X_detached)
|
||||
_scale, _zero_point = self.activation_post_process.calculate_qparams()
|
||||
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(
|
||||
self.zero_point.device
|
||||
)
|
||||
dims = list(range(X.dim()))
|
||||
if not self.is_per_tensor:
|
||||
dims.remove(self.ch_axis)
|
||||
if not self.is_per_tensor:
|
||||
for i in range(X.dim()):
|
||||
if i == self.ch_axis:
|
||||
continue
|
||||
_scale = _scale.unsqueeze(i)
|
||||
_zero_point = _zero_point.unsqueeze(i)
|
||||
self.update_scale(X_detached, _scale, _zero_point)
|
||||
|
||||
if self.fake_quant_enabled[0] == 1:
|
||||
# Perform soft quantization
|
||||
# See the equation (23) in Adaround paper
|
||||
h_v = self.get_rectified_sigmoid_func()
|
||||
X_q = X / self.scale
|
||||
# Straight-Through Estimator for floor function
|
||||
X_q_floor = torch.floor(X_q) + self.zero_point
|
||||
# Regardless of rounding, gradient should be able to flow back to self.V from X_q_dq.
|
||||
# With adaround, we don't train weight, but train V only.
|
||||
X_q_dq = (
|
||||
torch.clamp(X_q_floor + h_v, min=self.quant_min, max=self.quant_max)
|
||||
- self.zero_point
|
||||
) * self.scale
|
||||
return X_q_dq
|
||||
else:
|
||||
return X
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
ADAROUND_ZETA: float = 1.1
|
||||
ADAROUND_GAMMA: float = -0.1
|
||||
|
||||
|
||||
class AdaptiveRoundingLoss(torch.nn.Module):
|
||||
"""
|
||||
Adaptive Rounding Loss functions described in https://arxiv.org/pdf/2004.10568.pdf
|
||||
rounding regularization is eq [24]
|
||||
reconstruction loss is eq [25] except regularization term
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_iter: int,
|
||||
warm_start: float = 0.2,
|
||||
beta_range: Tuple[int, int] = (20, 2),
|
||||
reg_param: float = 0.001,
|
||||
) -> None:
|
||||
super(AdaptiveRoundingLoss, self).__init__()
|
||||
self.max_iter = max_iter
|
||||
self.warm_start = warm_start
|
||||
self.beta_range = beta_range
|
||||
self.reg_param = reg_param
|
||||
|
||||
def rounding_regularization(
|
||||
self,
|
||||
V: torch.Tensor,
|
||||
curr_iter: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Major logics copied from original Adaround Implementation : https://github.com/quic/aimet/blob/develop/TrainingExtensions/torch/src/python/aimet_torch/adaround/adaround_loss.py#L114
|
||||
Apply rounding regularization to the input tensor V.
|
||||
"""
|
||||
assert (
|
||||
curr_iter < self.max_iter
|
||||
), "Current iteration strictly les sthan max iteration"
|
||||
if curr_iter < self.warm_start * self.max_iter:
|
||||
return torch.tensor(0.0)
|
||||
else:
|
||||
start_beta, end_beta = self.beta_range
|
||||
warm_start_end_iter = self.warm_start * self.max_iter
|
||||
|
||||
# compute relative iteration of current iteration
|
||||
rel_iter = (curr_iter - warm_start_end_iter) / (
|
||||
self.max_iter - warm_start_end_iter
|
||||
)
|
||||
beta = end_beta + 0.5 * (start_beta - end_beta) * (
|
||||
1 + np.cos(rel_iter * np.pi)
|
||||
)
|
||||
|
||||
# A rectified sigmoid for soft-quantization as formualted [23] in https://arxiv.org/pdf/2004.10568.pdf
|
||||
h_alpha = torch.clamp(
|
||||
torch.sigmoid(V) * (ADAROUND_ZETA - ADAROUND_GAMMA) + ADAROUND_GAMMA,
|
||||
min=0,
|
||||
max=1,
|
||||
)
|
||||
|
||||
# Apply rounding regularization
|
||||
# This regularization term helps out term to converge into binary solution either 0 or 1 at the end of optimization.
|
||||
inner_term = torch.add(2 * h_alpha, -1).abs().pow(beta)
|
||||
regularization_term = torch.add(1, -inner_term).sum()
|
||||
return regularization_term * self.reg_param
|
||||
|
||||
def reconstruction_loss(
|
||||
self,
|
||||
soft_quantized_output: torch.Tensor,
|
||||
original_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the reconstruction loss between the soft quantized output and the original output.
|
||||
"""
|
||||
return F.mse_loss(
|
||||
soft_quantized_output, original_output, reduction="none"
|
||||
).mean()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
soft_quantized_output: torch.Tensor,
|
||||
original_output: torch.Tensor,
|
||||
V: torch.Tensor,
|
||||
curr_iter: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute the asymmetric reconstruction formulation as eq [25]
|
||||
"""
|
||||
regularization_term = self.rounding_regularization(V, curr_iter)
|
||||
reconstruction_term = self.reconstruction_loss(
|
||||
soft_quantized_output, original_output
|
||||
)
|
||||
return regularization_term, reconstruction_term
|
||||
|
|
@ -1,237 +0,0 @@
|
|||
import copy
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch.ao.quantization.experimental.adaround_fake_quantize import AdaroundFakeQuantizer
|
||||
from torch.ao.quantization.experimental.adaround_loss import AdaptiveRoundingLoss
|
||||
from torch.ao.quantization.observer import MinMaxObserver
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parallel import DataParallel
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdaptiveRoundingOptimizer:
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[torch.nn.Module, torch.nn.DataParallel],
|
||||
callback: Callable[[torch.nn.Module, List[Any]], None],
|
||||
forward_hook_wrapper: Callable[[List[torch.Tensor]], Callable],
|
||||
data: List[Any],
|
||||
observer: Type[torch.ao.quantization.observer.ObserverBase] = MinMaxObserver,
|
||||
max_iter=10000,
|
||||
dtype: torch.dtype = torch.qint8,
|
||||
quant_min=-128,
|
||||
quant_max=127,
|
||||
qscheme: torch.qscheme = torch.per_tensor_symmetric,
|
||||
batch_size: int = 256,
|
||||
):
|
||||
self.model = model
|
||||
self.q_model = copy.deepcopy(self.model)
|
||||
self.device = torch.device("cuda") if torch.cuda.is_available() else None
|
||||
self.callback = callback
|
||||
self.forward_hook_wrapper = forward_hook_wrapper
|
||||
# TODO rather than having a data as list type or, we better pass *iterator* instead of list
|
||||
self.data = data
|
||||
self.batch_size = min(batch_size, len(data))
|
||||
self.max_iter = max_iter
|
||||
self.adaptive_round_loss_fn = AdaptiveRoundingLoss(
|
||||
max_iter=self.max_iter, warm_start=0.2
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.observer = observer
|
||||
self.quant_min = quant_min
|
||||
self.quant_max = quant_max
|
||||
self.qscheme = qscheme
|
||||
|
||||
def run_adaround(self) -> torch.nn.Module:
|
||||
layer_list : List[Tuple[str, torch.nn.Module, torch.nn.Module]] = []
|
||||
for (name, module), q_module in zip(
|
||||
self.model.named_modules(), self.q_model.modules()
|
||||
):
|
||||
if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear)):
|
||||
# Knowing activation ahead-of-time would be helpful for asymmetric formulation
|
||||
# But this is challenging in eager mode, but graph module.
|
||||
layer_list.append((name, module, q_module))
|
||||
logger.info(f"Total number of layers : {len(layer_list)}")
|
||||
|
||||
for (name, module, q_module) in layer_list:
|
||||
print(
|
||||
f"Kick start adaptive rounding on {name} module {module}"
|
||||
)
|
||||
self.optimize_adaptive_rounding(
|
||||
module,
|
||||
q_module,
|
||||
None,
|
||||
)
|
||||
|
||||
return (
|
||||
self.q_model.module
|
||||
if isinstance(self.q_model, DataParallel)
|
||||
else self.q_model
|
||||
)
|
||||
|
||||
def get_data_inp_out(
|
||||
self, module: torch.nn.Module, q_module: torch.nn.Module, data: List[Any]
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
|
||||
fp_out: List[torch.Tensor] = []
|
||||
q_input: List[torch.Tensor] = []
|
||||
fp_input: List[torch.Tensor] = []
|
||||
fp32_fetcher: List[torch.Tensor] = []
|
||||
quant_fetcher: List[torch.Tensor] = []
|
||||
handler1 = module.register_forward_hook(self.forward_hook_wrapper(fp32_fetcher))
|
||||
handler2 = q_module.register_forward_hook(
|
||||
self.forward_hook_wrapper(quant_fetcher)
|
||||
)
|
||||
for data_ in data:
|
||||
with torch.no_grad():
|
||||
self.callback(self.model, data_)
|
||||
self.callback(self.q_model, data_)
|
||||
fp32_output = fp32_fetcher[1]
|
||||
quant_input = quant_fetcher[0]
|
||||
fp_out.append(fp32_output)
|
||||
q_input.append(quant_input)
|
||||
fp_input.append(fp32_fetcher[0])
|
||||
handler1.remove()
|
||||
handler2.remove()
|
||||
return q_input, fp_out, fp_input
|
||||
|
||||
@torch.no_grad()
|
||||
def feed_forward(self, x, weight, module):
|
||||
if isinstance(module, torch.nn.Conv1d):
|
||||
out = torch.nn.functional.conv1d(
|
||||
x,
|
||||
weight,
|
||||
stride=module.stride,
|
||||
padding=module.padding,
|
||||
dilation=module.dilation,
|
||||
groups=module.groups,
|
||||
)
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
out = torch.nn.functional.linear(
|
||||
x,
|
||||
weight,
|
||||
bias=module.bias,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return out
|
||||
|
||||
def _compute_and_display_local_losses(
|
||||
self,
|
||||
ada_quantizer: AdaroundFakeQuantizer,
|
||||
q_module: torch.nn.Module,
|
||||
q_inp: torch.Tensor,
|
||||
fp_out: torch.Tensor,
|
||||
):
|
||||
with torch.no_grad():
|
||||
ada_quantizer.use_soft_rounding = False
|
||||
q_w_hard_round = ada_quantizer(q_module.weight)
|
||||
out_hard_quant = self.feed_forward(q_inp, q_w_hard_round, q_module)
|
||||
ada_quantizer.use_soft_rounding = True
|
||||
q_w_soft_round = ada_quantizer(q_module.weight)
|
||||
out_soft_quant = self.feed_forward(q_inp, q_w_soft_round, q_module)
|
||||
soft_quant_loss = F.mse_loss(out_soft_quant, fp_out)
|
||||
hard_quant_loss = F.mse_loss(out_hard_quant, fp_out)
|
||||
logger.info(
|
||||
f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}"
|
||||
)
|
||||
|
||||
def optimize_adaptive_rounding(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
q_module: torch.nn.Module,
|
||||
activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
||||
) -> None:
|
||||
|
||||
ada_quantizer = AdaroundFakeQuantizer(
|
||||
dtype=self.dtype,
|
||||
observer=self.observer,
|
||||
qscheme=self.qscheme,
|
||||
quant_min=self.quant_min,
|
||||
quant_max=self.quant_max,
|
||||
reduce_range=False,
|
||||
)
|
||||
ada_quantizer.enable_observer()
|
||||
ada_quantizer(q_module.weight)
|
||||
ada_quantizer.disable_observer()
|
||||
ada_quantizer.enable_fake_quant()
|
||||
optimizer = torch.optim.Adam([ada_quantizer.V])
|
||||
inp, out, fp_in = self.get_data_inp_out(module, q_module, self.data)
|
||||
|
||||
logger.info("==================== Before adaround ====================")
|
||||
test_in, test_out, fp_test_in = self.get_data_inp_out(
|
||||
module, q_module, self.data[0]
|
||||
)
|
||||
|
||||
assert (
|
||||
torch.abs(test_out[0] - module(fp_test_in[0])).sum().item() == 0
|
||||
), "In-placed activation is detected, please do not use activation in-placed"
|
||||
# Stack the tensors in each list into a single tensor
|
||||
# Assuming inp and out are your lists of tensors
|
||||
inp_tensor = torch.vstack(inp)
|
||||
out_tensor = torch.vstack(out)
|
||||
dataset = TensorDataset(inp_tensor, out_tensor)
|
||||
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
|
||||
|
||||
self._compute_and_display_local_losses(
|
||||
ada_quantizer, q_module, test_in[0], test_out[0]
|
||||
)
|
||||
global_idx = 0
|
||||
one_iter = len(out) // self.batch_size
|
||||
for iteration in range(self.max_iter // one_iter):
|
||||
reconstruction_loss = regularization_loss = torch.tensor(0)
|
||||
for q_inp, fp_out in dataloader:
|
||||
optimizer.zero_grad()
|
||||
q_weight = ada_quantizer(q_module.weight)
|
||||
if isinstance(module, torch.nn.Conv1d):
|
||||
q_out = torch.nn.functional.conv1d(
|
||||
q_inp,
|
||||
q_weight,
|
||||
stride=q_module.stride,
|
||||
padding=q_module.padding,
|
||||
dilation=q_module.dilation,
|
||||
groups=q_module.groups,
|
||||
)
|
||||
elif isinstance(q_module, torch.nn.Linear):
|
||||
q_out = torch.nn.functional.linear(
|
||||
q_inp,
|
||||
q_weight,
|
||||
bias=q_module.bias,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
regularization_loss, reconstruction_loss = self.adaptive_round_loss_fn(
|
||||
fp_out,
|
||||
q_out,
|
||||
ada_quantizer.V,
|
||||
curr_iter=global_idx,
|
||||
)
|
||||
loss = regularization_loss + reconstruction_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
global_idx += 1
|
||||
if global_idx >= self.max_iter:
|
||||
break
|
||||
if global_idx >= self.max_iter:
|
||||
break
|
||||
if iteration % 30 == 0:
|
||||
logger.info(
|
||||
f"glob iter {global_idx} regularization_loss {regularization_loss.item()} reconstruction_loss {reconstruction_loss.item()}"
|
||||
)
|
||||
logger.info("==================== After adaround ====================")
|
||||
self._compute_and_display_local_losses(
|
||||
ada_quantizer, q_module, test_in[0], test_out[0]
|
||||
)
|
||||
|
||||
ada_quantizer.use_soft_rounding = True
|
||||
ada_quantizer.V.requires_grad = False
|
||||
ada_quantizer = ada_quantizer.eval()
|
||||
q_weight = ada_quantizer(q_module.weight)
|
||||
# Reference : https://github.com/quic/aimet/blob/develop/TrainingExtensions/torch/src/python/aimet_torch/adaround/adaround_weight.py#L374
|
||||
# At the end of optimization, we need to copy the adarounded weight back to the original module
|
||||
q_module.weight.data.copy_(q_weight)
|
||||
# Eager mode requires observer to be set as "weight_fake_quant" to be parsed
|
||||
q_module.weight_fake_quant = ada_quantizer.activation_post_process
|
||||
Loading…
Reference in New Issue
Block a user