mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26709 Polishes implementation from #25975. Primarily, we use NoopObserver to communicate that weights need to be quantized to float16. The very top-level API (quantize_dynamic) stays the same with `dtype` argument but the implementation follows the common flow. One can argue that dynamic fp16 quantization doesn't really fit into the 'observer' mechanism. It's in fact not ideal, but it's better to have the same flow than branching on both dtype and qconfig. Test Plan: Imported from OSS Differential Revision: D17544103 Pulled By: dzhulgakov fbshipit-source-id: 6af3f18c35929a1a53ea734079c005f656e4925f
373 lines
16 KiB
Python
373 lines
16 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor # noqa: F401
|
|
from torch.nn import _VF
|
|
from torch._jit_internal import Tuple, Optional, List # noqa: F401
|
|
from torch.nn.utils.rnn import PackedSequence
|
|
import numbers
|
|
|
|
|
|
def apply_permutation(tensor, permutation, dim=1):
|
|
# type: (Tensor, Tensor, int) -> Tensor
|
|
return tensor.index_select(dim, permutation)
|
|
|
|
|
|
class RNNBase(torch.nn.Module):
|
|
|
|
_FLOAT_MODULE = nn.RNNBase
|
|
|
|
def __init__(self, mode, input_size, hidden_size,
|
|
num_layers=1, bias=True, batch_first=False,
|
|
dropout=0., bidirectional=False, dtype=torch.qint8):
|
|
super(RNNBase, self).__init__()
|
|
|
|
self.mode = mode
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.bias = bias
|
|
self.batch_first = batch_first
|
|
self.dropout = float(dropout)
|
|
self.bidirectional = bidirectional
|
|
self.dtype = dtype
|
|
num_directions = 2 if bidirectional else 1
|
|
|
|
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
|
|
isinstance(dropout, bool):
|
|
raise ValueError("dropout should be a number in range [0, 1] "
|
|
"representing the probability of an element being "
|
|
"zeroed")
|
|
if dropout > 0 and num_layers == 1:
|
|
warnings.warn("dropout option adds dropout after all but last "
|
|
"recurrent layer, so non-zero dropout expects "
|
|
"num_layers greater than 1, but got dropout={} and "
|
|
"num_layers={}".format(dropout, num_layers))
|
|
|
|
if mode == 'LSTM':
|
|
gate_size = 4 * hidden_size
|
|
else:
|
|
raise ValueError("Unrecognized RNN mode: " + mode)
|
|
|
|
self._all_weight_names = []
|
|
self._all_weight_values = []
|
|
for layer in range(num_layers):
|
|
for direction in range(num_directions):
|
|
layer_input_size = input_size if layer == 0 else hidden_size * num_directions
|
|
|
|
def process_weights(ihhh, layer, suffix, qweight, bias, dtype):
|
|
if dtype == torch.qint8:
|
|
# for each layer, for each direction we need to quantize and pack
|
|
# weights and pack parameters in this order:
|
|
#
|
|
# w_ih, w_hh
|
|
packed_weight = \
|
|
torch.ops.quantized.linear_prepack(qweight, bias)
|
|
|
|
params = [packed_weight]
|
|
pos_names = ['w']
|
|
ret_name = ['{}_{}_l{}{}'.format(
|
|
name, ihhh, layer, suffix) for name in pos_names]
|
|
return params, ret_name
|
|
else:
|
|
# for each layer, for each direction we need to quantize and pack
|
|
# weights and pack parameters in this order:
|
|
#
|
|
# packed_ih, packed_hh, b_ih, b_hh
|
|
packed_weight = torch.fbgemm_pack_gemm_matrix_fp16(
|
|
qweight)
|
|
|
|
params = [packed_weight, bias]
|
|
pos_names = ['packed', 'b']
|
|
ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names]
|
|
return params, ret_name
|
|
|
|
if dtype == torch.qint8:
|
|
w_ih = torch._empty_affine_quantized(
|
|
[gate_size, layer_input_size], scale=1, zero_point=0, dtype=torch.qint8)
|
|
w_hh = torch._empty_affine_quantized(
|
|
[gate_size, hidden_size], scale=1, zero_point=0, dtype=torch.qint8)
|
|
b_ih = torch._empty_affine_quantized(
|
|
[gate_size], scale=1, zero_point=0, dtype=torch.qint32)
|
|
# Second bias vector included for CuDNN compatibility. Only one
|
|
# bias vector is needed in standard definition.
|
|
b_hh = torch._empty_affine_quantized(
|
|
[gate_size], scale=1, zero_point=0, dtype=torch.qint32)
|
|
|
|
else:
|
|
w_ih = torch.Tensor(gate_size, layer_input_size).float()
|
|
w_hh = torch.Tensor(gate_size, hidden_size).float()
|
|
b_ih = torch.Tensor(gate_size).float()
|
|
# Second bias vector included for CuDNN compatibility. Only one
|
|
# bias vector is needed in standard definition.
|
|
b_hh = torch.Tensor(gate_size).float()
|
|
|
|
suffix = '_reverse' if direction == 1 else ''
|
|
ih_params, ih_param_names = process_weights(
|
|
'ih', layer, suffix, w_ih, b_ih, dtype)
|
|
hh_params, hh_param_names = process_weights(
|
|
'hh', layer, suffix, w_hh, b_hh, dtype)
|
|
|
|
for (ih, ih_name), (hh, hh_name) in zip(zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)):
|
|
self._all_weight_names.extend([ih_name, hh_name])
|
|
self._all_weight_values.extend([ih, hh])
|
|
|
|
def check_input(self, input, batch_sizes):
|
|
# type: (Tensor, Optional[Tensor]) -> None
|
|
expected_input_dim = 2 if batch_sizes is not None else 3
|
|
if input.dim() != expected_input_dim:
|
|
raise RuntimeError(
|
|
'input must have {} dimensions, got {}'.format(
|
|
expected_input_dim, input.dim()))
|
|
if self.input_size != input.size(-1):
|
|
raise RuntimeError(
|
|
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
|
|
self.input_size, input.size(-1)))
|
|
|
|
def get_expected_hidden_size(self, input, batch_sizes):
|
|
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
|
|
if batch_sizes is not None:
|
|
mini_batch = batch_sizes[0]
|
|
mini_batch = int(mini_batch)
|
|
else:
|
|
mini_batch = input.size(0) if self.batch_first else input.size(1)
|
|
num_directions = 2 if self.bidirectional else 1
|
|
expected_hidden_size = (self.num_layers * num_directions,
|
|
mini_batch, self.hidden_size)
|
|
return expected_hidden_size
|
|
|
|
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
|
|
# type: (Tensor, Tuple[int, int, int], str) -> None
|
|
if hx.size() != expected_hidden_size:
|
|
raise RuntimeError(msg.format(
|
|
expected_hidden_size, tuple(hx.size())))
|
|
|
|
def check_forward_args(self, input, hidden, batch_sizes):
|
|
# type: (Tensor, Tensor, Optional[Tensor]) -> None
|
|
self.check_input(input, batch_sizes)
|
|
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
|
self.check_hidden_size(hidden, expected_hidden_size,
|
|
msg='Expected hidden size {}, got {}')
|
|
|
|
def permute_hidden(self, hx, permutation):
|
|
# type: (Tensor, Optional[Tensor]) -> Tensor
|
|
if permutation is None:
|
|
return hx
|
|
return apply_permutation(hx, permutation)
|
|
|
|
@torch.jit.export
|
|
def __getstate__(self):
|
|
vals = (
|
|
self.mode,
|
|
self.input_size,
|
|
self.hidden_size,
|
|
self.num_layers,
|
|
self.bias,
|
|
self.batch_first,
|
|
self.dropout,
|
|
self.bidirectional,
|
|
self._all_weight_names,
|
|
self.__overloads__,
|
|
self.training,
|
|
self.dtype,
|
|
)
|
|
|
|
dynamic_vals = torch.jit.annotate(List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
|
|
[])
|
|
|
|
for i in range(len(self._all_weight_names)):
|
|
dynamic_vals.append(torch.ops.quantized.linear_unpack(self._all_weight_values[i]))
|
|
return vals, dynamic_vals
|
|
|
|
@torch.jit.export
|
|
def __setstate__(self, state):
|
|
vals, dynamic_vals = state
|
|
self.mode = vals[0]
|
|
self.input_size = vals[1]
|
|
self.hidden_size = vals[2]
|
|
self.num_layers = vals[3]
|
|
self.bias = vals[4]
|
|
self.batch_first = vals[5]
|
|
self.dropout = vals[6]
|
|
self.bidirectional = vals[7]
|
|
self._all_weight_names = vals[8]
|
|
self.__overloads__ = vals[9]
|
|
self.training = vals[10]
|
|
self.dtype = vals[11]
|
|
|
|
self._all_weight_values = []
|
|
for i in range(len(self._all_weight_names)):
|
|
self._all_weight_values.append(torch.ops.quantized.linear_prepack(*dynamic_vals[i]))
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
assert type(mod) == torch.nn.LSTM, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM'
|
|
assert hasattr(
|
|
mod, 'qconfig'), 'Input float module must have qconfig defined'
|
|
|
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
|
weight_observer = mod.qconfig.weight()
|
|
else:
|
|
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
|
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
|
# import until we need it.
|
|
from torch.quantization.QConfig import default_dynamic_qconfig
|
|
weight_observer = default_dynamic_qconfig.weight()
|
|
|
|
dtype = weight_observer.dtype
|
|
supported_scalar_types = [torch.qint8, torch.float16]
|
|
if dtype not in supported_scalar_types:
|
|
raise RuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype))
|
|
|
|
if mod.mode == 'LSTM':
|
|
qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers,
|
|
mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype)
|
|
else:
|
|
raise NotImplementedError('Only LSTM is supported for QuantizedRNN for now')
|
|
|
|
num_directions = 2 if mod.bidirectional else 1
|
|
|
|
assert mod.bias
|
|
|
|
qRNNBase._all_weight_names = []
|
|
qRNNBase._all_weight_values = []
|
|
for layer in range(qRNNBase.num_layers):
|
|
for direction in range(num_directions):
|
|
layer_input_size = qRNNBase.input_size if layer == 0 else qRNNBase.hidden_size * num_directions
|
|
|
|
def process_weights(ihhh, layer, suffix, dtype):
|
|
weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix)
|
|
bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix)
|
|
|
|
weight = getattr(mod, weight_name)
|
|
bias = getattr(mod, bias_name)
|
|
|
|
if dtype == torch.qint8:
|
|
# for each layer, for each direction we need to quantize and pack
|
|
# weights and pack parameters in this order:
|
|
#
|
|
# w_ih, w_hh
|
|
weight_observer(weight)
|
|
wt_scale, wt_zp = weight_observer.calculate_qparams()
|
|
qweight = torch.quantize_per_tensor(
|
|
weight.float(), float(wt_scale), int(wt_zp), torch.qint8)
|
|
packed_weight = \
|
|
torch.ops.quantized.linear_prepack(qweight, bias)
|
|
|
|
params = [packed_weight]
|
|
pos_names = ['w']
|
|
ret_name = ['{}_{}_l{}{}'.format(
|
|
name, ihhh, layer, suffix) for name in pos_names]
|
|
return params, ret_name
|
|
else:
|
|
# for each layer, for each direction we need to quantize and pack
|
|
# weights and pack parameters in this order:
|
|
#
|
|
# packed_ih, packed_hh, b_ih, b_hh
|
|
packed_weight = torch.fbgemm_pack_gemm_matrix_fp16(
|
|
weight.float())
|
|
|
|
params = [packed_weight, bias]
|
|
pos_names = ['packed', 'b']
|
|
ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names]
|
|
return params, ret_name
|
|
|
|
suffix = '_reverse' if direction == 1 else ''
|
|
ih_params, ih_param_names = process_weights('ih', layer, suffix, dtype)
|
|
hh_params, hh_param_names = process_weights('hh', layer, suffix, dtype)
|
|
|
|
for (ih, ih_name), (hh, hh_name) in zip(zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)):
|
|
qRNNBase._all_weight_names.extend([ih_name, hh_name])
|
|
qRNNBase._all_weight_values.extend([ih, hh])
|
|
|
|
return qRNNBase
|
|
|
|
|
|
class LSTM(RNNBase):
|
|
|
|
_FLOAT_MODULE = nn.LSTM
|
|
|
|
__overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(LSTM, self).__init__('LSTM', *args, **kwargs)
|
|
|
|
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
|
|
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
|
|
if hx is None:
|
|
num_directions = 2 if self.bidirectional else 1
|
|
zeros = torch.zeros(self.num_layers * num_directions,
|
|
max_batch_size, self.hidden_size,
|
|
dtype=input.dtype, device=input.device)
|
|
hx = (zeros, zeros)
|
|
else:
|
|
# Each batch of the hidden state should match the input sequence that
|
|
# the user believes he/she is passing in.
|
|
hx = self.permute_hidden(hx, sorted_indices)
|
|
|
|
self.check_forward_args(input, hx, batch_sizes)
|
|
assert batch_sizes is None
|
|
|
|
result = _VF.quantized_lstm(input, hx, self._all_weight_values, self.bias, self.num_layers,
|
|
float(self.dropout), self.training, self.bidirectional,
|
|
self.batch_first, dtype=self.dtype, use_dynamic=True)
|
|
output = result[0]
|
|
hidden = result[1:]
|
|
|
|
return output, hidden
|
|
|
|
@torch.jit.export
|
|
def forward_tensor(self, input, hx=None):
|
|
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
|
batch_sizes = None
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
sorted_indices = None
|
|
unsorted_indices = None
|
|
|
|
output, hidden = self.forward_impl(
|
|
input, hx, batch_sizes, max_batch_size, sorted_indices)
|
|
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
@torch.jit.export
|
|
def forward_packed(self, input, hx=None):
|
|
# type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
|
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
|
max_batch_size = batch_sizes[0]
|
|
max_batch_size = int(max_batch_size)
|
|
|
|
output, hidden = self.forward_impl(
|
|
input, hx, batch_sizes, max_batch_size, sorted_indices)
|
|
|
|
output = PackedSequence(output, batch_sizes,
|
|
sorted_indices, unsorted_indices)
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
def permute_hidden(self, hx, permutation):
|
|
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
|
|
if permutation is None:
|
|
return hx
|
|
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
|
|
|
|
def check_forward_args(self, input, hidden, batch_sizes):
|
|
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor])->None
|
|
self.check_input(input, batch_sizes)
|
|
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
|
|
|
self.check_hidden_size(hidden[0], expected_hidden_size,
|
|
'Expected hidden[0] size {}, got {}')
|
|
self.check_hidden_size(hidden[1], expected_hidden_size,
|
|
'Expected hidden[1] size {}, got {}')
|
|
|
|
@torch.jit.ignore
|
|
def forward(self, input, hx=None):
|
|
if isinstance(input, PackedSequence):
|
|
return self.forward_packed(input, hx)
|
|
else:
|
|
return self.forward_tensor(input, hx)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod):
|
|
return super(LSTM, cls).from_float(mod)
|