mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48038 nn.ReLU works for both float and quantized input, we don't want to define an nn.quantized.ReLU that does the same thing as nn.ReLU, similarly for nn.quantized.functional.relu this also removes the numerical inconsistency for models quantizes nn.ReLU independently in qat mode Test Plan: Imported from OSS Imported from OSS Reviewed By: vkuzo Differential Revision: D25000462 fbshipit-source-id: e3609a3ae4a3476a42f61276619033054194a0d2
69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
import torch
|
|
import torch.nn.quantized.functional
|
|
import torch.nn.intrinsic as nni
|
|
|
|
class BatchNorm2d(torch.nn.BatchNorm2d):
|
|
r"""This is the quantized version of :class:`~torch.nn.BatchNorm2d`.
|
|
"""
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
|
super(BatchNorm2d, self).__init__(num_features)
|
|
self.eps = eps
|
|
self.scale = 1.0
|
|
self.zero_point = 0
|
|
|
|
def forward(self, input):
|
|
return torch.ops.quantized.batch_norm2d(input, self.weight, self.bias, self.running_mean,
|
|
self.running_var, self.eps, self.scale, self.zero_point)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizedBatchNorm2d'
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
activation_post_process = mod.activation_post_process
|
|
if type(mod) == nni.BNReLU2d:
|
|
mod = mod[0]
|
|
scale, zero_point = activation_post_process.calculate_qparams()
|
|
new_mod = cls(mod.num_features, mod.eps)
|
|
new_mod.weight = mod.weight
|
|
new_mod.bias = mod.bias
|
|
new_mod.running_mean = mod.running_mean
|
|
new_mod.running_var = mod.running_var
|
|
new_mod.scale = float(scale)
|
|
new_mod.zero_point = int(zero_point)
|
|
return new_mod
|
|
|
|
# TODO: dedup with BatchNorm2d
|
|
class BatchNorm3d(torch.nn.BatchNorm3d):
|
|
r"""This is the quantized version of :class:`~torch.nn.BatchNorm3d`.
|
|
"""
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
|
super(BatchNorm3d, self).__init__(num_features)
|
|
self.eps = eps
|
|
self.scale = 1.0
|
|
self.zero_point = 0
|
|
|
|
def forward(self, input):
|
|
return torch.ops.quantized.batch_norm3d(input, self.weight, self.bias, self.running_mean,
|
|
self.running_var, self.eps, self.scale, self.zero_point)
|
|
|
|
def _get_name(self):
|
|
return 'QuantizedBatchNorm3d'
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
activation_post_process = mod.activation_post_process
|
|
if type(mod) == nni.BNReLU3d:
|
|
mod = mod[0]
|
|
scale, zero_point = activation_post_process.calculate_qparams()
|
|
new_mod = cls(mod.num_features, mod.eps)
|
|
new_mod.weight = mod.weight
|
|
new_mod.bias = mod.bias
|
|
new_mod.running_mean = mod.running_mean
|
|
new_mod.running_var = mod.running_var
|
|
new_mod.scale = float(scale)
|
|
new_mod.zero_point = int(zero_point)
|
|
return new_mod
|