mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Further breakup test_misc.h. The remaining tests don't directly map to a jit file so I left them in test_misc. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18191 Differential Revision: D14533442 Pulled By: eellison fbshipit-source-id: 7f538ce0aea208b6b55a4716dfcf039548305041
36 lines
1.1 KiB
C++
36 lines
1.1 KiB
C++
#pragma once
|
|
|
|
#include "test/cpp/jit/test_base.h"
|
|
#include "test/cpp/jit/test_utils.h"
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace test {
|
|
|
|
void testInterp() {
|
|
constexpr int batch_size = 4;
|
|
constexpr int input_size = 256;
|
|
constexpr int seq_len = 32;
|
|
|
|
int hidden_size = 2 * input_size;
|
|
|
|
auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA);
|
|
auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
|
|
auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
|
|
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
|
|
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
|
|
|
|
auto lstm_g = build_lstm();
|
|
Code lstm_function(lstm_g);
|
|
InterpreterState lstm_interp(lstm_function);
|
|
auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
|
|
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
|
|
|
|
// std::cout << almostEqual(outputs[0],hx) << "\n";
|
|
ASSERT_TRUE(exactlyEqual(outputs[0], hx));
|
|
ASSERT_TRUE(exactlyEqual(outputs[1], cx));
|
|
}
|
|
} // namespace test
|
|
} // namespace jit
|
|
} // namespace torch
|