pytorch/caffe2/python/test/inference_lstm_op_test.py
Ahmed Aly f8778aef78 Implement a Caffe2 standalone LSTM operator (#17726)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17726

Pull Request resolved: https://github.com/pytorch/pytorch/pull/17725

Pull Request resolved: https://github.com/pytorch/pytorch/pull/17461

Implementing a standalone LSTM Operator in Caffe2 adopted from this Aten implementation: diffusion/FBS/browse/master/fbcode/caffe2/aten/src/ATen/native/RNN.cpp. The most tricky thing in this exercise was that caffe2::Tensor has no copy constructor that made it necessary to implement a custom templated copy constructor for the different Tensor containers used in the code. Also there was no way to use off-the-shelf C2 operators in my code easily so I had to copy some code that is doing basic matmul, cat, split, transpose and linear as utility functions.

Two things missing:

- Profiling this implementation against the current ONNXified LSTM op
- Make this operator available to use in PyTorch

Reviewed By: dzhulgakov

Differential Revision: D14351575

fbshipit-source-id: 3b99b53212cf593c7a49e45580b5a07b90809e64
2019-03-07 01:08:49 -08:00

73 lines
2.1 KiB
Python

#!/usr/bin/env python3
import inspect
import hypothesis.strategies as st
import numpy as np
import torch
from caffe2.python import core, workspace
from caffe2.python.test_util import TestCase
from hypothesis import given
from torch import nn
class TestC2LSTM(TestCase):
@given(
bsz=st.integers(1, 5),
seq_lens=st.integers(1, 6),
emb_lens=st.integers(5, 10),
hidden_size=st.integers(3, 7),
num_layers=st.integers(1, 4),
has_biases=st.booleans(),
is_bidirectional=st.booleans(),
batch_first=st.booleans(),
)
def test_c2_lstm(
self,
bsz,
seq_lens,
emb_lens,
hidden_size,
num_layers,
has_biases,
is_bidirectional,
batch_first,
):
net = core.Net("test_net")
num_directions = 2 if is_bidirectional else 1
py_lstm = nn.LSTM(
emb_lens,
hidden_size,
batch_first=batch_first,
bidirectional=is_bidirectional,
bias=has_biases,
num_layers=num_layers,
)
hx = np.zeros((num_layers * num_directions, bsz, hidden_size), dtype=np.float32)
if batch_first:
inputs = np.random.randn(bsz, seq_lens, emb_lens).astype(np.float32)
else:
inputs = np.random.randn(seq_lens, bsz, emb_lens).astype(np.float32)
py_results = py_lstm(torch.from_numpy(inputs))
lstm_in = [
torch.from_numpy(inputs),
torch.from_numpy(hx),
torch.from_numpy(hx),
] + [param.detach() for param in py_lstm._flat_weights]
c2_results = torch.ops._caffe2.InferenceLSTM(
lstm_in, num_layers, has_biases, batch_first, is_bidirectional
)
np.testing.assert_array_almost_equal(
py_results[0].detach().numpy(), c2_results[0].detach().numpy()
)
np.testing.assert_array_almost_equal(
py_results[1][0].detach().numpy(), c2_results[1].detach().numpy()
)
np.testing.assert_array_almost_equal(
py_results[1][1].detach().numpy(), c2_results[2].detach().numpy()
)