pytorch/caffe2/python/rnn/rnn_cell_test_util.py
Robert Verkuil 48bd102b95 Moved sigmoid, tanh, and _prepare_lstm (renamed) to a util file.
Summary:
Moved sigmoid, tanh, and _prepare_lstm (renamed) to a util file.
Also renamed _prepare_lstm to _preapare_rnn since it is being used for both setting up and LSTM and GRU model.

The reason for this commit is to allow the creation of GRU Op and testing code without copying and pasting code for sigmoid, tanh, and setting up an rnn unit op mode.

Reviewed By: jamesr66a

Differential Revision: D5363675

fbshipit-source-id: 352bd70378031f1d81606c9267e625c6728b18fd
2017-07-10 17:52:22 -07:00

73 lines
2.3 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import workspace, scope
from caffe2.python.model_helper import ModelHelper
import numpy as np
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-x))
def tanh(x):
return 2.0 * sigmoid(2.0 * x) - 1
def _prepare_rnn(t, n, dim_in, create_rnn, outputs_with_grads,
forget_bias, memory_optim=False,
forward_only=False, drop_states=False, T=None,
two_d_initial_states=None, dim_out=None):
if dim_out is None:
dim_out = [dim_in]
print("Dims: ", t, n, dim_in, dim_out)
model = ModelHelper(name='external')
if two_d_initial_states is None:
two_d_initial_states = np.random.randint(2)
def generate_input_state(n, d):
if two_d_initial_states:
return np.random.randn(n, d).astype(np.float32)
else:
return np.random.randn(1, n, d).astype(np.float32)
states = []
for layer_id, d in enumerate(dim_out):
h, c = model.net.AddExternalInputs(
"hidden_init_{}".format(layer_id),
"cell_init_{}".format(layer_id),
)
states.extend([h, c])
workspace.FeedBlob(h, generate_input_state(n, d).astype(np.float32))
workspace.FeedBlob(c, generate_input_state(n, d).astype(np.float32))
# Due to convoluted RNN scoping logic we make sure that things
# work from a namescope
with scope.NameScope("test_name_scope"):
input_blob, seq_lengths = model.net.AddScopedExternalInputs(
'input_blob', 'seq_lengths')
outputs = create_rnn(
model, input_blob, seq_lengths, states,
dim_in=dim_in, dim_out=dim_out, scope="external/recurrent",
outputs_with_grads=outputs_with_grads,
memory_optimization=memory_optim,
forget_bias=forget_bias,
forward_only=forward_only,
drop_states=drop_states,
static_rnn_unroll_size=T,
)
workspace.RunNetOnce(model.param_init_net)
workspace.FeedBlob(
seq_lengths,
np.random.randint(1, t + 1, size=(n,)).astype(np.int32)
)
return outputs, model.net, states + [input_blob]