mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/27399 This was devised in a time when we didn't have module attributes. They are essentially just tensor lists, so represent them that way. This has the additional benefit of making the RNN forward pass faster because we effectively cache the flattened weights. The only complication part is that someone may come along and do: ``` my_rnn_mod.w_ih_l0 = torch.nn.Parameter(...) ``` This means we need to override setattr to keep the flattened weights cache up to date. Test Plan: Imported from OSS Differential Revision: D17785658 Pulled By: suo fbshipit-source-id: 7789cd1d0d4922bfd5eba1716976442fbf150766
644 lines
28 KiB
Python
644 lines
28 KiB
Python
import torch
|
|
|
|
from torch._jit_internal import Tuple, Optional, List # noqa: F401
|
|
|
|
from torch import Tensor # noqa: F401
|
|
from torch.nn import _VF
|
|
|
|
from torch.nn.utils.rnn import PackedSequence
|
|
|
|
class QuantizedLinear(torch.jit.ScriptModule):
|
|
__constants__ = ['scale', 'zero_point']
|
|
|
|
def __init__(self, other):
|
|
super(QuantizedLinear, self).__init__()
|
|
self.in_features = other.in_features
|
|
self.out_features = other.out_features
|
|
# Quantize weight and discard the original
|
|
self.weight, self.col_offsets, self.scale, self.zero_point = torch.fbgemm_linear_quantize_weight(
|
|
other.weight.clone().float())
|
|
self.weight = torch.nn.Parameter(self.weight, requires_grad=False)
|
|
self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False)
|
|
assert other.bias is not None, 'QuantizedLinear requires a bias'
|
|
self.bias = torch.nn.Parameter(other.bias.clone().float(), requires_grad=False)
|
|
|
|
self.register_buffer(
|
|
'packed_tensor_ptr',
|
|
torch.fbgemm_pack_quantized_matrix(self.weight.clone()))
|
|
|
|
@torch.jit.script_method
|
|
def _unpack(self):
|
|
self.packed_tensor_ptr.set_(
|
|
torch.fbgemm_pack_quantized_matrix(self.weight))
|
|
|
|
@torch.jit.script_method
|
|
def _pack(self):
|
|
self.packed_tensor_ptr.set_(
|
|
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
out = torch.fbgemm_linear_int8_weight_fp32_activation(
|
|
input.float(), self.weight, self.packed_tensor_ptr, self.col_offsets,
|
|
self.scale, self.zero_point, self.bias)
|
|
return out.to(input.dtype)
|
|
|
|
def extra_repr(self):
|
|
repr = 'in_features={in_features}, out_features={out_features}, ' \
|
|
'scale={scale}, zero_point={zero_point}'.format(**self.__dict__)
|
|
return repr
|
|
|
|
# FP16 weights
|
|
class QuantizedLinearFP16(torch.jit.ScriptModule):
|
|
|
|
def __init__(self, other):
|
|
super(QuantizedLinearFP16, self).__init__()
|
|
self.in_features = other.in_features
|
|
self.out_features = other.out_features
|
|
self.original_weight = other.weight
|
|
self.weight = torch.fbgemm_pack_gemm_matrix_fp16(
|
|
other.weight.clone().float())
|
|
assert other.bias is not None, 'QuantizedLinearFP16 requires a bias'
|
|
self.bias = torch.nn.Parameter(other.bias.clone().float(), requires_grad=False)
|
|
self.register_buffer('packed_weight', self.weight)
|
|
|
|
@torch.jit.script_method
|
|
def _unpack(self):
|
|
self.packed_weight.set_(
|
|
torch.fbgemm_pack_gemm_matrix_fp16(
|
|
self.original_weight))
|
|
|
|
@torch.jit.script_method
|
|
def _pack(self):
|
|
self.packed_weight.set_(
|
|
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
out = torch.fbgemm_linear_fp16_weight_fp32_activation(
|
|
input.float(), self.packed_weight, self.bias)
|
|
return out
|
|
|
|
def extra_repr(self):
|
|
repr = 'in_features={in_features}, out_features={out_features}, '.format(**self.__dict__)
|
|
return repr
|
|
|
|
# Quantized RNN cell implementations
|
|
class QuantizedRNNCellBase(torch.jit.ScriptModule):
|
|
__constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
|
|
'zero_point_ih', 'zero_point_hh']
|
|
|
|
def __init__(self, other):
|
|
super(QuantizedRNNCellBase, self).__init__()
|
|
self.input_size = other.input_size
|
|
self.hidden_size = other.hidden_size
|
|
self.bias = other.bias
|
|
if not self.bias:
|
|
raise ValueError("Quantized RNN cells require bias terms")
|
|
|
|
weight_ih, col_offsets_ih, self.scale_ih, self.zero_point_ih = \
|
|
torch.fbgemm_linear_quantize_weight(other.weight_ih.clone().float())
|
|
self.register_buffer('weight_ih', weight_ih)
|
|
self.register_buffer('col_offsets_ih', col_offsets_ih)
|
|
weight_hh, col_offsets_hh, self.scale_hh, self.zero_point_hh = \
|
|
torch.fbgemm_linear_quantize_weight(other.weight_hh.clone().float())
|
|
self.register_buffer('weight_hh', weight_hh)
|
|
self.register_buffer('col_offsets_hh', col_offsets_hh)
|
|
|
|
packed_ih = torch.fbgemm_pack_quantized_matrix(self.weight_ih)
|
|
self.register_buffer('packed_ih', packed_ih)
|
|
packed_hh = torch.fbgemm_pack_quantized_matrix(self.weight_hh)
|
|
self.register_buffer('packed_hh', packed_hh)
|
|
|
|
self.bias_ih = torch.nn.Parameter(other.bias_ih.clone().float(), requires_grad=False)
|
|
self.bias_hh = torch.nn.Parameter(other.bias_hh.clone().float(), requires_grad=False)
|
|
|
|
def extra_repr(self):
|
|
s = '{input_size}, {hidden_size}'
|
|
if 'bias' in self.__dict__ and self.bias is not True:
|
|
s += ', bias={bias}'
|
|
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
|
|
s += ', nonlinearity={nonlinearity}'
|
|
return s.format(**self.__dict__)
|
|
|
|
@torch.jit.script_method
|
|
def check_forward_input(self, input):
|
|
if input.size(1) != self.input_size:
|
|
raise RuntimeError(
|
|
"input has inconsistent input_size: got {}, expected {}".format(
|
|
input.size(1), self.input_size))
|
|
|
|
@torch.jit.script_method
|
|
def check_forward_hidden(self, input, hx, hidden_label=''):
|
|
# type: (Tensor, Tensor, str) -> None
|
|
if input.size(0) != hx.size(0):
|
|
raise RuntimeError(
|
|
"Input batch size {} doesn't match hidden{} batch size {}".format(
|
|
input.size(0), hidden_label, hx.size(0)))
|
|
|
|
if hx.size(1) != self.hidden_size:
|
|
raise RuntimeError(
|
|
"hidden{} has inconsistent hidden_size: got {}, expected {}".format(
|
|
hidden_label, hx.size(1), self.hidden_size))
|
|
|
|
# TODO: for some reason weak_script_method causes a destruction of the
|
|
# module to occur, which in turn frees the packed_ih object via its DataPtr
|
|
# deleter. This is bizarre and should probably get fixed.
|
|
# @torch._jit_internal.weak_script_method
|
|
@torch.jit.script_method
|
|
def _unpack(self):
|
|
self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(self.weight_ih))
|
|
self.packed_hh.set_(torch.fbgemm_pack_quantized_matrix(self.weight_hh))
|
|
|
|
# @torch._jit_internal.weak_script_method
|
|
@torch.jit.script_method
|
|
def _pack(self):
|
|
self.packed_ih.set_(
|
|
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
|
|
self.packed_hh.set_(
|
|
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
|
|
|
|
|
|
class QuantizedRNNCell(QuantizedRNNCellBase):
|
|
__constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
|
|
'zero_point_ih', 'zero_point_hh', 'nonlinearity']
|
|
|
|
def __init__(self, other):
|
|
super(QuantizedRNNCell, self).__init__(other)
|
|
self.nonlinearity = other.nonlinearity
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input, hx=None):
|
|
# type: (Tensor, Optional[Tensor]) -> Tensor
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
|
self.check_forward_hidden(input, hx, '')
|
|
if self.nonlinearity == "tanh":
|
|
ret = _VF.quantized_rnn_tanh_cell(
|
|
input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
|
|
self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
|
|
self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
|
|
self.zero_point_hh
|
|
)
|
|
elif self.nonlinearity == "relu":
|
|
ret = _VF.quantized_rnn_relu_cell(
|
|
input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
|
|
self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
|
|
self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
|
|
self.zero_point_hh
|
|
)
|
|
else:
|
|
ret = input # TODO: remove when jit supports exception flow
|
|
raise RuntimeError(
|
|
"Unknown nonlinearity: {}".format(self.nonlinearity))
|
|
return ret
|
|
|
|
|
|
class QuantizedLSTMCell(QuantizedRNNCellBase):
|
|
def __init__(self, other):
|
|
super(QuantizedLSTMCell, self).__init__(other)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input, hx=None):
|
|
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
|
hx = (zeros, zeros)
|
|
self.check_forward_hidden(input, hx[0], '[0]')
|
|
self.check_forward_hidden(input, hx[1], '[1]')
|
|
return _VF.quantized_lstm_cell(
|
|
input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
|
|
self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
|
|
self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
|
|
self.zero_point_hh
|
|
)
|
|
|
|
|
|
class QuantizedGRUCell(QuantizedRNNCellBase):
|
|
def __init__(self, other):
|
|
super(QuantizedGRUCell, self).__init__(other)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input, hx=None):
|
|
# type: (Tensor, Optional[Tensor]) -> Tensor
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
|
self.check_forward_hidden(input, hx, '')
|
|
return _VF.quantized_gru_cell(
|
|
input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
|
|
self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
|
|
self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
|
|
self.zero_point_hh
|
|
)
|
|
|
|
|
|
@torch.jit.script
|
|
def apply_permutation(tensor, permutation, dim=1):
|
|
# type: (Tensor, Tensor, int) -> Tensor
|
|
return tensor.index_select(dim, permutation)
|
|
|
|
class QuantizedRNNBase(torch.jit.ScriptModule):
|
|
__constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
|
|
'batch_first', 'dropout', 'bidirectional', 'dtype']
|
|
|
|
def __init__(self, other, dtype=torch.int8):
|
|
super(QuantizedRNNBase, self).__init__()
|
|
self.mode = other.mode
|
|
self.input_size = other.input_size
|
|
self.hidden_size = other.hidden_size
|
|
self.num_layers = other.num_layers
|
|
self.bias = other.bias
|
|
self.batch_first = other.batch_first
|
|
if self.mode != 'GRU':
|
|
assert not self.batch_first
|
|
self.dropout = other.dropout
|
|
self.bidirectional = other.bidirectional
|
|
num_directions = 2 if self.bidirectional else 1
|
|
self.dtype = dtype
|
|
|
|
assert self.bias
|
|
|
|
# TODO: support more than just LSTM
|
|
if self.mode != 'LSTM' and self.mode != 'GRU':
|
|
raise RuntimeError('Only LSTM or GRU is supported for QuantizedRNN')
|
|
|
|
if dtype != torch.int8 and dtype != torch.float16:
|
|
raise RuntimeError('Unsupported dtype: {}'.format(dtype))
|
|
|
|
self._all_weights_names = []
|
|
self._packed_weights_names = []
|
|
self._quantized_weights_names = []
|
|
self._orig_weights_names = []
|
|
for layer in range(self.num_layers):
|
|
for direction in range(num_directions):
|
|
layer_input_size = self.input_size if layer == 0 else self.hidden_size * num_directions
|
|
|
|
def process_weights(ihhh, layer, suffix, dtype):
|
|
weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix)
|
|
bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix)
|
|
|
|
weight = getattr(other, weight_name)
|
|
bias = getattr(other, bias_name)
|
|
|
|
if dtype == torch.int8:
|
|
# for each layer, for each direction we need to quantize and pack
|
|
# weights and pack parameters in this order:
|
|
#
|
|
# w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
|
# col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh
|
|
qweight, col_offsets, scale, zero_point = \
|
|
torch.fbgemm_linear_quantize_weight(weight.clone().float())
|
|
packed_weight = torch.fbgemm_pack_quantized_matrix(qweight)
|
|
|
|
params = [qweight, bias, packed_weight, col_offsets, scale, zero_point]
|
|
pos_names = ['w', 'b', 'packed', 'col_offsets', 'scale', 'zero_point']
|
|
ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names]
|
|
self._quantized_weights_names.append(ret_name[0])
|
|
self._packed_weights_names.append(ret_name[2])
|
|
return params, ret_name
|
|
else:
|
|
# for each layer, for each direction we need to quantize and pack
|
|
# weights and pack parameters in this order:
|
|
#
|
|
# packed_ih, packed_hh, b_ih, b_hh
|
|
packed_weight = torch.fbgemm_pack_gemm_matrix_fp16(
|
|
weight.clone().float())
|
|
|
|
self._orig_weights_names.append(weight_name)
|
|
self.register_buffer(weight_name, weight)
|
|
params = [packed_weight, bias]
|
|
pos_names = ['packed', 'b']
|
|
ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names]
|
|
self._packed_weights_names.append(ret_name[0])
|
|
self._quantized_weights_names.append(ret_name[0])
|
|
return params, ret_name
|
|
|
|
suffix = '_reverse' if direction == 1 else ''
|
|
ih_params, ih_param_names = process_weights('ih', layer, suffix, dtype)
|
|
hh_params, hh_param_names = process_weights('hh', layer, suffix, dtype)
|
|
|
|
for (ih, ih_name), (hh, hh_name) in zip(zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)):
|
|
self.register_buffer(ih_name, torch.tensor(ih) if not isinstance(ih, torch.Tensor) else ih)
|
|
self.register_buffer(hh_name, torch.tensor(hh) if not isinstance(hh, torch.Tensor) else hh)
|
|
self._all_weights_names.extend([ih_name, hh_name])
|
|
|
|
# For int8 quantization, _orig_weights is not needed in the quantization logic,
|
|
# however there is a JIT compilation error without it. This is just used to
|
|
# workaround that error.
|
|
if dtype == torch.int8:
|
|
self._orig_weights_names = self._packed_weights_names
|
|
|
|
self._packed_weights = torch.jit.Attribute(
|
|
[getattr(self, weight) for weight in self._packed_weights_names],
|
|
List[Tensor],
|
|
)
|
|
self._quantized_weights = torch.jit.Attribute(
|
|
[getattr(self, weight) for weight in self._quantized_weights_names],
|
|
List[Tensor],
|
|
)
|
|
self._orig_weights = torch.jit.Attribute(
|
|
[getattr(self, weight) for weight in self._orig_weights_names], List[Tensor]
|
|
)
|
|
# this one is public
|
|
self.all_weights = torch.jit.Attribute(
|
|
[getattr(self, weight) for weight in self._all_weights_names], List[Tensor]
|
|
)
|
|
|
|
@torch.jit.script_method
|
|
def check_input(self, input, batch_sizes):
|
|
# type: (Tensor, Optional[Tensor]) -> None
|
|
expected_input_dim = 2 if batch_sizes is not None else 3
|
|
if input.dim() != expected_input_dim:
|
|
raise RuntimeError(
|
|
'input must have {} dimensions, got {}'.format(
|
|
expected_input_dim, input.dim()))
|
|
if self.input_size != input.size(-1):
|
|
raise RuntimeError(
|
|
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
|
|
self.input_size, input.size(-1)))
|
|
|
|
@torch.jit.script_method
|
|
def get_expected_hidden_size(self, input, batch_sizes):
|
|
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
|
|
if batch_sizes is not None:
|
|
mini_batch = batch_sizes[0]
|
|
mini_batch = int(mini_batch)
|
|
else:
|
|
mini_batch = input.size(0) if self.batch_first else input.size(1)
|
|
num_directions = 2 if self.bidirectional else 1
|
|
expected_hidden_size = (self.num_layers * num_directions,
|
|
mini_batch, self.hidden_size)
|
|
return expected_hidden_size
|
|
|
|
@torch.jit.script_method
|
|
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
|
|
# type: (Tensor, Tuple[int, int, int], str) -> None
|
|
if hx.size() != expected_hidden_size:
|
|
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
|
|
|
|
@torch.jit.script_method
|
|
def check_forward_args(self, input, hidden, batch_sizes):
|
|
# type: (Tensor, Tensor, Optional[Tensor]) -> None
|
|
self.check_input(input, batch_sizes)
|
|
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
|
self.check_hidden_size(hidden, expected_hidden_size, msg='Expected hidden size {}, got {}')
|
|
|
|
@torch.jit.script_method
|
|
def permute_hidden(self, hx, permutation):
|
|
# type: (Tensor, Optional[Tensor]) -> Tensor
|
|
if permutation is None:
|
|
return hx
|
|
return apply_permutation(hx, permutation)
|
|
|
|
def __setattr__(self, attr, value):
|
|
if hasattr(self, "_orig_weight_names"):
|
|
# keep weight attributes up to date if you do self.weight = ...
|
|
if attr in self._all_weights_names:
|
|
idx = self._all_weights_names.index(attr)
|
|
self.all_weights[idx] = value
|
|
elif attr in self._packed_weights_names:
|
|
idx = self._packed_weights_names.index(attr)
|
|
self._packed_weights[idx] = value
|
|
elif attr in self._orig_weights_names:
|
|
idx = self._orig_weights_names.index(attr)
|
|
self._orig_weights[idx] = value
|
|
elif attr in self._quantized_weights_names:
|
|
idx = self._quantized_weights_names.index(attr)
|
|
self._quantized_weights[idx] = value
|
|
|
|
return super(QuantizedRNNBase, self).__setattr__(attr, value)
|
|
|
|
# TODO: for some reason torch.jit.script_method causes a destruction of the
|
|
# module to occur, which in turn frees the packed_ih object via its DataPtr
|
|
# deleter. This is bizarre and should probably get fixed.
|
|
# @torch._jit_internal.torch.jit.script_method
|
|
@torch.jit.script_method
|
|
def _unpack(self):
|
|
if self.dtype == torch.int8:
|
|
packed_weights = self._packed_weights
|
|
quantized_weights = self._quantized_weights
|
|
assert len(packed_weights) == len(quantized_weights)
|
|
for i in range(len(packed_weights)):
|
|
packed = packed_weights[i]
|
|
quantized = quantized_weights[i]
|
|
packed.set_(torch.fbgemm_pack_quantized_matrix(quantized))
|
|
else:
|
|
packed_weights = self._packed_weights
|
|
orig_weights = self._orig_weights
|
|
assert len(packed_weights) == len(orig_weights)
|
|
for i in range(len(packed_weights)):
|
|
packed = packed_weights[i]
|
|
orig_weight = orig_weights[i]
|
|
packed.set_(torch.fbgemm_pack_gemm_matrix_fp16(
|
|
orig_weight))
|
|
|
|
@torch.jit.script_method
|
|
def _pack(self):
|
|
for weight in self._packed_weights:
|
|
weight.set_(torch.zeros(torch.jit.annotate(List[int], []),
|
|
dtype=torch.uint8).detach())
|
|
|
|
|
|
class QuantizedLSTM(QuantizedRNNBase):
|
|
__overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
|
|
|
|
def __init__(self, other, dtype):
|
|
super(QuantizedLSTM, self).__init__(other, dtype)
|
|
|
|
@torch.jit.script_method
|
|
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
|
|
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
|
|
if hx is None:
|
|
num_directions = 2 if self.bidirectional else 1
|
|
zeros = torch.zeros(self.num_layers * num_directions,
|
|
max_batch_size, self.hidden_size,
|
|
dtype=input.dtype, device=input.device)
|
|
hx = (zeros, zeros)
|
|
else:
|
|
# Each batch of the hidden state should match the input sequence that
|
|
# the user believes he/she is passing in.
|
|
hx = self.permute_hidden(hx, sorted_indices)
|
|
|
|
self.check_forward_args(input, hx, batch_sizes)
|
|
assert batch_sizes is None
|
|
result = _VF.quantized_lstm(input, hx, self.all_weights, self.bias, self.num_layers,
|
|
float(self.dropout), self.training, self.bidirectional,
|
|
self.batch_first, dtype=self.dtype, use_dynamic=False)
|
|
output = result[0]
|
|
hidden = result[1:]
|
|
|
|
return output, hidden
|
|
|
|
@torch.jit.script_method
|
|
def forward_tensor(self, input, hx=None):
|
|
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
|
batch_sizes = None
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
sorted_indices = None
|
|
unsorted_indices = None
|
|
|
|
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
|
|
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
@torch.jit.script_method
|
|
def forward_packed(self, input, hx=None):
|
|
# type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
|
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
|
max_batch_size = batch_sizes[0]
|
|
max_batch_size = int(max_batch_size)
|
|
|
|
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
|
|
|
|
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
|
|
@torch.jit.script_method
|
|
def permute_hidden(self, hx, permutation):
|
|
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
|
|
if permutation is None:
|
|
return hx
|
|
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
|
|
|
|
@torch.jit.script_method
|
|
def check_forward_args(self, input, hidden, batch_sizes):
|
|
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None
|
|
self.check_input(input, batch_sizes)
|
|
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
|
|
|
self.check_hidden_size(hidden[0], expected_hidden_size,
|
|
'Expected hidden[0] size {}, got {}')
|
|
self.check_hidden_size(hidden[1], expected_hidden_size,
|
|
'Expected hidden[1] size {}, got {}')
|
|
|
|
def forward(self, input, hx=None):
|
|
if isinstance(input, PackedSequence):
|
|
return self.forward_packed(input, hx)
|
|
else:
|
|
return self.forward_tensor(input, hx)
|
|
|
|
|
|
class QuantizedGRU(QuantizedRNNBase):
|
|
__overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
|
|
|
|
@torch.jit.script_method
|
|
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
|
|
# type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa
|
|
if hx is None:
|
|
num_directions = 2 if self.bidirectional else 1
|
|
hx = torch.zeros(self.num_layers * num_directions,
|
|
max_batch_size, self.hidden_size,
|
|
dtype=input.dtype, device=input.device)
|
|
else:
|
|
# Each batch of the hidden state should match the input sequence that
|
|
# the user believes he/she is passing in.
|
|
hx = self.permute_hidden(hx, sorted_indices)
|
|
|
|
self.check_forward_args(input, hx, batch_sizes)
|
|
if batch_sizes is None:
|
|
result = _VF.quantized_gru(input, hx, self.all_weights, self.bias, self.num_layers,
|
|
float(self.dropout), self.training, self.bidirectional,
|
|
self.batch_first)
|
|
else:
|
|
result = _VF.quantized_gru(input, batch_sizes, hx, self.all_weights, self.bias, self.num_layers,
|
|
float(self.dropout), self.training, self.bidirectional)
|
|
|
|
output = result[0]
|
|
hidden = result[1]
|
|
|
|
return output, hidden
|
|
|
|
@torch.jit.script_method
|
|
def forward_tensor(self, input, hx=None):
|
|
# type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
|
|
batch_sizes = None
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
sorted_indices = None
|
|
unsorted_indices = None
|
|
|
|
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
@torch.jit.script_method
|
|
def forward_packed(self, input, hx=None):
|
|
# type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor]
|
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
|
max_batch_size = batch_sizes[0]
|
|
max_batch_size = int(max_batch_size)
|
|
|
|
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
|
|
|
|
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
def forward(self, input, hx=None):
|
|
if isinstance(input, PackedSequence):
|
|
return self.forward_packed(input, hx)
|
|
else:
|
|
return self.forward_tensor(input, hx)
|
|
|
|
|
|
def quantize_rnn_cell_modules(module):
|
|
reassign = {}
|
|
for name, mod in module.named_modules():
|
|
if mod is module:
|
|
continue
|
|
new_mod = quantize_rnn_cell_modules(mod)
|
|
if new_mod is not mod:
|
|
reassign[name] = new_mod
|
|
for name, mod in reassign.items():
|
|
setattr(module, name, mod)
|
|
if isinstance(module, torch.nn.LSTMCell):
|
|
return QuantizedLSTMCell(module)
|
|
if isinstance(module, torch.nn.GRUCell):
|
|
return QuantizedGRUCell(module)
|
|
if isinstance(module, torch.nn.RNNCell):
|
|
return QuantizedRNNCell(module)
|
|
return module
|
|
|
|
|
|
def quantize_linear_modules(module, dtype=torch.int8):
|
|
reassign = {}
|
|
for name, mod in module.named_modules():
|
|
if mod is module:
|
|
continue
|
|
new_mod = quantize_linear_modules(mod, dtype)
|
|
if new_mod is not mod:
|
|
reassign[name] = new_mod
|
|
|
|
for name, mod in reassign.items():
|
|
setattr(module, name, mod)
|
|
if isinstance(module, torch.nn.Linear):
|
|
if dtype == torch.int8:
|
|
return QuantizedLinear(module)
|
|
elif dtype == torch.float16:
|
|
return QuantizedLinearFP16(module)
|
|
else:
|
|
raise RuntimeError(
|
|
"Unsupported dtype: {}".format(dtype))
|
|
return module
|
|
|
|
|
|
def quantize_rnn_modules(module, dtype=torch.int8):
|
|
reassign = {}
|
|
for name, mod in module.named_modules():
|
|
if mod is module:
|
|
continue
|
|
new_mod = quantize_rnn_modules(mod, dtype)
|
|
if new_mod is not mod:
|
|
reassign[name] = new_mod
|
|
|
|
for name, mod in reassign.items():
|
|
setattr(module, name, mod)
|
|
if isinstance(module, torch.nn.LSTM):
|
|
if dtype != torch.int8 and dtype != torch.float16:
|
|
raise RuntimeError("Unsupported dtype: {}".format(dtype))
|
|
return QuantizedLSTM(module, dtype)
|
|
if isinstance(module, torch.nn.GRU):
|
|
return QuantizedGRU(module)
|
|
return module
|