mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Add typing annotations for torch.nn.quantized.dynamic.modules.rnn (#43186)
Summary: Fixes https://github.com/pytorch/pytorch/issues/43185 xref: [gh-43072](https://github.com/pytorch/pytorch/issues/43072) Pull Request resolved: https://github.com/pytorch/pytorch/pull/43186 Reviewed By: ezyang Differential Revision: D23441259 Pulled By: malfet fbshipit-source-id: 80265ae7f3a70f0087e620969dbd4aa8ca17c317
This commit is contained in:
parent
8ca3913f47
commit
63a0bb0ab9
3
mypy.ini
3
mypy.ini
|
|
@ -170,9 +170,6 @@ ignore_errors = True
|
|||
[mypy-torch.nn.qat.modules.conv]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.quantized.dynamic.modules.rnn]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.quantized.dynamic.modules.linear]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,11 @@ import warnings
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor # noqa: F401
|
||||
from torch._jit_internal import Tuple, Optional, List # noqa: F401
|
||||
from torch._jit_internal import Tuple, Optional, List, Union, Dict # noqa: F401
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.nn.quantized.modules.utils import _quantize_weight
|
||||
|
||||
def apply_permutation(tensor, permutation, dim=1):
|
||||
# type: (Tensor, Tensor, int) -> Tensor
|
||||
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
|
||||
return tensor.index_select(dim, permutation)
|
||||
|
||||
class PackedParameter(torch.nn.Module):
|
||||
|
|
@ -53,12 +52,14 @@ class RNNBase(torch.nn.Module):
|
|||
self.training = False
|
||||
num_directions = 2 if bidirectional else 1
|
||||
|
||||
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
|
||||
isinstance(dropout, bool):
|
||||
# "type: ignore" is required since ints and Numbers are not fully comparable
|
||||
# https://github.com/python/mypy/issues/8566
|
||||
if not isinstance(dropout, numbers.Number) \
|
||||
or not 0 <= dropout <= 1 or isinstance(dropout, bool): # type: ignore
|
||||
raise ValueError("dropout should be a number in range [0, 1] "
|
||||
"representing the probability of an element being "
|
||||
"zeroed")
|
||||
if dropout > 0 and num_layers == 1:
|
||||
if dropout > 0 and num_layers == 1: # type: ignore
|
||||
warnings.warn("dropout option adds dropout after all but last "
|
||||
"recurrent layer, so non-zero dropout expects "
|
||||
"num_layers greater than 1, but got dropout={} and "
|
||||
|
|
@ -149,8 +150,7 @@ class RNNBase(torch.nn.Module):
|
|||
main_str += ')'
|
||||
return main_str
|
||||
|
||||
def check_input(self, input, batch_sizes):
|
||||
# type: (Tensor, Optional[Tensor]) -> None
|
||||
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
|
||||
expected_input_dim = 2 if batch_sizes is not None else 3
|
||||
if input.dim() != expected_input_dim:
|
||||
raise RuntimeError(
|
||||
|
|
@ -161,11 +161,9 @@ class RNNBase(torch.nn.Module):
|
|||
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
|
||||
self.input_size, input.size(-1)))
|
||||
|
||||
def get_expected_hidden_size(self, input, batch_sizes):
|
||||
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
|
||||
def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
|
||||
if batch_sizes is not None:
|
||||
mini_batch = batch_sizes[0]
|
||||
mini_batch = int(mini_batch)
|
||||
mini_batch = int(batch_sizes[0])
|
||||
else:
|
||||
mini_batch = input.size(0) if self.batch_first else input.size(1)
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
|
|
@ -173,21 +171,21 @@ class RNNBase(torch.nn.Module):
|
|||
mini_batch, self.hidden_size)
|
||||
return expected_hidden_size
|
||||
|
||||
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
|
||||
# type: (Tensor, Tuple[int, int, int], str) -> None
|
||||
def check_hidden_size(
|
||||
self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
|
||||
msg: str = 'Expected hidden size {}, got {}'
|
||||
) -> None:
|
||||
if hx.size() != expected_hidden_size:
|
||||
raise RuntimeError(msg.format(
|
||||
expected_hidden_size, list(hx.size())))
|
||||
|
||||
def check_forward_args(self, input, hidden, batch_sizes):
|
||||
# type: (Tensor, Tensor, Optional[Tensor]) -> None
|
||||
def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None:
|
||||
self.check_input(input, batch_sizes)
|
||||
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
||||
self.check_hidden_size(hidden, expected_hidden_size,
|
||||
msg='Expected hidden size {}, got {}')
|
||||
|
||||
def permute_hidden(self, hx, permutation):
|
||||
# type: (Tensor, Optional[Tensor]) -> Tensor
|
||||
def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
|
||||
if permutation is None:
|
||||
return hx
|
||||
return apply_permutation(hx, permutation)
|
||||
|
|
@ -287,7 +285,7 @@ class RNNBase(torch.nn.Module):
|
|||
|
||||
def _weight_bias(self):
|
||||
# Returns a dict of weights and biases
|
||||
weight_bias_dict = {'weight' : {}, 'bias' : {}}
|
||||
weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
|
||||
count = 0
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
for layer in range(self.num_layers):
|
||||
|
|
@ -337,8 +335,11 @@ class LSTM(RNNBase):
|
|||
def _get_name(self):
|
||||
return 'DynamicQuantizedLSTM'
|
||||
|
||||
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
|
||||
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
|
||||
def forward_impl(
|
||||
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]],
|
||||
batch_sizes: Optional[Tensor], max_batch_size: int,
|
||||
sorted_indices: Optional[Tensor]
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
if hx is None:
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
zeros = torch.zeros(self.num_layers * num_directions,
|
||||
|
|
@ -367,8 +368,9 @@ class LSTM(RNNBase):
|
|||
return output, hidden
|
||||
|
||||
@torch.jit.export
|
||||
def forward_tensor(self, input, hx=None):
|
||||
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
|
||||
def forward_tensor(
|
||||
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
batch_sizes = None
|
||||
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
||||
sorted_indices = None
|
||||
|
|
@ -380,27 +382,32 @@ class LSTM(RNNBase):
|
|||
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||
|
||||
@torch.jit.export
|
||||
def forward_packed(self, input, hx=None):
|
||||
# type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
|
||||
input, batch_sizes, sorted_indices, unsorted_indices = input
|
||||
def forward_packed(
|
||||
self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
|
||||
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa
|
||||
input_, batch_sizes, sorted_indices, unsorted_indices = input
|
||||
max_batch_size = batch_sizes[0]
|
||||
max_batch_size = int(max_batch_size)
|
||||
|
||||
output, hidden = self.forward_impl(
|
||||
input, hx, batch_sizes, max_batch_size, sorted_indices)
|
||||
output_, hidden = self.forward_impl(
|
||||
input_, hx, batch_sizes, max_batch_size, sorted_indices)
|
||||
|
||||
output = PackedSequence(output, batch_sizes,
|
||||
output = PackedSequence(output_, batch_sizes,
|
||||
sorted_indices, unsorted_indices)
|
||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||
|
||||
def permute_hidden(self, hx, permutation):
|
||||
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
|
||||
# "type: ignore" is required due to issue #43072
|
||||
def permute_hidden( # type: ignore
|
||||
self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
if permutation is None:
|
||||
return hx
|
||||
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
|
||||
|
||||
def check_forward_args(self, input, hidden, batch_sizes):
|
||||
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor])->None
|
||||
# "type: ignore" is required due to issue #43072
|
||||
def check_forward_args( # type: ignore
|
||||
self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]
|
||||
) -> None:
|
||||
self.check_input(input, batch_sizes)
|
||||
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
||||
|
||||
|
|
@ -483,8 +490,7 @@ class RNNCellBase(torch.nn.Module):
|
|||
"input has inconsistent input_size: got {}, expected {}".format(
|
||||
input.size(1), self.input_size))
|
||||
|
||||
def check_forward_hidden(self, input, hx, hidden_label=''):
|
||||
# type: (Tensor, Tensor, str) -> None
|
||||
def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None:
|
||||
if input.size(0) != hx.size(0):
|
||||
raise RuntimeError(
|
||||
"Input batch size {} doesn't match hidden{} batch size {}".format(
|
||||
|
|
@ -518,6 +524,8 @@ class RNNCellBase(torch.nn.Module):
|
|||
if dtype not in supported_scalar_types:
|
||||
raise RuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype))
|
||||
|
||||
qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]
|
||||
|
||||
if type(mod) == torch.nn.LSTMCell:
|
||||
qRNNCellBase = LSTMCell(mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype)
|
||||
elif type(mod) == torch.nn.GRUCell:
|
||||
|
|
@ -561,7 +569,7 @@ class RNNCellBase(torch.nn.Module):
|
|||
|
||||
def _weight_bias(self):
|
||||
# Returns a dict of weights and biases
|
||||
weight_bias_dict = {'weight' : {}, 'bias' : {}}
|
||||
weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
|
||||
w1, b1 = self._packed_weight_ih.__getstate__()[0]
|
||||
w2, b2 = self._packed_weight_hh.__getstate__()[0]
|
||||
weight_bias_dict['weight']['weight_ih'] = w1
|
||||
|
|
@ -614,8 +622,7 @@ class RNNCell(RNNCellBase):
|
|||
def _get_name(self):
|
||||
return 'DynamicQuantizedRNNCell'
|
||||
|
||||
def forward(self, input, hx=None):
|
||||
# type: (Tensor, Optional[Tensor]) -> Tensor
|
||||
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
|
||||
self.check_forward_input(input)
|
||||
if hx is None:
|
||||
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
||||
|
|
@ -661,13 +668,12 @@ class LSTMCell(RNNCellBase):
|
|||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LSTMCell, self).__init__(*args, num_chunks=4, **kwargs)
|
||||
super(LSTMCell, self).__init__(*args, num_chunks=4, **kwargs) # type: ignore
|
||||
|
||||
def _get_name(self):
|
||||
return 'DynamicQuantizedLSTMCell'
|
||||
|
||||
def forward(self, input, hx=None):
|
||||
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
|
||||
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
|
||||
self.check_forward_input(input)
|
||||
if hx is None:
|
||||
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
||||
|
|
@ -707,8 +713,7 @@ class GRUCell(RNNCellBase):
|
|||
def _get_name(self):
|
||||
return 'DynamicQuantizedGRUCell'
|
||||
|
||||
def forward(self, input, hx=None):
|
||||
# type: (Tensor, Optional[Tensor]) -> Tensor
|
||||
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
|
||||
self.check_forward_input(input)
|
||||
if hx is None:
|
||||
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user