pytorch/torch/jit/quantized.py
Michael Suo ffa422a8b3 kill _parameter_list (#27399)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27399

This was devised in a time when we didn't have module attributes. They
are essentially just tensor lists, so represent them that way. This has
the additional benefit of making the RNN forward pass faster because we
effectively cache the flattened weights.

The only complication part is that someone may come along and do:
```
my_rnn_mod.w_ih_l0 = torch.nn.Parameter(...)
```

This means we need to override setattr to keep the flattened weights
cache up to date.

Test Plan: Imported from OSS

Differential Revision: D17785658

Pulled By: suo

fbshipit-source-id: 7789cd1d0d4922bfd5eba1716976442fbf150766
2019-10-12 09:51:53 -07:00

644 lines
28 KiB
Python

import torch
from torch._jit_internal import Tuple, Optional, List # noqa: F401
from torch import Tensor # noqa: F401
from torch.nn import _VF
from torch.nn.utils.rnn import PackedSequence
class QuantizedLinear(torch.jit.ScriptModule):
__constants__ = ['scale', 'zero_point']
def __init__(self, other):
super(QuantizedLinear, self).__init__()
self.in_features = other.in_features
self.out_features = other.out_features
# Quantize weight and discard the original
self.weight, self.col_offsets, self.scale, self.zero_point = torch.fbgemm_linear_quantize_weight(
other.weight.clone().float())
self.weight = torch.nn.Parameter(self.weight, requires_grad=False)
self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False)
assert other.bias is not None, 'QuantizedLinear requires a bias'
self.bias = torch.nn.Parameter(other.bias.clone().float(), requires_grad=False)
self.register_buffer(
'packed_tensor_ptr',
torch.fbgemm_pack_quantized_matrix(self.weight.clone()))
@torch.jit.script_method
def _unpack(self):
self.packed_tensor_ptr.set_(
torch.fbgemm_pack_quantized_matrix(self.weight))
@torch.jit.script_method
def _pack(self):
self.packed_tensor_ptr.set_(
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
@torch.jit.script_method
def forward(self, input):
out = torch.fbgemm_linear_int8_weight_fp32_activation(
input.float(), self.weight, self.packed_tensor_ptr, self.col_offsets,
self.scale, self.zero_point, self.bias)
return out.to(input.dtype)
def extra_repr(self):
repr = 'in_features={in_features}, out_features={out_features}, ' \
'scale={scale}, zero_point={zero_point}'.format(**self.__dict__)
return repr
# FP16 weights
class QuantizedLinearFP16(torch.jit.ScriptModule):
def __init__(self, other):
super(QuantizedLinearFP16, self).__init__()
self.in_features = other.in_features
self.out_features = other.out_features
self.original_weight = other.weight
self.weight = torch.fbgemm_pack_gemm_matrix_fp16(
other.weight.clone().float())
assert other.bias is not None, 'QuantizedLinearFP16 requires a bias'
self.bias = torch.nn.Parameter(other.bias.clone().float(), requires_grad=False)
self.register_buffer('packed_weight', self.weight)
@torch.jit.script_method
def _unpack(self):
self.packed_weight.set_(
torch.fbgemm_pack_gemm_matrix_fp16(
self.original_weight))
@torch.jit.script_method
def _pack(self):
self.packed_weight.set_(
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
@torch.jit.script_method
def forward(self, input):
out = torch.fbgemm_linear_fp16_weight_fp32_activation(
input.float(), self.packed_weight, self.bias)
return out
def extra_repr(self):
repr = 'in_features={in_features}, out_features={out_features}, '.format(**self.__dict__)
return repr
# Quantized RNN cell implementations
class QuantizedRNNCellBase(torch.jit.ScriptModule):
__constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
'zero_point_ih', 'zero_point_hh']
def __init__(self, other):
super(QuantizedRNNCellBase, self).__init__()
self.input_size = other.input_size
self.hidden_size = other.hidden_size
self.bias = other.bias
if not self.bias:
raise ValueError("Quantized RNN cells require bias terms")
weight_ih, col_offsets_ih, self.scale_ih, self.zero_point_ih = \
torch.fbgemm_linear_quantize_weight(other.weight_ih.clone().float())
self.register_buffer('weight_ih', weight_ih)
self.register_buffer('col_offsets_ih', col_offsets_ih)
weight_hh, col_offsets_hh, self.scale_hh, self.zero_point_hh = \
torch.fbgemm_linear_quantize_weight(other.weight_hh.clone().float())
self.register_buffer('weight_hh', weight_hh)
self.register_buffer('col_offsets_hh', col_offsets_hh)
packed_ih = torch.fbgemm_pack_quantized_matrix(self.weight_ih)
self.register_buffer('packed_ih', packed_ih)
packed_hh = torch.fbgemm_pack_quantized_matrix(self.weight_hh)
self.register_buffer('packed_hh', packed_hh)
self.bias_ih = torch.nn.Parameter(other.bias_ih.clone().float(), requires_grad=False)
self.bias_hh = torch.nn.Parameter(other.bias_hh.clone().float(), requires_grad=False)
def extra_repr(self):
s = '{input_size}, {hidden_size}'
if 'bias' in self.__dict__ and self.bias is not True:
s += ', bias={bias}'
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
s += ', nonlinearity={nonlinearity}'
return s.format(**self.__dict__)
@torch.jit.script_method
def check_forward_input(self, input):
if input.size(1) != self.input_size:
raise RuntimeError(
"input has inconsistent input_size: got {}, expected {}".format(
input.size(1), self.input_size))
@torch.jit.script_method
def check_forward_hidden(self, input, hx, hidden_label=''):
# type: (Tensor, Tensor, str) -> None
if input.size(0) != hx.size(0):
raise RuntimeError(
"Input batch size {} doesn't match hidden{} batch size {}".format(
input.size(0), hidden_label, hx.size(0)))
if hx.size(1) != self.hidden_size:
raise RuntimeError(
"hidden{} has inconsistent hidden_size: got {}, expected {}".format(
hidden_label, hx.size(1), self.hidden_size))
# TODO: for some reason weak_script_method causes a destruction of the
# module to occur, which in turn frees the packed_ih object via its DataPtr
# deleter. This is bizarre and should probably get fixed.
# @torch._jit_internal.weak_script_method
@torch.jit.script_method
def _unpack(self):
self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(self.weight_ih))
self.packed_hh.set_(torch.fbgemm_pack_quantized_matrix(self.weight_hh))
# @torch._jit_internal.weak_script_method
@torch.jit.script_method
def _pack(self):
self.packed_ih.set_(
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
self.packed_hh.set_(
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
class QuantizedRNNCell(QuantizedRNNCellBase):
__constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
'zero_point_ih', 'zero_point_hh', 'nonlinearity']
def __init__(self, other):
super(QuantizedRNNCell, self).__init__(other)
self.nonlinearity = other.nonlinearity
@torch.jit.script_method
def forward(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor
self.check_forward_input(input)
if hx is None:
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
self.check_forward_hidden(input, hx, '')
if self.nonlinearity == "tanh":
ret = _VF.quantized_rnn_tanh_cell(
input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
self.zero_point_hh
)
elif self.nonlinearity == "relu":
ret = _VF.quantized_rnn_relu_cell(
input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
self.zero_point_hh
)
else:
ret = input # TODO: remove when jit supports exception flow
raise RuntimeError(
"Unknown nonlinearity: {}".format(self.nonlinearity))
return ret
class QuantizedLSTMCell(QuantizedRNNCellBase):
def __init__(self, other):
super(QuantizedLSTMCell, self).__init__(other)
@torch.jit.script_method
def forward(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
self.check_forward_input(input)
if hx is None:
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
hx = (zeros, zeros)
self.check_forward_hidden(input, hx[0], '[0]')
self.check_forward_hidden(input, hx[1], '[1]')
return _VF.quantized_lstm_cell(
input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
self.zero_point_hh
)
class QuantizedGRUCell(QuantizedRNNCellBase):
def __init__(self, other):
super(QuantizedGRUCell, self).__init__(other)
@torch.jit.script_method
def forward(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor
self.check_forward_input(input)
if hx is None:
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
self.check_forward_hidden(input, hx, '')
return _VF.quantized_gru_cell(
input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
self.zero_point_hh
)
@torch.jit.script
def apply_permutation(tensor, permutation, dim=1):
# type: (Tensor, Tensor, int) -> Tensor
return tensor.index_select(dim, permutation)
class QuantizedRNNBase(torch.jit.ScriptModule):
__constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
'batch_first', 'dropout', 'bidirectional', 'dtype']
def __init__(self, other, dtype=torch.int8):
super(QuantizedRNNBase, self).__init__()
self.mode = other.mode
self.input_size = other.input_size
self.hidden_size = other.hidden_size
self.num_layers = other.num_layers
self.bias = other.bias
self.batch_first = other.batch_first
if self.mode != 'GRU':
assert not self.batch_first
self.dropout = other.dropout
self.bidirectional = other.bidirectional
num_directions = 2 if self.bidirectional else 1
self.dtype = dtype
assert self.bias
# TODO: support more than just LSTM
if self.mode != 'LSTM' and self.mode != 'GRU':
raise RuntimeError('Only LSTM or GRU is supported for QuantizedRNN')
if dtype != torch.int8 and dtype != torch.float16:
raise RuntimeError('Unsupported dtype: {}'.format(dtype))
self._all_weights_names = []
self._packed_weights_names = []
self._quantized_weights_names = []
self._orig_weights_names = []
for layer in range(self.num_layers):
for direction in range(num_directions):
layer_input_size = self.input_size if layer == 0 else self.hidden_size * num_directions
def process_weights(ihhh, layer, suffix, dtype):
weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix)
bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix)
weight = getattr(other, weight_name)
bias = getattr(other, bias_name)
if dtype == torch.int8:
# for each layer, for each direction we need to quantize and pack
# weights and pack parameters in this order:
#
# w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
# col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh
qweight, col_offsets, scale, zero_point = \
torch.fbgemm_linear_quantize_weight(weight.clone().float())
packed_weight = torch.fbgemm_pack_quantized_matrix(qweight)
params = [qweight, bias, packed_weight, col_offsets, scale, zero_point]
pos_names = ['w', 'b', 'packed', 'col_offsets', 'scale', 'zero_point']
ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names]
self._quantized_weights_names.append(ret_name[0])
self._packed_weights_names.append(ret_name[2])
return params, ret_name
else:
# for each layer, for each direction we need to quantize and pack
# weights and pack parameters in this order:
#
# packed_ih, packed_hh, b_ih, b_hh
packed_weight = torch.fbgemm_pack_gemm_matrix_fp16(
weight.clone().float())
self._orig_weights_names.append(weight_name)
self.register_buffer(weight_name, weight)
params = [packed_weight, bias]
pos_names = ['packed', 'b']
ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names]
self._packed_weights_names.append(ret_name[0])
self._quantized_weights_names.append(ret_name[0])
return params, ret_name
suffix = '_reverse' if direction == 1 else ''
ih_params, ih_param_names = process_weights('ih', layer, suffix, dtype)
hh_params, hh_param_names = process_weights('hh', layer, suffix, dtype)
for (ih, ih_name), (hh, hh_name) in zip(zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)):
self.register_buffer(ih_name, torch.tensor(ih) if not isinstance(ih, torch.Tensor) else ih)
self.register_buffer(hh_name, torch.tensor(hh) if not isinstance(hh, torch.Tensor) else hh)
self._all_weights_names.extend([ih_name, hh_name])
# For int8 quantization, _orig_weights is not needed in the quantization logic,
# however there is a JIT compilation error without it. This is just used to
# workaround that error.
if dtype == torch.int8:
self._orig_weights_names = self._packed_weights_names
self._packed_weights = torch.jit.Attribute(
[getattr(self, weight) for weight in self._packed_weights_names],
List[Tensor],
)
self._quantized_weights = torch.jit.Attribute(
[getattr(self, weight) for weight in self._quantized_weights_names],
List[Tensor],
)
self._orig_weights = torch.jit.Attribute(
[getattr(self, weight) for weight in self._orig_weights_names], List[Tensor]
)
# this one is public
self.all_weights = torch.jit.Attribute(
[getattr(self, weight) for weight in self._all_weights_names], List[Tensor]
)
@torch.jit.script_method
def check_input(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> None
expected_input_dim = 2 if batch_sizes is not None else 3
if input.dim() != expected_input_dim:
raise RuntimeError(
'input must have {} dimensions, got {}'.format(
expected_input_dim, input.dim()))
if self.input_size != input.size(-1):
raise RuntimeError(
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
self.input_size, input.size(-1)))
@torch.jit.script_method
def get_expected_hidden_size(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
if batch_sizes is not None:
mini_batch = batch_sizes[0]
mini_batch = int(mini_batch)
else:
mini_batch = input.size(0) if self.batch_first else input.size(1)
num_directions = 2 if self.bidirectional else 1
expected_hidden_size = (self.num_layers * num_directions,
mini_batch, self.hidden_size)
return expected_hidden_size
@torch.jit.script_method
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
# type: (Tensor, Tuple[int, int, int], str) -> None
if hx.size() != expected_hidden_size:
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
@torch.jit.script_method
def check_forward_args(self, input, hidden, batch_sizes):
# type: (Tensor, Tensor, Optional[Tensor]) -> None
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
self.check_hidden_size(hidden, expected_hidden_size, msg='Expected hidden size {}, got {}')
@torch.jit.script_method
def permute_hidden(self, hx, permutation):
# type: (Tensor, Optional[Tensor]) -> Tensor
if permutation is None:
return hx
return apply_permutation(hx, permutation)
def __setattr__(self, attr, value):
if hasattr(self, "_orig_weight_names"):
# keep weight attributes up to date if you do self.weight = ...
if attr in self._all_weights_names:
idx = self._all_weights_names.index(attr)
self.all_weights[idx] = value
elif attr in self._packed_weights_names:
idx = self._packed_weights_names.index(attr)
self._packed_weights[idx] = value
elif attr in self._orig_weights_names:
idx = self._orig_weights_names.index(attr)
self._orig_weights[idx] = value
elif attr in self._quantized_weights_names:
idx = self._quantized_weights_names.index(attr)
self._quantized_weights[idx] = value
return super(QuantizedRNNBase, self).__setattr__(attr, value)
# TODO: for some reason torch.jit.script_method causes a destruction of the
# module to occur, which in turn frees the packed_ih object via its DataPtr
# deleter. This is bizarre and should probably get fixed.
# @torch._jit_internal.torch.jit.script_method
@torch.jit.script_method
def _unpack(self):
if self.dtype == torch.int8:
packed_weights = self._packed_weights
quantized_weights = self._quantized_weights
assert len(packed_weights) == len(quantized_weights)
for i in range(len(packed_weights)):
packed = packed_weights[i]
quantized = quantized_weights[i]
packed.set_(torch.fbgemm_pack_quantized_matrix(quantized))
else:
packed_weights = self._packed_weights
orig_weights = self._orig_weights
assert len(packed_weights) == len(orig_weights)
for i in range(len(packed_weights)):
packed = packed_weights[i]
orig_weight = orig_weights[i]
packed.set_(torch.fbgemm_pack_gemm_matrix_fp16(
orig_weight))
@torch.jit.script_method
def _pack(self):
for weight in self._packed_weights:
weight.set_(torch.zeros(torch.jit.annotate(List[int], []),
dtype=torch.uint8).detach())
class QuantizedLSTM(QuantizedRNNBase):
__overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
def __init__(self, other, dtype):
super(QuantizedLSTM, self).__init__(other, dtype)
@torch.jit.script_method
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
if hx is None:
num_directions = 2 if self.bidirectional else 1
zeros = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
dtype=input.dtype, device=input.device)
hx = (zeros, zeros)
else:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)
self.check_forward_args(input, hx, batch_sizes)
assert batch_sizes is None
result = _VF.quantized_lstm(input, hx, self.all_weights, self.bias, self.num_layers,
float(self.dropout), self.training, self.bidirectional,
self.batch_first, dtype=self.dtype, use_dynamic=False)
output = result[0]
hidden = result[1:]
return output, hidden
@torch.jit.script_method
def forward_tensor(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
return output, self.permute_hidden(hidden, unsorted_indices)
@torch.jit.script_method
def forward_packed(self, input, hx=None):
# type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output, self.permute_hidden(hidden, unsorted_indices)
@torch.jit.script_method
def permute_hidden(self, hx, permutation):
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
if permutation is None:
return hx
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
@torch.jit.script_method
def check_forward_args(self, input, hidden, batch_sizes):
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
self.check_hidden_size(hidden[0], expected_hidden_size,
'Expected hidden[0] size {}, got {}')
self.check_hidden_size(hidden[1], expected_hidden_size,
'Expected hidden[1] size {}, got {}')
def forward(self, input, hx=None):
if isinstance(input, PackedSequence):
return self.forward_packed(input, hx)
else:
return self.forward_tensor(input, hx)
class QuantizedGRU(QuantizedRNNBase):
__overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
@torch.jit.script_method
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
# type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
dtype=input.dtype, device=input.device)
else:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)
self.check_forward_args(input, hx, batch_sizes)
if batch_sizes is None:
result = _VF.quantized_gru(input, hx, self.all_weights, self.bias, self.num_layers,
float(self.dropout), self.training, self.bidirectional,
self.batch_first)
else:
result = _VF.quantized_gru(input, batch_sizes, hx, self.all_weights, self.bias, self.num_layers,
float(self.dropout), self.training, self.bidirectional)
output = result[0]
hidden = result[1]
return output, hidden
@torch.jit.script_method
def forward_tensor(self, input, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
return output, self.permute_hidden(hidden, unsorted_indices)
@torch.jit.script_method
def forward_packed(self, input, hx=None):
# type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor]
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output, self.permute_hidden(hidden, unsorted_indices)
def forward(self, input, hx=None):
if isinstance(input, PackedSequence):
return self.forward_packed(input, hx)
else:
return self.forward_tensor(input, hx)
def quantize_rnn_cell_modules(module):
reassign = {}
for name, mod in module.named_modules():
if mod is module:
continue
new_mod = quantize_rnn_cell_modules(mod)
if new_mod is not mod:
reassign[name] = new_mod
for name, mod in reassign.items():
setattr(module, name, mod)
if isinstance(module, torch.nn.LSTMCell):
return QuantizedLSTMCell(module)
if isinstance(module, torch.nn.GRUCell):
return QuantizedGRUCell(module)
if isinstance(module, torch.nn.RNNCell):
return QuantizedRNNCell(module)
return module
def quantize_linear_modules(module, dtype=torch.int8):
reassign = {}
for name, mod in module.named_modules():
if mod is module:
continue
new_mod = quantize_linear_modules(mod, dtype)
if new_mod is not mod:
reassign[name] = new_mod
for name, mod in reassign.items():
setattr(module, name, mod)
if isinstance(module, torch.nn.Linear):
if dtype == torch.int8:
return QuantizedLinear(module)
elif dtype == torch.float16:
return QuantizedLinearFP16(module)
else:
raise RuntimeError(
"Unsupported dtype: {}".format(dtype))
return module
def quantize_rnn_modules(module, dtype=torch.int8):
reassign = {}
for name, mod in module.named_modules():
if mod is module:
continue
new_mod = quantize_rnn_modules(mod, dtype)
if new_mod is not mod:
reassign[name] = new_mod
for name, mod in reassign.items():
setattr(module, name, mod)
if isinstance(module, torch.nn.LSTM):
if dtype != torch.int8 and dtype != torch.float16:
raise RuntimeError("Unsupported dtype: {}".format(dtype))
return QuantizedLSTM(module, dtype)
if isinstance(module, torch.nn.GRU):
return QuantizedGRU(module)
return module