mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61318 Remove the `float()` and `int()` calls in the forward function so that we can directly use the tensor qparams in the fake_quantize operator. Calling `float()/int()` internally calls `item()` which can trigger a gpu-> cpu copy if the original tensors reside on GPU. Local benchmark P427668213 Before this change ``` Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls --------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ aten::_aminmax 2.57% 1.507ms 3.10% 1.819ms 36.371us 2.872ms 4.81% 2.872ms 57.446us 50 aten::fake_quantize_per_tensor_affine 1.04% 610.915us 3.60% 2.114ms 42.276us 472.896us 0.79% 2.698ms 53.962us 50 aten::fake_quantize_per_tensor_affine_cachemask 1.69% 993.626us 2.56% 1.503ms 30.058us 2.225ms 3.73% 2.225ms 44.504us 50 aten::is_nonzero 3.85% 2.258ms 19.68% 11.540ms 46.161us 2.168ms 3.63% 11.084ms 44.336us 250 aten::zeros_like 1.82% 1.064ms 6.65% 3.901ms 39.007us 1.531ms 2.57% 3.905ms 39.045us 100 aten::eq 13.80% 8.093ms 25.90% 15.189ms 37.972us 9.580ms 16.05% 15.566ms 38.914us 400 aten::item 5.67% 3.323ms 21.50% 12.607ms 36.019us 3.233ms 5.42% 12.167ms 34.762us 350 aten::zeros 0.94% 549.208us 2.93% 1.717ms 34.343us 688.928us 1.15% 1.695ms 33.894us 50 aten::le 2.52% 1.478ms 4.50% 2.641ms 26.411us 1.753ms 2.94% 2.845ms 28.448us 100 aten::rsub 1.04% 608.715us 2.44% 1.433ms 28.667us 532.000us 0.89% 1.418ms 28.353us 50 aten::max 1.54% 905.401us 4.62% 2.711ms 27.106us 847.488us 1.42% 2.697ms 26.969us 100 aten::ones 0.92% 542.159us 2.16% 1.266ms 25.324us 661.856us 1.11% 1.301ms 26.017us 50 aten::min 0.82% 479.167us 2.15% 1.258ms 25.160us 407.808us 0.68% 1.276ms 25.530us 50 aten::_local_scalar_dense 15.83% 9.284ms 15.83% 9.284ms 26.526us 8.934ms 14.97% 8.934ms 25.524us 350 aten::clamp 2.35% 1.378ms 4.21% 2.467ms 24.669us 1.546ms 2.59% 2.461ms 24.612us 100 aten::zero_ 2.53% 1.482ms 5.65% 3.316ms 22.108us 1.326ms 2.22% 3.380ms 22.531us 150 aten::maximum 3.08% 1.805ms 3.08% 1.805ms 18.052us 1.849ms 3.10% 1.849ms 18.494us 100 aten::minimum 1.33% 778.854us 1.33% 778.854us 15.577us 868.672us 1.46% 868.672us 17.373us 50 aten::round 1.36% 799.910us 1.36% 799.910us 15.998us 809.568us 1.36% 809.568us 16.191us 50 aten::copy_ 6.61% 3.878ms 6.61% 3.878ms 15.513us 4.036ms 6.76% 4.036ms 16.143us 250 aten::div 2.53% 1.483ms 2.53% 1.483ms 14.833us 1.535ms 2.57% 1.535ms 15.353us 100 aten::mul 2.44% 1.431ms 2.44% 1.431ms 14.314us 1.478ms 2.48% 1.478ms 14.782us 100 aten::detach 1.46% 855.670us 2.41% 1.411ms 14.110us 832.448us 1.39% 1.395ms 13.949us 100 aten::add 2.22% 1.301ms 2.22% 1.301ms 13.008us 1.383ms 2.32% 1.383ms 13.828us 100 aten::fill_ 4.18% 2.452ms 4.18% 2.452ms 12.262us 2.693ms 4.51% 2.693ms 13.463us 200 aten::sub 5.06% 2.967ms 5.06% 2.967ms 14.837us 2.675ms 4.48% 2.675ms 13.374us 200 aten::to 2.10% 1.230ms 3.65% 2.140ms 10.701us 1.310ms 2.20% 2.062ms 10.310us 200 aten::select 1.28% 749.144us 1.49% 874.227us 8.742us 863.232us 1.45% 863.232us 8.632us 100 detach 0.95% 555.326us 0.95% 555.326us 5.553us 562.496us 0.94% 562.496us 5.625us 100 aten::as_strided 0.40% 232.289us 0.40% 232.289us 1.161us 0.000us 0.00% 0.000us 0.000us 200 aten::empty 2.93% 1.720ms 2.93% 1.720ms 3.439us 0.000us 0.00% 0.000us 0.000us 500 aten::resize_ 1.04% 611.313us 1.04% 611.313us 2.038us 0.000us 0.00% 0.000us 0.000us 300 aten::empty_like 0.75% 438.585us 1.77% 1.036ms 5.180us 0.000us 0.00% 0.000us 0.000us 200 aten::empty_strided 1.36% 799.442us 1.36% 799.442us 3.198us 0.000us 0.00% 0.000us 0.000us 250 --------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 58.645ms Self CUDA time total: 59.674ms ``` After this change ``` test_fake_quant_profiler (scripts.supriyar.benchmark.module_bench.ProfilerBench) ... ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ aten::fake_quantize_per_tensor_affine 0.98% 505.210us 4.38% 2.259ms 45.187us 419.424us 0.78% 3.218ms 64.367us 50 aten::_aminmax 2.78% 1.434ms 3.42% 1.766ms 35.321us 2.825ms 5.27% 2.825ms 56.505us 50 aten::fake_quantize_per_tensor_affine_cachemask_tens... 2.38% 1.229ms 3.40% 1.754ms 35.083us 2.799ms 5.22% 2.799ms 55.979us 50 aten::rsub 0.94% 485.040us 5.02% 2.590ms 51.793us 458.976us 0.86% 2.587ms 51.747us 50 aten::is_nonzero 3.78% 1.952ms 23.64% 12.196ms 48.786us 2.055ms 3.83% 11.986ms 47.944us 250 aten::item 6.92% 3.572ms 19.86% 10.244ms 40.977us 3.670ms 6.85% 9.931ms 39.724us 250 aten::zeros_like 1.65% 848.874us 6.64% 3.426ms 34.260us 1.397ms 2.61% 3.572ms 35.717us 100 aten::zeros 0.85% 436.691us 3.00% 1.549ms 30.984us 551.936us 1.03% 1.576ms 31.516us 50 aten::eq 10.60% 5.467ms 20.26% 10.452ms 26.130us 7.018ms 13.09% 10.832ms 27.079us 400 aten::le 2.58% 1.332ms 4.67% 2.407ms 24.074us 1.580ms 2.95% 2.614ms 26.144us 100 aten::_local_scalar_dense 12.93% 6.673ms 12.93% 6.673ms 26.691us 6.261ms 11.68% 6.261ms 25.046us 250 aten::clamp 2.43% 1.253ms 4.37% 2.256ms 22.560us 1.431ms 2.67% 2.273ms 22.725us 100 aten::ones 0.89% 460.133us 2.18% 1.123ms 22.467us 570.496us 1.06% 1.128ms 22.551us 50 aten::min 0.74% 383.132us 2.06% 1.065ms 21.296us 377.536us 0.70% 1.091ms 21.824us 50 aten::zero_ 2.36% 1.219ms 5.87% 3.029ms 20.194us 1.261ms 2.35% 3.199ms 21.327us 150 aten::max 1.51% 779.081us 4.06% 2.096ms 20.960us 791.680us 1.48% 2.130ms 21.295us 100 aten::sub 7.97% 4.111ms 7.97% 4.111ms 20.556us 3.847ms 7.18% 3.847ms 19.234us 200 aten::div 2.94% 1.516ms 2.94% 1.516ms 15.158us 1.580ms 2.95% 1.580ms 15.798us 100 aten::round 1.45% 750.445us 1.45% 750.445us 15.009us 756.064us 1.41% 756.064us 15.121us 50 aten::copy_ 6.88% 3.548ms 6.88% 3.548ms 14.190us 3.701ms 6.90% 3.701ms 14.803us 250 aten::minimum 1.32% 681.654us 1.32% 681.654us 13.633us 713.664us 1.33% 713.664us 14.273us 50 aten::maximum 2.55% 1.317ms 2.55% 1.317ms 13.169us 1.338ms 2.50% 1.338ms 13.378us 100 aten::mul 2.63% 1.358ms 2.63% 1.358ms 13.581us 1.328ms 2.48% 1.328ms 13.283us 100 aten::detach 1.34% 688.820us 2.35% 1.211ms 12.110us 772.800us 1.44% 1.278ms 12.779us 100 aten::fill_ 4.53% 2.338ms 4.53% 2.338ms 11.692us 2.495ms 4.65% 2.495ms 12.473us 200 aten::add 2.32% 1.197ms 2.32% 1.197ms 11.968us 1.240ms 2.31% 1.240ms 12.405us 100 aten::to 2.07% 1.069ms 3.66% 1.889ms 9.443us 1.224ms 2.28% 1.975ms 9.874us 200 aten::select 1.44% 743.042us 1.64% 848.207us 8.482us 641.600us 1.20% 641.600us 6.416us 100 detach 1.01% 522.155us 1.01% 522.155us 5.222us 505.088us 0.94% 505.088us 5.051us 100 aten::as_strided 0.44% 227.884us 0.44% 227.884us 1.139us 0.000us 0.00% 0.000us 0.000us 200 aten::empty 3.20% 1.652ms 3.20% 1.652ms 3.304us 0.000us 0.00% 0.000us 0.000us 500 aten::resize_ 1.25% 646.711us 1.25% 646.711us 2.156us 0.000us 0.00% 0.000us 0.000us 300 aten::empty_like 0.79% 407.768us 2.07% 1.067ms 5.334us 0.000us 0.00% 0.000us 0.000us 200 aten::empty_strided 1.52% 785.788us 1.52% 785.788us 3.143us 0.000us 0.00% 0.000us 0.000us 250 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 51.590ms Self CUDA time total: 53.609ms ghstack-source-id: 133370215 Test Plan: buck test mode/dev-nosan caffe2/test/:quantization Reviewed By: raghuramank100 Differential Revision: D29566512 fbshipit-source-id: 1aefca51f99949da7334bcfe504848275c9f952c
299 lines
14 KiB
Python
299 lines
14 KiB
Python
import torch
|
|
from torch.nn import Module
|
|
from .observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, _with_args
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
|
|
def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
|
|
return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]
|
|
|
|
def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
|
|
return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
|
|
|
|
class FakeQuantizeBase(ABC, Module):
|
|
r""" Base fake quantize module
|
|
Any fake quantize implementation should derive from this class.
|
|
|
|
Concrete fake quantize module should follow the same API. In forward, they will update
|
|
the statistics of the observed Tensor and fake quantize the input. They should also provide a
|
|
`calculate_qparams` function that computes the quantization parameters given
|
|
the collected statistics.
|
|
|
|
"""
|
|
|
|
fake_quant_enabled: torch.Tensor
|
|
observer_enabled: torch.Tensor
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
# fake_quant_enabled and observer_enabled are buffers to support their
|
|
# replication in DDP. Data type is uint8 because NCCL does not support
|
|
# bool tensors.
|
|
self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
|
|
self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
|
|
|
|
@abstractmethod
|
|
def forward(self, x):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def calculate_qparams(self, **kwargs):
|
|
pass
|
|
|
|
@torch.jit.export
|
|
def enable_fake_quant(self, enabled: bool = True) -> None:
|
|
self.fake_quant_enabled[0] = 1 if enabled else 0
|
|
|
|
@torch.jit.export
|
|
def disable_fake_quant(self):
|
|
self.enable_fake_quant(False)
|
|
|
|
@torch.jit.export
|
|
def enable_observer(self, enabled: bool = True) -> None:
|
|
self.observer_enabled[0] = 1 if enabled else 0
|
|
|
|
@torch.jit.export
|
|
def disable_observer(self):
|
|
self.enable_observer(False)
|
|
|
|
with_args = classmethod(_with_args)
|
|
|
|
class FakeQuantize(FakeQuantizeBase):
|
|
r""" Simulate the quantize and dequantize operations in training time.
|
|
The output of this module is given by
|
|
|
|
x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
|
|
|
|
|
|
|
|
* :attr:`scale` defines the scale factor used for quantization.
|
|
|
|
* :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
|
|
|
|
* :attr:`quant_min` specifies the minimum allowable quantized value.
|
|
|
|
* :attr:`quant_max` specifies the maximum allowable quantized value.
|
|
|
|
* :attr:`fake_quant_enable` controls the application of fake quantization on tensors, note that
|
|
statistics can still be updated.
|
|
|
|
* :attr:`observer_enable` controls statistics collection on tensors
|
|
|
|
* :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
|
|
allowable values are torch.qint8 and torch.quint8. The values of quant_min and
|
|
quant_max should be chosen to be consistent with the dtype
|
|
|
|
|
|
Args:
|
|
observer (module): Module for observing statistics on input tensors and calculating scale
|
|
and zero-point.
|
|
quant_min (int): The minimum allowable quantized value.
|
|
quant_max (int): The maximum allowable quantized value.
|
|
observer_kwargs (optional): Arguments for the observer module
|
|
|
|
Attributes:
|
|
observer (Module): User provided module that collects statistics on the input tensor and
|
|
provides a method to calculate scale and zero-point.
|
|
|
|
"""
|
|
|
|
scale: torch.Tensor
|
|
zero_point: torch.Tensor
|
|
|
|
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
|
|
super().__init__()
|
|
assert quant_min <= quant_max, \
|
|
'quant_min must be less than or equal to quant_max'
|
|
self.quant_min = quant_min
|
|
self.quant_max = quant_max
|
|
self.activation_post_process = observer(**observer_kwargs)
|
|
assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound'
|
|
assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound'
|
|
self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
|
|
self.register_buffer('zero_point', torch.tensor([0], dtype=torch.int))
|
|
self.dtype = self.activation_post_process.dtype
|
|
self.qscheme = self.activation_post_process.qscheme
|
|
self.ch_axis = self.activation_post_process.ch_axis \
|
|
if hasattr(self.activation_post_process, 'ch_axis') else -1
|
|
assert _is_per_channel(self.qscheme) or \
|
|
_is_per_tensor(self.qscheme), \
|
|
'Only per channel and per tensor quantization are supported in fake quantize' + \
|
|
' got qscheme: ' + str(self.qscheme)
|
|
self.is_per_channel = _is_per_channel(self.qscheme)
|
|
|
|
@torch.jit.export
|
|
def calculate_qparams(self):
|
|
return self.activation_post_process.calculate_qparams()
|
|
|
|
def forward(self, X):
|
|
if self.observer_enabled[0] == 1:
|
|
self.activation_post_process(X.detach())
|
|
_scale, _zero_point = self.calculate_qparams()
|
|
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
|
|
if self.scale.shape != _scale.shape:
|
|
self.scale.resize_(_scale.shape)
|
|
self.zero_point.resize_(_zero_point.shape)
|
|
self.scale.copy_(_scale)
|
|
self.zero_point.copy_(_zero_point)
|
|
|
|
if self.fake_quant_enabled[0] == 1:
|
|
if self.is_per_channel:
|
|
X = torch.fake_quantize_per_channel_affine(
|
|
X, self.scale, self.zero_point,
|
|
self.ch_axis, self.quant_min, self.quant_max)
|
|
else:
|
|
X = torch.fake_quantize_per_tensor_affine(
|
|
X, self.scale, self.zero_point,
|
|
self.quant_min, self.quant_max)
|
|
return X
|
|
|
|
@torch.jit.export
|
|
def extra_repr(self):
|
|
return 'fake_quant_enabled={}, observer_enabled={}, ' \
|
|
'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
|
|
'scale={}, zero_point={}'.format(
|
|
self.fake_quant_enabled, self.observer_enabled,
|
|
self.quant_min, self.quant_max,
|
|
self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
# We cannot currently register scalar values as buffers, so need to manually
|
|
# specify serialization here.
|
|
super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
|
|
destination[prefix + 'scale'] = self.scale
|
|
destination[prefix + 'zero_point'] = self.zero_point
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
# Removing this function throws an error that the the size of the loaded tensor does not match the original size
|
|
# i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
|
|
local_state = ['scale', 'zero_point']
|
|
for name in local_state:
|
|
key = prefix + name
|
|
if key in state_dict:
|
|
val = state_dict[key]
|
|
# Custom handling to allow loading scale and zero_point
|
|
# of size N into uninitialized buffers of size 0. The
|
|
# buffers are resized here, and the values are copied in
|
|
# the default state_dict loading code of the parent.
|
|
if name == 'scale':
|
|
self.scale.resize_(val.shape)
|
|
else:
|
|
assert name == 'zero_point'
|
|
self.zero_point.resize_(val.shape)
|
|
# For torchscript module we need to update the attributes here since we do not
|
|
# call the `_load_from_state_dict` function defined module.py
|
|
if torch.jit.is_scripting():
|
|
if name == 'scale':
|
|
self.scale.copy_(val)
|
|
else:
|
|
assert name == 'zero_point'
|
|
self.zero_point.copy_(val)
|
|
elif strict:
|
|
missing_keys.append(key)
|
|
super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
|
|
class FixedQParamsFakeQuantize(FakeQuantizeBase):
|
|
""" Simulate quantize and dequantize with fixed quantization
|
|
parameters in training time. Only per tensor quantization
|
|
is supported.
|
|
Args:
|
|
`scale` (float): fixed scale for the fake quantize module
|
|
`zero_point` (int): fixed zero point for the fake quantize module
|
|
`dtype`, `qscheme`, `quant_min`, `quant_max`
|
|
"""
|
|
|
|
scale: torch.Tensor
|
|
zero_point: torch.Tensor
|
|
|
|
def __init__(self,
|
|
scale,
|
|
zero_point,
|
|
dtype=torch.quint8,
|
|
qscheme=torch.per_tensor_affine,
|
|
quant_min=0,
|
|
quant_max=255):
|
|
super().__init__()
|
|
assert quant_min <= quant_max, 'quant_min should be less than or equal to quant_max'
|
|
self.quant_min = quant_min
|
|
self.quant_max = quant_max
|
|
self.register_buffer('scale', torch.tensor([scale], dtype=torch.float))
|
|
self.register_buffer('zero_point', torch.tensor([zero_point], dtype=torch.int))
|
|
self.dtype = dtype
|
|
self.qscheme = qscheme
|
|
assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
|
|
' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
|
|
|
|
def forward(self, X):
|
|
if self.fake_quant_enabled[0] == 1:
|
|
X = torch.fake_quantize_per_tensor_affine(X, self.scale,
|
|
self.zero_point, self.quant_min,
|
|
self.quant_max)
|
|
return X
|
|
|
|
@torch.jit.export
|
|
def calculate_qparams(self):
|
|
return self.scale, self.zero_point
|
|
|
|
@torch.jit.export
|
|
def extra_repr(self):
|
|
return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
|
|
'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
|
|
self.fake_quant_enabled, self.observer_enabled,
|
|
self.scale, self.zero_point, self.dtype,
|
|
self.quant_min, self.quant_max, self.qscheme)
|
|
|
|
|
|
default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
|
|
dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
|
|
default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
|
|
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
|
|
|
|
# TODO(future PR): remove these defaults and enforce activation functions
|
|
# to explicitly specify their output range
|
|
default_symmetric_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
|
|
scale=2.0 / 256.0, zero_point=128, dtype=torch.quint8, quant_min=0, quant_max=255)
|
|
default_affine_fixed_qparams_fake_quant = FixedQParamsFakeQuantize.with_args(
|
|
scale=1.0 / 256.0, zero_point=0, dtype=torch.quint8, quant_min=0, quant_max=255)
|
|
|
|
default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
|
|
quant_min=-128,
|
|
quant_max=127,
|
|
dtype=torch.qint8,
|
|
qscheme=torch.per_channel_symmetric,
|
|
reduce_range=False,
|
|
ch_axis=0)
|
|
default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
|
|
quant_min=0,
|
|
quant_max=255,
|
|
dtype=torch.quint8,
|
|
qscheme=torch.per_tensor_affine,
|
|
reduce_range=True)
|
|
|
|
def _is_fake_quant_script_module(mod):
|
|
''' Returns true if given mod is an instance of FakeQuantize script module.
|
|
'''
|
|
if isinstance(mod, torch.jit.RecursiveScriptModule):
|
|
# qualified name looks like '__torch__.torch.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
|
|
suffix = mod._c.qualified_name.split('.', 1)[1]
|
|
name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
|
|
return name == 'torch.quantization.fake_quantize.FakeQuantize'
|
|
return False
|
|
|
|
def disable_fake_quant(mod):
|
|
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
|
|
mod.disable_fake_quant()
|
|
|
|
def enable_fake_quant(mod):
|
|
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
|
|
mod.enable_fake_quant()
|
|
|
|
def disable_observer(mod):
|
|
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
|
|
mod.disable_observer()
|
|
|
|
def enable_observer(mod):
|
|
if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
|
|
mod.enable_observer()
|