pytorch/torch/nn/quantized/modules/linear.py
David Riazati 10c4b98ade Remove weak script (#22212)
Summary:
* Deletes all weak script decorators / associated data structures / methods
   * In order to keep supporting the standard library in script, this enables recursive script on any function defined in `torch.nn`
   * Most changes in `torch/nn` are the result of `ag -Q "weak" torch/nn/ -l | xargs sed -i '/weak/d'`, only `rnn.py` needed manual editing to use the `ignore` and `export` to continue supporting the overloaded `forward` methods
* `Sequential`/`ModuleList` no longer need to be added to constants since they are compiled on demand

This should also fix https://github.com/pytorch/pytorch/issues/22212
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22212

Differential Revision: D15988346

Pulled By: driazati

fbshipit-source-id: af223e3ad0580be895377312949997a70e988e4f
2019-07-03 17:28:25 -07:00

139 lines
5.3 KiB
Python

from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from ...modules.module import Module
from ...modules.linear import Linear as NNLinear
class Quantize(Module):
r"""Quantizes an incoming tensor
Args:
`out_scale`: scale of the output Quantized Tensor
`out_zero_point`: zero_point of output Quantized Tensor
`out_dtype`: data type of output Quantized Tensor
Attributes:
`out_scale`, `out_zero_point`, `out_dtype`
Examples::
>>> t = torch.tensor([[1., -1.], [1., -1.]])
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qm = Quantize(scale, zero_point, dtype)
>>> qt = qm(t)
>>> print(qt)
>>> tensor([[ 1., -1.],
> [ 1., -1.]], size=(2, 2), dtype=torch.qint8, scale=1.0, zero_point=2)
"""
def __init__(self, out_scale, out_zero_point, out_dtype):
super(Quantize, self).__init__()
self.register_buffer('out_scale', torch.tensor([out_scale]))
self.register_buffer('out_zero_point', torch.tensor([out_zero_point], dtype=torch.long))
self.out_dtype = out_dtype
def forward(self, X):
return torch.quantize_linear(X, self.out_scale.item(),
self.out_zero_point.item(), self.out_dtype)
@staticmethod
def from_float(mod):
return Quantize(mod.qparams[0].item(), mod.qparams[1].item(), torch.quint8)
class DeQuantize(Module):
r"""Dequantizes an incoming tensor
Examples::
>>> input = torch.tensor([[1., -1.], [1., -1.]])
>>> scale, zero_point, dtype = 1.0, 2, torch.qint8
>>> qm = Quantize(scale, zero_point, dtype)
>>> quantized_input = qm(input)
>>> dqm = DeQuantize()
>>> dequantized = dqm(quantized_input)
>>> print(dequantized)
>>> tensor([[ 1., -1.],
[ 1., -1.]], dtype=torch.float32)
"""
def __init__(self):
super(DeQuantize, self).__init__()
def forward(self, Xq):
return Xq.dequantize()
@staticmethod
def from_float(mod):
return DeQuantize()
class Linear(NNLinear):
r"""
A quantized linear module with quantized tensor as inputs
and outputs.
We adopt the same interface as `torch.nn.Linear`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
for documentation.
Similar to `torch.nn.Linear`, attributes will be randomly initialized at
module creation time and will be overwritten later
Attributes:
weight: the non-learnable quantized weights of the
module which are of shape :math:`(\text{out\_features}, \text{in\_features})`.
bias: the non-learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized to zero.
out_scale: `scale` parameter of output Quantized Tensor, type: float
out_zero_point: `zero_point` parameter for output Quantized Tensor, type: long
Examples::
>>> m = nn.quantized.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
__constants__ = ['bias', 'in_features', 'out_features']
def __init__(self, in_features, out_features, bias=True):
assert bias, 'nobias is not supported in Quantized Linear module yet'
super(Linear, self).__init__(in_features, out_features, bias)
del self.weight
del self.bias
qweight = torch._empty_affine_quantized(
[out_features, in_features], scale=1, zero_point=0,
dtype=torch.qint8)
qbias = torch._empty_affine_quantized(
[out_features], scale=1, zero_point=0, dtype=torch.qint32)
self.register_buffer('_packed_weight',
torch.ops.quantized.fbgemm_linear_prepack(qweight))
self.register_buffer('bias', qbias)
self.register_buffer('out_scale', torch.Tensor([1]))
self.register_buffer('out_zero_point', torch.Tensor([0]))
@property
def weight(self):
return torch.ops.quantized.fbgemm_linear_unpack(self._packed_weight)
@weight.setter
def weight(self, w):
self._packed_weight = torch.ops.quantized.fbgemm_linear_prepack(w)
def forward(self, x):
Y_q = torch.ops.quantized.fbgemm_linear(
x, self._packed_weight,
self.bias,
self.out_scale,
self.out_zero_point)
return Y_q
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'weight'] = torch.ops.quantized.fbgemm_linear_unpack(destination[prefix + '_packed_weight'])
destination.pop(prefix + '_packed_weight')
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
self._packed_weight = torch.ops.quantized.fbgemm_linear_prepack(state_dict[prefix + 'weight'])
self.bias.copy_(state_dict[prefix + 'bias'])
state_dict.pop(prefix + 'weight')
state_dict.pop(prefix + 'bias')
super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
return