#include "test/cpp/jit/test_base.h" #include "test/cpp/jit/test_utils.h" namespace torch { namespace jit { void testInterp() { constexpr int batch_size = 4; constexpr int input_size = 256; constexpr int seq_len = 32; int hidden_size = 2 * input_size; auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA); auto hx = at::randn({batch_size, hidden_size}, at::kCUDA); auto cx = at::randn({batch_size, hidden_size}, at::kCUDA); auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA)); auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA)); auto lstm_g = build_lstm(); Code lstm_function(lstm_g, ""); InterpreterState lstm_interp(lstm_function); auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh}); std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh); // std::cout << almostEqual(outputs[0],hx) << "\n"; ASSERT_TRUE(exactlyEqual(outputs[0], hx)); ASSERT_TRUE(exactlyEqual(outputs[1], cx)); } } // namespace jit } // namespace torch