pytorch/test/cpp/jit/test_interpreter.h
Elias Ellison 3baf99bea7 Breakup test misc pt2 (#18191)
Summary:
Further breakup test_misc.h. The remaining tests don't directly map to a jit file so I left them in test_misc.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18191

Differential Revision: D14533442

Pulled By: eellison

fbshipit-source-id: 7f538ce0aea208b6b55a4716dfcf039548305041
2019-03-19 19:41:22 -07:00

36 lines
1.1 KiB
C++

#pragma once
#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
namespace torch {
namespace jit {
namespace test {
void testInterp() {
constexpr int batch_size = 4;
constexpr int input_size = 256;
constexpr int seq_len = 32;
int hidden_size = 2 * input_size;
auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA);
auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
auto lstm_g = build_lstm();
Code lstm_function(lstm_g);
InterpreterState lstm_interp(lstm_function);
auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
// std::cout << almostEqual(outputs[0],hx) << "\n";
ASSERT_TRUE(exactlyEqual(outputs[0], hx));
ASSERT_TRUE(exactlyEqual(outputs[1], cx));
}
} // namespace test
} // namespace jit
} // namespace torch