mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the `project-excludes` field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: 0 errors (4,263 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748 Approved by: https://github.com/oulgen
257 lines
10 KiB
Python
257 lines
10 KiB
Python
# mypy: allow-untyped-defs
|
|
import copy
|
|
from collections.abc import Callable
|
|
from typing import Any, Optional, Union
|
|
|
|
import torch
|
|
from torch.ao.quantization.experimental.adaround_fake_quantize import (
|
|
AdaroundFakeQuantizer,
|
|
)
|
|
from torch.ao.quantization.experimental.adaround_loss import AdaptiveRoundingLoss
|
|
from torch.ao.quantization.observer import MinMaxObserver
|
|
from torch.nn import functional as F
|
|
from torch.nn.parallel import DataParallel
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
|
|
|
|
class AdaptiveRoundingOptimizer:
|
|
def __init__(
|
|
self,
|
|
model: Union[torch.nn.Module, torch.nn.DataParallel],
|
|
callback: Callable[
|
|
[
|
|
Union[torch.nn.Module, torch.nn.DataParallel],
|
|
Any,
|
|
Optional[torch.nn.Module],
|
|
],
|
|
None,
|
|
],
|
|
forward_hook_wrapper: Callable[[list[torch.Tensor]], Callable],
|
|
data: Any,
|
|
observer: type[torch.ao.quantization.observer.ObserverBase] = MinMaxObserver,
|
|
max_iter=10000,
|
|
dtype: torch.dtype = torch.qint8,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
qscheme: torch.qscheme = torch.per_tensor_symmetric,
|
|
batch_size: int = 256,
|
|
feed_forward_wrapper: Optional[torch.nn.Module] = None,
|
|
):
|
|
if torch.cuda.is_available():
|
|
self.model = model.cuda()
|
|
if torch.cuda.device_count() > 1:
|
|
self.model = torch.nn.DataParallel(model)
|
|
else:
|
|
self.model = model
|
|
self.q_model = copy.deepcopy(self.model)
|
|
self.device = torch.device("cuda") if torch.cuda.is_available() else None
|
|
self.callback = callback
|
|
self.forward_hook_wrapper = forward_hook_wrapper
|
|
# TODO rather than having a data as list type or, we better pass *iterator* instead of list
|
|
self.data = data
|
|
self.batch_size = min(batch_size, len(data))
|
|
self.max_iter = max_iter
|
|
self.adaptive_round_loss_fn = AdaptiveRoundingLoss(
|
|
max_iter=self.max_iter, warm_start=0.2
|
|
)
|
|
self.dtype = dtype
|
|
self.observer = observer
|
|
self.quant_min = quant_min
|
|
self.quant_max = quant_max
|
|
self.qscheme = qscheme
|
|
self.feed_forward_wrapper = feed_forward_wrapper
|
|
|
|
def run_adaround(self) -> torch.nn.Module:
|
|
layer_list: list[tuple[str, torch.nn.Module, torch.nn.Module]] = []
|
|
for (name, module), q_module in zip(
|
|
self.model.named_modules(), self.q_model.modules()
|
|
):
|
|
if isinstance(module, torch.nn.ReLU):
|
|
# Disable all inplace operations
|
|
module.inplace = False
|
|
if isinstance(q_module, torch.nn.ReLU):
|
|
# Disable all inplace operations
|
|
q_module.inplace = False
|
|
if isinstance(module, (torch.nn.Conv1d, torch.nn.Linear)):
|
|
# Knowing activation ahead-of-time would be helpful for asymmetric formulation
|
|
# But this is challenging in eager mode, but graph module.
|
|
layer_list.append((name, module, q_module))
|
|
print(f"Total number of layers : {len(layer_list)}") # noqa: G004
|
|
|
|
for name, module, q_module in layer_list:
|
|
print(
|
|
f"Kick start adaptive rounding on {name} module {module}" # noqa: G004
|
|
)
|
|
self.optimize_adaptive_rounding(
|
|
module,
|
|
q_module,
|
|
None,
|
|
)
|
|
|
|
return (
|
|
self.q_model.module
|
|
if isinstance(self.q_model, DataParallel)
|
|
else self.q_model
|
|
)
|
|
|
|
def get_data_inp_out(
|
|
self, module: torch.nn.Module, q_module: torch.nn.Module, data: list[Any]
|
|
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
|
|
fp_out: list[torch.Tensor] = []
|
|
q_input: list[torch.Tensor] = []
|
|
fp_input: list[torch.Tensor] = []
|
|
fp32_fetcher: list[torch.Tensor] = []
|
|
quant_fetcher: list[torch.Tensor] = []
|
|
handler1 = module.register_forward_hook(self.forward_hook_wrapper(fp32_fetcher))
|
|
handler2 = q_module.register_forward_hook(
|
|
self.forward_hook_wrapper(quant_fetcher)
|
|
)
|
|
if torch.cuda.is_available():
|
|
# Somehow, we need to move the model continuously
|
|
# Otherwise, the model will be lowered to CPU mysteriously
|
|
self.model = self.model.cuda()
|
|
self.q_model = self.q_model.cuda()
|
|
for data_ in data:
|
|
with torch.no_grad():
|
|
self.callback(self.model, data_, self.feed_forward_wrapper)
|
|
self.callback(self.q_model, data_, self.feed_forward_wrapper)
|
|
fp32_output = fp32_fetcher[1]
|
|
quant_input = quant_fetcher[0]
|
|
fp_out.append(fp32_output)
|
|
q_input.append(quant_input)
|
|
fp_input.append(fp32_fetcher[0])
|
|
handler1.remove()
|
|
handler2.remove()
|
|
return q_input, fp_out, fp_input
|
|
|
|
@torch.no_grad()
|
|
def feed_forward(self, x, weight, module):
|
|
if isinstance(module, torch.nn.Conv1d):
|
|
# pyrefly: ignore # no-matching-overload
|
|
out = torch.nn.functional.conv1d(
|
|
x,
|
|
weight,
|
|
stride=module.stride,
|
|
padding=module.padding,
|
|
dilation=module.dilation,
|
|
groups=module.groups,
|
|
)
|
|
elif isinstance(module, torch.nn.Linear):
|
|
out = torch.nn.functional.linear(
|
|
x,
|
|
weight,
|
|
bias=module.bias,
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
return out
|
|
|
|
def _compute_and_display_local_losses(
|
|
self,
|
|
ada_quantizer: AdaroundFakeQuantizer,
|
|
q_module: torch.nn.Module,
|
|
q_inp: torch.Tensor,
|
|
fp_out: torch.Tensor,
|
|
):
|
|
with torch.no_grad():
|
|
ada_quantizer.use_soft_rounding = False
|
|
q_w_hard_round = ada_quantizer(q_module.weight)
|
|
out_hard_quant = self.feed_forward(q_inp, q_w_hard_round, q_module)
|
|
ada_quantizer.use_soft_rounding = True
|
|
q_w_soft_round = ada_quantizer(q_module.weight)
|
|
out_soft_quant = self.feed_forward(q_inp, q_w_soft_round, q_module)
|
|
soft_quant_loss = F.mse_loss(out_soft_quant, fp_out)
|
|
hard_quant_loss = F.mse_loss(out_hard_quant, fp_out)
|
|
print(
|
|
f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}" # noqa: G004
|
|
)
|
|
|
|
def optimize_adaptive_rounding(
|
|
self,
|
|
module: torch.nn.Module,
|
|
q_module: torch.nn.Module,
|
|
activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
|
) -> None:
|
|
ada_quantizer = AdaroundFakeQuantizer(
|
|
dtype=self.dtype,
|
|
observer=self.observer,
|
|
qscheme=self.qscheme,
|
|
quant_min=self.quant_min,
|
|
quant_max=self.quant_max,
|
|
reduce_range=False,
|
|
)
|
|
ada_quantizer.enable_observer()
|
|
ada_quantizer(q_module.weight)
|
|
ada_quantizer.disable_observer()
|
|
ada_quantizer.enable_fake_quant()
|
|
optimizer = torch.optim.Adam([ada_quantizer.V])
|
|
inp, out, fp_in = self.get_data_inp_out(module, q_module, self.data)
|
|
|
|
print("==================== Before adaround ====================")
|
|
assert torch.abs(out[0] - module(fp_in[0])).sum().item() == 0, (
|
|
"In-placed activation is detected, please do not use activation in-placed"
|
|
)
|
|
# Stack the tensors in each list into a single tensor
|
|
# Assuming inp and out are your lists of tensors
|
|
inp_tensor = torch.vstack(inp)
|
|
out_tensor = torch.vstack(out)
|
|
dataset = TensorDataset(inp_tensor, out_tensor)
|
|
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
|
|
|
|
self._compute_and_display_local_losses(ada_quantizer, q_module, inp[0], out[0])
|
|
global_idx = 0
|
|
one_iter = len(out) // self.batch_size
|
|
for iteration in range(self.max_iter // one_iter):
|
|
reconstruction_loss = regularization_loss = torch.tensor(0)
|
|
for q_inp, fp_out in dataloader:
|
|
optimizer.zero_grad()
|
|
q_weight = ada_quantizer(q_module.weight)
|
|
if isinstance(module, torch.nn.Conv1d):
|
|
q_out = torch.nn.functional.conv1d( # type: ignore[call-overload, misc]
|
|
q_inp,
|
|
q_weight,
|
|
bias=q_module.bias,
|
|
stride=q_module.stride,
|
|
padding=q_module.padding,
|
|
dilation=q_module.dilation,
|
|
groups=q_module.groups,
|
|
)
|
|
elif isinstance(q_module, torch.nn.Linear):
|
|
q_out = torch.nn.functional.linear(
|
|
q_inp,
|
|
q_weight,
|
|
bias=q_module.bias,
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
regularization_loss, reconstruction_loss = self.adaptive_round_loss_fn(
|
|
fp_out,
|
|
q_out,
|
|
ada_quantizer.V,
|
|
curr_iter=global_idx,
|
|
)
|
|
loss = regularization_loss + reconstruction_loss
|
|
loss.backward()
|
|
optimizer.step()
|
|
global_idx += 1
|
|
if global_idx >= self.max_iter:
|
|
break
|
|
if global_idx >= self.max_iter:
|
|
break
|
|
if iteration % 30 == 0:
|
|
print(
|
|
f"glob iter {global_idx} regularization_loss {regularization_loss.item()} " # noqa: G004
|
|
f"reconstruction_loss {reconstruction_loss.item()}" # noqa: G004
|
|
)
|
|
print("==================== After adaround ====================")
|
|
self._compute_and_display_local_losses(ada_quantizer, q_module, inp[0], out[0])
|
|
|
|
ada_quantizer.use_soft_rounding = True
|
|
ada_quantizer.V.requires_grad = False
|
|
ada_quantizer = ada_quantizer.eval()
|
|
q_weight = ada_quantizer(q_module.weight)
|
|
# At the end of optimization, we need to copy the adarounded weight back to the original module
|
|
q_module.weight.data.copy_(q_weight) # type: ignore[operator]
|
|
# Eager mode requires observer to be set as "weight_fake_quant" to be parsed
|
|
q_module.weight_fake_quant = ada_quantizer.activation_post_process
|