diff --git a/test/quantization/core/experimental/test_adaround_eager.py b/test/quantization/core/experimental/test_adaround_eager.py new file mode 100644 index 00000000000..33a16f21bd0 --- /dev/null +++ b/test/quantization/core/experimental/test_adaround_eager.py @@ -0,0 +1,118 @@ +# Owner(s): ["oncall: speech_infra"] + +import copy + +import torch +import torch.nn as nn +from torch.ao.quantization.experimental.adaround_optimization import ( + AdaptiveRoundingOptimizer, +) + +from torch.nn import functional as F +from torch.quantization.observer import MinMaxObserver +from torch.testing._internal.common_quantization import QuantizationTestCase + + +def forward_wrapper(fetcher): + def forward(module, input, output): + fetcher.append(input[0].detach()) + fetcher.append(output.detach()) + + return forward + + +class TestAdaround(QuantizationTestCase): + def feedforawrd_callback( + self, + model, + data, + ) -> None: + model(data) + + def run_adaround(self, model, img_data): + adaround_optimizer = AdaptiveRoundingOptimizer( + model, + self.feedforawrd_callback, + forward_wrapper, + img_data, + max_iter=100, + batch_size=10, + ) + adarounded_model = adaround_optimizer.run_adaround() + return adarounded_model + + def get_fake_quant(self, model): + hard_fake_quant_model = copy.deepcopy(model) + for _, module in hard_fake_quant_model.named_modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)): + weight_observer = MinMaxObserver( + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=torch.per_tensor_symmetric, + ) + weight_observer(module.weight) + scale, zero_point = weight_observer.calculate_qparams() + fake_quant_module = torch.fake_quantize_per_tensor_affine( + module.weight, + scale=scale, + zero_point=zero_point, + quant_min=-128, + quant_max=127, + ) + module.weight.data.copy_(fake_quant_module) + return hard_fake_quant_model + + def test_linear_chain(self): + class LinearChain(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(3, 4) + self.linear2 = nn.Linear(4, 5) + self.linear3 = nn.Linear(5, 6) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + return x + + float_model = LinearChain() + img_data = [torch.rand(10, 3, dtype=torch.float) for _ in range(50)] + adarounded_model = self.run_adaround(float_model, img_data) + fq_model = self.get_fake_quant(float_model) + rand_input = torch.rand(10, 3) + with torch.no_grad(): + ada_out = adarounded_model(rand_input) + fq_out = fq_model(rand_input) + float_out = float_model(rand_input) + ada_loss = F.mse_loss(ada_out, float_out) + fq_loss = F.mse_loss(fq_out, float_out) + self.assertTrue(ada_loss.item() < fq_loss.item()) + + def test_conv_chain(self): + class ConvChain(nn.Module): + def __init__(self): + super().__init__() + self.conv2d1 = nn.Conv2d(3, 4, 5, 5) + self.conv2d2 = nn.Conv2d(4, 5, 5, 5) + self.conv2d3 = nn.Conv2d(5, 6, 5, 5) + + def forward(self, x): + x = self.conv2d1(x) + x = self.conv2d2(x) + x = self.conv2d3(x) + return x + + float_model = ConvChain() + img_data = [torch.rand(10, 3, 125, 125, dtype=torch.float) for _ in range(50)] + adarounded_model = self.run_adaround(float_model, img_data) + fq_model = self.get_fake_quant(float_model) + rand_input = torch.rand(10, 3, 256, 256) + with torch.no_grad(): + ada_out = adarounded_model(rand_input) + fq_out = fq_model(rand_input) + float_out = float_model(rand_input) + ada_loss = F.mse_loss(ada_out, float_out) + fq_loss = F.mse_loss(fq_out, float_out) + self.assertTrue(ada_loss.item() < fq_loss.item()) diff --git a/torch/ao/quantization/experimental/adaround_fake_quantize.py b/torch/ao/quantization/experimental/adaround_fake_quantize.py new file mode 100644 index 00000000000..4d988bbb25b --- /dev/null +++ b/torch/ao/quantization/experimental/adaround_fake_quantize.py @@ -0,0 +1,148 @@ +from typing import Tuple + +import torch +from torch.ao.quantization.fake_quantize import _is_symmetric_quant +from torch.ao.quantization.utils import is_per_tensor +from torch.quantization import FakeQuantize +from torch.quantization.observer import MinMaxObserver + + +class AdaroundFakeQuantizer(FakeQuantize): + """ + This is a FakeQuantizer that enables an adaptive rounding fake quantizer. + Adaround is a technique to adaptively round weights, derived from the paper https://arxiv.org/pdf/2004.10568.pdf + For HTP compatibility, we are targeting to use symmetric quantization + """ + + scale: torch.Tensor + zero_point: torch.Tensor + V: torch.nn.Parameter + + # pyre-fixme[3]: Return type must be annotated. + def __init__( + self, + observer=MinMaxObserver, + qscheme=torch.per_tensor_symmetric, # not used, but needed for fakequant + quant_min: int = -128, + quant_max: int = 127, + ch_axis: int = 0, + # pyre-fixme[2]: Parameter must be annotated. + **observer_kwargs, + ): + super().__init__( + observer=observer, + qscheme=qscheme, + quant_min=quant_min, + quant_max=quant_max, + is_dynamic=False, + **observer_kwargs, + ) + # Populate quant_min/quant_max to observer_kwargs if valid + if quant_min is not None and quant_max is not None: + assert ( + quant_min <= quant_max + ), "quant_min must be less than or equal to quant_max" + # pyre-fixme[4]: Attribute must be annotated. + self.qscheme = qscheme + self.is_per_tensor: bool = is_per_tensor(qscheme) + self.is_symmetric: bool = _is_symmetric_quant(qscheme) + assert self.is_symmetric, "Only symmetric quantization is supported" + self.ch_axis: int = ch_axis + + self.scale = torch.tensor([], requires_grad=False) + self.zero_point = torch.tensor([], requires_grad=False) + self.V = torch.nn.Parameter(torch.tensor([]), requires_grad=True) + # Fixed Stretch parameters + self.zeta: torch.Tensor = torch.tensor(1.1, requires_grad=False) + self.gamma: torch.Tensor = torch.tensor(-0.1, requires_grad=False) + self.sigmoid = torch.nn.Sigmoid() + self.use_soft_rounding = True + + @torch.jit.export + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self.scale, self.zero_point + + @torch.jit.export + def extra_repr(self) -> str: + return ( + f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, " + f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, " + f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, " + f"scale={self.scale}, zero_point={self.zero_point}, (self.V >= 0).int().sum()={(self.V >= 0).int().sum()}" + ) + + def enable_weight_fake_quant(self) -> None: + self.fake_quant_enabled[0] = 1 + + def get_rectified_sigmoid_func(self) -> torch.Tensor: + if self.use_soft_rounding: + return torch.clamp( + self.sigmoid(self.V) * (self.zeta - self.gamma) + self.gamma, + min=0, + max=1, + ) + else: + # This will dump a binary solution + return (self.V >= 0).int() + + @torch.jit.ignore + def update_scale( + self, X: torch.Tensor, _scale: torch.Tensor, _zero_point: torch.Tensor + ) -> None: + if self.scale.numel() == 0: + self.scale.data = _scale.to(X.device) + self.zero_point = _zero_point.to(X.device) + else: + self.scale.data = _scale + if not self.is_symmetric: + self.zero_point = _zero_point + else: + self.zero_point = torch.zeros_like(_zero_point) + for i in range(X.dim()): + if i == self.ch_axis: + continue + self.zero_point = self.zero_point.unsqueeze(i) + X_q = X / self.scale + X_q_floor = torch.floor(X_q) + residual = X_q - X_q_floor # [0,1) + assert torch.all( + torch.ge(residual, 0) + ), "residual should be non-negative [0, 1)" + V_init = -torch.log((self.zeta - self.gamma) / (residual - self.gamma) - 1) + self.V.data = V_init + + def forward(self, X: torch.Tensor) -> torch.Tensor: + if self.observer_enabled[0] == 1: + X_detached = X.detach() + self.activation_post_process(X_detached) + _scale, _zero_point = self.activation_post_process.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( + self.zero_point.device + ) + dims = list(range(X.dim())) + if not self.is_per_tensor: + dims.remove(self.ch_axis) + if not self.is_per_tensor: + for i in range(X.dim()): + if i == self.ch_axis: + continue + _scale = _scale.unsqueeze(i) + _zero_point = _zero_point.unsqueeze(i) + self.update_scale(X_detached, _scale, _zero_point) + + if self.fake_quant_enabled[0] == 1: + # Perform soft quantization + # See the equation (23) in Adaround paper + h_v = self.get_rectified_sigmoid_func() + X_q = X / self.scale + # Straight-Through Estimator for floor function + X_q_floor = torch.floor(X_q) + self.zero_point + # Regardless of rounding, gradient should be able to flow back to self.V from X_q_dq. + # With adaround, we don't train weight, but train V only. + X_q_dq = ( + torch.clamp(X_q_floor + h_v, min=self.quant_min, max=self.quant_max) + - self.zero_point + ) * self.scale + return X_q_dq + else: + return X diff --git a/torch/ao/quantization/experimental/adaround_loss.py b/torch/ao/quantization/experimental/adaround_loss.py new file mode 100644 index 00000000000..8080d72cc6d --- /dev/null +++ b/torch/ao/quantization/experimental/adaround_loss.py @@ -0,0 +1,96 @@ +from typing import Tuple + +import numpy as np +import torch +from torch.nn import functional as F + +ADAROUND_ZETA: float = 1.1 +ADAROUND_GAMMA: float = -0.1 + + +class AdaptiveRoundingLoss(torch.nn.Module): + """ + Adaptive Rounding Loss functions described in https://arxiv.org/pdf/2004.10568.pdf + rounding regularization is eq [24] + reconstruction loss is eq [25] except regularization term + """ + + def __init__( + self, + max_iter: int, + warm_start: float = 0.2, + beta_range: Tuple[int, int] = (20, 2), + reg_param: float = 0.001, + ) -> None: + super().__init__() + self.max_iter = max_iter + self.warm_start = warm_start + self.beta_range = beta_range + self.reg_param = reg_param + + def rounding_regularization( + self, + V: torch.Tensor, + curr_iter: int, + ) -> torch.Tensor: + """ + Major logics copied from official Adaround Implementation. + Apply rounding regularization to the input tensor V. + """ + assert ( + curr_iter < self.max_iter + ), "Current iteration strictly les sthan max iteration" + if curr_iter < self.warm_start * self.max_iter: + return torch.tensor(0.0) + else: + start_beta, end_beta = self.beta_range + warm_start_end_iter = self.warm_start * self.max_iter + + # compute relative iteration of current iteration + rel_iter = (curr_iter - warm_start_end_iter) / ( + self.max_iter - warm_start_end_iter + ) + beta = end_beta + 0.5 * (start_beta - end_beta) * ( + 1 + np.cos(rel_iter * np.pi) + ) + + # A rectified sigmoid for soft-quantization as formualted [23] in https://arxiv.org/pdf/2004.10568.pdf + h_alpha = torch.clamp( + torch.sigmoid(V) * (ADAROUND_ZETA - ADAROUND_GAMMA) + ADAROUND_GAMMA, + min=0, + max=1, + ) + + # Apply rounding regularization + # This regularization term helps out term to converge into binary solution either 0 or 1 at the end of optimization. + inner_term = torch.add(2 * h_alpha, -1).abs().pow(beta) + regularization_term = torch.add(1, -inner_term).sum() + return regularization_term * self.reg_param + + def reconstruction_loss( + self, + soft_quantized_output: torch.Tensor, + original_output: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the reconstruction loss between the soft quantized output and the original output. + """ + return F.mse_loss( + soft_quantized_output, original_output, reduction="none" + ).mean() + + def forward( + self, + soft_quantized_output: torch.Tensor, + original_output: torch.Tensor, + V: torch.Tensor, + curr_iter: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the asymmetric reconstruction formulation as eq [25] + """ + regularization_term = self.rounding_regularization(V, curr_iter) + reconstruction_term = self.reconstruction_loss( + soft_quantized_output, original_output + ) + return regularization_term, reconstruction_term diff --git a/torch/ao/quantization/experimental/adaround_optimization.py b/torch/ao/quantization/experimental/adaround_optimization.py new file mode 100644 index 00000000000..7304f885a6f --- /dev/null +++ b/torch/ao/quantization/experimental/adaround_optimization.py @@ -0,0 +1,238 @@ +import copy +import logging +from typing import Any, Callable, List, Optional, Tuple, Type, Union + +import torch +from torch.ao.quantization.experimental.adaround_fake_quantize import ( + AdaroundFakeQuantizer, +) +from torch.ao.quantization.experimental.adaround_loss import AdaptiveRoundingLoss +from torch.ao.quantization.observer import MinMaxObserver +from torch.nn import functional as F +from torch.nn.parallel import DataParallel +from torch.utils.data import DataLoader, TensorDataset + +logger: logging.Logger = logging.getLogger(__name__) + + +class AdaptiveRoundingOptimizer: + def __init__( + self, + model: Union[torch.nn.Module, torch.nn.DataParallel], + callback: Callable[[torch.nn.Module, List[Any]], None], + forward_hook_wrapper: Callable[[List[torch.Tensor]], Callable], + data: List[Any], + observer: Type[torch.ao.quantization.observer.ObserverBase] = MinMaxObserver, + max_iter=10000, + dtype: torch.dtype = torch.qint8, + quant_min=-128, + quant_max=127, + qscheme: torch.qscheme = torch.per_tensor_symmetric, + batch_size: int = 256, + ): + self.model = model + self.q_model = copy.deepcopy(self.model) + self.device = torch.device("cuda") if torch.cuda.is_available() else None + self.callback = callback + self.forward_hook_wrapper = forward_hook_wrapper + # TODO rather than having a data as list type or, we better pass *iterator* instead of list + self.data = data + self.batch_size = min(batch_size, len(data)) + self.max_iter = max_iter + self.adaptive_round_loss_fn = AdaptiveRoundingLoss( + max_iter=self.max_iter, warm_start=0.2 + ) + self.dtype = dtype + self.observer = observer + self.quant_min = quant_min + self.quant_max = quant_max + self.qscheme = qscheme + + def run_adaround(self) -> torch.nn.Module: + layer_list: List[Tuple[str, torch.nn.Module, torch.nn.Module]] = [] + for (name, module), q_module in zip( + self.model.named_modules(), self.q_model.modules() + ): + if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear)): + # Knowing activation ahead-of-time would be helpful for asymmetric formulation + # But this is challenging in eager mode, but graph module. + layer_list.append((name, module, q_module)) + logger.info(f"Total number of layers : {len(layer_list)}") # noqa: G004 + + for name, module, q_module in layer_list: + logger.info( + f"Kick start adaptive rounding on {name} module {module}" # noqa: G004 + ) + self.optimize_adaptive_rounding( + module, + q_module, + None, + ) + + return ( + self.q_model.module + if isinstance(self.q_model, DataParallel) + else self.q_model + ) + + def get_data_inp_out( + self, module: torch.nn.Module, q_module: torch.nn.Module, data: List[Any] + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + fp_out: List[torch.Tensor] = [] + q_input: List[torch.Tensor] = [] + fp_input: List[torch.Tensor] = [] + fp32_fetcher: List[torch.Tensor] = [] + quant_fetcher: List[torch.Tensor] = [] + handler1 = module.register_forward_hook(self.forward_hook_wrapper(fp32_fetcher)) + handler2 = q_module.register_forward_hook( + self.forward_hook_wrapper(quant_fetcher) + ) + for data_ in data: + with torch.no_grad(): + self.callback(self.model, data_) + self.callback(self.q_model, data_) + fp32_output = fp32_fetcher[1] + quant_input = quant_fetcher[0] + fp_out.append(fp32_output) + q_input.append(quant_input) + fp_input.append(fp32_fetcher[0]) + handler1.remove() + handler2.remove() + return q_input, fp_out, fp_input + + @torch.no_grad() + def feed_forward(self, x, weight, module): + if isinstance(module, torch.nn.Conv1d): + out = torch.nn.functional.conv1d( + x, + weight, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + elif isinstance(module, torch.nn.Linear): + out = torch.nn.functional.linear( + x, + weight, + bias=module.bias, + ) + else: + raise NotImplementedError + return out + + def _compute_and_display_local_losses( + self, + ada_quantizer: AdaroundFakeQuantizer, + q_module: torch.nn.Module, + q_inp: torch.Tensor, + fp_out: torch.Tensor, + ): + with torch.no_grad(): + ada_quantizer.use_soft_rounding = False + q_w_hard_round = ada_quantizer(q_module.weight) + out_hard_quant = self.feed_forward(q_inp, q_w_hard_round, q_module) + ada_quantizer.use_soft_rounding = True + q_w_soft_round = ada_quantizer(q_module.weight) + out_soft_quant = self.feed_forward(q_inp, q_w_soft_round, q_module) + soft_quant_loss = F.mse_loss(out_soft_quant, fp_out) + hard_quant_loss = F.mse_loss(out_hard_quant, fp_out) + logger.info( + f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}" # noqa: G004 + ) + + def optimize_adaptive_rounding( + self, + module: torch.nn.Module, + q_module: torch.nn.Module, + activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + ) -> None: + ada_quantizer = AdaroundFakeQuantizer( + dtype=self.dtype, + observer=self.observer, + qscheme=self.qscheme, + quant_min=self.quant_min, + quant_max=self.quant_max, + reduce_range=False, + ) + ada_quantizer.enable_observer() + ada_quantizer(q_module.weight) + ada_quantizer.disable_observer() + ada_quantizer.enable_fake_quant() + optimizer = torch.optim.Adam([ada_quantizer.V]) + inp, out, fp_in = self.get_data_inp_out(module, q_module, self.data) + + logger.info("==================== Before adaround ====================") + test_in, test_out, fp_test_in = self.get_data_inp_out( + module, q_module, self.data[0] + ) + + assert ( + torch.abs(test_out[0] - module(fp_test_in[0])).sum().item() == 0 + ), "In-placed activation is detected, please do not use activation in-placed" + # Stack the tensors in each list into a single tensor + # Assuming inp and out are your lists of tensors + inp_tensor = torch.vstack(inp) + out_tensor = torch.vstack(out) + dataset = TensorDataset(inp_tensor, out_tensor) + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) + + self._compute_and_display_local_losses( + ada_quantizer, q_module, test_in[0], test_out[0] + ) + global_idx = 0 + one_iter = len(out) // self.batch_size + for iteration in range(self.max_iter // one_iter): + reconstruction_loss = regularization_loss = torch.tensor(0) + for q_inp, fp_out in dataloader: + optimizer.zero_grad() + q_weight = ada_quantizer(q_module.weight) + if isinstance(module, torch.nn.Conv1d): + q_out = torch.nn.functional.conv1d( + q_inp, + q_weight, + stride=q_module.stride, + padding=q_module.padding, + dilation=q_module.dilation, + groups=q_module.groups, + ) + elif isinstance(q_module, torch.nn.Linear): + q_out = torch.nn.functional.linear( + q_inp, + q_weight, + bias=q_module.bias, + ) + else: + raise NotImplementedError + regularization_loss, reconstruction_loss = self.adaptive_round_loss_fn( + fp_out, + q_out, + ada_quantizer.V, + curr_iter=global_idx, + ) + loss = regularization_loss + reconstruction_loss + loss.backward() + optimizer.step() + global_idx += 1 + if global_idx >= self.max_iter: + break + if global_idx >= self.max_iter: + break + if iteration % 30 == 0: + logger.info( + f"glob iter {global_idx} regularization_loss {regularization_loss.item()} " # noqa: G004 + f"reconstruction_loss {reconstruction_loss.item()}" # noqa: G004 + ) + logger.info("==================== After adaround ====================") + self._compute_and_display_local_losses( + ada_quantizer, q_module, test_in[0], test_out[0] + ) + + ada_quantizer.use_soft_rounding = True + ada_quantizer.V.requires_grad = False + ada_quantizer = ada_quantizer.eval() + q_weight = ada_quantizer(q_module.weight) + # At the end of optimization, we need to copy the adarounded weight back to the original module + q_module.weight.data.copy_(q_weight) + # Eager mode requires observer to be set as "weight_fake_quant" to be parsed + q_module.weight_fake_quant = ada_quantizer.activation_post_process