Initial implementation of AdaRound (#126153)

Summary:
This is an implementation of AdaRound from a paper https://arxiv.org/abs/2004.10568

This algorithm is going to be used by multiple people, hence we need make it official implementation.

Differential Revision: D57227565

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126153
Approved by: https://github.com/jerryzh168, https://github.com/huydhn
This commit is contained in:
Kwanghoon An 2024-05-17 19:44:50 +00:00 committed by PyTorch MergeBot
parent 875221dedf
commit eb0b16db92
4 changed files with 600 additions and 0 deletions

View File

@ -0,0 +1,118 @@
# Owner(s): ["oncall: speech_infra"]
import copy
import torch
import torch.nn as nn
from torch.ao.quantization.experimental.adaround_optimization import (
AdaptiveRoundingOptimizer,
)
from torch.nn import functional as F
from torch.quantization.observer import MinMaxObserver
from torch.testing._internal.common_quantization import QuantizationTestCase
def forward_wrapper(fetcher):
def forward(module, input, output):
fetcher.append(input[0].detach())
fetcher.append(output.detach())
return forward
class TestAdaround(QuantizationTestCase):
def feedforawrd_callback(
self,
model,
data,
) -> None:
model(data)
def run_adaround(self, model, img_data):
adaround_optimizer = AdaptiveRoundingOptimizer(
model,
self.feedforawrd_callback,
forward_wrapper,
img_data,
max_iter=100,
batch_size=10,
)
adarounded_model = adaround_optimizer.run_adaround()
return adarounded_model
def get_fake_quant(self, model):
hard_fake_quant_model = copy.deepcopy(model)
for _, module in hard_fake_quant_model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
weight_observer = MinMaxObserver(
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
)
weight_observer(module.weight)
scale, zero_point = weight_observer.calculate_qparams()
fake_quant_module = torch.fake_quantize_per_tensor_affine(
module.weight,
scale=scale,
zero_point=zero_point,
quant_min=-128,
quant_max=127,
)
module.weight.data.copy_(fake_quant_module)
return hard_fake_quant_model
def test_linear_chain(self):
class LinearChain(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(3, 4)
self.linear2 = nn.Linear(4, 5)
self.linear3 = nn.Linear(5, 6)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
float_model = LinearChain()
img_data = [torch.rand(10, 3, dtype=torch.float) for _ in range(50)]
adarounded_model = self.run_adaround(float_model, img_data)
fq_model = self.get_fake_quant(float_model)
rand_input = torch.rand(10, 3)
with torch.no_grad():
ada_out = adarounded_model(rand_input)
fq_out = fq_model(rand_input)
float_out = float_model(rand_input)
ada_loss = F.mse_loss(ada_out, float_out)
fq_loss = F.mse_loss(fq_out, float_out)
self.assertTrue(ada_loss.item() < fq_loss.item())
def test_conv_chain(self):
class ConvChain(nn.Module):
def __init__(self):
super().__init__()
self.conv2d1 = nn.Conv2d(3, 4, 5, 5)
self.conv2d2 = nn.Conv2d(4, 5, 5, 5)
self.conv2d3 = nn.Conv2d(5, 6, 5, 5)
def forward(self, x):
x = self.conv2d1(x)
x = self.conv2d2(x)
x = self.conv2d3(x)
return x
float_model = ConvChain()
img_data = [torch.rand(10, 3, 125, 125, dtype=torch.float) for _ in range(50)]
adarounded_model = self.run_adaround(float_model, img_data)
fq_model = self.get_fake_quant(float_model)
rand_input = torch.rand(10, 3, 256, 256)
with torch.no_grad():
ada_out = adarounded_model(rand_input)
fq_out = fq_model(rand_input)
float_out = float_model(rand_input)
ada_loss = F.mse_loss(ada_out, float_out)
fq_loss = F.mse_loss(fq_out, float_out)
self.assertTrue(ada_loss.item() < fq_loss.item())

View File

@ -0,0 +1,148 @@
from typing import Tuple
import torch
from torch.ao.quantization.fake_quantize import _is_symmetric_quant
from torch.ao.quantization.utils import is_per_tensor
from torch.quantization import FakeQuantize
from torch.quantization.observer import MinMaxObserver
class AdaroundFakeQuantizer(FakeQuantize):
"""
This is a FakeQuantizer that enables an adaptive rounding fake quantizer.
Adaround is a technique to adaptively round weights, derived from the paper https://arxiv.org/pdf/2004.10568.pdf
For HTP compatibility, we are targeting to use symmetric quantization
"""
scale: torch.Tensor
zero_point: torch.Tensor
V: torch.nn.Parameter
# pyre-fixme[3]: Return type must be annotated.
def __init__(
self,
observer=MinMaxObserver,
qscheme=torch.per_tensor_symmetric, # not used, but needed for fakequant
quant_min: int = -128,
quant_max: int = 127,
ch_axis: int = 0,
# pyre-fixme[2]: Parameter must be annotated.
**observer_kwargs,
):
super().__init__(
observer=observer,
qscheme=qscheme,
quant_min=quant_min,
quant_max=quant_max,
is_dynamic=False,
**observer_kwargs,
)
# Populate quant_min/quant_max to observer_kwargs if valid
if quant_min is not None and quant_max is not None:
assert (
quant_min <= quant_max
), "quant_min must be less than or equal to quant_max"
# pyre-fixme[4]: Attribute must be annotated.
self.qscheme = qscheme
self.is_per_tensor: bool = is_per_tensor(qscheme)
self.is_symmetric: bool = _is_symmetric_quant(qscheme)
assert self.is_symmetric, "Only symmetric quantization is supported"
self.ch_axis: int = ch_axis
self.scale = torch.tensor([], requires_grad=False)
self.zero_point = torch.tensor([], requires_grad=False)
self.V = torch.nn.Parameter(torch.tensor([]), requires_grad=True)
# Fixed Stretch parameters
self.zeta: torch.Tensor = torch.tensor(1.1, requires_grad=False)
self.gamma: torch.Tensor = torch.tensor(-0.1, requires_grad=False)
self.sigmoid = torch.nn.Sigmoid()
self.use_soft_rounding = True
@torch.jit.export
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
return self.scale, self.zero_point
@torch.jit.export
def extra_repr(self) -> str:
return (
f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, "
f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, "
f"scale={self.scale}, zero_point={self.zero_point}, (self.V >= 0).int().sum()={(self.V >= 0).int().sum()}"
)
def enable_weight_fake_quant(self) -> None:
self.fake_quant_enabled[0] = 1
def get_rectified_sigmoid_func(self) -> torch.Tensor:
if self.use_soft_rounding:
return torch.clamp(
self.sigmoid(self.V) * (self.zeta - self.gamma) + self.gamma,
min=0,
max=1,
)
else:
# This will dump a binary solution
return (self.V >= 0).int()
@torch.jit.ignore
def update_scale(
self, X: torch.Tensor, _scale: torch.Tensor, _zero_point: torch.Tensor
) -> None:
if self.scale.numel() == 0:
self.scale.data = _scale.to(X.device)
self.zero_point = _zero_point.to(X.device)
else:
self.scale.data = _scale
if not self.is_symmetric:
self.zero_point = _zero_point
else:
self.zero_point = torch.zeros_like(_zero_point)
for i in range(X.dim()):
if i == self.ch_axis:
continue
self.zero_point = self.zero_point.unsqueeze(i)
X_q = X / self.scale
X_q_floor = torch.floor(X_q)
residual = X_q - X_q_floor # [0,1)
assert torch.all(
torch.ge(residual, 0)
), "residual should be non-negative [0, 1)"
V_init = -torch.log((self.zeta - self.gamma) / (residual - self.gamma) - 1)
self.V.data = V_init
def forward(self, X: torch.Tensor) -> torch.Tensor:
if self.observer_enabled[0] == 1:
X_detached = X.detach()
self.activation_post_process(X_detached)
_scale, _zero_point = self.activation_post_process.calculate_qparams()
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(
self.zero_point.device
)
dims = list(range(X.dim()))
if not self.is_per_tensor:
dims.remove(self.ch_axis)
if not self.is_per_tensor:
for i in range(X.dim()):
if i == self.ch_axis:
continue
_scale = _scale.unsqueeze(i)
_zero_point = _zero_point.unsqueeze(i)
self.update_scale(X_detached, _scale, _zero_point)
if self.fake_quant_enabled[0] == 1:
# Perform soft quantization
# See the equation (23) in Adaround paper
h_v = self.get_rectified_sigmoid_func()
X_q = X / self.scale
# Straight-Through Estimator for floor function
X_q_floor = torch.floor(X_q) + self.zero_point
# Regardless of rounding, gradient should be able to flow back to self.V from X_q_dq.
# With adaround, we don't train weight, but train V only.
X_q_dq = (
torch.clamp(X_q_floor + h_v, min=self.quant_min, max=self.quant_max)
- self.zero_point
) * self.scale
return X_q_dq
else:
return X

View File

@ -0,0 +1,96 @@
from typing import Tuple
import numpy as np
import torch
from torch.nn import functional as F
ADAROUND_ZETA: float = 1.1
ADAROUND_GAMMA: float = -0.1
class AdaptiveRoundingLoss(torch.nn.Module):
"""
Adaptive Rounding Loss functions described in https://arxiv.org/pdf/2004.10568.pdf
rounding regularization is eq [24]
reconstruction loss is eq [25] except regularization term
"""
def __init__(
self,
max_iter: int,
warm_start: float = 0.2,
beta_range: Tuple[int, int] = (20, 2),
reg_param: float = 0.001,
) -> None:
super().__init__()
self.max_iter = max_iter
self.warm_start = warm_start
self.beta_range = beta_range
self.reg_param = reg_param
def rounding_regularization(
self,
V: torch.Tensor,
curr_iter: int,
) -> torch.Tensor:
"""
Major logics copied from official Adaround Implementation.
Apply rounding regularization to the input tensor V.
"""
assert (
curr_iter < self.max_iter
), "Current iteration strictly les sthan max iteration"
if curr_iter < self.warm_start * self.max_iter:
return torch.tensor(0.0)
else:
start_beta, end_beta = self.beta_range
warm_start_end_iter = self.warm_start * self.max_iter
# compute relative iteration of current iteration
rel_iter = (curr_iter - warm_start_end_iter) / (
self.max_iter - warm_start_end_iter
)
beta = end_beta + 0.5 * (start_beta - end_beta) * (
1 + np.cos(rel_iter * np.pi)
)
# A rectified sigmoid for soft-quantization as formualted [23] in https://arxiv.org/pdf/2004.10568.pdf
h_alpha = torch.clamp(
torch.sigmoid(V) * (ADAROUND_ZETA - ADAROUND_GAMMA) + ADAROUND_GAMMA,
min=0,
max=1,
)
# Apply rounding regularization
# This regularization term helps out term to converge into binary solution either 0 or 1 at the end of optimization.
inner_term = torch.add(2 * h_alpha, -1).abs().pow(beta)
regularization_term = torch.add(1, -inner_term).sum()
return regularization_term * self.reg_param
def reconstruction_loss(
self,
soft_quantized_output: torch.Tensor,
original_output: torch.Tensor,
) -> torch.Tensor:
"""
Compute the reconstruction loss between the soft quantized output and the original output.
"""
return F.mse_loss(
soft_quantized_output, original_output, reduction="none"
).mean()
def forward(
self,
soft_quantized_output: torch.Tensor,
original_output: torch.Tensor,
V: torch.Tensor,
curr_iter: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute the asymmetric reconstruction formulation as eq [25]
"""
regularization_term = self.rounding_regularization(V, curr_iter)
reconstruction_term = self.reconstruction_loss(
soft_quantized_output, original_output
)
return regularization_term, reconstruction_term

View File

@ -0,0 +1,238 @@
import copy
import logging
from typing import Any, Callable, List, Optional, Tuple, Type, Union
import torch
from torch.ao.quantization.experimental.adaround_fake_quantize import (
AdaroundFakeQuantizer,
)
from torch.ao.quantization.experimental.adaround_loss import AdaptiveRoundingLoss
from torch.ao.quantization.observer import MinMaxObserver
from torch.nn import functional as F
from torch.nn.parallel import DataParallel
from torch.utils.data import DataLoader, TensorDataset
logger: logging.Logger = logging.getLogger(__name__)
class AdaptiveRoundingOptimizer:
def __init__(
self,
model: Union[torch.nn.Module, torch.nn.DataParallel],
callback: Callable[[torch.nn.Module, List[Any]], None],
forward_hook_wrapper: Callable[[List[torch.Tensor]], Callable],
data: List[Any],
observer: Type[torch.ao.quantization.observer.ObserverBase] = MinMaxObserver,
max_iter=10000,
dtype: torch.dtype = torch.qint8,
quant_min=-128,
quant_max=127,
qscheme: torch.qscheme = torch.per_tensor_symmetric,
batch_size: int = 256,
):
self.model = model
self.q_model = copy.deepcopy(self.model)
self.device = torch.device("cuda") if torch.cuda.is_available() else None
self.callback = callback
self.forward_hook_wrapper = forward_hook_wrapper
# TODO rather than having a data as list type or, we better pass *iterator* instead of list
self.data = data
self.batch_size = min(batch_size, len(data))
self.max_iter = max_iter
self.adaptive_round_loss_fn = AdaptiveRoundingLoss(
max_iter=self.max_iter, warm_start=0.2
)
self.dtype = dtype
self.observer = observer
self.quant_min = quant_min
self.quant_max = quant_max
self.qscheme = qscheme
def run_adaround(self) -> torch.nn.Module:
layer_list: List[Tuple[str, torch.nn.Module, torch.nn.Module]] = []
for (name, module), q_module in zip(
self.model.named_modules(), self.q_model.modules()
):
if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear)):
# Knowing activation ahead-of-time would be helpful for asymmetric formulation
# But this is challenging in eager mode, but graph module.
layer_list.append((name, module, q_module))
logger.info(f"Total number of layers : {len(layer_list)}") # noqa: G004
for name, module, q_module in layer_list:
logger.info(
f"Kick start adaptive rounding on {name} module {module}" # noqa: G004
)
self.optimize_adaptive_rounding(
module,
q_module,
None,
)
return (
self.q_model.module
if isinstance(self.q_model, DataParallel)
else self.q_model
)
def get_data_inp_out(
self, module: torch.nn.Module, q_module: torch.nn.Module, data: List[Any]
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
fp_out: List[torch.Tensor] = []
q_input: List[torch.Tensor] = []
fp_input: List[torch.Tensor] = []
fp32_fetcher: List[torch.Tensor] = []
quant_fetcher: List[torch.Tensor] = []
handler1 = module.register_forward_hook(self.forward_hook_wrapper(fp32_fetcher))
handler2 = q_module.register_forward_hook(
self.forward_hook_wrapper(quant_fetcher)
)
for data_ in data:
with torch.no_grad():
self.callback(self.model, data_)
self.callback(self.q_model, data_)
fp32_output = fp32_fetcher[1]
quant_input = quant_fetcher[0]
fp_out.append(fp32_output)
q_input.append(quant_input)
fp_input.append(fp32_fetcher[0])
handler1.remove()
handler2.remove()
return q_input, fp_out, fp_input
@torch.no_grad()
def feed_forward(self, x, weight, module):
if isinstance(module, torch.nn.Conv1d):
out = torch.nn.functional.conv1d(
x,
weight,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
elif isinstance(module, torch.nn.Linear):
out = torch.nn.functional.linear(
x,
weight,
bias=module.bias,
)
else:
raise NotImplementedError
return out
def _compute_and_display_local_losses(
self,
ada_quantizer: AdaroundFakeQuantizer,
q_module: torch.nn.Module,
q_inp: torch.Tensor,
fp_out: torch.Tensor,
):
with torch.no_grad():
ada_quantizer.use_soft_rounding = False
q_w_hard_round = ada_quantizer(q_module.weight)
out_hard_quant = self.feed_forward(q_inp, q_w_hard_round, q_module)
ada_quantizer.use_soft_rounding = True
q_w_soft_round = ada_quantizer(q_module.weight)
out_soft_quant = self.feed_forward(q_inp, q_w_soft_round, q_module)
soft_quant_loss = F.mse_loss(out_soft_quant, fp_out)
hard_quant_loss = F.mse_loss(out_hard_quant, fp_out)
logger.info(
f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}" # noqa: G004
)
def optimize_adaptive_rounding(
self,
module: torch.nn.Module,
q_module: torch.nn.Module,
activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
) -> None:
ada_quantizer = AdaroundFakeQuantizer(
dtype=self.dtype,
observer=self.observer,
qscheme=self.qscheme,
quant_min=self.quant_min,
quant_max=self.quant_max,
reduce_range=False,
)
ada_quantizer.enable_observer()
ada_quantizer(q_module.weight)
ada_quantizer.disable_observer()
ada_quantizer.enable_fake_quant()
optimizer = torch.optim.Adam([ada_quantizer.V])
inp, out, fp_in = self.get_data_inp_out(module, q_module, self.data)
logger.info("==================== Before adaround ====================")
test_in, test_out, fp_test_in = self.get_data_inp_out(
module, q_module, self.data[0]
)
assert (
torch.abs(test_out[0] - module(fp_test_in[0])).sum().item() == 0
), "In-placed activation is detected, please do not use activation in-placed"
# Stack the tensors in each list into a single tensor
# Assuming inp and out are your lists of tensors
inp_tensor = torch.vstack(inp)
out_tensor = torch.vstack(out)
dataset = TensorDataset(inp_tensor, out_tensor)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
self._compute_and_display_local_losses(
ada_quantizer, q_module, test_in[0], test_out[0]
)
global_idx = 0
one_iter = len(out) // self.batch_size
for iteration in range(self.max_iter // one_iter):
reconstruction_loss = regularization_loss = torch.tensor(0)
for q_inp, fp_out in dataloader:
optimizer.zero_grad()
q_weight = ada_quantizer(q_module.weight)
if isinstance(module, torch.nn.Conv1d):
q_out = torch.nn.functional.conv1d(
q_inp,
q_weight,
stride=q_module.stride,
padding=q_module.padding,
dilation=q_module.dilation,
groups=q_module.groups,
)
elif isinstance(q_module, torch.nn.Linear):
q_out = torch.nn.functional.linear(
q_inp,
q_weight,
bias=q_module.bias,
)
else:
raise NotImplementedError
regularization_loss, reconstruction_loss = self.adaptive_round_loss_fn(
fp_out,
q_out,
ada_quantizer.V,
curr_iter=global_idx,
)
loss = regularization_loss + reconstruction_loss
loss.backward()
optimizer.step()
global_idx += 1
if global_idx >= self.max_iter:
break
if global_idx >= self.max_iter:
break
if iteration % 30 == 0:
logger.info(
f"glob iter {global_idx} regularization_loss {regularization_loss.item()} " # noqa: G004
f"reconstruction_loss {reconstruction_loss.item()}" # noqa: G004
)
logger.info("==================== After adaround ====================")
self._compute_and_display_local_losses(
ada_quantizer, q_module, test_in[0], test_out[0]
)
ada_quantizer.use_soft_rounding = True
ada_quantizer.V.requires_grad = False
ada_quantizer = ada_quantizer.eval()
q_weight = ada_quantizer(q_module.weight)
# At the end of optimization, we need to copy the adarounded weight back to the original module
q_module.weight.data.copy_(q_weight)
# Eager mode requires observer to be set as "weight_fake_quant" to be parsed
q_module.weight_fake_quant = ada_quantizer.activation_post_process