mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Fixes #112988 For files __init__.py _correct_bias.py _equalize.py _learnable_fake_quantize.py backend_config experimental fake_quantize.py fuse_modules.py fuser_method_mappings.py Correct the following __init__.py:1 at module level: D104: Missing docstring in public package __init__.py:144 in public function `default_eval_fn`: D205: 1 blank line required between summary line and description (found 0) __init__.py:144 in public function `default_eval_fn`: D400: First line should end with a period (not 'f') __init__.py:144 in public function `default_eval_fn`: D401: First line should be in imperative mood; try rephrasing (found 'Default') __init__.py:152 in private class `_DerivedObserverOrFakeQuantize`: D204: 1 blank line required after class docstring (found 0) __init__.py:152 in private class `_DerivedObserverOrFakeQuantize`: D205: 1 blank line required between summary line and description (found 0) __init__.py:152 in private class `_DerivedObserverOrFakeQuantize`: D210: No whitespaces allowed surrounding docstring text __init__.py:152 in private class `_DerivedObserverOrFakeQuantize`: D400: First line should end with a period (not 's') _correct_bias.py:20 in public function `get_module`: D200: One-line docstring should fit on one line with quotes (found 2) _correct_bias.py:20 in public function `get_module`: D210: No whitespaces allowed surrounding docstring text _correct_bias.py:20 in public function `get_module`: D300: Use """triple double quotes""" (found '''-quotes) _correct_bias.py:20 in public function `get_module`: D400: First line should end with a period (not 'l') _correct_bias.py:25 in public function `parent_child_names`: D200: One-line docstring should fit on one line with quotes (found 2) _correct_bias.py:25 in public function `parent_child_names`: D300: Use """triple double quotes""" (found '''-quotes) _correct_bias.py:25 in public function `parent_child_names`: D400: First line should end with a period (not 'e') _correct_bias.py:25 in public function `parent_child_names`: D401: First line should be in imperative mood (perhaps 'Split', not 'Splits') _correct_bias.py:34 in public function `get_param`: D205: 1 blank line required between summary line and description (found 0) _correct_bias.py:34 in public function `get_param`: D210: No whitespaces allowed surrounding docstring text _correct_bias.py:34 in public function `get_param`: D300: Use """triple double quotes""" (found '''-quotes) _correct_bias.py:34 in public function `get_param`: D400: First line should end with a period (not 's') _correct_bias.py:44 in public class `MeanShadowLogger`: D204: 1 blank line required after class docstring (found 0) _correct_bias.py:44 in public class `MeanShadowLogger`: D205: 1 blank line required between summary line and description (found 0) _correct_bias.py:44 in public class `MeanShadowLogger`: D400: First line should end with a period (not 'n') _correct_bias.py:47 in public method `__init__`: D107: Missing docstring in __init__ _correct_bias.py:56 in public method `forward`: D205: 1 blank line required between summary line and description (found 0) _correct_bias.py:56 in public method `forward`: D210: No whitespaces allowed surrounding docstring text _correct_bias.py:56 in public method `forward`: D300: Use """triple double quotes""" (found '''-quotes) _correct_bias.py:56 in public method `forward`: D401: First line should be in imperative mood; try rephrasing (found 'The') _correct_bias.py:77 in public method `clear`: D102: Missing docstring in public method _correct_bias.py:85 in public function `bias_correction`: D205: 1 blank line required between summary line and description (found 0) _correct_bias.py:85 in public function `bias_correction`: D210: No whitespaces allowed surrounding docstring text _correct_bias.py:85 in public function `bias_correction`: D300: Use """triple double quotes""" (found '''-quotes) _correct_bias.py:85 in public function `bias_correction`: D400: First line should end with a period (not 's') _correct_bias.py:85 in public function `bias_correction`: D401: First line should be in imperative mood (perhaps 'Use', not 'Using') _equalize.py:22 in public function `set_module_weight`: D103: Missing docstring in public function _equalize.py:28 in public function `set_module_bias`: D103: Missing docstring in public function _equalize.py:34 in public function `get_module_weight`: D103: Missing docstring in public function _equalize.py:40 in public function `get_module_bias`: D103: Missing docstring in public function _equalize.py:47 in public function `max_over_ndim`: D200: One-line docstring should fit on one line with quotes (found 2) _equalize.py:47 in public function `max_over_ndim`: D210: No whitespaces allowed surrounding docstring text _equalize.py:47 in public function `max_over_ndim`: D300: Use """triple double quotes""" (found '''-quotes) _equalize.py:47 in public function `max_over_ndim`: D400: First line should end with a period (not 's') _equalize.py:47 in public function `max_over_ndim`: D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies') _equalize.py:55 in public function `min_over_ndim`: D200: One-line docstring should fit on one line with quotes (found 2) _equalize.py:55 in public function `min_over_ndim`: D210: No whitespaces allowed surrounding docstring text _equalize.py:55 in public function `min_over_ndim`: D300: Use """triple double quotes""" (found '''-quotes) _equalize.py:55 in public function `min_over_ndim`: D400: First line should end with a period (not 's') _equalize.py:55 in public function `min_over_ndim`: D401: First line should be in imperative mood (perhaps 'Apply', not 'Applies') _equalize.py:63 in public function `channel_range`: D200: One-line docstring should fit on one line with quotes (found 2) _equalize.py:63 in public function `channel_range`: D210: No whitespaces allowed surrounding docstring text _equalize.py:63 in public function `channel_range`: D300: Use """triple double quotes""" (found '''-quotes) _equalize.py:63 in public function `channel_range`: D400: First line should end with a period (not 'l') _equalize.py:63 in public function `channel_range`: D401: First line should be in imperative mood (perhaps 'Find', not 'finds') _equalize.py:63 in public function `channel_range`: D403: First word of the first line should be properly capitalized ('Finds', not 'finds') _equalize.py:76 in public function `cross_layer_equalization`: D205: 1 blank line required between summary line and description (found 0) _equalize.py:76 in public function `cross_layer_equalization`: D210: No whitespaces allowed surrounding docstring text _equalize.py:76 in public function `cross_layer_equalization`: D300: Use """triple double quotes""" (found '''-quotes) _equalize.py:76 in public function `cross_layer_equalization`: D400: First line should end with a period (not 't') _equalize.py:120 in public function `equalize`: D205: 1 blank line required between summary line and description (found 0) _equalize.py:120 in public function `equalize`: D210: No whitespaces allowed surrounding docstring text _equalize.py:120 in public function `equalize`: D300: Use """triple double quotes""" (found '''-quotes) _equalize.py:120 in public function `equalize`: D400: First line should end with a period (not 'l') _equalize.py:159 in public function `converged`: D205: 1 blank line required between summary line and description (found 0) _equalize.py:159 in public function `converged`: D210: No whitespaces allowed surrounding docstring text _equalize.py:159 in public function `converged`: D300: Use """triple double quotes""" (found '''-quotes) _equalize.py:159 in public function `converged`: D400: First line should end with a period (not 's') _equalize.py:159 in public function `converged`: D401: First line should be in imperative mood (perhaps 'Test', not 'Tests') _learnable_fake_quantize.py:8 in private class `_LearnableFakeQuantize`: D204: 1 blank line required after class docstring (found 0) _learnable_fake_quantize.py:8 in private class `_LearnableFakeQuantize`: D205: 1 blank line required between summary line and description (found 0) _learnable_fake_quantize.py:8 in private class `_LearnableFakeQuantize`: D210: No whitespaces allowed surrounding docstring text _learnable_fake_quantize.py:8 in private class `_LearnableFakeQuantize`: D400: First line should end with a period (not 'h') _learnable_fake_quantize.py:68 in private method `enable_param_learning`: D205: 1 blank line required between summary line and description (found 0) _learnable_fake_quantize.py:68 in private method `enable_param_learning`: D400: First line should end with a period (not 'd') _learnable_fake_quantize.py:68 in private method `enable_param_learning`: D401: First line should be in imperative mood (perhaps 'Enable', not 'Enables') _learnable_fake_quantize.py:78 in private method `enable_static_estimate`: D205: 1 blank line required between summary line and description (found 0) _learnable_fake_quantize.py:78 in private method `enable_static_estimate`: D400: First line should end with a period (not 'f') _learnable_fake_quantize.py:78 in private method `enable_static_estimate`: D401: First line should be in imperative mood (perhaps 'Enable', not 'Enables') _learnable_fake_quantize.py:87 in private method `enable_static_observation`: D205: 1 blank line required between summary line and description (found 0) _learnable_fake_quantize.py:87 in private method `enable_static_observation`: D400: First line should end with a period (not 't') _learnable_fake_quantize.py:87 in private method `enable_static_observation`: D401: First line should be in imperative mood (perhaps 'Enable', not 'Enables') fake_quantize.py:1 at module level: D205: 1 blank line required between summary line and description (found 0) fake_quantize.py:1 at module level: D400: First line should end with a period (not 'n') fake_quantize.py:61 in public class `FakeQuantizeBase`: D205: 1 blank line required between summary line and description (found 0) fake_quantize.py:61 in public class `FakeQuantizeBase`: D210: No whitespaces allowed surrounding docstring text fake_quantize.py:61 in public class `FakeQuantizeBase`: D400: First line should end with a period (not 'e') fake_quantize.py:74 in public method `__init__`: D107: Missing docstring in __init__ fake_quantize.py:83 in public method `forward`: D102: Missing docstring in public method fake_quantize.py:87 in public method `calculate_qparams`: D102: Missing docstring in public method fake_quantize.py:91 in public method `enable_fake_quant`: D102: Missing docstring in public method fake_quantize.py:95 in public method `disable_fake_quant`: D102: Missing docstring in public method fake_quantize.py:99 in public method `enable_observer`: D102: Missing docstring in public method fake_quantize.py:103 in public method `disable_observer`: D102: Missing docstring in public method fake_quantize.py:107 in public method `with_args`: D102: Missing docstring in public method fake_quantize.py:115 in public class `FakeQuantize`: D205: 1 blank line required between summary line and description (found 0) fake_quantize.py:115 in public class `FakeQuantize`: D210: No whitespaces allowed surrounding docstring text fake_quantize.py:115 in public class `FakeQuantize`: D412: No blank lines allowed between a section header and its content ('Attributes') fake_quantize.py:150 in public method `__init__`: D107: Missing docstring in __init__ fake_quantize.py:188 in public method `calculate_qparams`: D102: Missing docstring in public method fake_quantize.py:191 in public method `forward`: D102: Missing docstring in public method fake_quantize.py:214 in public method `extra_repr`: D102: Missing docstring in public method fake_quantize.py:262 in public class `FixedQParamsFakeQuantize`: D205: 1 blank line required between summary line and description (found 0) fake_quantize.py:262 in public class `FixedQParamsFakeQuantize`: D210: No whitespaces allowed surrounding docstring text fake_quantize.py:262 in public class `FixedQParamsFakeQuantize`: D400: First line should end with a period (not 'n') fake_quantize.py:268 in public method `__init__`: D107: Missing docstring in __init__ fake_quantize.py:279 in public method `calculate_qparams`: D102: Missing docstring in public method fake_quantize.py:283 in public method `extra_repr`: D102: Missing docstring in public method fake_quantize.py:292 in public class `FusedMovingAvgObsFakeQuantize`: D205: 1 blank line required between summary line and description (found 0) fake_quantize.py:292 in public class `FusedMovingAvgObsFakeQuantize`: D400: First line should end with a period (not 'e') fake_quantize.py:307 in public method `__init__`: D107: Missing docstring in __init__ fake_quantize.py:322 in public method `calculate_qparams`: D102: Missing docstring in public method fake_quantize.py:326 in public method `extra_repr`: D102: Missing docstring in public method fake_quantize.py:342 in public method `forward`: D102: Missing docstring in public method fake_quantize.py:480 in private function `_is_fake_quant_script_module`: D200: One-line docstring should fit on one line with quotes (found 2) fake_quantize.py:480 in private function `_is_fake_quant_script_module`: D210: No whitespaces allowed surrounding docstring text fake_quantize.py:480 in private function `_is_fake_quant_script_module`: D300: Use """triple double quotes""" (found '''-quotes) fake_quantize.py:480 in private function `_is_fake_quant_script_module`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') fake_quantize.py:491 in public function `disable_fake_quant`: D400: First line should end with a period (not ':') fake_quantize.py:502 in public function `enable_fake_quant`: D400: First line should end with a period (not ':') fake_quantize.py:513 in public function `disable_observer`: D400: First line should end with a period (not ':') fake_quantize.py:524 in public function `enable_observer`: D400: First line should end with a period (not ':') fuse_modules.py:1 at module level: D100: Missing docstring in public module fuse_modules.py:39 in public function `fuse_known_modules`: D205: 1 blank line required between summary line and description (found 0) fuse_modules.py:39 in public function `fuse_known_modules`: D400: First line should end with a period (not 'd') fuse_modules.py:39 in public function `fuse_known_modules`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') fuse_modules.py:104 in public function `fuse_modules`: D400: First line should end with a period (not 'e') fuse_modules.py:167 in public function `fuse_modules_qat`: D200: One-line docstring should fit on one line with quotes (found 2) fuse_modules.py:167 in public function `fuse_modules_qat`: D210: No whitespaces allowed surrounding docstring text fuse_modules.py:167 in public function `fuse_modules_qat`: D400: First line should end with a period (not '`') fuser_method_mappings.py:1 at module level: D100: Missing docstring in public module fuser_method_mappings.py:18 in public function `fuse_conv_bn`: D400: First line should end with a period (not 'e') fuser_method_mappings.py:55 in public function `fuse_conv_bn_relu`: D400: First line should end with a period (not 'e') fuser_method_mappings.py:102 in public function `fuse_linear_bn`: D400: First line should end with a period (not 'e') fuser_method_mappings.py:131 in public function `fuse_convtranspose_bn`: D400: First line should end with a period (not 'e') fuser_method_mappings.py:154 in private function `_sequential_wrapper2`: D205: 1 blank line required between summary line and description (found 0) fuser_method_mappings.py:154 in private function `_sequential_wrapper2`: D210: No whitespaces allowed surrounding docstring text fuser_method_mappings.py:154 in private function `_sequential_wrapper2`: D400: First line should end with a period (not 's') fuser_method_mappings.py:182 in public function `get_fuser_method`: D205: 1 blank line required between summary line and description (found 0) fuser_method_mappings.py:182 in public function `get_fuser_method`: D210: No whitespaces allowed surrounding docstring text fuser_method_mappings.py:182 in public function `get_fuser_method`: D300: Use """triple double quotes""" (found '''-quotes) fuser_method_mappings.py:182 in public function `get_fuser_method`: D400: First line should end with a period (not ',') fuser_method_mappings.py:205 in private function `_get_valid_patterns`: D205: 1 blank line required between summary line and description (found 0) fuser_method_mappings.py:205 in private function `_get_valid_patterns`: D400: First line should end with a period (not ',') fuser_method_mappings.py:205 in private function `_get_valid_patterns`: D401: First line should be in imperative mood (perhaps 'Return', not 'Returns') fuser_method_mappings.py:238 in public function `get_fuser_method_new`: D205: 1 blank line required between summary line and description (found 0) fuser_method_mappings.py:238 in public function `get_fuser_method_new`: D210: No whitespaces allowed surrounding docstring text fuser_method_mappings.py:238 in public function `get_fuser_method_new`: D400: First line should end with a period (not 'd') fuser_method_mappings.py:238 in public function `get_fuser_method_new`: D401: First line should be in imperative mood; try rephrasing (found 'This') Pull Request resolved: https://github.com/pytorch/pytorch/pull/112992 Approved by: https://github.com/kit1980
183 lines
6.8 KiB
Python
183 lines
6.8 KiB
Python
import torch
|
|
import copy
|
|
from typing import Dict, Any
|
|
|
|
__all__ = [
|
|
"set_module_weight",
|
|
"set_module_bias",
|
|
"get_module_weight",
|
|
"get_module_bias",
|
|
"max_over_ndim",
|
|
"min_over_ndim",
|
|
"channel_range",
|
|
"cross_layer_equalization",
|
|
"equalize",
|
|
"converged",
|
|
]
|
|
|
|
_supported_types = {torch.nn.Conv2d, torch.nn.Linear}
|
|
_supported_intrinsic_types = {torch.ao.nn.intrinsic.ConvReLU2d, torch.ao.nn.intrinsic.LinearReLU}
|
|
_all_supported_types = _supported_types.union(_supported_intrinsic_types)
|
|
|
|
def set_module_weight(module, weight) -> None:
|
|
if type(module) in _supported_types:
|
|
module.weight = torch.nn.Parameter(weight)
|
|
else:
|
|
module[0].weight = torch.nn.Parameter(weight)
|
|
|
|
def set_module_bias(module, bias) -> None:
|
|
if type(module) in _supported_types:
|
|
module.bias = torch.nn.Parameter(bias)
|
|
else:
|
|
module[0].bias = torch.nn.Parameter(bias)
|
|
|
|
def get_module_weight(module):
|
|
if type(module) in _supported_types:
|
|
return module.weight
|
|
else:
|
|
return module[0].weight
|
|
|
|
def get_module_bias(module):
|
|
if type(module) in _supported_types:
|
|
return module.bias
|
|
else:
|
|
return module[0].bias
|
|
|
|
def max_over_ndim(input, axis_list, keepdim=False):
|
|
"""Apply 'torch.max' over the given axes."""
|
|
axis_list.sort(reverse=True)
|
|
for axis in axis_list:
|
|
input, _ = input.max(axis, keepdim)
|
|
return input
|
|
|
|
def min_over_ndim(input, axis_list, keepdim=False):
|
|
"""Apply 'torch.min' over the given axes."""
|
|
axis_list.sort(reverse=True)
|
|
for axis in axis_list:
|
|
input, _ = input.min(axis, keepdim)
|
|
return input
|
|
|
|
def channel_range(input, axis=0):
|
|
"""Find the range of weights associated with a specific channel."""
|
|
size_of_tensor_dim = input.ndim
|
|
axis_list = list(range(size_of_tensor_dim))
|
|
axis_list.remove(axis)
|
|
|
|
mins = min_over_ndim(input, axis_list)
|
|
maxs = max_over_ndim(input, axis_list)
|
|
|
|
assert mins.size(0) == input.size(axis), "Dimensions of resultant channel range does not match size of requested axis"
|
|
return maxs - mins
|
|
|
|
def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
|
|
"""Scale the range of Tensor1.output to equal Tensor2.input.
|
|
|
|
Given two adjacent tensors', the weights are scaled such that
|
|
the ranges of the first tensors' output channel are equal to the
|
|
ranges of the second tensors' input channel
|
|
"""
|
|
if type(module1) not in _all_supported_types or type(module2) not in _all_supported_types:
|
|
raise ValueError("module type not supported:", type(module1), " ", type(module2))
|
|
|
|
weight1 = get_module_weight(module1)
|
|
weight2 = get_module_weight(module2)
|
|
|
|
if weight1.size(output_axis) != weight2.size(input_axis):
|
|
raise TypeError("Number of output channels of first arg do not match \
|
|
number input channels of second arg")
|
|
|
|
bias = get_module_bias(module1)
|
|
|
|
weight1_range = channel_range(weight1, output_axis)
|
|
weight2_range = channel_range(weight2, input_axis)
|
|
|
|
# producing scaling factors to applied
|
|
weight2_range += 1e-9
|
|
scaling_factors = torch.sqrt(weight1_range / weight2_range)
|
|
inverse_scaling_factors = torch.reciprocal(scaling_factors)
|
|
|
|
bias = bias * inverse_scaling_factors
|
|
|
|
# formatting the scaling (1D) tensors to be applied on the given argument tensors
|
|
# pads axis to (1D) tensors to then be broadcasted
|
|
size1 = [1] * weight1.ndim
|
|
size1[output_axis] = weight1.size(output_axis)
|
|
size2 = [1] * weight2.ndim
|
|
size2[input_axis] = weight2.size(input_axis)
|
|
|
|
scaling_factors = torch.reshape(scaling_factors, size2)
|
|
inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1)
|
|
|
|
weight1 = weight1 * inverse_scaling_factors
|
|
weight2 = weight2 * scaling_factors
|
|
|
|
set_module_weight(module1, weight1)
|
|
set_module_bias(module1, bias)
|
|
set_module_weight(module2, weight2)
|
|
|
|
def equalize(model, paired_modules_list, threshold=1e-4, inplace=True):
|
|
"""Equalize modules until convergence is achieved.
|
|
|
|
Given a list of adjacent modules within a model, equalization will
|
|
be applied between each pair, this will repeated until convergence is achieved
|
|
|
|
Keeps a copy of the changing modules from the previous iteration, if the copies
|
|
are not that different than the current modules (determined by converged_test),
|
|
then the modules have converged enough that further equalizing is not necessary
|
|
|
|
Implementation of this referced section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf
|
|
|
|
Args:
|
|
model: a model (nn.module) that equalization is to be applied on
|
|
paired_modules_list: a list of lists where each sublist is a pair of two
|
|
submodules found in the model, for each pair the two submodules generally
|
|
have to be adjacent in the model to get expected/reasonable results
|
|
threshold: a number used by the converged function to determine what degree
|
|
similarity between models is necessary for them to be called equivalent
|
|
inplace: determines if function is inplace or not
|
|
"""
|
|
if not inplace:
|
|
model = copy.deepcopy(model)
|
|
|
|
name_to_module : Dict[str, torch.nn.Module] = {}
|
|
previous_name_to_module: Dict[str, Any] = {}
|
|
name_set = {name for pair in paired_modules_list for name in pair}
|
|
|
|
for name, module in model.named_modules():
|
|
if name in name_set:
|
|
name_to_module[name] = module
|
|
previous_name_to_module[name] = None
|
|
while not converged(name_to_module, previous_name_to_module, threshold):
|
|
for pair in paired_modules_list:
|
|
previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]])
|
|
previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]])
|
|
|
|
cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]])
|
|
|
|
return model
|
|
|
|
def converged(curr_modules, prev_modules, threshold=1e-4):
|
|
"""Test whether modules are converged to a specified threshold.
|
|
|
|
Tests for the summed norm of the differences between each set of modules
|
|
being less than the given threshold
|
|
|
|
Takes two dictionaries mapping names to modules, the set of names for each dictionary
|
|
should be the same, looping over the set of names, for each name take the difference
|
|
between the associated modules in each dictionary
|
|
|
|
"""
|
|
if curr_modules.keys() != prev_modules.keys():
|
|
raise ValueError("The keys to the given mappings must have the same set of names of modules")
|
|
|
|
summed_norms = torch.tensor(0.)
|
|
if None in prev_modules.values():
|
|
return False
|
|
for name in curr_modules.keys():
|
|
curr_weight = get_module_weight(curr_modules[name])
|
|
prev_weight = get_module_weight(prev_modules[name])
|
|
|
|
difference = curr_weight.sub(prev_weight)
|
|
summed_norms += torch.norm(difference)
|
|
return bool(summed_norms < threshold)
|