mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: * Deletes all weak script decorators / associated data structures / methods * In order to keep supporting the standard library in script, this enables recursive script on any function defined in `torch.nn` * Most changes in `torch/nn` are the result of `ag -Q "weak" torch/nn/ -l | xargs sed -i '/weak/d'`, only `rnn.py` needed manual editing to use the `ignore` and `export` to continue supporting the overloaded `forward` methods * `Sequential`/`ModuleList` no longer need to be added to constants since they are compiled on demand This should also fix https://github.com/pytorch/pytorch/issues/22212 Pull Request resolved: https://github.com/pytorch/pytorch/pull/22212 Differential Revision: D15988346 Pulled By: driazati fbshipit-source-id: af223e3ad0580be895377312949997a70e988e4f
139 lines
5.3 KiB
Python
139 lines
5.3 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
import torch
|
|
from ...modules.module import Module
|
|
from ...modules.linear import Linear as NNLinear
|
|
|
|
class Quantize(Module):
|
|
r"""Quantizes an incoming tensor
|
|
Args:
|
|
`out_scale`: scale of the output Quantized Tensor
|
|
`out_zero_point`: zero_point of output Quantized Tensor
|
|
`out_dtype`: data type of output Quantized Tensor
|
|
|
|
Attributes:
|
|
`out_scale`, `out_zero_point`, `out_dtype`
|
|
|
|
Examples::
|
|
>>> t = torch.tensor([[1., -1.], [1., -1.]])
|
|
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
|
|
>>> qm = Quantize(scale, zero_point, dtype)
|
|
>>> qt = qm(t)
|
|
>>> print(qt)
|
|
>>> tensor([[ 1., -1.],
|
|
> [ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2)
|
|
"""
|
|
|
|
def __init__(self, out_scale, out_zero_point, out_dtype):
|
|
super(Quantize, self).__init__()
|
|
self.register_buffer('out_scale', torch.tensor([out_scale]))
|
|
self.register_buffer('out_zero_point', torch.tensor([out_zero_point], dtype=torch.long))
|
|
self.out_dtype = out_dtype
|
|
|
|
def forward(self, X):
|
|
return torch.quantize_linear(X, self.out_scale.item(),
|
|
self.out_zero_point.item(), self.out_dtype)
|
|
|
|
@staticmethod
|
|
def from_float(mod):
|
|
return Quantize(mod.qparams[0].item(), mod.qparams[1].item(), torch.quint8)
|
|
|
|
class DeQuantize(Module):
|
|
r"""Dequantizes an incoming tensor
|
|
|
|
Examples::
|
|
>>> input = torch.tensor([[1., -1.], [1., -1.]])
|
|
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
|
|
>>> qm = Quantize(scale, zero_point, dtype)
|
|
>>> quantized_input = qm(input)
|
|
>>> dqm = DeQuantize()
|
|
>>> dequantized = dqm(quantized_input)
|
|
>>> print(dequantized)
|
|
>>> tensor([[ 1., -1.],
|
|
[ 1., -1.]], dtype=torch.float32)
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(DeQuantize, self).__init__()
|
|
|
|
def forward(self, Xq):
|
|
return Xq.dequantize()
|
|
|
|
@staticmethod
|
|
def from_float(mod):
|
|
return DeQuantize()
|
|
|
|
class Linear(NNLinear):
|
|
r"""
|
|
A quantized linear module with quantized tensor as inputs
|
|
and outputs.
|
|
We adopt the same interface as `torch.nn.Linear`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
|
|
for documentation.
|
|
|
|
Similar to `torch.nn.Linear`, attributes will be randomly initialized at
|
|
module creation time and will be overwritten later
|
|
|
|
Attributes:
|
|
weight: the non-learnable quantized weights of the
|
|
module which are of shape :math:`(\text{out\_features}, \text{in\_features})`.
|
|
bias: the non-learnable bias of the module of shape :math:`(\text{out\_features})`.
|
|
If :attr:`bias` is ``True``, the values are initialized to zero.
|
|
out_scale: `scale` parameter of output Quantized Tensor, type: float
|
|
out_zero_point: `zero_point` parameter for output Quantized Tensor, type: long
|
|
|
|
Examples::
|
|
|
|
>>> m = nn.quantized.Linear(20, 30)
|
|
>>> input = torch.randn(128, 20)
|
|
>>> output = m(input)
|
|
>>> print(output.size())
|
|
torch.Size([128, 30])
|
|
"""
|
|
__constants__ = ['bias', 'in_features', 'out_features']
|
|
|
|
def __init__(self, in_features, out_features, bias=True):
|
|
assert bias, 'nobias is not supported in Quantized Linear module yet'
|
|
super(Linear, self).__init__(in_features, out_features, bias)
|
|
del self.weight
|
|
del self.bias
|
|
qweight = torch._empty_affine_quantized(
|
|
[out_features, in_features], scale=1, zero_point=0,
|
|
dtype=torch.qint8)
|
|
qbias = torch._empty_affine_quantized(
|
|
[out_features], scale=1, zero_point=0, dtype=torch.qint32)
|
|
self.register_buffer('_packed_weight',
|
|
torch.ops.quantized.fbgemm_linear_prepack(qweight))
|
|
self.register_buffer('bias', qbias)
|
|
self.register_buffer('out_scale', torch.Tensor([1]))
|
|
self.register_buffer('out_zero_point', torch.Tensor([0]))
|
|
|
|
@property
|
|
def weight(self):
|
|
return torch.ops.quantized.fbgemm_linear_unpack(self._packed_weight)
|
|
|
|
@weight.setter
|
|
def weight(self, w):
|
|
self._packed_weight = torch.ops.quantized.fbgemm_linear_prepack(w)
|
|
|
|
def forward(self, x):
|
|
Y_q = torch.ops.quantized.fbgemm_linear(
|
|
x, self._packed_weight,
|
|
self.bias,
|
|
self.out_scale,
|
|
self.out_zero_point)
|
|
return Y_q
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
|
destination[prefix + 'weight'] = torch.ops.quantized.fbgemm_linear_unpack(destination[prefix + '_packed_weight'])
|
|
destination.pop(prefix + '_packed_weight')
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
self._packed_weight = torch.ops.quantized.fbgemm_linear_prepack(state_dict[prefix + 'weight'])
|
|
self.bias.copy_(state_dict[prefix + 'bias'])
|
|
state_dict.pop(prefix + 'weight')
|
|
state_dict.pop(prefix + 'bias')
|
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
|
|
missing_keys, unexpected_keys, error_msgs)
|
|
return
|