pytorch/torch/ao/quantization/_equalize.py
Andrew Hoblitzell 9724d0fd87 docstyle _correct_bias.py _equalize.py _learnable_fake_quantize.py backend_config experimental fake_quantize.py fuse_modules.py fuser_method_mappings.py (#112992)
Fixes #112988

For files

__init__.py
_correct_bias.py
_equalize.py
_learnable_fake_quantize.py
backend_config
experimental
fake_quantize.py
fuse_modules.py
fuser_method_mappings.py

Correct the following

__init__.py:1 at module level:
        D104: Missing docstring in public package
__init__.py:144 in public function `default_eval_fn`:
        D205: 1 blank line required between summary line and description (found 0)
__init__.py:144 in public function `default_eval_fn`:
        D400: First line should end with a period (not 'f')
__init__.py:144 in public function `default_eval_fn`:
        D401: First line should be in imperative mood; try rephrasing (found 'Default')
__init__.py:152 in private class `_DerivedObserverOrFakeQuantize`:
        D204: 1 blank line required after class docstring (found 0)
__init__.py:152 in private class `_DerivedObserverOrFakeQuantize`:
        D205: 1 blank line required between summary line and description (found 0)
__init__.py:152 in private class `_DerivedObserverOrFakeQuantize`:
        D210: No whitespaces allowed surrounding docstring text
__init__.py:152 in private class `_DerivedObserverOrFakeQuantize`:
        D400: First line should end with a period (not 's')
_correct_bias.py:20 in public function `get_module`:
        D200: One-line docstring should fit on one line with quotes (found 2)
_correct_bias.py:20 in public function `get_module`:
        D210: No whitespaces allowed surrounding docstring text
_correct_bias.py:20 in public function `get_module`:
        D300: Use """triple double quotes""" (found '''-quotes)
_correct_bias.py:20 in public function `get_module`:
        D400: First line should end with a period (not 'l')
_correct_bias.py:25 in public function `parent_child_names`:
        D200: One-line docstring should fit on one line with quotes (found 2)
_correct_bias.py:25 in public function `parent_child_names`:
        D300: Use """triple double quotes""" (found '''-quotes)
_correct_bias.py:25 in public function `parent_child_names`:
        D400: First line should end with a period (not 'e')
_correct_bias.py:25 in public function `parent_child_names`:
        D401: First line should be in imperative mood (perhaps 'Split', not 'Splits')
_correct_bias.py:34 in public function `get_param`:
        D205: 1 blank line required between summary line and description (found 0)
_correct_bias.py:34 in public function `get_param`:
        D210: No whitespaces allowed surrounding docstring text
_correct_bias.py:34 in public function `get_param`:
        D300: Use """triple double quotes""" (found '''-quotes)
_correct_bias.py:34 in public function `get_param`:
        D400: First line should end with a period (not 's')
_correct_bias.py:44 in public class `MeanShadowLogger`:
        D204: 1 blank line required after class docstring (found 0)
_correct_bias.py:44 in public class `MeanShadowLogger`:
        D205: 1 blank line required between summary line and description (found 0)
_correct_bias.py:44 in public class `MeanShadowLogger`:
        D400: First line should end with a period (not 'n')
_correct_bias.py:47 in public method `__init__`:
        D107: Missing docstring in __init__
_correct_bias.py:56 in public method `forward`:
        D205: 1 blank line required between summary line and description (found 0)
_correct_bias.py:56 in public method `forward`:
        D210: No whitespaces allowed surrounding docstring text
_correct_bias.py:56 in public method `forward`:
        D300: Use """triple double quotes""" (found '''-quotes)
_correct_bias.py:56 in public method `forward`:
        D401: First line should be in imperative mood; try rephrasing (found 'The')
_correct_bias.py:77 in public method `clear`:
        D102: Missing docstring in public method
_correct_bias.py:85 in public function `bias_correction`:
        D205: 1 blank line required between summary line and description (found 0)
_correct_bias.py:85 in public function `bias_correction`:
        D210: No whitespaces allowed surrounding docstring text
_correct_bias.py:85 in public function `bias_correction`:
        D300: Use """triple double quotes""" (found '''-quotes)
_correct_bias.py:85 in public function `bias_correction`:
        D400: First line should end with a period (not 's')
_correct_bias.py:85 in public function `bias_correction`:
        D401: First line should be in imperative mood (perhaps 'Use', not 'Using')
_equalize.py:22 in public function `set_module_weight`:
        D103: Missing docstring in public function
_equalize.py:28 in public function `set_module_bias`:
        D103: Missing docstring in public function
_equalize.py:34 in public function `get_module_weight`:
        D103: Missing docstring in public function
_equalize.py:40 in public function `get_module_bias`:
        D103: Missing docstring in public function
_equalize.py:47 in public function `max_over_ndim`:
        D200: One-line docstring should fit on one line with quotes (found 2)
_equalize.py:47 in public function `max_over_ndim`:
        D210: No whitespaces allowed surrounding docstring text
_equalize.py:47 in public function `max_over_ndim`:
        D300: Use """triple double quotes""" (found '''-quotes)
_equalize.py:47 in public function `max_over_ndim`:
        D400: First line should end with a period (not 's')
_equalize.py:47 in public function `max_over_ndim`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
_equalize.py:55 in public function `min_over_ndim`:
        D200: One-line docstring should fit on one line with quotes (found 2)
_equalize.py:55 in public function `min_over_ndim`:
        D210: No whitespaces allowed surrounding docstring text
_equalize.py:55 in public function `min_over_ndim`:
        D300: Use """triple double quotes""" (found '''-quotes)
_equalize.py:55 in public function `min_over_ndim`:
        D400: First line should end with a period (not 's')
_equalize.py:55 in public function `min_over_ndim`:
        D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies')
_equalize.py:63 in public function `channel_range`:
        D200: One-line docstring should fit on one line with quotes (found 2)
_equalize.py:63 in public function `channel_range`:
        D210: No whitespaces allowed surrounding docstring text
_equalize.py:63 in public function `channel_range`:
        D300: Use """triple double quotes""" (found '''-quotes)
_equalize.py:63 in public function `channel_range`:
        D400: First line should end with a period (not 'l')
_equalize.py:63 in public function `channel_range`:
        D401: First line should be in imperative mood (perhaps 'Find', not 'finds')
_equalize.py:63 in public function `channel_range`:
        D403: First word of the first line should be properly capitalized ('Finds', not 'finds')
_equalize.py:76 in public function `cross_layer_equalization`:
        D205: 1 blank line required between summary line and description (found 0)
_equalize.py:76 in public function `cross_layer_equalization`:
        D210: No whitespaces allowed surrounding docstring text
_equalize.py:76 in public function `cross_layer_equalization`:
        D300: Use """triple double quotes""" (found '''-quotes)
_equalize.py:76 in public function `cross_layer_equalization`:
        D400: First line should end with a period (not 't')
_equalize.py:120 in public function `equalize`:
        D205: 1 blank line required between summary line and description (found 0)
_equalize.py:120 in public function `equalize`:
        D210: No whitespaces allowed surrounding docstring text
_equalize.py:120 in public function `equalize`:
        D300: Use """triple double quotes""" (found '''-quotes)
_equalize.py:120 in public function `equalize`:
        D400: First line should end with a period (not 'l')
_equalize.py:159 in public function `converged`:
        D205: 1 blank line required between summary line and description (found 0)
_equalize.py:159 in public function `converged`:
        D210: No whitespaces allowed surrounding docstring text
_equalize.py:159 in public function `converged`:
        D300: Use """triple double quotes""" (found '''-quotes)
_equalize.py:159 in public function `converged`:
        D400: First line should end with a period (not 's')
_equalize.py:159 in public function `converged`:
        D401: First line should be in imperative mood (perhaps 'Test', not 'Tests')
_learnable_fake_quantize.py:8 in private class `_LearnableFakeQuantize`:
        D204: 1 blank line required after class docstring (found 0)
_learnable_fake_quantize.py:8 in private class `_LearnableFakeQuantize`:
        D205: 1 blank line required between summary line and description (found 0)
_learnable_fake_quantize.py:8 in private class `_LearnableFakeQuantize`:
        D210: No whitespaces allowed surrounding docstring text
_learnable_fake_quantize.py:8 in private class `_LearnableFakeQuantize`:
        D400: First line should end with a period (not 'h')
_learnable_fake_quantize.py:68 in private method `enable_param_learning`:
        D205: 1 blank line required between summary line and description (found 0)
_learnable_fake_quantize.py:68 in private method `enable_param_learning`:
        D400: First line should end with a period (not 'd')
_learnable_fake_quantize.py:68 in private method `enable_param_learning`:
        D401: First line should be in imperative mood (perhaps 'Enable', not 'Enables')
_learnable_fake_quantize.py:78 in private method `enable_static_estimate`:
        D205: 1 blank line required between summary line and description (found 0)
_learnable_fake_quantize.py:78 in private method `enable_static_estimate`:
        D400: First line should end with a period (not 'f')
_learnable_fake_quantize.py:78 in private method `enable_static_estimate`:
        D401: First line should be in imperative mood (perhaps 'Enable', not 'Enables')
_learnable_fake_quantize.py:87 in private method `enable_static_observation`:
        D205: 1 blank line required between summary line and description (found 0)
_learnable_fake_quantize.py:87 in private method `enable_static_observation`:
        D400: First line should end with a period (not 't')
_learnable_fake_quantize.py:87 in private method `enable_static_observation`:
        D401: First line should be in imperative mood (perhaps 'Enable', not 'Enables')
fake_quantize.py:1 at module level:
        D205: 1 blank line required between summary line and description (found 0)
fake_quantize.py:1 at module level:
        D400: First line should end with a period (not 'n')
fake_quantize.py:61 in public class `FakeQuantizeBase`:
        D205: 1 blank line required between summary line and description (found 0)
fake_quantize.py:61 in public class `FakeQuantizeBase`:
        D210: No whitespaces allowed surrounding docstring text
fake_quantize.py:61 in public class `FakeQuantizeBase`:
        D400: First line should end with a period (not 'e')
fake_quantize.py:74 in public method `__init__`:
        D107: Missing docstring in __init__
fake_quantize.py:83 in public method `forward`:
        D102: Missing docstring in public method
fake_quantize.py:87 in public method `calculate_qparams`:
        D102: Missing docstring in public method
fake_quantize.py:91 in public method `enable_fake_quant`:
        D102: Missing docstring in public method
fake_quantize.py:95 in public method `disable_fake_quant`:
        D102: Missing docstring in public method
fake_quantize.py:99 in public method `enable_observer`:
        D102: Missing docstring in public method
fake_quantize.py:103 in public method `disable_observer`:
        D102: Missing docstring in public method
fake_quantize.py:107 in public method `with_args`:
        D102: Missing docstring in public method
fake_quantize.py:115 in public class `FakeQuantize`:
        D205: 1 blank line required between summary line and description (found 0)
fake_quantize.py:115 in public class `FakeQuantize`:
        D210: No whitespaces allowed surrounding docstring text
fake_quantize.py:115 in public class `FakeQuantize`:
        D412: No blank lines allowed between a section header and its content ('Attributes')
fake_quantize.py:150 in public method `__init__`:
        D107: Missing docstring in __init__
fake_quantize.py:188 in public method `calculate_qparams`:
        D102: Missing docstring in public method
fake_quantize.py:191 in public method `forward`:
        D102: Missing docstring in public method
fake_quantize.py:214 in public method `extra_repr`:
        D102: Missing docstring in public method
fake_quantize.py:262 in public class `FixedQParamsFakeQuantize`:
        D205: 1 blank line required between summary line and description (found 0)
fake_quantize.py:262 in public class `FixedQParamsFakeQuantize`:
        D210: No whitespaces allowed surrounding docstring text
fake_quantize.py:262 in public class `FixedQParamsFakeQuantize`:
        D400: First line should end with a period (not 'n')
fake_quantize.py:268 in public method `__init__`:
        D107: Missing docstring in __init__
fake_quantize.py:279 in public method `calculate_qparams`:
        D102: Missing docstring in public method
fake_quantize.py:283 in public method `extra_repr`:
        D102: Missing docstring in public method
fake_quantize.py:292 in public class `FusedMovingAvgObsFakeQuantize`:
        D205: 1 blank line required between summary line and description (found 0)
fake_quantize.py:292 in public class `FusedMovingAvgObsFakeQuantize`:
        D400: First line should end with a period (not 'e')
fake_quantize.py:307 in public method `__init__`:
        D107: Missing docstring in __init__
fake_quantize.py:322 in public method `calculate_qparams`:
        D102: Missing docstring in public method
fake_quantize.py:326 in public method `extra_repr`:
        D102: Missing docstring in public method
fake_quantize.py:342 in public method `forward`:
        D102: Missing docstring in public method
fake_quantize.py:480 in private function `_is_fake_quant_script_module`:
        D200: One-line docstring should fit on one line with quotes (found 2)
fake_quantize.py:480 in private function `_is_fake_quant_script_module`:
        D210: No whitespaces allowed surrounding docstring text
fake_quantize.py:480 in private function `_is_fake_quant_script_module`:
        D300: Use """triple double quotes""" (found '''-quotes)
fake_quantize.py:480 in private function `_is_fake_quant_script_module`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
fake_quantize.py:491 in public function `disable_fake_quant`:
        D400: First line should end with a period (not ':')
fake_quantize.py:502 in public function `enable_fake_quant`:
        D400: First line should end with a period (not ':')
fake_quantize.py:513 in public function `disable_observer`:
        D400: First line should end with a period (not ':')
fake_quantize.py:524 in public function `enable_observer`:
        D400: First line should end with a period (not ':')
fuse_modules.py:1 at module level:
        D100: Missing docstring in public module
fuse_modules.py:39 in public function `fuse_known_modules`:
        D205: 1 blank line required between summary line and description (found 0)
fuse_modules.py:39 in public function `fuse_known_modules`:
        D400: First line should end with a period (not 'd')
fuse_modules.py:39 in public function `fuse_known_modules`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
fuse_modules.py:104 in public function `fuse_modules`:
        D400: First line should end with a period (not 'e')
fuse_modules.py:167 in public function `fuse_modules_qat`:
        D200: One-line docstring should fit on one line with quotes (found 2)
fuse_modules.py:167 in public function `fuse_modules_qat`:
        D210: No whitespaces allowed surrounding docstring text
fuse_modules.py:167 in public function `fuse_modules_qat`:
        D400: First line should end with a period (not '`')
fuser_method_mappings.py:1 at module level:
        D100: Missing docstring in public module
fuser_method_mappings.py:18 in public function `fuse_conv_bn`:
        D400: First line should end with a period (not 'e')
fuser_method_mappings.py:55 in public function `fuse_conv_bn_relu`:
        D400: First line should end with a period (not 'e')
fuser_method_mappings.py:102 in public function `fuse_linear_bn`:
        D400: First line should end with a period (not 'e')
fuser_method_mappings.py:131 in public function `fuse_convtranspose_bn`:
        D400: First line should end with a period (not 'e')
fuser_method_mappings.py:154 in private function `_sequential_wrapper2`:
        D205: 1 blank line required between summary line and description (found 0)
fuser_method_mappings.py:154 in private function `_sequential_wrapper2`:
        D210: No whitespaces allowed surrounding docstring text
fuser_method_mappings.py:154 in private function `_sequential_wrapper2`:
        D400: First line should end with a period (not 's')
fuser_method_mappings.py:182 in public function `get_fuser_method`:
        D205: 1 blank line required between summary line and description (found 0)
fuser_method_mappings.py:182 in public function `get_fuser_method`:
        D210: No whitespaces allowed surrounding docstring text
fuser_method_mappings.py:182 in public function `get_fuser_method`:
        D300: Use """triple double quotes""" (found '''-quotes)
fuser_method_mappings.py:182 in public function `get_fuser_method`:
        D400: First line should end with a period (not ',')
fuser_method_mappings.py:205 in private function `_get_valid_patterns`:
        D205: 1 blank line required between summary line and description (found 0)
fuser_method_mappings.py:205 in private function `_get_valid_patterns`:
        D400: First line should end with a period (not ',')
fuser_method_mappings.py:205 in private function `_get_valid_patterns`:
        D401: First line should be in imperative mood (perhaps 'Return', not 'Returns')
fuser_method_mappings.py:238 in public function `get_fuser_method_new`:
        D205: 1 blank line required between summary line and description (found 0)
fuser_method_mappings.py:238 in public function `get_fuser_method_new`:
        D210: No whitespaces allowed surrounding docstring text
fuser_method_mappings.py:238 in public function `get_fuser_method_new`:
        D400: First line should end with a period (not 'd')
fuser_method_mappings.py:238 in public function `get_fuser_method_new`:
        D401: First line should be in imperative mood; try rephrasing (found 'This')

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112992
Approved by: https://github.com/kit1980
2023-11-15 00:59:44 +00:00

183 lines
6.8 KiB
Python

import torch
import copy
from typing import Dict, Any
__all__ = [
"set_module_weight",
"set_module_bias",
"get_module_weight",
"get_module_bias",
"max_over_ndim",
"min_over_ndim",
"channel_range",
"cross_layer_equalization",
"equalize",
"converged",
]
_supported_types = {torch.nn.Conv2d, torch.nn.Linear}
_supported_intrinsic_types = {torch.ao.nn.intrinsic.ConvReLU2d, torch.ao.nn.intrinsic.LinearReLU}
_all_supported_types = _supported_types.union(_supported_intrinsic_types)
def set_module_weight(module, weight) -> None:
if type(module) in _supported_types:
module.weight = torch.nn.Parameter(weight)
else:
module[0].weight = torch.nn.Parameter(weight)
def set_module_bias(module, bias) -> None:
if type(module) in _supported_types:
module.bias = torch.nn.Parameter(bias)
else:
module[0].bias = torch.nn.Parameter(bias)
def get_module_weight(module):
if type(module) in _supported_types:
return module.weight
else:
return module[0].weight
def get_module_bias(module):
if type(module) in _supported_types:
return module.bias
else:
return module[0].bias
def max_over_ndim(input, axis_list, keepdim=False):
"""Apply 'torch.max' over the given axes."""
axis_list.sort(reverse=True)
for axis in axis_list:
input, _ = input.max(axis, keepdim)
return input
def min_over_ndim(input, axis_list, keepdim=False):
"""Apply 'torch.min' over the given axes."""
axis_list.sort(reverse=True)
for axis in axis_list:
input, _ = input.min(axis, keepdim)
return input
def channel_range(input, axis=0):
"""Find the range of weights associated with a specific channel."""
size_of_tensor_dim = input.ndim
axis_list = list(range(size_of_tensor_dim))
axis_list.remove(axis)
mins = min_over_ndim(input, axis_list)
maxs = max_over_ndim(input, axis_list)
assert mins.size(0) == input.size(axis), "Dimensions of resultant channel range does not match size of requested axis"
return maxs - mins
def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
"""Scale the range of Tensor1.output to equal Tensor2.input.
Given two adjacent tensors', the weights are scaled such that
the ranges of the first tensors' output channel are equal to the
ranges of the second tensors' input channel
"""
if type(module1) not in _all_supported_types or type(module2) not in _all_supported_types:
raise ValueError("module type not supported:", type(module1), " ", type(module2))
weight1 = get_module_weight(module1)
weight2 = get_module_weight(module2)
if weight1.size(output_axis) != weight2.size(input_axis):
raise TypeError("Number of output channels of first arg do not match \
number input channels of second arg")
bias = get_module_bias(module1)
weight1_range = channel_range(weight1, output_axis)
weight2_range = channel_range(weight2, input_axis)
# producing scaling factors to applied
weight2_range += 1e-9
scaling_factors = torch.sqrt(weight1_range / weight2_range)
inverse_scaling_factors = torch.reciprocal(scaling_factors)
bias = bias * inverse_scaling_factors
# formatting the scaling (1D) tensors to be applied on the given argument tensors
# pads axis to (1D) tensors to then be broadcasted
size1 = [1] * weight1.ndim
size1[output_axis] = weight1.size(output_axis)
size2 = [1] * weight2.ndim
size2[input_axis] = weight2.size(input_axis)
scaling_factors = torch.reshape(scaling_factors, size2)
inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1)
weight1 = weight1 * inverse_scaling_factors
weight2 = weight2 * scaling_factors
set_module_weight(module1, weight1)
set_module_bias(module1, bias)
set_module_weight(module2, weight2)
def equalize(model, paired_modules_list, threshold=1e-4, inplace=True):
"""Equalize modules until convergence is achieved.
Given a list of adjacent modules within a model, equalization will
be applied between each pair, this will repeated until convergence is achieved
Keeps a copy of the changing modules from the previous iteration, if the copies
are not that different than the current modules (determined by converged_test),
then the modules have converged enough that further equalizing is not necessary
Implementation of this referced section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf
Args:
model: a model (nn.module) that equalization is to be applied on
paired_modules_list: a list of lists where each sublist is a pair of two
submodules found in the model, for each pair the two submodules generally
have to be adjacent in the model to get expected/reasonable results
threshold: a number used by the converged function to determine what degree
similarity between models is necessary for them to be called equivalent
inplace: determines if function is inplace or not
"""
if not inplace:
model = copy.deepcopy(model)
name_to_module : Dict[str, torch.nn.Module] = {}
previous_name_to_module: Dict[str, Any] = {}
name_set = {name for pair in paired_modules_list for name in pair}
for name, module in model.named_modules():
if name in name_set:
name_to_module[name] = module
previous_name_to_module[name] = None
while not converged(name_to_module, previous_name_to_module, threshold):
for pair in paired_modules_list:
previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]])
previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]])
cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]])
return model
def converged(curr_modules, prev_modules, threshold=1e-4):
"""Test whether modules are converged to a specified threshold.
Tests for the summed norm of the differences between each set of modules
being less than the given threshold
Takes two dictionaries mapping names to modules, the set of names for each dictionary
should be the same, looping over the set of names, for each name take the difference
between the associated modules in each dictionary
"""
if curr_modules.keys() != prev_modules.keys():
raise ValueError("The keys to the given mappings must have the same set of names of modules")
summed_norms = torch.tensor(0.)
if None in prev_modules.values():
return False
for name in curr_modules.keys():
curr_weight = get_module_weight(curr_modules[name])
prev_weight = get_module_weight(prev_modules[name])
difference = curr_weight.sub(prev_weight)
summed_norms += torch.norm(difference)
return bool(summed_norms < threshold)