mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
Just an implementation of the forward pass of the GRU Unit Op, not the full RNNCell.
Functions were created to mimic LSTM implementation as closely as possible.
Implementation defined in GRU_unit_op.{h, cc}
tests put in gru_cell_test.py, which import rnn_cell_test_util.py for sigmoid, tanh, and _prepare_rnn functions.
Reviewed By: jamesr66a
Differential Revision: D5363697
fbshipit-source-id: f9ba9fe0be01ffc868dd22027be8be4975b84998
174 lines
5.4 KiB
Python
174 lines
5.4 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import workspace, scope
|
|
from caffe2.python.model_helper import ModelHelper
|
|
from caffe2.python.rnn.rnn_cell_test_util import sigmoid, tanh
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
|
|
|
|
from functools import partial
|
|
from hypothesis import given
|
|
from hypothesis import settings as ht_settings
|
|
import hypothesis.strategies as st
|
|
import numpy as np
|
|
|
|
|
|
def gru_unit(hidden_t_prev, gates_out_t,
|
|
seq_lengths, timestep, drop_states=False):
|
|
'''
|
|
Implements one GRU unit, for one time step
|
|
|
|
Shapes:
|
|
hidden_t_prev.shape = (1, N, D)
|
|
gates_out_t.shape = (1, N, G)
|
|
seq_lenths.shape = (N,)
|
|
'''
|
|
N = hidden_t_prev.shape[1]
|
|
D = hidden_t_prev.shape[2]
|
|
G = gates_out_t.shape[2]
|
|
t = (timestep * np.ones(shape=(N, D))).astype(np.int32)
|
|
assert t.shape == (N, D)
|
|
seq_lengths = (np.ones(shape=(N, D)) *
|
|
seq_lengths.reshape(N, 1)).astype(np.int32)
|
|
assert seq_lengths.shape == (N, D)
|
|
assert G == 3 * D
|
|
# Calculate reset, update, and output gates separately
|
|
# because output gate depends on reset gate.
|
|
gates_out_t = gates_out_t.reshape(N, 3, D)
|
|
reset_gate_t = gates_out_t[:, 0, :].reshape(N, D)
|
|
update_gate_t = gates_out_t[:, 1, :].reshape(N, D)
|
|
output_gate_t = gates_out_t[:, 2, :].reshape(N, D)
|
|
|
|
# Calculate gate outputs.
|
|
reset_gate_t = sigmoid(reset_gate_t)
|
|
update_gate_t = sigmoid(update_gate_t)
|
|
output_gate_t = tanh(output_gate_t)
|
|
|
|
valid = (t < seq_lengths).astype(np.int32)
|
|
assert valid.shape == (N, D)
|
|
hidden_t = update_gate_t * hidden_t_prev + (1 - update_gate_t) * output_gate_t
|
|
hidden_t = hidden_t * valid + hidden_t_prev * (1 - valid) * (1 - drop_states)
|
|
hidden_t = hidden_t.reshape(1, N, D)
|
|
|
|
return (hidden_t, )
|
|
|
|
|
|
def gru_unit_op_input():
|
|
'''
|
|
Create input tensor where each dimension is from 1 to 4, ndim=3 and
|
|
last dimension size is a factor of 3
|
|
|
|
hidden_t_prev.shape = (1, N, D)
|
|
'''
|
|
dims_ = st.tuples(
|
|
st.integers(min_value=1, max_value=1), # 1, one timestep
|
|
st.integers(min_value=1, max_value=4), # n
|
|
st.integers(min_value=1, max_value=4), # d
|
|
)
|
|
|
|
def create_input(dims):
|
|
dims = list(dims)
|
|
dims[2] *= 3
|
|
return hu.arrays(dims)
|
|
|
|
return dims_.flatmap(create_input)
|
|
|
|
|
|
def _prepare_gru_unit_op(n, d, outputs_with_grads,
|
|
forward_only=False, drop_states=False,
|
|
two_d_initial_states=None):
|
|
print("Dims: (n,d) = ({},{})".format(n, d))
|
|
|
|
def generate_input_state(n, d):
|
|
if two_d_initial_states:
|
|
return np.random.randn(n, d).astype(np.float32)
|
|
else:
|
|
return np.random.randn(1, n, d).astype(np.float32)
|
|
|
|
model = ModelHelper(name='external')
|
|
|
|
with scope.NameScope("test_name_scope"):
|
|
hidden_t_prev, gates_t, seq_lengths, timestep = \
|
|
model.net.AddScopedExternalInputs(
|
|
"hidden_t_prev",
|
|
"gates_t",
|
|
'seq_lengths',
|
|
"timestep",
|
|
)
|
|
workspace.FeedBlob(
|
|
hidden_t_prev,
|
|
generate_input_state(n, d).astype(np.float32)
|
|
)
|
|
workspace.FeedBlob(
|
|
gates_t,
|
|
generate_input_state(n, 3 * d).astype(np.float32)
|
|
)
|
|
|
|
hidden_t = model.net.GRUUnit(
|
|
[
|
|
hidden_t_prev,
|
|
gates_t,
|
|
seq_lengths,
|
|
timestep,
|
|
],
|
|
['hidden_t'],
|
|
forget_bias=0.0,
|
|
drop_states=drop_states,
|
|
)
|
|
model.net.AddExternalOutputs(hidden_t)
|
|
workspace.RunNetOnce(model.param_init_net)
|
|
|
|
# 10 is used as a magic number to simualte some reasonable timestep
|
|
# and generate some reasonable seq. lengths
|
|
workspace.FeedBlob(
|
|
seq_lengths,
|
|
np.random.randint(1, 10, size=(n,)).astype(np.int32)
|
|
)
|
|
workspace.FeedBlob(
|
|
timestep,
|
|
np.random.randint(1, 10, size=(1,)).astype(np.int32)
|
|
)
|
|
|
|
return hidden_t, model.net
|
|
|
|
|
|
class GRUCellTest(hu.HypothesisTestCase):
|
|
|
|
# Make the test just for GRUUnitOp
|
|
@given(
|
|
input_tensor=gru_unit_op_input(),
|
|
drop_states=st.booleans(),
|
|
)
|
|
@ht_settings(max_examples=15)
|
|
def test_gru_unit_op(self, input_tensor, drop_states, **kwargs):
|
|
outputs_with_grads = [0]
|
|
ref = gru_unit
|
|
ref = partial(ref)
|
|
|
|
t, n, d = input_tensor.shape
|
|
assert d % 3 == 0
|
|
d = d // 3
|
|
ref = partial(ref, drop_states=drop_states)
|
|
|
|
net = _prepare_gru_unit_op(n, d,
|
|
outputs_with_grads=outputs_with_grads,
|
|
forward_only=True,
|
|
drop_states=drop_states)[1]
|
|
# here we don't provide a real input for the net but just for one of
|
|
# its ops (RecurrentNetworkOp). So have to hardcode this name
|
|
workspace.FeedBlob("test_name_scope/external/recurrent/i2h",
|
|
input_tensor)
|
|
op = net._net.op[-1]
|
|
inputs = [workspace.FetchBlob(name) for name in op.input]
|
|
|
|
self.assertReferenceChecks(
|
|
hu.cpu_do,
|
|
op,
|
|
inputs,
|
|
ref,
|
|
outputs_to_check=[0],
|
|
)
|