mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38211 Just because the annotations are inline doesn't mean the files type check; most of the newly annotated files have type errors and I added exclusions for them in mypy.ini. The payoff of moving all of these modules inline is I can delete the relevant code generation logic for the pyi files (which was added ignore annotations that weren't actually relevant anymore.) For the most part the translation was completely mechanical, but there were two hairy issues. First, I needed to work around a Python 3.6 and earlier bug where Generic has a nontrivial metaclass. This fix is in torch/jit/__init__.py. Second, module.py, we need to apply the same fix for avoiding contravariance checks that the pyi file used to have; this is done by declaring forward as a variable (rather than a function), which appears to be sufficient enough to get mypy to not contravariantly check input arguments. Because we aren't actually typechecking these modules in most cases, it is inevitable that some of these type annotations are wrong. I slavishly copied the old annotations from the pyi files unless there was an obvious correction I could make. These annotations will probably need fixing up later. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D21497397 Pulled By: ezyang fbshipit-source-id: 2b08bacc152c48f074e7edc4ee5dce1b77d83702
1048 lines
47 KiB
Python
1048 lines
47 KiB
Python
import math
|
|
import warnings
|
|
import numbers
|
|
from typing import List, Tuple, Optional, overload
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from .module import Module
|
|
from ..parameter import Parameter
|
|
from ..utils.rnn import PackedSequence
|
|
from .. import init
|
|
from ... import _VF
|
|
|
|
_rnn_impls = {
|
|
'RNN_TANH': _VF.rnn_tanh,
|
|
'RNN_RELU': _VF.rnn_relu,
|
|
}
|
|
|
|
|
|
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
|
|
return tensor.index_select(dim, permutation)
|
|
|
|
|
|
class RNNBase(Module):
|
|
__constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
|
|
'batch_first', 'dropout', 'bidirectional']
|
|
|
|
mode: str
|
|
input_size: int
|
|
hidden_size: int
|
|
num_layers: int
|
|
bias: bool
|
|
batch_first: bool
|
|
dropout: float
|
|
bidirectional: bool
|
|
|
|
def __init__(self, mode: str, input_size: int, hidden_size: int,
|
|
num_layers: int = 1, bias: bool = True, batch_first: bool = False,
|
|
dropout: float = 0., bidirectional: bool = False) -> None:
|
|
super(RNNBase, self).__init__()
|
|
self.mode = mode
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.num_layers = num_layers
|
|
self.bias = bias
|
|
self.batch_first = batch_first
|
|
self.dropout = float(dropout)
|
|
self.bidirectional = bidirectional
|
|
num_directions = 2 if bidirectional else 1
|
|
|
|
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
|
|
isinstance(dropout, bool):
|
|
raise ValueError("dropout should be a number in range [0, 1] "
|
|
"representing the probability of an element being "
|
|
"zeroed")
|
|
if dropout > 0 and num_layers == 1:
|
|
warnings.warn("dropout option adds dropout after all but last "
|
|
"recurrent layer, so non-zero dropout expects "
|
|
"num_layers greater than 1, but got dropout={} and "
|
|
"num_layers={}".format(dropout, num_layers))
|
|
|
|
if mode == 'LSTM':
|
|
gate_size = 4 * hidden_size
|
|
elif mode == 'GRU':
|
|
gate_size = 3 * hidden_size
|
|
elif mode == 'RNN_TANH':
|
|
gate_size = hidden_size
|
|
elif mode == 'RNN_RELU':
|
|
gate_size = hidden_size
|
|
else:
|
|
raise ValueError("Unrecognized RNN mode: " + mode)
|
|
|
|
self._flat_weights_names = []
|
|
self._all_weights = []
|
|
for layer in range(num_layers):
|
|
for direction in range(num_directions):
|
|
layer_input_size = input_size if layer == 0 else hidden_size * num_directions
|
|
|
|
w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
|
|
w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
|
|
b_ih = Parameter(torch.Tensor(gate_size))
|
|
# Second bias vector included for CuDNN compatibility. Only one
|
|
# bias vector is needed in standard definition.
|
|
b_hh = Parameter(torch.Tensor(gate_size))
|
|
layer_params = (w_ih, w_hh, b_ih, b_hh)
|
|
|
|
suffix = '_reverse' if direction == 1 else ''
|
|
param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
|
|
if bias:
|
|
param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
|
|
param_names = [x.format(layer, suffix) for x in param_names]
|
|
|
|
for name, param in zip(param_names, layer_params):
|
|
setattr(self, name, param)
|
|
self._flat_weights_names.extend(param_names)
|
|
self._all_weights.append(param_names)
|
|
|
|
self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
|
|
self.flatten_parameters()
|
|
self.reset_parameters()
|
|
|
|
def __setattr__(self, attr, value):
|
|
if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names:
|
|
# keep self._flat_weights up to date if you do self.weight = ...
|
|
idx = self._flat_weights_names.index(attr)
|
|
self._flat_weights[idx] = value
|
|
super(RNNBase, self).__setattr__(attr, value)
|
|
|
|
def flatten_parameters(self) -> None:
|
|
"""Resets parameter data pointer so that they can use faster code paths.
|
|
|
|
Right now, this works only if the module is on the GPU and cuDNN is enabled.
|
|
Otherwise, it's a no-op.
|
|
"""
|
|
# Short-circuits if _flat_weights is only partially instantiated
|
|
if len(self._flat_weights) != len(self._flat_weights_names):
|
|
return
|
|
|
|
for w in self._flat_weights:
|
|
if not isinstance(w, Tensor):
|
|
return
|
|
# Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
|
|
# or the tensors in _flat_weights are of different dtypes
|
|
|
|
first_fw = self._flat_weights[0]
|
|
dtype = first_fw.dtype
|
|
for fw in self._flat_weights:
|
|
if (not isinstance(fw.data, Tensor) or not (fw.data.dtype == dtype) or
|
|
not fw.data.is_cuda or
|
|
not torch.backends.cudnn.is_acceptable(fw.data)):
|
|
return
|
|
|
|
# If any parameters alias, we fall back to the slower, copying code path. This is
|
|
# a sufficient check, because overlapping parameter buffers that don't completely
|
|
# alias would break the assumptions of the uniqueness check in
|
|
# Module.named_parameters().
|
|
unique_data_ptrs = set(p.data_ptr() for p in self._flat_weights)
|
|
if len(unique_data_ptrs) != len(self._flat_weights):
|
|
return
|
|
|
|
with torch.cuda.device_of(first_fw):
|
|
import torch.backends.cudnn.rnn as rnn
|
|
|
|
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
|
|
# an inplace operation on self._flat_weights
|
|
with torch.no_grad():
|
|
if torch._use_cudnn_rnn_flatten_weight():
|
|
torch._cudnn_rnn_flatten_weight(
|
|
self._flat_weights, (4 if self.bias else 2),
|
|
self.input_size, rnn.get_cudnn_mode(self.mode), self.hidden_size, self.num_layers,
|
|
self.batch_first, bool(self.bidirectional))
|
|
|
|
def _apply(self, fn):
|
|
ret = super(RNNBase, self)._apply(fn)
|
|
|
|
# Resets _flat_weights
|
|
# Note: be v. careful before removing this, as 3rd party device types
|
|
# likely rely on this behavior to properly .to() modules like LSTM.
|
|
self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
|
|
# Flattens params (on CUDA)
|
|
self.flatten_parameters()
|
|
|
|
return ret
|
|
|
|
def reset_parameters(self) -> None:
|
|
stdv = 1.0 / math.sqrt(self.hidden_size)
|
|
for weight in self.parameters():
|
|
init.uniform_(weight, -stdv, stdv)
|
|
|
|
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
|
|
expected_input_dim = 2 if batch_sizes is not None else 3
|
|
if input.dim() != expected_input_dim:
|
|
raise RuntimeError(
|
|
'input must have {} dimensions, got {}'.format(
|
|
expected_input_dim, input.dim()))
|
|
if self.input_size != input.size(-1):
|
|
raise RuntimeError(
|
|
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
|
|
self.input_size, input.size(-1)))
|
|
|
|
def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
|
|
if batch_sizes is not None:
|
|
mini_batch = batch_sizes[0]
|
|
mini_batch = int(mini_batch)
|
|
else:
|
|
mini_batch = input.size(0) if self.batch_first else input.size(1)
|
|
num_directions = 2 if self.bidirectional else 1
|
|
expected_hidden_size = (self.num_layers * num_directions,
|
|
mini_batch, self.hidden_size)
|
|
return expected_hidden_size
|
|
|
|
def check_hidden_size(self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
|
|
msg: str = 'Expected hidden size {}, got {}') -> None:
|
|
if hx.size() != expected_hidden_size:
|
|
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
|
|
|
|
def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]):
|
|
self.check_input(input, batch_sizes)
|
|
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
|
|
|
self.check_hidden_size(hidden, expected_hidden_size)
|
|
|
|
def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]):
|
|
if permutation is None:
|
|
return hx
|
|
return apply_permutation(hx, permutation)
|
|
|
|
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
|
|
is_packed = isinstance(input, PackedSequence)
|
|
if is_packed:
|
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
|
max_batch_size = batch_sizes[0]
|
|
max_batch_size = int(max_batch_size)
|
|
else:
|
|
batch_sizes = None
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
sorted_indices = None
|
|
unsorted_indices = None
|
|
|
|
if hx is None:
|
|
num_directions = 2 if self.bidirectional else 1
|
|
hx = torch.zeros(self.num_layers * num_directions,
|
|
max_batch_size, self.hidden_size,
|
|
dtype=input.dtype, device=input.device)
|
|
else:
|
|
# Each batch of the hidden state should match the input sequence that
|
|
# the user believes he/she is passing in.
|
|
hx = self.permute_hidden(hx, sorted_indices)
|
|
|
|
self.check_forward_args(input, hx, batch_sizes)
|
|
_impl = _rnn_impls[self.mode]
|
|
if batch_sizes is None:
|
|
result = _impl(input, hx, self._flat_weights, self.bias, self.num_layers,
|
|
self.dropout, self.training, self.bidirectional, self.batch_first)
|
|
else:
|
|
result = _impl(input, batch_sizes, hx, self._flat_weights, self.bias,
|
|
self.num_layers, self.dropout, self.training, self.bidirectional)
|
|
output = result[0]
|
|
hidden = result[1]
|
|
|
|
if is_packed:
|
|
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
def extra_repr(self) -> str:
|
|
s = '{input_size}, {hidden_size}'
|
|
if self.num_layers != 1:
|
|
s += ', num_layers={num_layers}'
|
|
if self.bias is not True:
|
|
s += ', bias={bias}'
|
|
if self.batch_first is not False:
|
|
s += ', batch_first={batch_first}'
|
|
if self.dropout != 0:
|
|
s += ', dropout={dropout}'
|
|
if self.bidirectional is not False:
|
|
s += ', bidirectional={bidirectional}'
|
|
return s.format(**self.__dict__)
|
|
|
|
def __setstate__(self, d):
|
|
super(RNNBase, self).__setstate__(d)
|
|
if 'all_weights' in d:
|
|
self._all_weights = d['all_weights']
|
|
|
|
if isinstance(self._all_weights[0][0], str):
|
|
return
|
|
num_layers = self.num_layers
|
|
num_directions = 2 if self.bidirectional else 1
|
|
self._flat_weights_names = []
|
|
self._all_weights = []
|
|
for layer in range(num_layers):
|
|
for direction in range(num_directions):
|
|
suffix = '_reverse' if direction == 1 else ''
|
|
weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
|
|
weights = [x.format(layer, suffix) for x in weights]
|
|
if self.bias:
|
|
self._all_weights += [weights]
|
|
self._flat_weights_names.extend(weights)
|
|
else:
|
|
self._all_weights += [weights[:2]]
|
|
self._flat_weights_names.extend(weights[:2])
|
|
self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names]
|
|
|
|
@property
|
|
def all_weights(self) -> List[Parameter]:
|
|
return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
|
|
|
|
def _replicate_for_data_parallel(self):
|
|
replica = super(RNNBase, self)._replicate_for_data_parallel()
|
|
# Need to copy these caches, otherwise the replica will share the same
|
|
# flat weights list.
|
|
replica._flat_weights = replica._flat_weights[:]
|
|
replica._flat_weights_names = replica._flat_weights_names[:]
|
|
return replica
|
|
|
|
|
|
class RNN(RNNBase):
|
|
r"""Applies a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` non-linearity to an
|
|
input sequence.
|
|
|
|
|
|
For each element in the input sequence, each layer computes the following
|
|
function:
|
|
|
|
.. math::
|
|
h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{(t-1)} + b_{hh})
|
|
|
|
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
|
|
the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
|
|
previous layer at time `t-1` or the initial hidden state at time `0`.
|
|
If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
|
|
would mean stacking two RNNs together to form a `stacked RNN`,
|
|
with the second RNN taking in outputs of the first RNN and
|
|
computing the final results. Default: 1
|
|
nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
|
|
Default: ``True``
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as `(batch, seq, feature)`. Default: ``False``
|
|
dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
|
|
RNN layer except the last layer, with dropout probability equal to
|
|
:attr:`dropout`. Default: 0
|
|
bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
|
|
|
|
Inputs: input, h_0
|
|
- **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
|
|
of the input sequence. The input can also be a packed variable length
|
|
sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
|
|
or :func:`torch.nn.utils.rnn.pack_sequence`
|
|
for details.
|
|
- **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the initial hidden state for each element in the batch.
|
|
Defaults to zero if not provided. If the RNN is bidirectional,
|
|
num_directions should be 2, else it should be 1.
|
|
|
|
Outputs: output, h_n
|
|
- **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
|
|
containing the output features (`h_t`) from the last layer of the RNN,
|
|
for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has
|
|
been given as the input, the output will also be a packed sequence.
|
|
|
|
For the unpacked case, the directions can be separated
|
|
using ``output.view(seq_len, batch, num_directions, hidden_size)``,
|
|
with forward and backward being direction `0` and `1` respectively.
|
|
Similarly, the directions can be separated in the packed case.
|
|
- **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the hidden state for `t = seq_len`.
|
|
|
|
Like *output*, the layers can be separated using
|
|
``h_n.view(num_layers, num_directions, batch, hidden_size)``.
|
|
|
|
Shape:
|
|
- Input1: :math:`(L, N, H_{in})` tensor containing input features where
|
|
:math:`H_{in}=\text{input\_size}` and `L` represents a sequence length.
|
|
- Input2: :math:`(S, N, H_{out})` tensor
|
|
containing the initial hidden state for each element in the batch.
|
|
:math:`H_{out}=\text{hidden\_size}`
|
|
Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}`
|
|
If the RNN is bidirectional, num_directions should be 2, else it should be 1.
|
|
- Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}`
|
|
- Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Attributes:
|
|
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
|
|
of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is
|
|
`(hidden_size, num_directions * hidden_size)`
|
|
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
|
|
of shape `(hidden_size, hidden_size)`
|
|
bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
|
|
of shape `(hidden_size)`
|
|
bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
|
|
of shape `(hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
.. include:: ../cudnn_persistent_rnn.rst
|
|
|
|
Examples::
|
|
|
|
>>> rnn = nn.RNN(10, 20, 2)
|
|
>>> input = torch.randn(5, 3, 10)
|
|
>>> h0 = torch.randn(2, 3, 20)
|
|
>>> output, hn = rnn(input, h0)
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.nonlinearity = kwargs.pop('nonlinearity', 'tanh')
|
|
if self.nonlinearity == 'tanh':
|
|
mode = 'RNN_TANH'
|
|
elif self.nonlinearity == 'relu':
|
|
mode = 'RNN_RELU'
|
|
else:
|
|
raise ValueError("Unknown nonlinearity '{}'".format(self.nonlinearity))
|
|
super(RNN, self).__init__(mode, *args, **kwargs)
|
|
|
|
|
|
# XXX: LSTM and GRU implementation is different from RNNBase, this is because:
|
|
# 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in
|
|
# its current state could not support the python Union Type or Any Type
|
|
# 2. TorchScript static typing does not allow a Function or Callable type in
|
|
# Dict values, so we have to separately call _VF instead of using _rnn_impls
|
|
# 3. This is temporary only and in the transition state that we want to make it
|
|
# on time for the release
|
|
#
|
|
# More discussion details in https://github.com/pytorch/pytorch/pull/23266
|
|
#
|
|
# TODO: remove the overriding implementations for LSTM and GRU when TorchScript
|
|
# support expressing these two modules generally.
|
|
class LSTM(RNNBase):
|
|
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
|
|
sequence.
|
|
|
|
|
|
For each element in the input sequence, each layer computes the following
|
|
function:
|
|
|
|
.. math::
|
|
\begin{array}{ll} \\
|
|
i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
|
|
f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
|
|
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
|
|
o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
|
|
c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
|
|
h_t = o_t \odot \tanh(c_t) \\
|
|
\end{array}
|
|
|
|
where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
|
|
state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
|
|
is the hidden state of the layer at time `t-1` or the initial hidden
|
|
state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
|
|
:math:`o_t` are the input, forget, cell, and output gates, respectively.
|
|
:math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
|
|
|
|
In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
|
|
(:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
|
|
dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
|
|
variable which is :math:`0` with probability :attr:`dropout`.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
|
|
would mean stacking two LSTMs together to form a `stacked LSTM`,
|
|
with the second LSTM taking in outputs of the first LSTM and
|
|
computing the final results. Default: 1
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
|
|
Default: ``True``
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as (batch, seq, feature). Default: ``False``
|
|
dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
|
|
LSTM layer except the last layer, with dropout probability equal to
|
|
:attr:`dropout`. Default: 0
|
|
bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
|
|
|
|
Inputs: input, (h_0, c_0)
|
|
- **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
|
|
of the input sequence.
|
|
The input can also be a packed variable length sequence.
|
|
See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
|
|
:func:`torch.nn.utils.rnn.pack_sequence` for details.
|
|
- **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the initial hidden state for each element in the batch.
|
|
If the LSTM is bidirectional, num_directions should be 2, else it should be 1.
|
|
- **c_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the initial cell state for each element in the batch.
|
|
|
|
If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
|
|
|
|
|
|
Outputs: output, (h_n, c_n)
|
|
- **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
|
|
containing the output features `(h_t)` from the last layer of the LSTM,
|
|
for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
|
|
given as the input, the output will also be a packed sequence.
|
|
|
|
For the unpacked case, the directions can be separated
|
|
using ``output.view(seq_len, batch, num_directions, hidden_size)``,
|
|
with forward and backward being direction `0` and `1` respectively.
|
|
Similarly, the directions can be separated in the packed case.
|
|
- **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the hidden state for `t = seq_len`.
|
|
|
|
Like *output*, the layers can be separated using
|
|
``h_n.view(num_layers, num_directions, batch, hidden_size)`` and similarly for *c_n*.
|
|
- **c_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the cell state for `t = seq_len`.
|
|
|
|
Attributes:
|
|
weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
|
|
`(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`.
|
|
Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`
|
|
weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
|
|
`(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`
|
|
bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
|
|
`(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`
|
|
bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
|
|
`(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
.. include:: ../cudnn_persistent_rnn.rst
|
|
|
|
Examples::
|
|
|
|
>>> rnn = nn.LSTM(10, 20, 2)
|
|
>>> input = torch.randn(5, 3, 10)
|
|
>>> h0 = torch.randn(2, 3, 20)
|
|
>>> c0 = torch.randn(2, 3, 20)
|
|
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(LSTM, self).__init__('LSTM', *args, **kwargs)
|
|
|
|
def check_forward_args(self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]):
|
|
self.check_input(input, batch_sizes)
|
|
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
|
|
|
self.check_hidden_size(hidden[0], expected_hidden_size,
|
|
'Expected hidden[0] size {}, got {}')
|
|
self.check_hidden_size(hidden[1], expected_hidden_size,
|
|
'Expected hidden[1] size {}, got {}')
|
|
|
|
def permute_hidden(self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]) -> Tuple[Tensor, Tensor]:
|
|
if permutation is None:
|
|
return hx
|
|
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
|
|
|
|
@overload
|
|
@torch._jit_internal._overload_method # noqa: F811
|
|
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
|
|
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811
|
|
pass
|
|
|
|
@overload
|
|
@torch._jit_internal._overload_method # noqa: F811
|
|
def forward(self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
|
|
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa: F811
|
|
pass
|
|
|
|
def forward(self, input, hx=None): # noqa: F811
|
|
orig_input = input
|
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
|
if isinstance(orig_input, PackedSequence):
|
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
|
max_batch_size = batch_sizes[0]
|
|
max_batch_size = int(max_batch_size)
|
|
else:
|
|
batch_sizes = None
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
sorted_indices = None
|
|
unsorted_indices = None
|
|
|
|
if hx is None:
|
|
num_directions = 2 if self.bidirectional else 1
|
|
zeros = torch.zeros(self.num_layers * num_directions,
|
|
max_batch_size, self.hidden_size,
|
|
dtype=input.dtype, device=input.device)
|
|
hx = (zeros, zeros)
|
|
else:
|
|
# Each batch of the hidden state should match the input sequence that
|
|
# the user believes he/she is passing in.
|
|
hx = self.permute_hidden(hx, sorted_indices)
|
|
|
|
self.check_forward_args(input, hx, batch_sizes)
|
|
if batch_sizes is None:
|
|
result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
|
|
self.dropout, self.training, self.bidirectional, self.batch_first)
|
|
else:
|
|
result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
|
|
self.num_layers, self.dropout, self.training, self.bidirectional)
|
|
output = result[0]
|
|
hidden = result[1:]
|
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
|
if isinstance(orig_input, PackedSequence):
|
|
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
|
else:
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
|
|
class GRU(RNNBase):
|
|
r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
|
|
|
|
|
|
For each element in the input sequence, each layer computes the following
|
|
function:
|
|
|
|
.. math::
|
|
\begin{array}{ll}
|
|
r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
|
|
z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
|
|
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
|
|
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)}
|
|
\end{array}
|
|
|
|
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
|
|
at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
|
|
at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
|
|
:math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
|
|
:math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
|
|
|
|
In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
|
|
(:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
|
|
dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
|
|
variable which is :math:`0` with probability :attr:`dropout`.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
|
|
would mean stacking two GRUs together to form a `stacked GRU`,
|
|
with the second GRU taking in outputs of the first GRU and
|
|
computing the final results. Default: 1
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
|
|
Default: ``True``
|
|
batch_first: If ``True``, then the input and output tensors are provided
|
|
as (batch, seq, feature). Default: ``False``
|
|
dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
|
|
GRU layer except the last layer, with dropout probability equal to
|
|
:attr:`dropout`. Default: 0
|
|
bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
|
|
|
|
Inputs: input, h_0
|
|
- **input** of shape `(seq_len, batch, input_size)`: tensor containing the features
|
|
of the input sequence. The input can also be a packed variable length
|
|
sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence`
|
|
for details.
|
|
- **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the initial hidden state for each element in the batch.
|
|
Defaults to zero if not provided. If the RNN is bidirectional,
|
|
num_directions should be 2, else it should be 1.
|
|
|
|
Outputs: output, h_n
|
|
- **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
|
|
containing the output features h_t from the last layer of the GRU,
|
|
for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been
|
|
given as the input, the output will also be a packed sequence.
|
|
For the unpacked case, the directions can be separated
|
|
using ``output.view(seq_len, batch, num_directions, hidden_size)``,
|
|
with forward and backward being direction `0` and `1` respectively.
|
|
|
|
Similarly, the directions can be separated in the packed case.
|
|
- **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
|
|
containing the hidden state for `t = seq_len`
|
|
|
|
Like *output*, the layers can be separated using
|
|
``h_n.view(num_layers, num_directions, batch, hidden_size)``.
|
|
|
|
Shape:
|
|
- Input1: :math:`(L, N, H_{in})` tensor containing input features where
|
|
:math:`H_{in}=\text{input\_size}` and `L` represents a sequence length.
|
|
- Input2: :math:`(S, N, H_{out})` tensor
|
|
containing the initial hidden state for each element in the batch.
|
|
:math:`H_{out}=\text{hidden\_size}`
|
|
Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}`
|
|
If the RNN is bidirectional, num_directions should be 2, else it should be 1.
|
|
- Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}`
|
|
- Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Attributes:
|
|
weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
|
|
(W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
|
|
Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
|
|
weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
|
|
(W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
|
|
bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
|
|
(b_ir|b_iz|b_in), of shape `(3*hidden_size)`
|
|
bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
|
|
(b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
.. include:: ../cudnn_persistent_rnn.rst
|
|
|
|
Examples::
|
|
|
|
>>> rnn = nn.GRU(10, 20, 2)
|
|
>>> input = torch.randn(5, 3, 10)
|
|
>>> h0 = torch.randn(2, 3, 20)
|
|
>>> output, hn = rnn(input, h0)
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(GRU, self).__init__('GRU', *args, **kwargs)
|
|
|
|
@overload
|
|
@torch._jit_internal._overload_method # noqa: F811
|
|
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: # noqa: F811
|
|
pass
|
|
|
|
@overload
|
|
@torch._jit_internal._overload_method # noqa: F811
|
|
def forward(self, input: PackedSequence, hx: Optional[Tensor] = None) -> Tuple[PackedSequence, Tensor]: # noqa: F811
|
|
pass
|
|
|
|
def forward(self, input, hx=None): # noqa: F811
|
|
orig_input = input
|
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
|
if isinstance(orig_input, PackedSequence):
|
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
|
max_batch_size = batch_sizes[0]
|
|
max_batch_size = int(max_batch_size)
|
|
else:
|
|
batch_sizes = None
|
|
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
sorted_indices = None
|
|
unsorted_indices = None
|
|
|
|
if hx is None:
|
|
num_directions = 2 if self.bidirectional else 1
|
|
hx = torch.zeros(self.num_layers * num_directions,
|
|
max_batch_size, self.hidden_size,
|
|
dtype=input.dtype, device=input.device)
|
|
else:
|
|
# Each batch of the hidden state should match the input sequence that
|
|
# the user believes he/she is passing in.
|
|
hx = self.permute_hidden(hx, sorted_indices)
|
|
|
|
self.check_forward_args(input, hx, batch_sizes)
|
|
if batch_sizes is None:
|
|
result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
|
|
self.dropout, self.training, self.bidirectional, self.batch_first)
|
|
else:
|
|
result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias,
|
|
self.num_layers, self.dropout, self.training, self.bidirectional)
|
|
output = result[0]
|
|
hidden = result[1]
|
|
|
|
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
|
if isinstance(orig_input, PackedSequence):
|
|
output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
|
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
|
else:
|
|
return output, self.permute_hidden(hidden, unsorted_indices)
|
|
|
|
|
|
class RNNCellBase(Module):
|
|
__constants__ = ['input_size', 'hidden_size', 'bias']
|
|
|
|
input_size: int
|
|
hidden_size: int
|
|
bias: bool
|
|
weight_ih: Tensor
|
|
weight_hh: Tensor
|
|
# WARNING: bias_ih and bias_hh purposely not defined here.
|
|
# See https://github.com/pytorch/pytorch/issues/39670
|
|
|
|
def __init__(self, input_size: int, hidden_size: int, bias: bool, num_chunks: int) -> None:
|
|
super(RNNCellBase, self).__init__()
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.bias = bias
|
|
self.weight_ih = Parameter(torch.Tensor(num_chunks * hidden_size, input_size))
|
|
self.weight_hh = Parameter(torch.Tensor(num_chunks * hidden_size, hidden_size))
|
|
if bias:
|
|
self.bias_ih = Parameter(torch.Tensor(num_chunks * hidden_size))
|
|
self.bias_hh = Parameter(torch.Tensor(num_chunks * hidden_size))
|
|
else:
|
|
self.register_parameter('bias_ih', None)
|
|
self.register_parameter('bias_hh', None)
|
|
self.reset_parameters()
|
|
|
|
def extra_repr(self) -> str:
|
|
s = '{input_size}, {hidden_size}'
|
|
if 'bias' in self.__dict__ and self.bias is not True:
|
|
s += ', bias={bias}'
|
|
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
|
|
s += ', nonlinearity={nonlinearity}'
|
|
return s.format(**self.__dict__)
|
|
|
|
def check_forward_input(self, input: Tensor) -> None:
|
|
if input.size(1) != self.input_size:
|
|
raise RuntimeError(
|
|
"input has inconsistent input_size: got {}, expected {}".format(
|
|
input.size(1), self.input_size))
|
|
|
|
def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None:
|
|
if input.size(0) != hx.size(0):
|
|
raise RuntimeError(
|
|
"Input batch size {} doesn't match hidden{} batch size {}".format(
|
|
input.size(0), hidden_label, hx.size(0)))
|
|
|
|
if hx.size(1) != self.hidden_size:
|
|
raise RuntimeError(
|
|
"hidden{} has inconsistent hidden_size: got {}, expected {}".format(
|
|
hidden_label, hx.size(1), self.hidden_size))
|
|
|
|
def reset_parameters(self) -> None:
|
|
stdv = 1.0 / math.sqrt(self.hidden_size)
|
|
for weight in self.parameters():
|
|
init.uniform_(weight, -stdv, stdv)
|
|
|
|
|
|
class RNNCell(RNNCellBase):
|
|
r"""An Elman RNN cell with tanh or ReLU non-linearity.
|
|
|
|
.. math::
|
|
|
|
h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
|
|
|
|
If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
|
|
Default: ``True``
|
|
nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
|
|
|
|
Inputs: input, hidden
|
|
- **input** of shape `(batch, input_size)`: tensor containing input features
|
|
- **hidden** of shape `(batch, hidden_size)`: tensor containing the initial hidden
|
|
state for each element in the batch.
|
|
Defaults to zero if not provided.
|
|
|
|
Outputs: h'
|
|
- **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Shape:
|
|
- Input1: :math:`(N, H_{in})` tensor containing input features where
|
|
:math:`H_{in}` = `input_size`
|
|
- Input2: :math:`(N, H_{out})` tensor containing the initial hidden
|
|
state for each element in the batch where :math:`H_{out}` = `hidden_size`
|
|
Defaults to zero if not provided.
|
|
- Output: :math:`(N, H_{out})` tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Attributes:
|
|
weight_ih: the learnable input-hidden weights, of shape
|
|
`(hidden_size, input_size)`
|
|
weight_hh: the learnable hidden-hidden weights, of shape
|
|
`(hidden_size, hidden_size)`
|
|
bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`
|
|
bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
Examples::
|
|
|
|
>>> rnn = nn.RNNCell(10, 20)
|
|
>>> input = torch.randn(6, 3, 10)
|
|
>>> hx = torch.randn(3, 20)
|
|
>>> output = []
|
|
>>> for i in range(6):
|
|
hx = rnn(input[i], hx)
|
|
output.append(hx)
|
|
"""
|
|
__constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
|
|
nonlinearity: str
|
|
|
|
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh") -> None:
|
|
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
|
|
self.nonlinearity = nonlinearity
|
|
|
|
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
|
self.check_forward_hidden(input, hx, '')
|
|
if self.nonlinearity == "tanh":
|
|
ret = _VF.rnn_tanh_cell(
|
|
input, hx,
|
|
self.weight_ih, self.weight_hh,
|
|
self.bias_ih, self.bias_hh,
|
|
)
|
|
elif self.nonlinearity == "relu":
|
|
ret = _VF.rnn_relu_cell(
|
|
input, hx,
|
|
self.weight_ih, self.weight_hh,
|
|
self.bias_ih, self.bias_hh,
|
|
)
|
|
else:
|
|
ret = input # TODO: remove when jit supports exception flow
|
|
raise RuntimeError(
|
|
"Unknown nonlinearity: {}".format(self.nonlinearity))
|
|
return ret
|
|
|
|
|
|
class LSTMCell(RNNCellBase):
|
|
r"""A long short-term memory (LSTM) cell.
|
|
|
|
.. math::
|
|
|
|
\begin{array}{ll}
|
|
i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
|
|
f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
|
|
g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
|
|
o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
|
|
c' = f * c + i * g \\
|
|
h' = o * \tanh(c') \\
|
|
\end{array}
|
|
|
|
where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and
|
|
`b_hh`. Default: ``True``
|
|
|
|
Inputs: input, (h_0, c_0)
|
|
- **input** of shape `(batch, input_size)`: tensor containing input features
|
|
- **h_0** of shape `(batch, hidden_size)`: tensor containing the initial hidden
|
|
state for each element in the batch.
|
|
- **c_0** of shape `(batch, hidden_size)`: tensor containing the initial cell state
|
|
for each element in the batch.
|
|
|
|
If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
|
|
|
|
Outputs: (h_1, c_1)
|
|
- **h_1** of shape `(batch, hidden_size)`: tensor containing the next hidden state
|
|
for each element in the batch
|
|
- **c_1** of shape `(batch, hidden_size)`: tensor containing the next cell state
|
|
for each element in the batch
|
|
|
|
Attributes:
|
|
weight_ih: the learnable input-hidden weights, of shape
|
|
`(4*hidden_size, input_size)`
|
|
weight_hh: the learnable hidden-hidden weights, of shape
|
|
`(4*hidden_size, hidden_size)`
|
|
bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`
|
|
bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
Examples::
|
|
|
|
>>> rnn = nn.LSTMCell(10, 20)
|
|
>>> input = torch.randn(6, 3, 10)
|
|
>>> hx = torch.randn(3, 20)
|
|
>>> cx = torch.randn(3, 20)
|
|
>>> output = []
|
|
>>> for i in range(6):
|
|
hx, cx = rnn(input[i], (hx, cx))
|
|
output.append(hx)
|
|
"""
|
|
|
|
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
|
|
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
|
|
|
|
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
|
hx = (zeros, zeros)
|
|
self.check_forward_hidden(input, hx[0], '[0]')
|
|
self.check_forward_hidden(input, hx[1], '[1]')
|
|
return _VF.lstm_cell(
|
|
input, hx,
|
|
self.weight_ih, self.weight_hh,
|
|
self.bias_ih, self.bias_hh,
|
|
)
|
|
|
|
|
|
class GRUCell(RNNCellBase):
|
|
r"""A gated recurrent unit (GRU) cell
|
|
|
|
.. math::
|
|
|
|
\begin{array}{ll}
|
|
r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
|
|
z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
|
|
n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
|
|
h' = (1 - z) * n + z * h
|
|
\end{array}
|
|
|
|
where :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product.
|
|
|
|
Args:
|
|
input_size: The number of expected features in the input `x`
|
|
hidden_size: The number of features in the hidden state `h`
|
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and
|
|
`b_hh`. Default: ``True``
|
|
|
|
Inputs: input, hidden
|
|
- **input** of shape `(batch, input_size)`: tensor containing input features
|
|
- **hidden** of shape `(batch, hidden_size)`: tensor containing the initial hidden
|
|
state for each element in the batch.
|
|
Defaults to zero if not provided.
|
|
|
|
Outputs: h'
|
|
- **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Shape:
|
|
- Input1: :math:`(N, H_{in})` tensor containing input features where
|
|
:math:`H_{in}` = `input_size`
|
|
- Input2: :math:`(N, H_{out})` tensor containing the initial hidden
|
|
state for each element in the batch where :math:`H_{out}` = `hidden_size`
|
|
Defaults to zero if not provided.
|
|
- Output: :math:`(N, H_{out})` tensor containing the next hidden state
|
|
for each element in the batch
|
|
|
|
Attributes:
|
|
weight_ih: the learnable input-hidden weights, of shape
|
|
`(3*hidden_size, input_size)`
|
|
weight_hh: the learnable hidden-hidden weights, of shape
|
|
`(3*hidden_size, hidden_size)`
|
|
bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`
|
|
bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`
|
|
|
|
.. note::
|
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
|
|
where :math:`k = \frac{1}{\text{hidden\_size}}`
|
|
|
|
Examples::
|
|
|
|
>>> rnn = nn.GRUCell(10, 20)
|
|
>>> input = torch.randn(6, 3, 10)
|
|
>>> hx = torch.randn(3, 20)
|
|
>>> output = []
|
|
>>> for i in range(6):
|
|
hx = rnn(input[i], hx)
|
|
output.append(hx)
|
|
"""
|
|
|
|
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
|
|
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
|
|
|
|
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
|
|
self.check_forward_input(input)
|
|
if hx is None:
|
|
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
|
|
self.check_forward_hidden(input, hx, '')
|
|
return _VF.gru_cell(
|
|
input, hx,
|
|
self.weight_ih, self.weight_hh,
|
|
self.bias_ih, self.bias_hh,
|
|
)
|